openigtlink_rust/io/
unified_async_client.rs

1//! Unified async client with optional TLS and reconnection
2//!
3//! This module provides `UnifiedAsyncClient`, a single async client type that elegantly
4//! handles all feature combinations (TLS, reconnection) through internal state management.
5//!
6//! # Design Philosophy
7//!
8//! Traditional approach would create separate types for each feature combination:
9//! - `TcpAsync`, `TcpAsyncTls`, `TcpAsyncReconnect`, `TcpAsyncTlsReconnect`...
10//! - This leads to **variant explosion**: 2 features = 4 types, 3 features = 8 types, etc.
11//!
12//! **Our approach**: Single `UnifiedAsyncClient` with optional features:
13//! - Internal `Transport` enum: `Plain(TcpStream)` or `Tls(TlsStream)`
14//! - Optional `reconnect_config: Option<ReconnectConfig>`
15//! - ✅ Scales linearly with features (not exponentially!)
16//! - ✅ Easy to add new features (compression, authentication, etc.)
17//! - ✅ Maintains type safety through builder pattern
18//!
19//! # Architecture
20//!
21//! ```text
22//! UnifiedAsyncClient
23//! ├─ transport: Option<Transport>
24//! │  ├─ Plain(TcpStream)     ← Regular TCP
25//! │  └─ Tls(TlsStream)       ← TLS-encrypted TCP
26//! ├─ reconnect_config: Option<ReconnectConfig>
27//! │  ├─ None                 ← No auto-reconnection
28//! │  └─ Some(config)         ← Auto-reconnect with backoff
29//! ├─ conn_params: ConnectionParams (host, port, TLS config)
30//! └─ verify_crc: bool        ← CRC verification
31//! ```
32//!
33//! # Examples
34//!
35//! ## Plain TCP Connection
36//!
37//! ```no_run
38//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
39//!
40//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
41//! let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! ## TLS-Encrypted Connection
47//!
48//! ```no_run
49//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
50//! use std::sync::Arc;
51//!
52//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
53//! let tls_config = rustls::ClientConfig::builder()
54//!     .with_root_certificates(rustls::RootCertStore::empty())
55//!     .with_no_client_auth();
56//!
57//! let client = UnifiedAsyncClient::connect_with_tls(
58//!     "hospital-server.local",
59//!     18944,
60//!     Arc::new(tls_config)
61//! ).await?;
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! ## With Auto-Reconnection
67//!
68//! ```no_run
69//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
70//! use openigtlink_rust::io::reconnect::ReconnectConfig;
71//!
72//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
73//! let mut client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
74//!
75//! // Enable auto-reconnection
76//! let reconnect_config = ReconnectConfig::with_max_attempts(10);
77//! client = client.with_reconnect(reconnect_config);
78//! # Ok(())
79//! # }
80//! ```
81//!
82//! ## TLS + Auto-Reconnect (Previously Impossible!)
83//!
84//! ```no_run
85//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
86//! use openigtlink_rust::io::reconnect::ReconnectConfig;
87//! use std::sync::Arc;
88//!
89//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
90//! let tls_config = rustls::ClientConfig::builder()
91//!     .with_root_certificates(rustls::RootCertStore::empty())
92//!     .with_no_client_auth();
93//!
94//! let mut client = UnifiedAsyncClient::connect_with_tls(
95//!     "production-server",
96//!     18944,
97//!     Arc::new(tls_config)
98//! ).await?;
99//!
100//! // Add auto-reconnection to TLS client
101//! let reconnect_config = ReconnectConfig::with_max_attempts(100);
102//! client = client.with_reconnect(reconnect_config);
103//! # Ok(())
104//! # }
105//! ```
106//!
107//! # Prefer Using the Builder
108//!
109//! While you can create `UnifiedAsyncClient` directly, it's recommended to use
110//! [`ClientBuilder`](crate::io::builder::ClientBuilder) for better ergonomics and type safety:
111//!
112//! ```no_run
113//! use openigtlink_rust::io::builder::ClientBuilder;
114//! use openigtlink_rust::io::reconnect::ReconnectConfig;
115//! use std::sync::Arc;
116//!
117//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
118//! let tls_config = rustls::ClientConfig::builder()
119//!     .with_root_certificates(rustls::RootCertStore::empty())
120//!     .with_no_client_auth();
121//!
122//! let client = ClientBuilder::new()
123//!     .tcp("production-server:18944")
124//!     .async_mode()
125//!     .with_tls(Arc::new(tls_config))
126//!     .with_reconnect(ReconnectConfig::with_max_attempts(100))
127//!     .verify_crc(true)
128//!     .build()
129//!     .await?;
130//! # Ok(())
131//! # }
132//! ```
133
134use crate::error::{IgtlError, Result};
135use crate::io::reconnect::ReconnectConfig;
136use crate::protocol::header::Header;
137use crate::protocol::message::{IgtlMessage, Message};
138use rustls::pki_types::ServerName;
139use std::sync::Arc;
140use tokio::io::{AsyncReadExt, AsyncWriteExt};
141use tokio::net::TcpStream;
142use tokio::time::sleep;
143use tokio_rustls::client::TlsStream;
144use tokio_rustls::{rustls, TlsConnector};
145use tracing::{debug, info, trace, warn};
146
147/// Transport type for the async client
148enum Transport {
149    Plain(TcpStream),
150    Tls(TlsStream<TcpStream>),
151}
152
153impl Transport {
154    async fn write_all(&mut self, data: &[u8]) -> Result<()> {
155        match self {
156            Transport::Plain(stream) => {
157                stream.write_all(data).await?;
158                Ok(())
159            }
160            Transport::Tls(stream) => {
161                stream.write_all(data).await?;
162                Ok(())
163            }
164        }
165    }
166
167    async fn flush(&mut self) -> Result<()> {
168        match self {
169            Transport::Plain(stream) => {
170                stream.flush().await?;
171                Ok(())
172            }
173            Transport::Tls(stream) => {
174                stream.flush().await?;
175                Ok(())
176            }
177        }
178    }
179
180    async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
181        match self {
182            Transport::Plain(stream) => {
183                stream.read_exact(buf).await?;
184                Ok(())
185            }
186            Transport::Tls(stream) => {
187                stream.read_exact(buf).await?;
188                Ok(())
189            }
190        }
191    }
192}
193
194/// Connection parameters for reconnection
195struct ConnectionParams {
196    addr: String,
197    hostname: Option<String>,
198    port: Option<u16>,
199    tls_config: Option<Arc<rustls::ClientConfig>>,
200}
201
202/// Unified async OpenIGTLink client
203///
204/// Supports optional TLS encryption and automatic reconnection without
205/// combinatorial type explosion.
206///
207/// # Examples
208///
209/// ```no_run
210/// use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
211///
212/// # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
213/// // Plain TCP client
214/// let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
215///
216/// // With TLS
217/// let tls_config = rustls::ClientConfig::builder()
218///     .with_root_certificates(rustls::RootCertStore::empty())
219///     .with_no_client_auth();
220/// let client = UnifiedAsyncClient::connect_with_tls(
221///     "localhost",
222///     18944,
223///     std::sync::Arc::new(tls_config)
224/// ).await?;
225/// # Ok(())
226/// # }
227/// ```
228pub struct UnifiedAsyncClient {
229    transport: Option<Transport>,
230    conn_params: ConnectionParams,
231    reconnect_config: Option<ReconnectConfig>,
232    reconnect_count: usize,
233    verify_crc: bool,
234}
235
236impl UnifiedAsyncClient {
237    /// Connect to a plain TCP server
238    ///
239    /// # Arguments
240    /// * `addr` - Server address (e.g., "127.0.0.1:18944")
241    pub async fn connect(addr: &str) -> Result<Self> {
242        info!(addr = addr, "Connecting to OpenIGTLink server");
243        let stream = TcpStream::connect(addr).await?;
244        let local_addr = stream.local_addr()?;
245        info!(
246            local_addr = %local_addr,
247            remote_addr = addr,
248            "Connected to OpenIGTLink server"
249        );
250
251        Ok(Self {
252            transport: Some(Transport::Plain(stream)),
253            conn_params: ConnectionParams {
254                addr: addr.to_string(),
255                hostname: None,
256                port: None,
257                tls_config: None,
258            },
259            reconnect_config: None,
260            reconnect_count: 0,
261            verify_crc: true,
262        })
263    }
264
265    /// Connect to a TLS-enabled server
266    ///
267    /// # Arguments
268    /// * `hostname` - Server hostname (for SNI)
269    /// * `port` - Server port
270    /// * `tls_config` - TLS client configuration
271    pub async fn connect_with_tls(
272        hostname: &str,
273        port: u16,
274        tls_config: Arc<rustls::ClientConfig>,
275    ) -> Result<Self> {
276        info!(
277            hostname = hostname,
278            port = port,
279            "Connecting to TLS-enabled OpenIGTLink server"
280        );
281
282        let addr = format!("{}:{}", hostname, port);
283        let tcp_stream = TcpStream::connect(&addr).await?;
284        let local_addr = tcp_stream.local_addr()?;
285
286        let server_name = ServerName::try_from(hostname.to_string()).map_err(|e| {
287            IgtlError::Io(std::io::Error::new(
288                std::io::ErrorKind::InvalidInput,
289                format!("Invalid hostname: {}", e),
290            ))
291        })?;
292
293        let connector = TlsConnector::from(tls_config.clone());
294        let tls_stream = connector.connect(server_name, tcp_stream).await.map_err(|e| {
295            warn!(error = %e, "TLS handshake failed");
296            IgtlError::Io(std::io::Error::new(
297                std::io::ErrorKind::ConnectionRefused,
298                format!("TLS handshake failed: {}", e),
299            ))
300        })?;
301
302        info!(
303            local_addr = %local_addr,
304            remote_addr = %addr,
305            "TLS connection established"
306        );
307
308        Ok(Self {
309            transport: Some(Transport::Tls(tls_stream)),
310            conn_params: ConnectionParams {
311                addr,
312                hostname: Some(hostname.to_string()),
313                port: Some(port),
314                tls_config: Some(tls_config),
315            },
316            reconnect_config: None,
317            reconnect_count: 0,
318            verify_crc: true,
319        })
320    }
321
322    /// Enable automatic reconnection
323    ///
324    /// # Arguments
325    /// * `config` - Reconnection configuration
326    pub fn with_reconnect(mut self, config: ReconnectConfig) -> Self {
327        self.reconnect_config = Some(config);
328        self
329    }
330
331    /// Enable or disable CRC verification
332    pub fn set_verify_crc(&mut self, verify: bool) {
333        self.verify_crc = verify;
334    }
335
336    /// Get current CRC verification setting
337    pub fn verify_crc(&self) -> bool {
338        self.verify_crc
339    }
340
341    /// Get reconnection count
342    pub fn reconnect_count(&self) -> usize {
343        self.reconnect_count
344    }
345
346    /// Check if currently connected
347    pub fn is_connected(&self) -> bool {
348        self.transport.is_some()
349    }
350
351    /// Ensure we have a valid connection, reconnecting if necessary
352    async fn ensure_connected(&mut self) -> Result<()> {
353        if self.transport.is_some() {
354            return Ok(());
355        }
356
357        let Some(ref config) = self.reconnect_config else {
358            return Err(IgtlError::Io(std::io::Error::new(
359                std::io::ErrorKind::NotConnected,
360                "Connection lost and reconnection is not enabled",
361            )));
362        };
363
364        let mut attempt = 0;
365
366        loop {
367            if let Some(max) = config.max_attempts {
368                if attempt >= max {
369                    warn!(
370                        attempts = attempt,
371                        max_attempts = max,
372                        "Max reconnection attempts reached"
373                    );
374                    return Err(IgtlError::Io(std::io::Error::new(
375                        std::io::ErrorKind::TimedOut,
376                        "Max reconnection attempts exceeded",
377                    )));
378                }
379            }
380
381            let delay = config.delay_for_attempt(attempt);
382            if attempt > 0 {
383                info!(
384                    attempt = attempt + 1,
385                    delay_ms = delay.as_millis(),
386                    "Reconnecting..."
387                );
388                sleep(delay).await;
389            }
390
391            let result = if let Some(ref tls_config) = self.conn_params.tls_config {
392                // TLS reconnection
393                let hostname = self.conn_params.hostname.as_ref().unwrap();
394                let port = self.conn_params.port.unwrap();
395                Self::connect_with_tls(hostname, port, tls_config.clone()).await
396            } else {
397                // Plain TCP reconnection
398                Self::connect(&self.conn_params.addr).await
399            };
400
401            match result {
402                Ok(new_client) => {
403                    self.transport = new_client.transport;
404                    if attempt > 0 {
405                        self.reconnect_count += 1;
406                        info!(
407                            reconnect_count = self.reconnect_count,
408                            "Reconnection successful"
409                        );
410                    }
411                    return Ok(());
412                }
413                Err(e) => {
414                    warn!(
415                        attempt = attempt + 1,
416                        error = %e,
417                        "Reconnection attempt failed"
418                    );
419                    attempt += 1;
420                }
421            }
422        }
423    }
424
425    /// Send a message
426    pub async fn send<T: Message>(&mut self, msg: &IgtlMessage<T>) -> Result<()> {
427        let data = msg.encode()?;
428        let msg_type = msg.header.type_name.as_str().unwrap_or("UNKNOWN");
429        let device_name = msg.header.device_name.as_str().unwrap_or("UNKNOWN");
430
431        debug!(
432            msg_type = msg_type,
433            device_name = device_name,
434            size = data.len(),
435            "Sending message"
436        );
437
438        loop {
439            if self.reconnect_config.is_some() {
440                self.ensure_connected().await?;
441            }
442
443            if let Some(transport) = &mut self.transport {
444                match transport.write_all(&data).await {
445                    Ok(_) => {
446                        transport.flush().await?;
447                        trace!(
448                            msg_type = msg_type,
449                            bytes_sent = data.len(),
450                            "Message sent successfully"
451                        );
452                        return Ok(());
453                    }
454                    Err(e) => {
455                        if self.reconnect_config.is_some() {
456                            warn!(error = %e, "Send failed, will reconnect");
457                            self.transport = None;
458                            // Loop will retry after reconnection
459                        } else {
460                            return Err(e);
461                        }
462                    }
463                }
464            } else {
465                return Err(IgtlError::Io(std::io::Error::new(
466                    std::io::ErrorKind::NotConnected,
467                    "Not connected",
468                )));
469            }
470        }
471    }
472
473    /// Receive a message
474    pub async fn receive<T: Message>(&mut self) -> Result<IgtlMessage<T>> {
475        loop {
476            if self.reconnect_config.is_some() {
477                self.ensure_connected().await?;
478            }
479
480            if let Some(transport) = &mut self.transport {
481                // Read header
482                let mut header_buf = vec![0u8; Header::SIZE];
483                match transport.read_exact(&mut header_buf).await {
484                    Ok(_) => {}
485                    Err(e) => {
486                        if self.reconnect_config.is_some() {
487                            warn!(error = %e, "Header read failed, will reconnect");
488                            self.transport = None;
489                            continue;
490                        } else {
491                            return Err(e);
492                        }
493                    }
494                }
495
496                let header = Header::decode(&header_buf)?;
497                let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
498                let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
499
500                debug!(
501                    msg_type = msg_type,
502                    device_name = device_name,
503                    body_size = header.body_size,
504                    version = header.version,
505                    "Received message header"
506                );
507
508                // Read body
509                let mut body_buf = vec![0u8; header.body_size as usize];
510                match transport.read_exact(&mut body_buf).await {
511                    Ok(_) => {}
512                    Err(e) => {
513                        if self.reconnect_config.is_some() {
514                            warn!(error = %e, "Body read failed, will reconnect");
515                            self.transport = None;
516                            continue;
517                        } else {
518                            return Err(e);
519                        }
520                    }
521                }
522
523                trace!(
524                    msg_type = msg_type,
525                    bytes_read = body_buf.len(),
526                    "Message body received"
527                );
528
529                // Decode full message
530                let mut full_msg = header_buf;
531                full_msg.extend_from_slice(&body_buf);
532
533                let result = IgtlMessage::decode_with_options(&full_msg, self.verify_crc);
534
535                match &result {
536                    Ok(_) => {
537                        debug!(
538                            msg_type = msg_type,
539                            device_name = device_name,
540                            "Message decoded successfully"
541                        );
542                    }
543                    Err(e) => {
544                        warn!(
545                            msg_type = msg_type,
546                            error = %e,
547                            "Failed to decode message"
548                        );
549                    }
550                }
551
552                return result;
553            } else {
554                return Err(IgtlError::Io(std::io::Error::new(
555                    std::io::ErrorKind::NotConnected,
556                    "Not connected",
557                )));
558            }
559        }
560    }
561}