openigtlink_rust/io/
session_manager.rs

1//! Multi-client session management for OpenIGTLink servers
2//!
3//! Provides a high-level abstraction for managing multiple concurrent client
4//! connections with message routing, broadcasting, and handler registration.
5//!
6//! # Features
7//!
8//! - Concurrent client session management
9//! - Message broadcasting to all/selected clients
10//! - Per-client message handlers
11//! - Automatic disconnection handling
12//! - Thread-safe client registry
13//!
14//! # Example
15//!
16//! ```no_run
17//! use openigtlink_rust::io::SessionManager;
18//! use openigtlink_rust::protocol::types::StatusMessage;
19//! use std::sync::Arc;
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//!     let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
24//!
25//!     // Spawn client acceptor
26//!     let mgr = manager.clone();
27//!     tokio::spawn(async move {
28//!         mgr.accept_clients().await;
29//!     });
30//!
31//!     // Broadcast status to all clients
32//!     let status = StatusMessage::ok("Server ready");
33//!     manager.broadcast(&status).await?;
34//!
35//!     Ok(())
36//! }
37//! ```
38
39use crate::error::{IgtlError, Result};
40use crate::protocol::header::Header;
41use crate::protocol::message::{IgtlMessage, Message};
42use std::collections::HashMap;
43use std::net::SocketAddr;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47use tokio::net::{TcpListener, TcpStream};
48use tokio::sync::{mpsc, RwLock};
49use tracing::{debug, info, trace, warn};
50
51/// Unique identifier for each client session
52pub type ClientId = u64;
53
54/// Client session state
55#[derive(Debug)]
56struct ClientSession {
57    /// Client ID
58    id: ClientId,
59    /// Client socket address
60    addr: SocketAddr,
61    /// Channel to send messages to this client
62    tx: mpsc::UnboundedSender<Vec<u8>>,
63    /// Connection start time
64    connected_at: std::time::Instant,
65}
66
67impl ClientSession {
68    /// Send a raw message to this client
69    async fn send_raw(&self, data: Vec<u8>) -> Result<()> {
70        self.tx.send(data).map_err(|_| {
71            IgtlError::Io(std::io::Error::new(
72                std::io::ErrorKind::BrokenPipe,
73                "Client disconnected",
74            ))
75        })?;
76        Ok(())
77    }
78
79    /// Get connection duration
80    fn uptime(&self) -> std::time::Duration {
81        self.connected_at.elapsed()
82    }
83}
84
85/// Multi-client session manager
86///
87/// Manages multiple concurrent OpenIGTLink client connections with automatic
88/// message routing and broadcasting capabilities.
89pub struct SessionManager {
90    /// TCP listener for accepting new clients
91    listener: TcpListener,
92    /// Active client sessions (ClientId -> ClientSession)
93    clients: Arc<RwLock<HashMap<ClientId, ClientSession>>>,
94    /// Client ID counter
95    next_client_id: AtomicU64,
96    /// Message handlers (optional)
97    handlers: Arc<RwLock<Vec<Box<dyn MessageHandler>>>>,
98}
99
100/// Trait for handling incoming messages
101///
102/// Implement this trait to process messages from clients.
103pub trait MessageHandler: Send + Sync {
104    /// Handle a message from a specific client
105    ///
106    /// # Arguments
107    /// * `client_id` - ID of the client that sent the message
108    /// * `type_name` - Message type name (e.g., "TRANSFORM")
109    /// * `data` - Raw message data (header + body)
110    fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]);
111}
112
113impl SessionManager {
114    /// Create a new session manager bound to the specified address
115    ///
116    /// # Arguments
117    /// * `addr` - Address to bind (e.g., "127.0.0.1:18944")
118    ///
119    /// # Examples
120    ///
121    /// ```no_run
122    /// use openigtlink_rust::io::SessionManager;
123    ///
124    /// #[tokio::main]
125    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
126    ///     let manager = SessionManager::new("0.0.0.0:18944").await?;
127    ///     Ok(())
128    /// }
129    /// ```
130    pub async fn new(addr: &str) -> Result<Self> {
131        info!(addr = %addr, "Creating SessionManager");
132        let listener = TcpListener::bind(addr).await?;
133        let local_addr = listener.local_addr()?;
134        info!(
135            local_addr = %local_addr,
136            "SessionManager listening for clients"
137        );
138        Ok(SessionManager {
139            listener,
140            clients: Arc::new(RwLock::new(HashMap::new())),
141            next_client_id: AtomicU64::new(1),
142            handlers: Arc::new(RwLock::new(Vec::new())),
143        })
144    }
145
146    /// Get the local address this manager is bound to
147    pub fn local_addr(&self) -> Result<SocketAddr> {
148        Ok(self.listener.local_addr()?)
149    }
150
151    /// Get the number of active client connections
152    ///
153    /// # Examples
154    ///
155    /// ```no_run
156    /// # use openigtlink_rust::io::SessionManager;
157    /// # #[tokio::main]
158    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
159    /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
160    /// println!("Active clients: {}", manager.client_count().await);
161    /// # Ok(())
162    /// # }
163    /// ```
164    pub async fn client_count(&self) -> usize {
165        self.clients.read().await.len()
166    }
167
168    /// Get a list of all active client IDs
169    pub async fn client_ids(&self) -> Vec<ClientId> {
170        self.clients.read().await.keys().copied().collect()
171    }
172
173    /// Get information about a specific client
174    pub async fn client_info(&self, client_id: ClientId) -> Option<ClientInfo> {
175        let clients = self.clients.read().await;
176        clients.get(&client_id).map(|session| ClientInfo {
177            id: session.id,
178            addr: session.addr,
179            uptime: session.uptime(),
180        })
181    }
182
183    /// Register a message handler
184    ///
185    /// Handlers are called in the order they were registered.
186    ///
187    /// # Examples
188    ///
189    /// ```no_run
190    /// use openigtlink_rust::io::{SessionManager, MessageHandler, ClientId};
191    ///
192    /// struct MyHandler;
193    ///
194    /// impl MessageHandler for MyHandler {
195    ///     fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]) {
196    ///         println!("Client {} sent {}", client_id, type_name);
197    ///     }
198    /// }
199    ///
200    /// # #[tokio::main]
201    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
202    /// # let mut manager = SessionManager::new("127.0.0.1:18944").await?;
203    /// manager.add_handler(Box::new(MyHandler)).await;
204    /// # Ok(())
205    /// # }
206    /// ```
207    pub async fn add_handler(&self, handler: Box<dyn MessageHandler>) {
208        debug!("Registering new message handler");
209        self.handlers.write().await.push(handler);
210        let count = self.handlers.read().await.len();
211        info!(handler_count = count, "Message handler registered");
212    }
213
214    /// Accept new client connections in a loop
215    ///
216    /// This method runs forever, accepting new clients and spawning handler tasks.
217    /// It should be run in a separate task.
218    ///
219    /// # Examples
220    ///
221    /// ```no_run
222    /// use openigtlink_rust::io::SessionManager;
223    /// use std::sync::Arc;
224    ///
225    /// #[tokio::main]
226    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
227    ///     let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
228    ///
229    ///     // Spawn acceptor in background
230    ///     let mgr = manager.clone();
231    ///     tokio::spawn(async move {
232    ///         mgr.accept_clients().await;
233    ///     });
234    ///
235    ///     // Do other work...
236    ///     Ok(())
237    /// }
238    /// ```
239    pub async fn accept_clients(&self) {
240        info!("Starting client accept loop");
241        loop {
242            match self.listener.accept().await {
243                Ok((socket, addr)) => {
244                    let client_id = self.next_client_id.fetch_add(1, Ordering::SeqCst);
245                    info!(
246                        client_id = client_id,
247                        addr = %addr,
248                        "Client connected"
249                    );
250
251                    if let Err(e) = self.handle_client(client_id, socket, addr).await {
252                        warn!(
253                            client_id = client_id,
254                            error = %e,
255                            "Failed to setup client session"
256                        );
257                    }
258                }
259                Err(e) => {
260                    warn!(error = %e, "Failed to accept client connection");
261                }
262            }
263        }
264    }
265
266    /// Handle a single client connection
267    async fn handle_client(
268        &self,
269        client_id: ClientId,
270        socket: TcpStream,
271        addr: SocketAddr,
272    ) -> Result<()> {
273        debug!(client_id = client_id, "Setting up client session");
274        let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
275
276        // Register client session
277        {
278            let session = ClientSession {
279                id: client_id,
280                addr,
281                tx,
282                connected_at: std::time::Instant::now(),
283            };
284            self.clients.write().await.insert(client_id, session);
285            let count = self.clients.read().await.len();
286            info!(
287                client_id = client_id,
288                total_clients = count,
289                "Client session registered"
290            );
291        }
292
293        // Split socket into read/write halves (consuming ownership)
294        let (mut reader, mut writer) = socket.into_split();
295
296        // Spawn sender task (sends messages to client)
297        let sender_task = tokio::spawn(async move {
298            while let Some(data) = rx.recv().await {
299                if writer.write_all(&data).await.is_err() {
300                    break;
301                }
302                if writer.flush().await.is_err() {
303                    break;
304                }
305            }
306        });
307
308        // Receiver task (reads messages from client)
309        let handlers = self.handlers.clone();
310
311        let receiver_task = tokio::spawn(async move {
312            trace!(client_id = client_id, "Client receiver task started");
313            loop {
314                // Read header
315                let mut header_buf = vec![0u8; Header::SIZE];
316                if reader.read_exact(&mut header_buf).await.is_err() {
317                    trace!(
318                        client_id = client_id,
319                        "Client disconnected (header read failed)"
320                    );
321                    break;
322                }
323
324                let header = match Header::decode(&header_buf) {
325                    Ok(h) => h,
326                    Err(e) => {
327                        warn!(
328                            client_id = client_id,
329                            error = %e,
330                            "Failed to decode header from client"
331                        );
332                        break;
333                    }
334                };
335
336                let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
337                let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
338
339                debug!(
340                    client_id = client_id,
341                    msg_type = msg_type,
342                    device_name = device_name,
343                    body_size = header.body_size,
344                    "Received message from client"
345                );
346
347                // Read body
348                let mut body_buf = vec![0u8; header.body_size as usize];
349                if reader.read_exact(&mut body_buf).await.is_err() {
350                    warn!(
351                        client_id = client_id,
352                        msg_type = msg_type,
353                        "Client disconnected while reading body"
354                    );
355                    break;
356                }
357
358                // Reconstruct full message
359                let mut full_msg = header_buf.clone();
360                full_msg.extend_from_slice(&body_buf);
361
362                // Call message handlers
363                let type_name = header.type_name.as_str().unwrap_or("UNKNOWN");
364                let handlers_guard = handlers.read().await;
365                trace!(
366                    client_id = client_id,
367                    msg_type = type_name,
368                    handler_count = handlers_guard.len(),
369                    "Dispatching message to handlers"
370                );
371                for handler in handlers_guard.iter() {
372                    handler.handle_message(client_id, type_name, &full_msg);
373                }
374            }
375        });
376
377        // Wait for either task to finish (indicates disconnection)
378        tokio::select! {
379            _ = sender_task => {
380                trace!(client_id = client_id, "Sender task finished");
381            },
382            _ = receiver_task => {
383                trace!(client_id = client_id, "Receiver task finished");
384            },
385        }
386
387        // Cleanup: remove client from registry
388        self.clients.write().await.remove(&client_id);
389        let remaining = self.clients.read().await.len();
390        info!(
391            client_id = client_id,
392            remaining_clients = remaining,
393            "Client disconnected"
394        );
395
396        Ok(())
397    }
398
399    /// Broadcast a message to all connected clients
400    ///
401    /// # Examples
402    ///
403    /// ```no_run
404    /// use openigtlink_rust::io::SessionManager;
405    /// use openigtlink_rust::protocol::types::StatusMessage;
406    ///
407    /// # #[tokio::main]
408    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
409    /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
410    /// let status = StatusMessage::ok("System ready");
411    /// manager.broadcast(&status).await?;
412    /// # Ok(())
413    /// # }
414    /// ```
415    pub async fn broadcast<T: Message + Clone>(&self, message: &T) -> Result<()> {
416        let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
417        let data = igtl_msg.encode()?;
418
419        let clients_guard = self.clients.read().await;
420        let client_count = clients_guard.len();
421
422        debug!(
423            msg_type = std::any::type_name::<T>(),
424            client_count = client_count,
425            size = data.len(),
426            "Broadcasting message to all clients"
427        );
428
429        for session in clients_guard.values() {
430            let _ = session.send_raw(data.clone()).await;
431        }
432
433        trace!(client_count = client_count, "Broadcast completed");
434
435        Ok(())
436    }
437
438    /// Send a message to a specific client
439    ///
440    /// # Examples
441    ///
442    /// ```no_run
443    /// use openigtlink_rust::io::SessionManager;
444    /// use openigtlink_rust::protocol::types::StatusMessage;
445    ///
446    /// # #[tokio::main]
447    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
448    /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
449    /// let status = StatusMessage::ok("Personal message");
450    /// manager.send_to(42, &status).await?;
451    /// # Ok(())
452    /// # }
453    /// ```
454    pub async fn send_to<T: Message + Clone>(
455        &self,
456        client_id: ClientId,
457        message: &T,
458    ) -> Result<()> {
459        let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
460        let data = igtl_msg.encode()?;
461
462        debug!(
463            client_id = client_id,
464            msg_type = std::any::type_name::<T>(),
465            size = data.len(),
466            "Sending message to client"
467        );
468
469        let clients_guard = self.clients.read().await;
470        if let Some(session) = clients_guard.get(&client_id) {
471            session.send_raw(data).await?;
472            trace!(client_id = client_id, "Message sent successfully");
473            Ok(())
474        } else {
475            warn!(client_id = client_id, "Client not found");
476            Err(IgtlError::Io(std::io::Error::new(
477                std::io::ErrorKind::NotFound,
478                format!("Client {} not found", client_id),
479            )))
480        }
481    }
482
483    /// Disconnect a specific client
484    pub async fn disconnect(&self, client_id: ClientId) -> Result<()> {
485        let mut clients = self.clients.write().await;
486        if clients.remove(&client_id).is_some() {
487            info!(client_id = client_id, "Forcibly disconnected client");
488            Ok(())
489        } else {
490            warn!(client_id = client_id, "Cannot disconnect: client not found");
491            Err(IgtlError::Io(std::io::Error::new(
492                std::io::ErrorKind::NotFound,
493                format!("Client {} not found", client_id),
494            )))
495        }
496    }
497
498    /// Disconnect all clients and shut down
499    pub async fn shutdown(&self) {
500        let mut clients = self.clients.write().await;
501        let count = clients.len();
502        clients.clear();
503        info!(
504            disconnected_clients = count,
505            "SessionManager shutdown complete"
506        );
507    }
508}
509
510/// Client information snapshot
511#[derive(Debug, Clone)]
512pub struct ClientInfo {
513    pub id: ClientId,
514    pub addr: SocketAddr,
515    pub uptime: std::time::Duration,
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use crate::protocol::types::StatusMessage;
522
523    #[tokio::test]
524    async fn test_session_manager_create() {
525        let manager = SessionManager::new("127.0.0.1:0").await;
526        assert!(manager.is_ok());
527    }
528
529    #[tokio::test]
530    async fn test_client_count() {
531        let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
532        assert_eq!(manager.client_count().await, 0);
533    }
534
535    #[tokio::test]
536    async fn test_broadcast_no_clients() {
537        let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
538        let status = StatusMessage::ok("test");
539        let result = manager.broadcast(&status).await;
540        assert!(result.is_ok());
541    }
542
543    #[tokio::test]
544    async fn test_send_to_nonexistent_client() {
545        let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
546        let status = StatusMessage::ok("test");
547        let result = manager.send_to(999, &status).await;
548        assert!(result.is_err());
549    }
550}