openigtlink_rust/io/
unified_async_client.rs

1//! Unified async client with optional TLS and reconnection
2//!
3//! This module provides `UnifiedAsyncClient`, a single async client type that elegantly
4//! handles all feature combinations (TLS, reconnection) through internal state management.
5//!
6//! # Design Philosophy
7//!
8//! Traditional approach would create separate types for each feature combination:
9//! - `TcpAsync`, `TcpAsyncTls`, `TcpAsyncReconnect`, `TcpAsyncTlsReconnect`...
10//! - This leads to **variant explosion**: 2 features = 4 types, 3 features = 8 types, etc.
11//!
12//! **Our approach**: Single `UnifiedAsyncClient` with optional features:
13//! - Internal `Transport` enum: `Plain(TcpStream)` or `Tls(TlsStream)`
14//! - Optional `reconnect_config: Option<ReconnectConfig>`
15//! - ✅ Scales linearly with features (not exponentially!)
16//! - ✅ Easy to add new features (compression, authentication, etc.)
17//! - ✅ Maintains type safety through builder pattern
18//!
19//! # Architecture
20//!
21//! ```text
22//! UnifiedAsyncClient
23//! ├─ transport: Option<Transport>
24//! │  ├─ Plain(TcpStream)     ← Regular TCP
25//! │  └─ Tls(TlsStream)       ← TLS-encrypted TCP
26//! ├─ reconnect_config: Option<ReconnectConfig>
27//! │  ├─ None                 ← No auto-reconnection
28//! │  └─ Some(config)         ← Auto-reconnect with backoff
29//! ├─ conn_params: ConnectionParams (host, port, TLS config)
30//! └─ verify_crc: bool        ← CRC verification
31//! ```
32//!
33//! # Examples
34//!
35//! ## Plain TCP Connection
36//!
37//! ```no_run
38//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
39//!
40//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
41//! let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! ## TLS-Encrypted Connection
47//!
48//! ```no_run
49//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
50//! use std::sync::Arc;
51//!
52//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
53//! let tls_config = rustls::ClientConfig::builder()
54//!     .with_root_certificates(rustls::RootCertStore::empty())
55//!     .with_no_client_auth();
56//!
57//! let client = UnifiedAsyncClient::connect_with_tls(
58//!     "hospital-server.local",
59//!     18944,
60//!     Arc::new(tls_config)
61//! ).await?;
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! ## With Auto-Reconnection
67//!
68//! ```no_run
69//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
70//! use openigtlink_rust::io::reconnect::ReconnectConfig;
71//!
72//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
73//! let mut client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
74//!
75//! // Enable auto-reconnection
76//! let reconnect_config = ReconnectConfig::with_max_attempts(10);
77//! client = client.with_reconnect(reconnect_config);
78//! # Ok(())
79//! # }
80//! ```
81//!
82//! ## TLS + Auto-Reconnect (Previously Impossible!)
83//!
84//! ```no_run
85//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
86//! use openigtlink_rust::io::reconnect::ReconnectConfig;
87//! use std::sync::Arc;
88//!
89//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
90//! let tls_config = rustls::ClientConfig::builder()
91//!     .with_root_certificates(rustls::RootCertStore::empty())
92//!     .with_no_client_auth();
93//!
94//! let mut client = UnifiedAsyncClient::connect_with_tls(
95//!     "production-server",
96//!     18944,
97//!     Arc::new(tls_config)
98//! ).await?;
99//!
100//! // Add auto-reconnection to TLS client
101//! let reconnect_config = ReconnectConfig::with_max_attempts(100);
102//! client = client.with_reconnect(reconnect_config);
103//! # Ok(())
104//! # }
105//! ```
106//!
107//! # Prefer Using the Builder
108//!
109//! While you can create `UnifiedAsyncClient` directly, it's recommended to use
110//! [`ClientBuilder`](crate::io::builder::ClientBuilder) for better ergonomics and type safety:
111//!
112//! ```no_run
113//! use openigtlink_rust::io::builder::ClientBuilder;
114//! use openigtlink_rust::io::reconnect::ReconnectConfig;
115//! use std::sync::Arc;
116//!
117//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
118//! let tls_config = rustls::ClientConfig::builder()
119//!     .with_root_certificates(rustls::RootCertStore::empty())
120//!     .with_no_client_auth();
121//!
122//! let client = ClientBuilder::new()
123//!     .tcp("production-server:18944")
124//!     .async_mode()
125//!     .with_tls(Arc::new(tls_config))
126//!     .with_reconnect(ReconnectConfig::with_max_attempts(100))
127//!     .verify_crc(true)
128//!     .build()
129//!     .await?;
130//! # Ok(())
131//! # }
132//! ```
133
134use crate::error::{IgtlError, Result};
135use crate::io::reconnect::ReconnectConfig;
136use crate::protocol::any_message::AnyMessage;
137use crate::protocol::factory::MessageFactory;
138use crate::protocol::header::Header;
139use crate::protocol::message::{IgtlMessage, Message};
140use rustls::pki_types::ServerName;
141use std::sync::Arc;
142use tokio::io::{AsyncReadExt, AsyncWriteExt};
143use tokio::net::TcpStream;
144use tokio::time::sleep;
145use tokio_rustls::client::TlsStream;
146use tokio_rustls::{rustls, TlsConnector};
147use tracing::{debug, info, trace, warn};
148
149/// Transport type for the async client
150enum Transport {
151    Plain(TcpStream),
152    Tls(Box<TlsStream<TcpStream>>),
153}
154
155impl Transport {
156    async fn write_all(&mut self, data: &[u8]) -> Result<()> {
157        match self {
158            Transport::Plain(stream) => {
159                stream.write_all(data).await?;
160                Ok(())
161            }
162            Transport::Tls(stream) => {
163                stream.write_all(data).await?;
164                Ok(())
165            }
166        }
167    }
168
169    async fn flush(&mut self) -> Result<()> {
170        match self {
171            Transport::Plain(stream) => {
172                stream.flush().await?;
173                Ok(())
174            }
175            Transport::Tls(stream) => {
176                stream.flush().await?;
177                Ok(())
178            }
179        }
180    }
181
182    async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
183        match self {
184            Transport::Plain(stream) => {
185                stream.read_exact(buf).await?;
186                Ok(())
187            }
188            Transport::Tls(stream) => {
189                stream.read_exact(buf).await?;
190                Ok(())
191            }
192        }
193    }
194}
195
196/// Connection parameters for reconnection
197struct ConnectionParams {
198    addr: String,
199    hostname: Option<String>,
200    port: Option<u16>,
201    tls_config: Option<Arc<rustls::ClientConfig>>,
202}
203
204/// Unified async OpenIGTLink client
205///
206/// Supports optional TLS encryption and automatic reconnection without
207/// combinatorial type explosion.
208///
209/// # Examples
210///
211/// ```no_run
212/// use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
213///
214/// # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
215/// // Plain TCP client
216/// let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
217///
218/// // With TLS
219/// let tls_config = rustls::ClientConfig::builder()
220///     .with_root_certificates(rustls::RootCertStore::empty())
221///     .with_no_client_auth();
222/// let client = UnifiedAsyncClient::connect_with_tls(
223///     "localhost",
224///     18944,
225///     std::sync::Arc::new(tls_config)
226/// ).await?;
227/// # Ok(())
228/// # }
229/// ```
230pub struct UnifiedAsyncClient {
231    transport: Option<Transport>,
232    conn_params: ConnectionParams,
233    reconnect_config: Option<ReconnectConfig>,
234    reconnect_count: usize,
235    verify_crc: bool,
236}
237
238impl UnifiedAsyncClient {
239    /// Connect to a plain TCP server
240    ///
241    /// # Arguments
242    /// * `addr` - Server address (e.g., "127.0.0.1:18944")
243    pub async fn connect(addr: &str) -> Result<Self> {
244        info!(addr = addr, "Connecting to OpenIGTLink server");
245        let stream = TcpStream::connect(addr).await?;
246        let local_addr = stream.local_addr()?;
247        info!(
248            local_addr = %local_addr,
249            remote_addr = addr,
250            "Connected to OpenIGTLink server"
251        );
252
253        Ok(Self {
254            transport: Some(Transport::Plain(stream)),
255            conn_params: ConnectionParams {
256                addr: addr.to_string(),
257                hostname: None,
258                port: None,
259                tls_config: None,
260            },
261            reconnect_config: None,
262            reconnect_count: 0,
263            verify_crc: true,
264        })
265    }
266
267    /// Connect to a TLS-enabled server
268    ///
269    /// # Arguments
270    /// * `hostname` - Server hostname (for SNI)
271    /// * `port` - Server port
272    /// * `tls_config` - TLS client configuration
273    pub async fn connect_with_tls(
274        hostname: &str,
275        port: u16,
276        tls_config: Arc<rustls::ClientConfig>,
277    ) -> Result<Self> {
278        info!(
279            hostname = hostname,
280            port = port,
281            "Connecting to TLS-enabled OpenIGTLink server"
282        );
283
284        let addr = format!("{}:{}", hostname, port);
285        let tcp_stream = TcpStream::connect(&addr).await?;
286        let local_addr = tcp_stream.local_addr()?;
287
288        let server_name = ServerName::try_from(hostname.to_string()).map_err(|e| {
289            IgtlError::Io(std::io::Error::new(
290                std::io::ErrorKind::InvalidInput,
291                format!("Invalid hostname: {}", e),
292            ))
293        })?;
294
295        let connector = TlsConnector::from(tls_config.clone());
296        let tls_stream = connector
297            .connect(server_name, tcp_stream)
298            .await
299            .map_err(|e| {
300                warn!(error = %e, "TLS handshake failed");
301                IgtlError::Io(std::io::Error::new(
302                    std::io::ErrorKind::ConnectionRefused,
303                    format!("TLS handshake failed: {}", e),
304                ))
305            })?;
306
307        info!(
308            local_addr = %local_addr,
309            remote_addr = %addr,
310            "TLS connection established"
311        );
312
313        Ok(Self {
314            transport: Some(Transport::Tls(Box::new(tls_stream))),
315            conn_params: ConnectionParams {
316                addr,
317                hostname: Some(hostname.to_string()),
318                port: Some(port),
319                tls_config: Some(tls_config),
320            },
321            reconnect_config: None,
322            reconnect_count: 0,
323            verify_crc: true,
324        })
325    }
326
327    /// Enable automatic reconnection
328    ///
329    /// # Arguments
330    /// * `config` - Reconnection configuration
331    pub fn with_reconnect(mut self, config: ReconnectConfig) -> Self {
332        self.reconnect_config = Some(config);
333        self
334    }
335
336    /// Enable or disable CRC verification
337    pub fn set_verify_crc(&mut self, verify: bool) {
338        self.verify_crc = verify;
339    }
340
341    /// Get current CRC verification setting
342    pub fn verify_crc(&self) -> bool {
343        self.verify_crc
344    }
345
346    /// Get reconnection count
347    pub fn reconnect_count(&self) -> usize {
348        self.reconnect_count
349    }
350
351    /// Check if currently connected
352    pub fn is_connected(&self) -> bool {
353        self.transport.is_some()
354    }
355
356    /// Ensure we have a valid connection, reconnecting if necessary
357    async fn ensure_connected(&mut self) -> Result<()> {
358        if self.transport.is_some() {
359            return Ok(());
360        }
361
362        let Some(ref config) = self.reconnect_config else {
363            return Err(IgtlError::Io(std::io::Error::new(
364                std::io::ErrorKind::NotConnected,
365                "Connection lost and reconnection is not enabled",
366            )));
367        };
368
369        let mut attempt = 0;
370
371        loop {
372            if let Some(max) = config.max_attempts {
373                if attempt >= max {
374                    warn!(
375                        attempts = attempt,
376                        max_attempts = max,
377                        "Max reconnection attempts reached"
378                    );
379                    return Err(IgtlError::Io(std::io::Error::new(
380                        std::io::ErrorKind::TimedOut,
381                        "Max reconnection attempts exceeded",
382                    )));
383                }
384            }
385
386            let delay = config.delay_for_attempt(attempt);
387            if attempt > 0 {
388                info!(
389                    attempt = attempt + 1,
390                    delay_ms = delay.as_millis(),
391                    "Reconnecting..."
392                );
393                sleep(delay).await;
394            }
395
396            let result = if let Some(ref tls_config) = self.conn_params.tls_config {
397                // TLS reconnection
398                let hostname = self.conn_params.hostname.as_ref().unwrap();
399                let port = self.conn_params.port.unwrap();
400                Self::connect_with_tls(hostname, port, tls_config.clone()).await
401            } else {
402                // Plain TCP reconnection
403                Self::connect(&self.conn_params.addr).await
404            };
405
406            match result {
407                Ok(new_client) => {
408                    self.transport = new_client.transport;
409                    if attempt > 0 {
410                        self.reconnect_count += 1;
411                        info!(
412                            reconnect_count = self.reconnect_count,
413                            "Reconnection successful"
414                        );
415                    }
416                    return Ok(());
417                }
418                Err(e) => {
419                    warn!(
420                        attempt = attempt + 1,
421                        error = %e,
422                        "Reconnection attempt failed"
423                    );
424                    attempt += 1;
425                }
426            }
427        }
428    }
429
430    /// Send a message
431    pub async fn send<T: Message>(&mut self, msg: &IgtlMessage<T>) -> Result<()> {
432        let data = msg.encode()?;
433        let msg_type = msg.header.type_name.as_str().unwrap_or("UNKNOWN");
434        let device_name = msg.header.device_name.as_str().unwrap_or("UNKNOWN");
435
436        debug!(
437            msg_type = msg_type,
438            device_name = device_name,
439            size = data.len(),
440            "Sending message"
441        );
442
443        loop {
444            if self.reconnect_config.is_some() {
445                self.ensure_connected().await?;
446            }
447
448            if let Some(transport) = &mut self.transport {
449                match transport.write_all(&data).await {
450                    Ok(_) => {
451                        transport.flush().await?;
452                        trace!(
453                            msg_type = msg_type,
454                            bytes_sent = data.len(),
455                            "Message sent successfully"
456                        );
457                        return Ok(());
458                    }
459                    Err(e) => {
460                        if self.reconnect_config.is_some() {
461                            warn!(error = %e, "Send failed, will reconnect");
462                            self.transport = None;
463                            // Loop will retry after reconnection
464                        } else {
465                            return Err(e);
466                        }
467                    }
468                }
469            } else {
470                return Err(IgtlError::Io(std::io::Error::new(
471                    std::io::ErrorKind::NotConnected,
472                    "Not connected",
473                )));
474            }
475        }
476    }
477
478    /// Receive a message
479    pub async fn receive<T: Message>(&mut self) -> Result<IgtlMessage<T>> {
480        loop {
481            if self.reconnect_config.is_some() {
482                self.ensure_connected().await?;
483            }
484
485            if let Some(transport) = &mut self.transport {
486                // Read header
487                let mut header_buf = vec![0u8; Header::SIZE];
488                match transport.read_exact(&mut header_buf).await {
489                    Ok(_) => {}
490                    Err(e) => {
491                        if self.reconnect_config.is_some() {
492                            warn!(error = %e, "Header read failed, will reconnect");
493                            self.transport = None;
494                            continue;
495                        } else {
496                            return Err(e);
497                        }
498                    }
499                }
500
501                let header = Header::decode(&header_buf)?;
502                let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
503                let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
504
505                debug!(
506                    msg_type = msg_type,
507                    device_name = device_name,
508                    body_size = header.body_size,
509                    version = header.version,
510                    "Received message header"
511                );
512
513                // Read body
514                let mut body_buf = vec![0u8; header.body_size as usize];
515                match transport.read_exact(&mut body_buf).await {
516                    Ok(_) => {}
517                    Err(e) => {
518                        if self.reconnect_config.is_some() {
519                            warn!(error = %e, "Body read failed, will reconnect");
520                            self.transport = None;
521                            continue;
522                        } else {
523                            return Err(e);
524                        }
525                    }
526                }
527
528                trace!(
529                    msg_type = msg_type,
530                    bytes_read = body_buf.len(),
531                    "Message body received"
532                );
533
534                // Decode full message
535                let mut full_msg = header_buf;
536                full_msg.extend_from_slice(&body_buf);
537
538                let result = IgtlMessage::decode_with_options(&full_msg, self.verify_crc);
539
540                match &result {
541                    Ok(_) => {
542                        debug!(
543                            msg_type = msg_type,
544                            device_name = device_name,
545                            "Message decoded successfully"
546                        );
547                    }
548                    Err(e) => {
549                        warn!(
550                            msg_type = msg_type,
551                            error = %e,
552                            "Failed to decode message"
553                        );
554                    }
555                }
556
557                return result;
558            } else {
559                return Err(IgtlError::Io(std::io::Error::new(
560                    std::io::ErrorKind::NotConnected,
561                    "Not connected",
562                )));
563            }
564        }
565    }
566
567    /// Receive any message type dynamically without knowing the type in advance
568    ///
569    /// This method reads the message header first, determines the message type,
570    /// and then decodes the appropriate message type dynamically.
571    ///
572    /// # Returns
573    ///
574    /// An `AnyMessage` enum containing the decoded message. If the message type
575    /// is not recognized, it will be returned as `AnyMessage::Unknown` with the
576    /// raw header and body bytes.
577    ///
578    /// # Examples
579    ///
580    /// ```no_run
581    /// use openigtlink_rust::io::builder::ClientBuilder;
582    /// use openigtlink_rust::protocol::AnyMessage;
583    ///
584    /// # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
585    /// let mut client = ClientBuilder::new()
586    ///     .tcp("127.0.0.1:18944")
587    ///     .async_mode()
588    ///     .build()
589    ///     .await?;
590    ///
591    /// loop {
592    ///     let msg = client.receive_any().await?;
593    ///
594    ///     match msg {
595    ///         AnyMessage::Transform(transform_msg) => {
596    ///             println!("Received transform from {}",
597    ///                      transform_msg.header.device_name.as_str()?);
598    ///         }
599    ///         AnyMessage::Status(status_msg) => {
600    ///             println!("Status: {}", status_msg.content.status_string);
601    ///         }
602    ///         AnyMessage::Image(image_msg) => {
603    ///             println!("Received image: {}x{}x{}",
604    ///                      image_msg.content.size[0],
605    ///                      image_msg.content.size[1],
606    ///                      image_msg.content.size[2]);
607    ///         }
608    ///         AnyMessage::Unknown { header, .. } => {
609    ///             println!("Unknown message type: {}",
610    ///                      header.type_name.as_str()?);
611    ///         }
612    ///         _ => {}
613    ///     }
614    /// }
615    /// # Ok(())
616    /// # }
617    /// ```
618    pub async fn receive_any(&mut self) -> Result<AnyMessage> {
619        loop {
620            if self.reconnect_config.is_some() {
621                self.ensure_connected().await?;
622            }
623
624            if let Some(transport) = &mut self.transport {
625                // Read header
626                let mut header_buf = vec![0u8; Header::SIZE];
627                match transport.read_exact(&mut header_buf).await {
628                    Ok(_) => {}
629                    Err(e) => {
630                        if self.reconnect_config.is_some() {
631                            warn!(error = %e, "Header read failed, will reconnect");
632                            self.transport = None;
633                            continue;
634                        } else {
635                            return Err(e);
636                        }
637                    }
638                }
639
640                let header = Header::decode(&header_buf)?;
641                let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
642                let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
643
644                debug!(
645                    msg_type = msg_type,
646                    device_name = device_name,
647                    body_size = header.body_size,
648                    version = header.version,
649                    "Received message header"
650                );
651
652                // Read body
653                let mut body_buf = vec![0u8; header.body_size as usize];
654                match transport.read_exact(&mut body_buf).await {
655                    Ok(_) => {}
656                    Err(e) => {
657                        if self.reconnect_config.is_some() {
658                            warn!(error = %e, "Body read failed, will reconnect");
659                            self.transport = None;
660                            continue;
661                        } else {
662                            return Err(e);
663                        }
664                    }
665                }
666
667                trace!(
668                    msg_type = msg_type,
669                    bytes_read = body_buf.len(),
670                    "Message body received"
671                );
672
673                // Decode using MessageFactory
674                let factory = MessageFactory::new();
675                let result = factory.decode_any(&header, &body_buf, self.verify_crc);
676
677                match &result {
678                    Ok(msg) => {
679                        debug!(
680                            msg_type = msg.message_type(),
681                            device_name = device_name,
682                            "Message decoded successfully"
683                        );
684                    }
685                    Err(e) => {
686                        warn!(
687                            msg_type = msg_type,
688                            error = %e,
689                            "Failed to decode message"
690                        );
691                    }
692                }
693
694                return result;
695            } else {
696                return Err(IgtlError::Io(std::io::Error::new(
697                    std::io::ErrorKind::NotConnected,
698                    "Not connected",
699                )));
700            }
701        }
702    }
703}