openigtlink_rust/io/session_manager.rs
1//! Multi-client session management for OpenIGTLink servers
2//!
3//! Provides a high-level abstraction for managing multiple concurrent client
4//! connections with message routing, broadcasting, and handler registration.
5//!
6//! # Features
7//!
8//! - Concurrent client session management
9//! - Message broadcasting to all/selected clients
10//! - Per-client message handlers
11//! - Automatic disconnection handling
12//! - Thread-safe client registry
13//!
14//! # Example
15//!
16//! ```no_run
17//! use openigtlink_rust::io::SessionManager;
18//! use openigtlink_rust::protocol::types::StatusMessage;
19//! use std::sync::Arc;
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//! let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
24//!
25//! // Spawn client acceptor
26//! let mgr = manager.clone();
27//! tokio::spawn(async move {
28//! mgr.accept_clients().await;
29//! });
30//!
31//! // Broadcast status to all clients
32//! let status = StatusMessage::ok("Server ready");
33//! manager.broadcast(&status).await?;
34//!
35//! Ok(())
36//! }
37//! ```
38
39use crate::error::{IgtlError, Result};
40use crate::protocol::header::Header;
41use crate::protocol::message::{IgtlMessage, Message};
42use std::collections::HashMap;
43use std::net::SocketAddr;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47use tokio::net::{TcpListener, TcpStream};
48use tokio::sync::{mpsc, RwLock};
49use tracing::{debug, info, trace, warn};
50
51/// Unique identifier for each client session
52pub type ClientId = u64;
53
54/// Client session state
55#[derive(Debug)]
56struct ClientSession {
57 /// Client ID
58 id: ClientId,
59 /// Client socket address
60 addr: SocketAddr,
61 /// Channel to send messages to this client
62 tx: mpsc::UnboundedSender<Vec<u8>>,
63 /// Connection start time
64 connected_at: std::time::Instant,
65}
66
67impl ClientSession {
68 /// Send a raw message to this client
69 async fn send_raw(&self, data: Vec<u8>) -> Result<()> {
70 self.tx.send(data).map_err(|_| {
71 IgtlError::Io(std::io::Error::new(
72 std::io::ErrorKind::BrokenPipe,
73 "Client disconnected",
74 ))
75 })?;
76 Ok(())
77 }
78
79 /// Get connection duration
80 fn uptime(&self) -> std::time::Duration {
81 self.connected_at.elapsed()
82 }
83}
84
85/// Multi-client session manager
86///
87/// Manages multiple concurrent OpenIGTLink client connections with automatic
88/// message routing and broadcasting capabilities.
89pub struct SessionManager {
90 /// TCP listener for accepting new clients
91 listener: TcpListener,
92 /// Active client sessions (ClientId -> ClientSession)
93 clients: Arc<RwLock<HashMap<ClientId, ClientSession>>>,
94 /// Client ID counter
95 next_client_id: AtomicU64,
96 /// Message handlers (optional)
97 handlers: Arc<RwLock<Vec<Box<dyn MessageHandler>>>>,
98}
99
100/// Trait for handling incoming messages
101///
102/// Implement this trait to process messages from clients.
103pub trait MessageHandler: Send + Sync {
104 /// Handle a message from a specific client
105 ///
106 /// # Arguments
107 /// * `client_id` - ID of the client that sent the message
108 /// * `type_name` - Message type name (e.g., "TRANSFORM")
109 /// * `data` - Raw message data (header + body)
110 fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]);
111}
112
113impl SessionManager {
114 /// Create a new session manager bound to the specified address
115 ///
116 /// # Arguments
117 /// * `addr` - Address to bind (e.g., "127.0.0.1:18944")
118 ///
119 /// # Examples
120 ///
121 /// ```no_run
122 /// use openigtlink_rust::io::SessionManager;
123 ///
124 /// #[tokio::main]
125 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
126 /// let manager = SessionManager::new("0.0.0.0:18944").await?;
127 /// Ok(())
128 /// }
129 /// ```
130 pub async fn new(addr: &str) -> Result<Self> {
131 info!(addr = %addr, "Creating SessionManager");
132 let listener = TcpListener::bind(addr).await?;
133 let local_addr = listener.local_addr()?;
134 info!(
135 local_addr = %local_addr,
136 "SessionManager listening for clients"
137 );
138 Ok(SessionManager {
139 listener,
140 clients: Arc::new(RwLock::new(HashMap::new())),
141 next_client_id: AtomicU64::new(1),
142 handlers: Arc::new(RwLock::new(Vec::new())),
143 })
144 }
145
146 /// Get the local address this manager is bound to
147 pub fn local_addr(&self) -> Result<SocketAddr> {
148 Ok(self.listener.local_addr()?)
149 }
150
151 /// Get the number of active client connections
152 ///
153 /// # Examples
154 ///
155 /// ```no_run
156 /// # use openigtlink_rust::io::SessionManager;
157 /// # #[tokio::main]
158 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
159 /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
160 /// println!("Active clients: {}", manager.client_count().await);
161 /// # Ok(())
162 /// # }
163 /// ```
164 pub async fn client_count(&self) -> usize {
165 self.clients.read().await.len()
166 }
167
168 /// Get a list of all active client IDs
169 pub async fn client_ids(&self) -> Vec<ClientId> {
170 self.clients.read().await.keys().copied().collect()
171 }
172
173 /// Get information about a specific client
174 pub async fn client_info(&self, client_id: ClientId) -> Option<ClientInfo> {
175 let clients = self.clients.read().await;
176 clients.get(&client_id).map(|session| ClientInfo {
177 id: session.id,
178 addr: session.addr,
179 uptime: session.uptime(),
180 })
181 }
182
183 /// Register a message handler
184 ///
185 /// Handlers are called in the order they were registered.
186 ///
187 /// # Examples
188 ///
189 /// ```no_run
190 /// use openigtlink_rust::io::{SessionManager, MessageHandler, ClientId};
191 ///
192 /// struct MyHandler;
193 ///
194 /// impl MessageHandler for MyHandler {
195 /// fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]) {
196 /// println!("Client {} sent {}", client_id, type_name);
197 /// }
198 /// }
199 ///
200 /// # #[tokio::main]
201 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
202 /// # let mut manager = SessionManager::new("127.0.0.1:18944").await?;
203 /// manager.add_handler(Box::new(MyHandler)).await;
204 /// # Ok(())
205 /// # }
206 /// ```
207 pub async fn add_handler(&self, handler: Box<dyn MessageHandler>) {
208 debug!("Registering new message handler");
209 self.handlers.write().await.push(handler);
210 let count = self.handlers.read().await.len();
211 info!(handler_count = count, "Message handler registered");
212 }
213
214 /// Accept new client connections in a loop
215 ///
216 /// This method runs forever, accepting new clients and spawning handler tasks.
217 /// It should be run in a separate task.
218 ///
219 /// # Examples
220 ///
221 /// ```no_run
222 /// use openigtlink_rust::io::SessionManager;
223 /// use std::sync::Arc;
224 ///
225 /// #[tokio::main]
226 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
227 /// let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
228 ///
229 /// // Spawn acceptor in background
230 /// let mgr = manager.clone();
231 /// tokio::spawn(async move {
232 /// mgr.accept_clients().await;
233 /// });
234 ///
235 /// // Do other work...
236 /// Ok(())
237 /// }
238 /// ```
239 pub async fn accept_clients(&self) {
240 info!("Starting client accept loop");
241 loop {
242 match self.listener.accept().await {
243 Ok((socket, addr)) => {
244 let client_id = self.next_client_id.fetch_add(1, Ordering::SeqCst);
245 info!(
246 client_id = client_id,
247 addr = %addr,
248 "Client connected"
249 );
250
251 if let Err(e) = self.handle_client(client_id, socket, addr).await {
252 warn!(
253 client_id = client_id,
254 error = %e,
255 "Failed to setup client session"
256 );
257 }
258 }
259 Err(e) => {
260 warn!(error = %e, "Failed to accept client connection");
261 }
262 }
263 }
264 }
265
266 /// Handle a single client connection
267 async fn handle_client(
268 &self,
269 client_id: ClientId,
270 socket: TcpStream,
271 addr: SocketAddr,
272 ) -> Result<()> {
273 debug!(client_id = client_id, "Setting up client session");
274 let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
275
276 // Register client session
277 {
278 let session = ClientSession {
279 id: client_id,
280 addr,
281 tx,
282 connected_at: std::time::Instant::now(),
283 };
284 self.clients.write().await.insert(client_id, session);
285 let count = self.clients.read().await.len();
286 info!(
287 client_id = client_id,
288 total_clients = count,
289 "Client session registered"
290 );
291 }
292
293 // Split socket into read/write halves (consuming ownership)
294 let (mut reader, mut writer) = socket.into_split();
295
296 // Spawn sender task (sends messages to client)
297 let sender_task = tokio::spawn(async move {
298 while let Some(data) = rx.recv().await {
299 if writer.write_all(&data).await.is_err() {
300 break;
301 }
302 if writer.flush().await.is_err() {
303 break;
304 }
305 }
306 });
307
308 // Receiver task (reads messages from client)
309 let handlers = self.handlers.clone();
310
311 let receiver_task = tokio::spawn(async move {
312 trace!(client_id = client_id, "Client receiver task started");
313 loop {
314 // Read header
315 let mut header_buf = vec![0u8; Header::SIZE];
316 if reader.read_exact(&mut header_buf).await.is_err() {
317 trace!(client_id = client_id, "Client disconnected (header read failed)");
318 break;
319 }
320
321 let header = match Header::decode(&header_buf) {
322 Ok(h) => h,
323 Err(e) => {
324 warn!(
325 client_id = client_id,
326 error = %e,
327 "Failed to decode header from client"
328 );
329 break;
330 }
331 };
332
333 let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
334 let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
335
336 debug!(
337 client_id = client_id,
338 msg_type = msg_type,
339 device_name = device_name,
340 body_size = header.body_size,
341 "Received message from client"
342 );
343
344 // Read body
345 let mut body_buf = vec![0u8; header.body_size as usize];
346 if reader.read_exact(&mut body_buf).await.is_err() {
347 warn!(
348 client_id = client_id,
349 msg_type = msg_type,
350 "Client disconnected while reading body"
351 );
352 break;
353 }
354
355 // Reconstruct full message
356 let mut full_msg = header_buf.clone();
357 full_msg.extend_from_slice(&body_buf);
358
359 // Call message handlers
360 let type_name = header.type_name.as_str().unwrap_or("UNKNOWN");
361 let handlers_guard = handlers.read().await;
362 trace!(
363 client_id = client_id,
364 msg_type = type_name,
365 handler_count = handlers_guard.len(),
366 "Dispatching message to handlers"
367 );
368 for handler in handlers_guard.iter() {
369 handler.handle_message(client_id, type_name, &full_msg);
370 }
371 }
372 });
373
374 // Wait for either task to finish (indicates disconnection)
375 tokio::select! {
376 _ = sender_task => {
377 trace!(client_id = client_id, "Sender task finished");
378 },
379 _ = receiver_task => {
380 trace!(client_id = client_id, "Receiver task finished");
381 },
382 }
383
384 // Cleanup: remove client from registry
385 self.clients.write().await.remove(&client_id);
386 let remaining = self.clients.read().await.len();
387 info!(
388 client_id = client_id,
389 remaining_clients = remaining,
390 "Client disconnected"
391 );
392
393 Ok(())
394 }
395
396 /// Broadcast a message to all connected clients
397 ///
398 /// # Examples
399 ///
400 /// ```no_run
401 /// use openigtlink_rust::io::SessionManager;
402 /// use openigtlink_rust::protocol::types::StatusMessage;
403 ///
404 /// # #[tokio::main]
405 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
406 /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
407 /// let status = StatusMessage::ok("System ready");
408 /// manager.broadcast(&status).await?;
409 /// # Ok(())
410 /// # }
411 /// ```
412 pub async fn broadcast<T: Message + Clone>(&self, message: &T) -> Result<()> {
413 let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
414 let data = igtl_msg.encode()?;
415
416 let clients_guard = self.clients.read().await;
417 let client_count = clients_guard.len();
418
419 debug!(
420 msg_type = std::any::type_name::<T>(),
421 client_count = client_count,
422 size = data.len(),
423 "Broadcasting message to all clients"
424 );
425
426 for session in clients_guard.values() {
427 let _ = session.send_raw(data.clone()).await;
428 }
429
430 trace!(
431 client_count = client_count,
432 "Broadcast completed"
433 );
434
435 Ok(())
436 }
437
438 /// Send a message to a specific client
439 ///
440 /// # Examples
441 ///
442 /// ```no_run
443 /// use openigtlink_rust::io::SessionManager;
444 /// use openigtlink_rust::protocol::types::StatusMessage;
445 ///
446 /// # #[tokio::main]
447 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
448 /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
449 /// let status = StatusMessage::ok("Personal message");
450 /// manager.send_to(42, &status).await?;
451 /// # Ok(())
452 /// # }
453 /// ```
454 pub async fn send_to<T: Message + Clone>(&self, client_id: ClientId, message: &T) -> Result<()> {
455 let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
456 let data = igtl_msg.encode()?;
457
458 debug!(
459 client_id = client_id,
460 msg_type = std::any::type_name::<T>(),
461 size = data.len(),
462 "Sending message to client"
463 );
464
465 let clients_guard = self.clients.read().await;
466 if let Some(session) = clients_guard.get(&client_id) {
467 session.send_raw(data).await?;
468 trace!(client_id = client_id, "Message sent successfully");
469 Ok(())
470 } else {
471 warn!(client_id = client_id, "Client not found");
472 Err(IgtlError::Io(std::io::Error::new(
473 std::io::ErrorKind::NotFound,
474 format!("Client {} not found", client_id),
475 )))
476 }
477 }
478
479 /// Disconnect a specific client
480 pub async fn disconnect(&self, client_id: ClientId) -> Result<()> {
481 let mut clients = self.clients.write().await;
482 if clients.remove(&client_id).is_some() {
483 info!(client_id = client_id, "Forcibly disconnected client");
484 Ok(())
485 } else {
486 warn!(client_id = client_id, "Cannot disconnect: client not found");
487 Err(IgtlError::Io(std::io::Error::new(
488 std::io::ErrorKind::NotFound,
489 format!("Client {} not found", client_id),
490 )))
491 }
492 }
493
494 /// Disconnect all clients and shut down
495 pub async fn shutdown(&self) {
496 let mut clients = self.clients.write().await;
497 let count = clients.len();
498 clients.clear();
499 info!(disconnected_clients = count, "SessionManager shutdown complete");
500 }
501}
502
503/// Client information snapshot
504#[derive(Debug, Clone)]
505pub struct ClientInfo {
506 pub id: ClientId,
507 pub addr: SocketAddr,
508 pub uptime: std::time::Duration,
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use crate::protocol::types::StatusMessage;
515 use tokio::time::Duration;
516
517 #[tokio::test]
518 async fn test_session_manager_create() {
519 let manager = SessionManager::new("127.0.0.1:0").await;
520 assert!(manager.is_ok());
521 }
522
523 #[tokio::test]
524 async fn test_client_count() {
525 let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
526 assert_eq!(manager.client_count().await, 0);
527 }
528
529 #[tokio::test]
530 async fn test_broadcast_no_clients() {
531 let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
532 let status = StatusMessage::ok("test");
533 let result = manager.broadcast(&status).await;
534 assert!(result.is_ok());
535 }
536
537 #[tokio::test]
538 async fn test_send_to_nonexistent_client() {
539 let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
540 let status = StatusMessage::ok("test");
541 let result = manager.send_to(999, &status).await;
542 assert!(result.is_err());
543 }
544}