openigtlink_rust/io/unified_async_client.rs
1//! Unified async client with optional TLS and reconnection
2//!
3//! This module provides `UnifiedAsyncClient`, a single async client type that elegantly
4//! handles all feature combinations (TLS, reconnection) through internal state management.
5//!
6//! # Design Philosophy
7//!
8//! Traditional approach would create separate types for each feature combination:
9//! - `TcpAsync`, `TcpAsyncTls`, `TcpAsyncReconnect`, `TcpAsyncTlsReconnect`...
10//! - This leads to **variant explosion**: 2 features = 4 types, 3 features = 8 types, etc.
11//!
12//! **Our approach**: Single `UnifiedAsyncClient` with optional features:
13//! - Internal `Transport` enum: `Plain(TcpStream)` or `Tls(TlsStream)`
14//! - Optional `reconnect_config: Option<ReconnectConfig>`
15//! - ✅ Scales linearly with features (not exponentially!)
16//! - ✅ Easy to add new features (compression, authentication, etc.)
17//! - ✅ Maintains type safety through builder pattern
18//!
19//! # Architecture
20//!
21//! ```text
22//! UnifiedAsyncClient
23//! ├─ transport: Option<Transport>
24//! │ ├─ Plain(TcpStream) ← Regular TCP
25//! │ └─ Tls(TlsStream) ← TLS-encrypted TCP
26//! ├─ reconnect_config: Option<ReconnectConfig>
27//! │ ├─ None ← No auto-reconnection
28//! │ └─ Some(config) ← Auto-reconnect with backoff
29//! ├─ conn_params: ConnectionParams (host, port, TLS config)
30//! └─ verify_crc: bool ← CRC verification
31//! ```
32//!
33//! # Examples
34//!
35//! ## Plain TCP Connection
36//!
37//! ```no_run
38//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
39//!
40//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
41//! let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! ## TLS-Encrypted Connection
47//!
48//! ```no_run
49//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
50//! use std::sync::Arc;
51//!
52//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
53//! let tls_config = rustls::ClientConfig::builder()
54//! .with_root_certificates(rustls::RootCertStore::empty())
55//! .with_no_client_auth();
56//!
57//! let client = UnifiedAsyncClient::connect_with_tls(
58//! "hospital-server.local",
59//! 18944,
60//! Arc::new(tls_config)
61//! ).await?;
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! ## With Auto-Reconnection
67//!
68//! ```no_run
69//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
70//! use openigtlink_rust::io::reconnect::ReconnectConfig;
71//!
72//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
73//! let mut client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
74//!
75//! // Enable auto-reconnection
76//! let reconnect_config = ReconnectConfig::with_max_attempts(10);
77//! client = client.with_reconnect(reconnect_config);
78//! # Ok(())
79//! # }
80//! ```
81//!
82//! ## TLS + Auto-Reconnect (Previously Impossible!)
83//!
84//! ```no_run
85//! use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
86//! use openigtlink_rust::io::reconnect::ReconnectConfig;
87//! use std::sync::Arc;
88//!
89//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
90//! let tls_config = rustls::ClientConfig::builder()
91//! .with_root_certificates(rustls::RootCertStore::empty())
92//! .with_no_client_auth();
93//!
94//! let mut client = UnifiedAsyncClient::connect_with_tls(
95//! "production-server",
96//! 18944,
97//! Arc::new(tls_config)
98//! ).await?;
99//!
100//! // Add auto-reconnection to TLS client
101//! let reconnect_config = ReconnectConfig::with_max_attempts(100);
102//! client = client.with_reconnect(reconnect_config);
103//! # Ok(())
104//! # }
105//! ```
106//!
107//! # Prefer Using the Builder
108//!
109//! While you can create `UnifiedAsyncClient` directly, it's recommended to use
110//! [`ClientBuilder`](crate::io::builder::ClientBuilder) for better ergonomics and type safety:
111//!
112//! ```no_run
113//! use openigtlink_rust::io::builder::ClientBuilder;
114//! use openigtlink_rust::io::reconnect::ReconnectConfig;
115//! use std::sync::Arc;
116//!
117//! # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
118//! let tls_config = rustls::ClientConfig::builder()
119//! .with_root_certificates(rustls::RootCertStore::empty())
120//! .with_no_client_auth();
121//!
122//! let client = ClientBuilder::new()
123//! .tcp("production-server:18944")
124//! .async_mode()
125//! .with_tls(Arc::new(tls_config))
126//! .with_reconnect(ReconnectConfig::with_max_attempts(100))
127//! .verify_crc(true)
128//! .build()
129//! .await?;
130//! # Ok(())
131//! # }
132//! ```
133
134use crate::error::{IgtlError, Result};
135use crate::io::reconnect::ReconnectConfig;
136use crate::protocol::header::Header;
137use crate::protocol::message::{IgtlMessage, Message};
138use rustls::pki_types::ServerName;
139use std::sync::Arc;
140use tokio::io::{AsyncReadExt, AsyncWriteExt};
141use tokio::net::TcpStream;
142use tokio::time::sleep;
143use tokio_rustls::client::TlsStream;
144use tokio_rustls::{rustls, TlsConnector};
145use tracing::{debug, info, trace, warn};
146
147/// Transport type for the async client
148enum Transport {
149 Plain(TcpStream),
150 Tls(TlsStream<TcpStream>),
151}
152
153impl Transport {
154 async fn write_all(&mut self, data: &[u8]) -> Result<()> {
155 match self {
156 Transport::Plain(stream) => {
157 stream.write_all(data).await?;
158 Ok(())
159 }
160 Transport::Tls(stream) => {
161 stream.write_all(data).await?;
162 Ok(())
163 }
164 }
165 }
166
167 async fn flush(&mut self) -> Result<()> {
168 match self {
169 Transport::Plain(stream) => {
170 stream.flush().await?;
171 Ok(())
172 }
173 Transport::Tls(stream) => {
174 stream.flush().await?;
175 Ok(())
176 }
177 }
178 }
179
180 async fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
181 match self {
182 Transport::Plain(stream) => {
183 stream.read_exact(buf).await?;
184 Ok(())
185 }
186 Transport::Tls(stream) => {
187 stream.read_exact(buf).await?;
188 Ok(())
189 }
190 }
191 }
192}
193
194/// Connection parameters for reconnection
195struct ConnectionParams {
196 addr: String,
197 hostname: Option<String>,
198 port: Option<u16>,
199 tls_config: Option<Arc<rustls::ClientConfig>>,
200}
201
202/// Unified async OpenIGTLink client
203///
204/// Supports optional TLS encryption and automatic reconnection without
205/// combinatorial type explosion.
206///
207/// # Examples
208///
209/// ```no_run
210/// use openigtlink_rust::io::unified_async_client::UnifiedAsyncClient;
211///
212/// # async fn example() -> Result<(), openigtlink_rust::error::IgtlError> {
213/// // Plain TCP client
214/// let client = UnifiedAsyncClient::connect("127.0.0.1:18944").await?;
215///
216/// // With TLS
217/// let tls_config = rustls::ClientConfig::builder()
218/// .with_root_certificates(rustls::RootCertStore::empty())
219/// .with_no_client_auth();
220/// let client = UnifiedAsyncClient::connect_with_tls(
221/// "localhost",
222/// 18944,
223/// std::sync::Arc::new(tls_config)
224/// ).await?;
225/// # Ok(())
226/// # }
227/// ```
228pub struct UnifiedAsyncClient {
229 transport: Option<Transport>,
230 conn_params: ConnectionParams,
231 reconnect_config: Option<ReconnectConfig>,
232 reconnect_count: usize,
233 verify_crc: bool,
234}
235
236impl UnifiedAsyncClient {
237 /// Connect to a plain TCP server
238 ///
239 /// # Arguments
240 /// * `addr` - Server address (e.g., "127.0.0.1:18944")
241 pub async fn connect(addr: &str) -> Result<Self> {
242 info!(addr = addr, "Connecting to OpenIGTLink server");
243 let stream = TcpStream::connect(addr).await?;
244 let local_addr = stream.local_addr()?;
245 info!(
246 local_addr = %local_addr,
247 remote_addr = addr,
248 "Connected to OpenIGTLink server"
249 );
250
251 Ok(Self {
252 transport: Some(Transport::Plain(stream)),
253 conn_params: ConnectionParams {
254 addr: addr.to_string(),
255 hostname: None,
256 port: None,
257 tls_config: None,
258 },
259 reconnect_config: None,
260 reconnect_count: 0,
261 verify_crc: true,
262 })
263 }
264
265 /// Connect to a TLS-enabled server
266 ///
267 /// # Arguments
268 /// * `hostname` - Server hostname (for SNI)
269 /// * `port` - Server port
270 /// * `tls_config` - TLS client configuration
271 pub async fn connect_with_tls(
272 hostname: &str,
273 port: u16,
274 tls_config: Arc<rustls::ClientConfig>,
275 ) -> Result<Self> {
276 info!(
277 hostname = hostname,
278 port = port,
279 "Connecting to TLS-enabled OpenIGTLink server"
280 );
281
282 let addr = format!("{}:{}", hostname, port);
283 let tcp_stream = TcpStream::connect(&addr).await?;
284 let local_addr = tcp_stream.local_addr()?;
285
286 let server_name = ServerName::try_from(hostname.to_string()).map_err(|e| {
287 IgtlError::Io(std::io::Error::new(
288 std::io::ErrorKind::InvalidInput,
289 format!("Invalid hostname: {}", e),
290 ))
291 })?;
292
293 let connector = TlsConnector::from(tls_config.clone());
294 let tls_stream = connector.connect(server_name, tcp_stream).await.map_err(|e| {
295 warn!(error = %e, "TLS handshake failed");
296 IgtlError::Io(std::io::Error::new(
297 std::io::ErrorKind::ConnectionRefused,
298 format!("TLS handshake failed: {}", e),
299 ))
300 })?;
301
302 info!(
303 local_addr = %local_addr,
304 remote_addr = %addr,
305 "TLS connection established"
306 );
307
308 Ok(Self {
309 transport: Some(Transport::Tls(tls_stream)),
310 conn_params: ConnectionParams {
311 addr,
312 hostname: Some(hostname.to_string()),
313 port: Some(port),
314 tls_config: Some(tls_config),
315 },
316 reconnect_config: None,
317 reconnect_count: 0,
318 verify_crc: true,
319 })
320 }
321
322 /// Enable automatic reconnection
323 ///
324 /// # Arguments
325 /// * `config` - Reconnection configuration
326 pub fn with_reconnect(mut self, config: ReconnectConfig) -> Self {
327 self.reconnect_config = Some(config);
328 self
329 }
330
331 /// Enable or disable CRC verification
332 pub fn set_verify_crc(&mut self, verify: bool) {
333 self.verify_crc = verify;
334 }
335
336 /// Get current CRC verification setting
337 pub fn verify_crc(&self) -> bool {
338 self.verify_crc
339 }
340
341 /// Get reconnection count
342 pub fn reconnect_count(&self) -> usize {
343 self.reconnect_count
344 }
345
346 /// Check if currently connected
347 pub fn is_connected(&self) -> bool {
348 self.transport.is_some()
349 }
350
351 /// Ensure we have a valid connection, reconnecting if necessary
352 async fn ensure_connected(&mut self) -> Result<()> {
353 if self.transport.is_some() {
354 return Ok(());
355 }
356
357 let Some(ref config) = self.reconnect_config else {
358 return Err(IgtlError::Io(std::io::Error::new(
359 std::io::ErrorKind::NotConnected,
360 "Connection lost and reconnection is not enabled",
361 )));
362 };
363
364 let mut attempt = 0;
365
366 loop {
367 if let Some(max) = config.max_attempts {
368 if attempt >= max {
369 warn!(
370 attempts = attempt,
371 max_attempts = max,
372 "Max reconnection attempts reached"
373 );
374 return Err(IgtlError::Io(std::io::Error::new(
375 std::io::ErrorKind::TimedOut,
376 "Max reconnection attempts exceeded",
377 )));
378 }
379 }
380
381 let delay = config.delay_for_attempt(attempt);
382 if attempt > 0 {
383 info!(
384 attempt = attempt + 1,
385 delay_ms = delay.as_millis(),
386 "Reconnecting..."
387 );
388 sleep(delay).await;
389 }
390
391 let result = if let Some(ref tls_config) = self.conn_params.tls_config {
392 // TLS reconnection
393 let hostname = self.conn_params.hostname.as_ref().unwrap();
394 let port = self.conn_params.port.unwrap();
395 Self::connect_with_tls(hostname, port, tls_config.clone()).await
396 } else {
397 // Plain TCP reconnection
398 Self::connect(&self.conn_params.addr).await
399 };
400
401 match result {
402 Ok(new_client) => {
403 self.transport = new_client.transport;
404 if attempt > 0 {
405 self.reconnect_count += 1;
406 info!(
407 reconnect_count = self.reconnect_count,
408 "Reconnection successful"
409 );
410 }
411 return Ok(());
412 }
413 Err(e) => {
414 warn!(
415 attempt = attempt + 1,
416 error = %e,
417 "Reconnection attempt failed"
418 );
419 attempt += 1;
420 }
421 }
422 }
423 }
424
425 /// Send a message
426 pub async fn send<T: Message>(&mut self, msg: &IgtlMessage<T>) -> Result<()> {
427 let data = msg.encode()?;
428 let msg_type = msg.header.type_name.as_str().unwrap_or("UNKNOWN");
429 let device_name = msg.header.device_name.as_str().unwrap_or("UNKNOWN");
430
431 debug!(
432 msg_type = msg_type,
433 device_name = device_name,
434 size = data.len(),
435 "Sending message"
436 );
437
438 loop {
439 if self.reconnect_config.is_some() {
440 self.ensure_connected().await?;
441 }
442
443 if let Some(transport) = &mut self.transport {
444 match transport.write_all(&data).await {
445 Ok(_) => {
446 transport.flush().await?;
447 trace!(
448 msg_type = msg_type,
449 bytes_sent = data.len(),
450 "Message sent successfully"
451 );
452 return Ok(());
453 }
454 Err(e) => {
455 if self.reconnect_config.is_some() {
456 warn!(error = %e, "Send failed, will reconnect");
457 self.transport = None;
458 // Loop will retry after reconnection
459 } else {
460 return Err(e);
461 }
462 }
463 }
464 } else {
465 return Err(IgtlError::Io(std::io::Error::new(
466 std::io::ErrorKind::NotConnected,
467 "Not connected",
468 )));
469 }
470 }
471 }
472
473 /// Receive a message
474 pub async fn receive<T: Message>(&mut self) -> Result<IgtlMessage<T>> {
475 loop {
476 if self.reconnect_config.is_some() {
477 self.ensure_connected().await?;
478 }
479
480 if let Some(transport) = &mut self.transport {
481 // Read header
482 let mut header_buf = vec![0u8; Header::SIZE];
483 match transport.read_exact(&mut header_buf).await {
484 Ok(_) => {}
485 Err(e) => {
486 if self.reconnect_config.is_some() {
487 warn!(error = %e, "Header read failed, will reconnect");
488 self.transport = None;
489 continue;
490 } else {
491 return Err(e);
492 }
493 }
494 }
495
496 let header = Header::decode(&header_buf)?;
497 let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
498 let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
499
500 debug!(
501 msg_type = msg_type,
502 device_name = device_name,
503 body_size = header.body_size,
504 version = header.version,
505 "Received message header"
506 );
507
508 // Read body
509 let mut body_buf = vec![0u8; header.body_size as usize];
510 match transport.read_exact(&mut body_buf).await {
511 Ok(_) => {}
512 Err(e) => {
513 if self.reconnect_config.is_some() {
514 warn!(error = %e, "Body read failed, will reconnect");
515 self.transport = None;
516 continue;
517 } else {
518 return Err(e);
519 }
520 }
521 }
522
523 trace!(
524 msg_type = msg_type,
525 bytes_read = body_buf.len(),
526 "Message body received"
527 );
528
529 // Decode full message
530 let mut full_msg = header_buf;
531 full_msg.extend_from_slice(&body_buf);
532
533 let result = IgtlMessage::decode_with_options(&full_msg, self.verify_crc);
534
535 match &result {
536 Ok(_) => {
537 debug!(
538 msg_type = msg_type,
539 device_name = device_name,
540 "Message decoded successfully"
541 );
542 }
543 Err(e) => {
544 warn!(
545 msg_type = msg_type,
546 error = %e,
547 "Failed to decode message"
548 );
549 }
550 }
551
552 return result;
553 } else {
554 return Err(IgtlError::Io(std::io::Error::new(
555 std::io::ErrorKind::NotConnected,
556 "Not connected",
557 )));
558 }
559 }
560 }
561}