openigtlink_rust/io/unified_async_client.rs
1//! Unified async client with optional TLS and reconnection
2//!
3//! This module provides `UnifiedAsyncClient`, a single async client type that elegantly
4//! handles all feature combinations (TLS, reconnection) through internal state management.
5//!
6//! # Design Philosophy
7//!
8//! Traditional approach would create separate types for each feature combination:
9//! - `TcpAsync`, `TcpAsyncTls`, `TcpAsyncReconnect`, `TcpAsyncTlsReconnect`...
10//! - This leads to **variant explosion**: 2 features = 4 types, 3 features = 8 types, etc.
11//!
12//! **Our approach**: Single `UnifiedAsyncClient` with optional features:
13//! - Internal `Transport` enum: `Plain(TcpStream)` or `Tls(TlsStream)`
14//! - Optional `reconnect_config: Option<ReconnectConfig>`
15//! - ✅ Scales linearly with features (not exponentially!)
16//! - ✅ Easy to add new features (compression, authentication, etc.)
17//! - ✅ Maintains type safety through builder pattern
18//!
19//! # Architecture
20//!
21//! ```text
22//! UnifiedAsyncClient
23//! ├─ transport: Option<Transport>
24//! │ ├─ Plain(TcpStream) ← Regular TCP
25//! │ └─ Tls(TlsStream) ← TLS-encrypted TCP
26//! ├─ reconnect_config: Option<ReconnectConfig>
27//! │ ├─ None ← No auto-reconnection
28//! │ └─ Some(config) ← Auto-reconnect with backoff
29//! ├─ conn_params: ConnectionParams (host, port, TLS config)
30//! └─ verify_crc: bool ← CRC verification
31//! ```
32//!
33//! # Examples
34//!
35//! ## Plain TCP Connection
36//!
37//! ```no_run
38//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
39//!
40//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
41//! let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! ## TLS-Encrypted Connection
47//!
48//! ```no_run
49//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
50//! use std::sync::Arc;
51//!
52//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
53//! let tls_config = rustls::ClientConfig::builder()
54//! .with_root_certificates(rustls::RootCertStore::empty())
55//! .with_no_client_auth();
56//!
57//! let client = UnifiedAsyncClient::connect_with_tls(
58//! "hospital-server.local",
59//! 18944,
60//! Arc::new(tls_config)
61//! ).await?;
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! ## With Auto-Reconnection
67//!
68//! ```no_run
69//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
70//! use openigtlink_rust::io::reconnect::ReconnectConfig;
71//!
72//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
73//! let mut client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
74//!
75//! // Enable auto-reconnection
76//! let reconnect_config = ReconnectConfig::with_max_attempts(10);
77//! client = client.with_reconnect(reconnect_config);
78//! # Ok(())
79//! # }
80//! ```
81//!
82//! ## TLS + Auto-Reconnect (Previously Impossible!)
83//!
84//! ```no_run
85//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
86//! use openigtlink_rust::io::reconnect::ReconnectConfig;
87//! use std::sync::Arc;
88//!
89//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
90//! let tls_config = rustls::ClientConfig::builder()
91//! .with_root_certificates(rustls::RootCertStore::empty())
92//! .with_no_client_auth();
93//!
94//! let mut client = UnifiedAsyncClient::connect_with_tls(
95//! "production-server",
96//! 18944,
97//! Arc::new(tls_config)
98//! ).await?;
99//!
100//! // Add auto-reconnection to TLS client
101//! let reconnect_config = ReconnectConfig::with_max_attempts(100);
102//! client = client.with_reconnect(reconnect_config);
103//! # Ok(())
104//! # }
105//! ```
106//!
107//! # Prefer Using the Builder
108//!
109//! While you can create `UnifiedAsyncClient` directly, it's recommended to use
110//! [`ClientBuilder`](crate::io::builder::ClientBuilder) for better ergonomics and type safety:
111//!
112//! ```no_run
113//! use openigtlink_rust::io::builder::ClientBuilder;
114//! use openigtlink_rust::io::reconnect::ReconnectConfig;
115//! use std::sync::Arc;
116//!
117//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
118//! let tls_config = rustls::ClientConfig::builder()
119//! .with_root_certificates(rustls::RootCertStore::empty())
120//! .with_no_client_auth();
121//!
122//! let client = ClientBuilder::new()
123//! .tcp("production-server:18944")
124//! .async_mode()
125//! .with_tls(Arc::new(tls_config))
126//! .with_reconnect(ReconnectConfig::with_max_attempts(100))
127//! .verify_crc(true)
128//! .build()
129//! .await?;
130//! # Ok(())
131//! # }
132//! ```
133
134use crate::error::{IgtlError, Result};
135use crate::io::reconnect::ReconnectConfig;
136use crate::protocol::any_message::AnyMessage;
137use crate::protocol::factory::MessageFactory;
138use crate::protocol::header::Header;
139use crate::protocol::message::{IgtlMessage, Message};
140use rustls::pki_types::ServerName;
141use std::sync::Arc;
142use tokio::io::{AsyncReadExt, AsyncWriteExt};
143use tokio::net::TcpStream;
144use tokio::time::sleep;
145use tokio_rustls::client::TlsStream;
146use tokio_rustls::{rustls, TlsConnector};
147use tracing::{debug, info, trace, warn};
148
149/// Transport type for the async client
150enum Transport {
151 Plain(TcpStream),
152 Tls(Box<TlsStream<TcpStream>>),
153}
154
155impl Transport {
156 async fn write_all(&mut self, data: &[u8]) -> Result<()> {
157 match self {
158 Transport::Plain(stream) => {
159 stream.write_all(data).await?;
160 Ok(())
161 }
162 Transport::Tls(stream) => {
163 stream.write_all(data).await?;
164 Ok(())
165 }
166 }
167 }
168
169 async fn flush(&mut self) -> Result<()> {
170 match self {
171 Transport::Plain(stream) => {
172 stream.flush().await?;
173 Ok(())
174 }
175 Transport::Tls(stream) => {
176 stream.flush().await?;
177 Ok(())
178 }
179 }
180 }
181
182 async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
183 match self {
184 Transport::Plain(stream) => {
185 stream.read_exact(buf).await?;
186 Ok(())
187 }
188 Transport::Tls(stream) => {
189 stream.read_exact(buf).await?;
190 Ok(())
191 }
192 }
193 }
194}
195
196/// Connection parameters for reconnection
197struct ConnectionParams {
198 addr: String,
199 hostname: Option<String>,
200 port: Option<u16>,
201 tls_config: Option<Arc<rustls::ClientConfig>>,
202}
203
204/// Unified async OpenIGTLink client
205///
206/// Supports optional TLS encryption and automatic reconnection without
207/// combinatorial type explosion.
208///
209/// # Examples
210///
211/// ```no_run
212/// use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
213///
214/// # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
215/// // Plain TCP client
216/// let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
217///
218/// // With TLS
219/// let tls_config = rustls::ClientConfig::builder()
220/// .with_root_certificates(rustls::RootCertStore::empty())
221/// .with_no_client_auth();
222/// let client = UnifiedAsyncClient::connect_with_tls(
223/// "localhost",
224/// 18944,
225/// std::sync::Arc::new(tls_config)
226/// ).await?;
227/// # Ok(())
228/// # }
229/// ```
230pub struct UnifiedAsyncClient {
231 transport: Option<Transport>,
232 conn_params: ConnectionParams,
233 reconnect_config: Option<ReconnectConfig>,
234 reconnect_count: usize,
235 verify_crc: bool,
236}
237
238impl UnifiedAsyncClient {
239 /// Connect to a plain TCP server
240 ///
241 /// # Arguments
242 /// * `addr` - Server address (e.g., "127.0.0.1:18944")
243 pub async fn connect(addr: &str) -> Result<Self> {
244 info!(addr = addr, "Connecting to OpenIGTLink server");
245 let stream = TcpStream::connect(addr).await?;
246 let local_addr = stream.local_addr()?;
247 info!(
248 local_addr = %local_addr,
249 remote_addr = addr,
250 "Connected to OpenIGTLink server"
251 );
252
253 Ok(Self {
254 transport: Some(Transport::Plain(stream)),
255 conn_params: ConnectionParams {
256 addr: addr.to_string(),
257 hostname: None,
258 port: None,
259 tls_config: None,
260 },
261 reconnect_config: None,
262 reconnect_count: 0,
263 verify_crc: true,
264 })
265 }
266
267 /// Connect to a TLS-enabled server
268 ///
269 /// # Arguments
270 /// * `hostname` - Server hostname (for SNI)
271 /// * `port` - Server port
272 /// * `tls_config` - TLS client configuration
273 pub async fn connect_with_tls(
274 hostname: &str,
275 port: u16,
276 tls_config: Arc<rustls::ClientConfig>,
277 ) -> Result<Self> {
278 info!(
279 hostname = hostname,
280 port = port,
281 "Connecting to TLS-enabled OpenIGTLink server"
282 );
283
284 let addr = format!("{}:{}", hostname, port);
285 let tcp_stream = TcpStream::connect(&addr).await?;
286 let local_addr = tcp_stream.local_addr()?;
287
288 let server_name = ServerName::try_from(hostname.to_string()).map_err(|e| {
289 IgtlError::Io(std::io::Error::new(
290 std::io::ErrorKind::InvalidInput,
291 format!("Invalid hostname: {}", e),
292 ))
293 })?;
294
295 let connector = TlsConnector::from(tls_config.clone());
296 let tls_stream = connector
297 .connect(server_name, tcp_stream)
298 .await
299 .map_err(|e| {
300 warn!(error = %e, "TLS handshake failed");
301 IgtlError::Io(std::io::Error::new(
302 std::io::ErrorKind::ConnectionRefused,
303 format!("TLS handshake failed: {}", e),
304 ))
305 })?;
306
307 info!(
308 local_addr = %local_addr,
309 remote_addr = %addr,
310 "TLS connection established"
311 );
312
313 Ok(Self {
314 transport: Some(Transport::Tls(Box::new(tls_stream))),
315 conn_params: ConnectionParams {
316 addr,
317 hostname: Some(hostname.to_string()),
318 port: Some(port),
319 tls_config: Some(tls_config),
320 },
321 reconnect_config: None,
322 reconnect_count: 0,
323 verify_crc: true,
324 })
325 }
326
327 /// Enable automatic reconnection
328 ///
329 /// # Arguments
330 /// * `config` - Reconnection configuration
331 pub fn with_reconnect(mut self, config: ReconnectConfig) -> Self {
332 self.reconnect_config = Some(config);
333 self
334 }
335
336 /// Enable or disable CRC verification
337 pub fn set_verify_crc(&mut self, verify: bool) {
338 self.verify_crc = verify;
339 }
340
341 /// Get current CRC verification setting
342 pub fn verify_crc(&self) -> bool {
343 self.verify_crc
344 }
345
346 /// Get reconnection count
347 pub fn reconnect_count(&self) -> usize {
348 self.reconnect_count
349 }
350
351 /// Check if currently connected
352 pub fn is_connected(&self) -> bool {
353 self.transport.is_some()
354 }
355
356 /// Ensure we have a valid connection, reconnecting if necessary
357 async fn ensure_connected(&mut self) -> Result<()> {
358 if self.transport.is_some() {
359 return Ok(());
360 }
361
362 let Some(ref config) = self.reconnect_config else {
363 return Err(IgtlError::Io(std::io::Error::new(
364 std::io::ErrorKind::NotConnected,
365 "Connection lost and reconnection is not enabled",
366 )));
367 };
368
369 let mut attempt = 0;
370
371 loop {
372 if let Some(max) = config.max_attempts {
373 if attempt >= max {
374 warn!(
375 attempts = attempt,
376 max_attempts = max,
377 "Max reconnection attempts reached"
378 );
379 return Err(IgtlError::Io(std::io::Error::new(
380 std::io::ErrorKind::TimedOut,
381 "Max reconnection attempts exceeded",
382 )));
383 }
384 }
385
386 let delay = config.delay_for_attempt(attempt);
387 if attempt > 0 {
388 info!(
389 attempt = attempt + 1,
390 delay_ms = delay.as_millis(),
391 "Reconnecting..."
392 );
393 sleep(delay).await;
394 }
395
396 let result = if let Some(ref tls_config) = self.conn_params.tls_config {
397 // TLS reconnection
398 let hostname = self.conn_params.hostname.as_ref().unwrap();
399 let port = self.conn_params.port.unwrap();
400 Self::connect_with_tls(hostname, port, tls_config.clone()).await
401 } else {
402 // Plain TCP reconnection
403 Self::connect(&self.conn_params.addr).await
404 };
405
406 match result {
407 Ok(new_client) => {
408 self.transport = new_client.transport;
409 if attempt > 0 {
410 self.reconnect_count += 1;
411 info!(
412 reconnect_count = self.reconnect_count,
413 "Reconnection successful"
414 );
415 }
416 return Ok(());
417 }
418 Err(e) => {
419 warn!(
420 attempt = attempt + 1,
421 error = %e,
422 "Reconnection attempt failed"
423 );
424 attempt += 1;
425 }
426 }
427 }
428 }
429
430 /// Send a message
431 pub async fn send<T: Message>(&mut self, msg: &IgtlMessage<T>) -> Result<()> {
432 let data = msg.encode()?;
433 let msg_type = msg.header.type_name.as_str().unwrap_or("UNKNOWN");
434 let device_name = msg.header.device_name.as_str().unwrap_or("UNKNOWN");
435
436 debug!(
437 msg_type = msg_type,
438 device_name = device_name,
439 size = data.len(),
440 "Sending message"
441 );
442
443 loop {
444 if self.reconnect_config.is_some() {
445 self.ensure_connected().await?;
446 }
447
448 if let Some(transport) = &mut self.transport {
449 match transport.write_all(&data).await {
450 Ok(_) => {
451 transport.flush().await?;
452 trace!(
453 msg_type = msg_type,
454 bytes_sent = data.len(),
455 "Message sent successfully"
456 );
457 return Ok(());
458 }
459 Err(e) => {
460 if self.reconnect_config.is_some() {
461 warn!(error = %e, "Send failed, will reconnect");
462 self.transport = None;
463 // Loop will retry after reconnection
464 } else {
465 return Err(e);
466 }
467 }
468 }
469 } else {
470 return Err(IgtlError::Io(std::io::Error::new(
471 std::io::ErrorKind::NotConnected,
472 "Not connected",
473 )));
474 }
475 }
476 }
477
478 /// Receive a message
479 pub async fn receive<T: Message>(&mut self) -> Result<IgtlMessage<T>> {
480 loop {
481 if self.reconnect_config.is_some() {
482 self.ensure_connected().await?;
483 }
484
485 if let Some(transport) = &mut self.transport {
486 // Read header
487 let mut header_buf = vec![0u8; Header::SIZE];
488 match transport.read_exact(&mut header_buf).await {
489 Ok(_) => {}
490 Err(e) => {
491 if self.reconnect_config.is_some() {
492 warn!(error = %e, "Header read failed, will reconnect");
493 self.transport = None;
494 continue;
495 } else {
496 return Err(e);
497 }
498 }
499 }
500
501 let header = Header::decode(&header_buf)?;
502 let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
503 let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
504
505 debug!(
506 msg_type = msg_type,
507 device_name = device_name,
508 body_size = header.body_size,
509 version = header.version,
510 "Received message header"
511 );
512
513 // Read body
514 let mut body_buf = vec![0u8; header.body_size as usize];
515 match transport.read_exact(&mut body_buf).await {
516 Ok(_) => {}
517 Err(e) => {
518 if self.reconnect_config.is_some() {
519 warn!(error = %e, "Body read failed, will reconnect");
520 self.transport = None;
521 continue;
522 } else {
523 return Err(e);
524 }
525 }
526 }
527
528 trace!(
529 msg_type = msg_type,
530 bytes_read = body_buf.len(),
531 "Message body received"
532 );
533
534 // Decode full message
535 let mut full_msg = header_buf;
536 full_msg.extend_from_slice(&body_buf);
537
538 let result = IgtlMessage::decode_with_options(&full_msg, self.verify_crc);
539
540 match &result {
541 Ok(_) => {
542 debug!(
543 msg_type = msg_type,
544 device_name = device_name,
545 "Message decoded successfully"
546 );
547 }
548 Err(e) => {
549 warn!(
550 msg_type = msg_type,
551 error = %e,
552 "Failed to decode message"
553 );
554 }
555 }
556
557 return result;
558 } else {
559 return Err(IgtlError::Io(std::io::Error::new(
560 std::io::ErrorKind::NotConnected,
561 "Not connected",
562 )));
563 }
564 }
565 }
566
567 /// Receive any message type dynamically without knowing the type in advance
568 ///
569 /// This method reads the message header first, determines the message type,
570 /// and then decodes the appropriate message type dynamically.
571 ///
572 /// # Returns
573 ///
574 /// An `AnyMessage` enum containing the decoded message. If the message type
575 /// is not recognized, it will be returned as `AnyMessage::Unknown` with the
576 /// raw header and body bytes.
577 ///
578 /// # Examples
579 ///
580 /// ```no_run
581 /// use openigtlink_rust::io::builder::ClientBuilder;
582 /// use openigtlink_rust::protocol::AnyMessage;
583 ///
584 /// # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
585 /// let mut client = ClientBuilder::new()
586 /// .tcp("127.0.0.1:18944")
587 /// .async_mode()
588 /// .build()
589 /// .await?;
590 ///
591 /// loop {
592 /// let msg = client.receive_any().await?;
593 ///
594 /// match msg {
595 /// AnyMessage::Transform(transform_msg) => {
596 /// println!("Received transform from {}",
597 /// transform_msg.header.device_name.as_str()?);
598 /// }
599 /// AnyMessage::Status(status_msg) => {
600 /// println!("Status: {}", status_msg.content.status_string);
601 /// }
602 /// AnyMessage::Image(image_msg) => {
603 /// println!("Received image: {}x{}x{}",
604 /// image_msg.content.size[0],
605 /// image_msg.content.size[1],
606 /// image_msg.content.size[2]);
607 /// }
608 /// AnyMessage::Unknown { header, .. } => {
609 /// println!("Unknown message type: {}",
610 /// header.type_name.as_str()?);
611 /// }
612 /// _ => {}
613 /// }
614 /// }
615 /// # Ok(())
616 /// # }
617 /// ```
618 pub async fn receive_any(&mut self) -> Result<AnyMessage> {
619 loop {
620 if self.reconnect_config.is_some() {
621 self.ensure_connected().await?;
622 }
623
624 if let Some(transport) = &mut self.transport {
625 // Read header
626 let mut header_buf = vec![0u8; Header::SIZE];
627 match transport.read_exact(&mut header_buf).await {
628 Ok(_) => {}
629 Err(e) => {
630 if self.reconnect_config.is_some() {
631 warn!(error = %e, "Header read failed, will reconnect");
632 self.transport = None;
633 continue;
634 } else {
635 return Err(e);
636 }
637 }
638 }
639
640 let header = Header::decode(&header_buf)?;
641 let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
642 let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
643
644 debug!(
645 msg_type = msg_type,
646 device_name = device_name,
647 body_size = header.body_size,
648 version = header.version,
649 "Received message header"
650 );
651
652 // Read body
653 let mut body_buf = vec![0u8; header.body_size as usize];
654 match transport.read_exact(&mut body_buf).await {
655 Ok(_) => {}
656 Err(e) => {
657 if self.reconnect_config.is_some() {
658 warn!(error = %e, "Body read failed, will reconnect");
659 self.transport = None;
660 continue;
661 } else {
662 return Err(e);
663 }
664 }
665 }
666
667 trace!(
668 msg_type = msg_type,
669 bytes_read = body_buf.len(),
670 "Message body received"
671 );
672
673 // Decode using MessageFactory
674 let factory = MessageFactory::new();
675 let result = factory.decode_any(&header, &body_buf, self.verify_crc);
676
677 match &result {
678 Ok(msg) => {
679 debug!(
680 msg_type = msg.message_type(),
681 device_name = device_name,
682 "Message decoded successfully"
683 );
684 }
685 Err(e) => {
686 warn!(
687 msg_type = msg_type,
688 error = %e,
689 "Failed to decode message"
690 );
691 }
692 }
693
694 return result;
695 } else {
696 return Err(IgtlError::Io(std::io::Error::new(
697 std::io::ErrorKind::NotConnected,
698 "Not connected",
699 )));
700 }
701 }
702 }
703}