openigtlink_rust/io/
session_manager.rs

1//! Multi-client session management for OpenIGTLink servers
2//!
3//! Provides a high-level abstraction for managing multiple concurrent client
4//! connections with message routing, broadcasting, and handler registration.
5//!
6//! # Features
7//!
8//! - Concurrent client session management
9//! - Message broadcasting to all/selected clients
10//! - Per-client message handlers
11//! - Automatic disconnection handling
12//! - Thread-safe client registry
13//!
14//! # Example
15//!
16//! ```no_run
17//! use openigtlink_rust::io::SessionManager;
18//! use openigtlink_rust::protocol::types::StatusMessage;
19//! use std::sync::Arc;
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//!     let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
24//!
25//!     // Spawn client acceptor
26//!     let mgr = manager.clone();
27//!     tokio::spawn(async move {
28//!         mgr.accept_clients().await;
29//!     });
30//!
31//!     // Broadcast status to all clients
32//!     let status = StatusMessage::ok("Server ready");
33//!     manager.broadcast(&status).await?;
34//!
35//!     Ok(())
36//! }
37//! ```
38
39use crate::error::{IgtlError, Result};
40use crate::protocol::header::Header;
41use crate::protocol::message::{IgtlMessage, Message};
42use std::collections::HashMap;
43use std::net::SocketAddr;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47use tokio::net::{TcpListener, TcpStream};
48use tokio::sync::{mpsc, RwLock};
49use tracing::{debug, info, trace, warn};
50
51/// Unique identifier for each client session
52pub type ClientId = u64;
53
54/// Client session state
55#[derive(Debug)]
56struct ClientSession {
57    /// Client ID
58    id: ClientId,
59    /// Client socket address
60    addr: SocketAddr,
61    /// Channel to send messages to this client
62    tx: mpsc::UnboundedSender<Vec<u8>>,
63    /// Connection start time
64    connected_at: std::time::Instant,
65}
66
67impl ClientSession {
68    /// Send a raw message to this client
69    async fn send_raw(&self, data: Vec<u8>) -> Result<()> {
70        self.tx.send(data).map_err(|_| {
71            IgtlError::Io(std::io::Error::new(
72                std::io::ErrorKind::BrokenPipe,
73                "Client disconnected",
74            ))
75        })?;
76        Ok(())
77    }
78
79    /// Get connection duration
80    fn uptime(&self) -> std::time::Duration {
81        self.connected_at.elapsed()
82    }
83}
84
85/// Multi-client session manager
86///
87/// Manages multiple concurrent OpenIGTLink client connections with automatic
88/// message routing and broadcasting capabilities.
89pub struct SessionManager {
90    /// TCP listener for accepting new clients
91    listener: TcpListener,
92    /// Active client sessions (ClientId -> ClientSession)
93    clients: Arc<RwLock<HashMap<ClientId, ClientSession>>>,
94    /// Client ID counter
95    next_client_id: AtomicU64,
96    /// Message handlers (optional)
97    handlers: Arc<RwLock<Vec<Box<dyn MessageHandler>>>>,
98}
99
100/// Trait for handling incoming messages
101///
102/// Implement this trait to process messages from clients.
103pub trait MessageHandler: Send + Sync {
104    /// Handle a message from a specific client
105    ///
106    /// # Arguments
107    /// * `client_id` - ID of the client that sent the message
108    /// * `type_name` - Message type name (e.g., "TRANSFORM")
109    /// * `data` - Raw message data (header + body)
110    fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]);
111}
112
113impl SessionManager {
114    /// Create a new session manager bound to the specified address
115    ///
116    /// # Arguments
117    /// * `addr` - Address to bind (e.g., "127.0.0.1:18944")
118    ///
119    /// # Examples
120    ///
121    /// ```no_run
122    /// use openigtlink_rust::io::SessionManager;
123    ///
124    /// #[tokio::main]
125    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
126    ///     let manager = SessionManager::new("0.0.0.0:18944").await?;
127    ///     Ok(())
128    /// }
129    /// ```
130    pub async fn new(addr: &str) -> Result<Self> {
131        info!(addr = %addr, "Creating SessionManager");
132        let listener = TcpListener::bind(addr).await?;
133        let local_addr = listener.local_addr()?;
134        info!(
135            local_addr = %local_addr,
136            "SessionManager listening for clients"
137        );
138        Ok(SessionManager {
139            listener,
140            clients: Arc::new(RwLock::new(HashMap::new())),
141            next_client_id: AtomicU64::new(1),
142            handlers: Arc::new(RwLock::new(Vec::new())),
143        })
144    }
145
146    /// Get the local address this manager is bound to
147    pub fn local_addr(&self) -> Result<SocketAddr> {
148        Ok(self.listener.local_addr()?)
149    }
150
151    /// Get the number of active client connections
152    ///
153    /// # Examples
154    ///
155    /// ```no_run
156    /// # use openigtlink_rust::io::SessionManager;
157    /// # #[tokio::main]
158    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
159    /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
160    /// println!("Active clients: {}", manager.client_count().await);
161    /// # Ok(())
162    /// # }
163    /// ```
164    pub async fn client_count(&self) -> usize {
165        self.clients.read().await.len()
166    }
167
168    /// Get a list of all active client IDs
169    pub async fn client_ids(&self) -> Vec<ClientId> {
170        self.clients.read().await.keys().copied().collect()
171    }
172
173    /// Get information about a specific client
174    pub async fn client_info(&self, client_id: ClientId) -> Option<ClientInfo> {
175        let clients = self.clients.read().await;
176        clients.get(&client_id).map(|session| ClientInfo {
177            id: session.id,
178            addr: session.addr,
179            uptime: session.uptime(),
180        })
181    }
182
183    /// Register a message handler
184    ///
185    /// Handlers are called in the order they were registered.
186    ///
187    /// # Examples
188    ///
189    /// ```no_run
190    /// use openigtlink_rust::io::{SessionManager, MessageHandler, ClientId};
191    ///
192    /// struct MyHandler;
193    ///
194    /// impl MessageHandler for MyHandler {
195    ///     fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]) {
196    ///         println!("Client {} sent {}", client_id, type_name);
197    ///     }
198    /// }
199    ///
200    /// # #[tokio::main]
201    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
202    /// # let mut manager = SessionManager::new("127.0.0.1:18944").await?;
203    /// manager.add_handler(Box::new(MyHandler)).await;
204    /// # Ok(())
205    /// # }
206    /// ```
207    pub async fn add_handler(&self, handler: Box<dyn MessageHandler>) {
208        debug!("Registering new message handler");
209        self.handlers.write().await.push(handler);
210        let count = self.handlers.read().await.len();
211        info!(handler_count = count, "Message handler registered");
212    }
213
214    /// Accept new client connections in a loop
215    ///
216    /// This method runs forever, accepting new clients and spawning handler tasks.
217    /// It should be run in a separate task.
218    ///
219    /// # Examples
220    ///
221    /// ```no_run
222    /// use openigtlink_rust::io::SessionManager;
223    /// use std::sync::Arc;
224    ///
225    /// #[tokio::main]
226    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
227    ///     let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
228    ///
229    ///     // Spawn acceptor in background
230    ///     let mgr = manager.clone();
231    ///     tokio::spawn(async move {
232    ///         mgr.accept_clients().await;
233    ///     });
234    ///
235    ///     // Do other work...
236    ///     Ok(())
237    /// }
238    /// ```
239    pub async fn accept_clients(&self) {
240        info!("Starting client accept loop");
241        loop {
242            match self.listener.accept().await {
243                Ok((socket, addr)) => {
244                    let client_id = self.next_client_id.fetch_add(1, Ordering::SeqCst);
245                    info!(
246                        client_id = client_id,
247                        addr = %addr,
248                        "Client connected"
249                    );
250
251                    if let Err(e) = self.handle_client(client_id, socket, addr).await {
252                        warn!(
253                            client_id = client_id,
254                            error = %e,
255                            "Failed to setup client session"
256                        );
257                    }
258                }
259                Err(e) => {
260                    warn!(error = %e, "Failed to accept client connection");
261                }
262            }
263        }
264    }
265
266    /// Handle a single client connection
267    async fn handle_client(
268        &self,
269        client_id: ClientId,
270        socket: TcpStream,
271        addr: SocketAddr,
272    ) -> Result<()> {
273        debug!(client_id = client_id, "Setting up client session");
274        let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
275
276        // Register client session
277        {
278            let session = ClientSession {
279                id: client_id,
280                addr,
281                tx,
282                connected_at: std::time::Instant::now(),
283            };
284            self.clients.write().await.insert(client_id, session);
285            let count = self.clients.read().await.len();
286            info!(
287                client_id = client_id,
288                total_clients = count,
289                "Client session registered"
290            );
291        }
292
293        // Split socket into read/write halves (consuming ownership)
294        let (mut reader, mut writer) = socket.into_split();
295
296        // Spawn sender task (sends messages to client)
297        let sender_task = tokio::spawn(async move {
298            while let Some(data) = rx.recv().await {
299                if writer.write_all(&data).await.is_err() {
300                    break;
301                }
302                if writer.flush().await.is_err() {
303                    break;
304                }
305            }
306        });
307
308        // Receiver task (reads messages from client)
309        let handlers = self.handlers.clone();
310
311        let receiver_task = tokio::spawn(async move {
312            trace!(client_id = client_id, "Client receiver task started");
313            loop {
314                // Read header
315                let mut header_buf = vec![0u8; Header::SIZE];
316                if reader.read_exact(&mut header_buf).await.is_err() {
317                    trace!(client_id = client_id, "Client disconnected (header read failed)");
318                    break;
319                }
320
321                let header = match Header::decode(&header_buf) {
322                    Ok(h) => h,
323                    Err(e) => {
324                        warn!(
325                            client_id = client_id,
326                            error = %e,
327                            "Failed to decode header from client"
328                        );
329                        break;
330                    }
331                };
332
333                let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
334                let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
335
336                debug!(
337                    client_id = client_id,
338                    msg_type = msg_type,
339                    device_name = device_name,
340                    body_size = header.body_size,
341                    "Received message from client"
342                );
343
344                // Read body
345                let mut body_buf = vec![0u8; header.body_size as usize];
346                if reader.read_exact(&mut body_buf).await.is_err() {
347                    warn!(
348                        client_id = client_id,
349                        msg_type = msg_type,
350                        "Client disconnected while reading body"
351                    );
352                    break;
353                }
354
355                // Reconstruct full message
356                let mut full_msg = header_buf.clone();
357                full_msg.extend_from_slice(&body_buf);
358
359                // Call message handlers
360                let type_name = header.type_name.as_str().unwrap_or("UNKNOWN");
361                let handlers_guard = handlers.read().await;
362                trace!(
363                    client_id = client_id,
364                    msg_type = type_name,
365                    handler_count = handlers_guard.len(),
366                    "Dispatching message to handlers"
367                );
368                for handler in handlers_guard.iter() {
369                    handler.handle_message(client_id, type_name, &full_msg);
370                }
371            }
372        });
373
374        // Wait for either task to finish (indicates disconnection)
375        tokio::select! {
376            _ = sender_task => {
377                trace!(client_id = client_id, "Sender task finished");
378            },
379            _ = receiver_task => {
380                trace!(client_id = client_id, "Receiver task finished");
381            },
382        }
383
384        // Cleanup: remove client from registry
385        self.clients.write().await.remove(&client_id);
386        let remaining = self.clients.read().await.len();
387        info!(
388            client_id = client_id,
389            remaining_clients = remaining,
390            "Client disconnected"
391        );
392
393        Ok(())
394    }
395
396    /// Broadcast a message to all connected clients
397    ///
398    /// # Examples
399    ///
400    /// ```no_run
401    /// use openigtlink_rust::io::SessionManager;
402    /// use openigtlink_rust::protocol::types::StatusMessage;
403    ///
404    /// # #[tokio::main]
405    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
406    /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
407    /// let status = StatusMessage::ok("System ready");
408    /// manager.broadcast(&status).await?;
409    /// # Ok(())
410    /// # }
411    /// ```
412    pub async fn broadcast<T: Message + Clone>(&self, message: &T) -> Result<()> {
413        let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
414        let data = igtl_msg.encode()?;
415
416        let clients_guard = self.clients.read().await;
417        let client_count = clients_guard.len();
418
419        debug!(
420            msg_type = std::any::type_name::<T>(),
421            client_count = client_count,
422            size = data.len(),
423            "Broadcasting message to all clients"
424        );
425
426        for session in clients_guard.values() {
427            let _ = session.send_raw(data.clone()).await;
428        }
429
430        trace!(
431            client_count = client_count,
432            "Broadcast completed"
433        );
434
435        Ok(())
436    }
437
438    /// Send a message to a specific client
439    ///
440    /// # Examples
441    ///
442    /// ```no_run
443    /// use openigtlink_rust::io::SessionManager;
444    /// use openigtlink_rust::protocol::types::StatusMessage;
445    ///
446    /// # #[tokio::main]
447    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
448    /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
449    /// let status = StatusMessage::ok("Personal message");
450    /// manager.send_to(42, &status).await?;
451    /// # Ok(())
452    /// # }
453    /// ```
454    pub async fn send_to<T: Message + Clone>(&self, client_id: ClientId, message: &T) -> Result<()> {
455        let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
456        let data = igtl_msg.encode()?;
457
458        debug!(
459            client_id = client_id,
460            msg_type = std::any::type_name::<T>(),
461            size = data.len(),
462            "Sending message to client"
463        );
464
465        let clients_guard = self.clients.read().await;
466        if let Some(session) = clients_guard.get(&client_id) {
467            session.send_raw(data).await?;
468            trace!(client_id = client_id, "Message sent successfully");
469            Ok(())
470        } else {
471            warn!(client_id = client_id, "Client not found");
472            Err(IgtlError::Io(std::io::Error::new(
473                std::io::ErrorKind::NotFound,
474                format!("Client {} not found", client_id),
475            )))
476        }
477    }
478
479    /// Disconnect a specific client
480    pub async fn disconnect(&self, client_id: ClientId) -> Result<()> {
481        let mut clients = self.clients.write().await;
482        if clients.remove(&client_id).is_some() {
483            info!(client_id = client_id, "Forcibly disconnected client");
484            Ok(())
485        } else {
486            warn!(client_id = client_id, "Cannot disconnect: client not found");
487            Err(IgtlError::Io(std::io::Error::new(
488                std::io::ErrorKind::NotFound,
489                format!("Client {} not found", client_id),
490            )))
491        }
492    }
493
494    /// Disconnect all clients and shut down
495    pub async fn shutdown(&self) {
496        let mut clients = self.clients.write().await;
497        let count = clients.len();
498        clients.clear();
499        info!(disconnected_clients = count, "SessionManager shutdown complete");
500    }
501}
502
503/// Client information snapshot
504#[derive(Debug, Clone)]
505pub struct ClientInfo {
506    pub id: ClientId,
507    pub addr: SocketAddr,
508    pub uptime: std::time::Duration,
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use crate::protocol::types::StatusMessage;
515    use tokio::time::Duration;
516
517    #[tokio::test]
518    async fn test_session_manager_create() {
519        let manager = SessionManager::new("127.0.0.1:0").await;
520        assert!(manager.is_ok());
521    }
522
523    #[tokio::test]
524    async fn test_client_count() {
525        let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
526        assert_eq!(manager.client_count().await, 0);
527    }
528
529    #[tokio::test]
530    async fn test_broadcast_no_clients() {
531        let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
532        let status = StatusMessage::ok("test");
533        let result = manager.broadcast(&status).await;
534        assert!(result.is_ok());
535    }
536
537    #[tokio::test]
538    async fn test_send_to_nonexistent_client() {
539        let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
540        let status = StatusMessage::ok("test");
541        let result = manager.send_to(999, &status).await;
542        assert!(result.is_err());
543    }
544}