openigtlink_rust/io/session_manager.rs
1//! Multi-client session management for OpenIGTLink servers
2//!
3//! Provides a high-level abstraction for managing multiple concurrent client
4//! connections with message routing, broadcasting, and handler registration.
5//!
6//! # Features
7//!
8//! - Concurrent client session management
9//! - Message broadcasting to all/selected clients
10//! - Per-client message handlers
11//! - Automatic disconnection handling
12//! - Thread-safe client registry
13//!
14//! # Example
15//!
16//! ```no_run
17//! use openigtlink_rust::io::SessionManager;
18//! use openigtlink_rust::protocol::types::StatusMessage;
19//! use std::sync::Arc;
20//!
21//! #[tokio::main]
22//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//! let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
24//!
25//! // Spawn client acceptor
26//! let mgr = manager.clone();
27//! tokio::spawn(async move {
28//! mgr.accept_clients().await;
29//! });
30//!
31//! // Broadcast status to all clients
32//! let status = StatusMessage::ok("Server ready");
33//! manager.broadcast(&status).await?;
34//!
35//! Ok(())
36//! }
37//! ```
38
39use crate::error::{IgtlError, Result};
40use crate::protocol::header::Header;
41use crate::protocol::message::{IgtlMessage, Message};
42use std::collections::HashMap;
43use std::net::SocketAddr;
44use std::sync::atomic::{AtomicU64, Ordering};
45use std::sync::Arc;
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47use tokio::net::{TcpListener, TcpStream};
48use tokio::sync::{mpsc, RwLock};
49use tracing::{debug, info, trace, warn};
50
51/// Unique identifier for each client session
52pub type ClientId = u64;
53
54/// Client session state
55#[derive(Debug)]
56struct ClientSession {
57 /// Client ID
58 id: ClientId,
59 /// Client socket address
60 addr: SocketAddr,
61 /// Channel to send messages to this client
62 tx: mpsc::UnboundedSender<Vec<u8>>,
63 /// Connection start time
64 connected_at: std::time::Instant,
65}
66
67impl ClientSession {
68 /// Send a raw message to this client
69 async fn send_raw(&self, data: Vec<u8>) -> Result<()> {
70 self.tx.send(data).map_err(|_| {
71 IgtlError::Io(std::io::Error::new(
72 std::io::ErrorKind::BrokenPipe,
73 "Client disconnected",
74 ))
75 })?;
76 Ok(())
77 }
78
79 /// Get connection duration
80 fn uptime(&self) -> std::time::Duration {
81 self.connected_at.elapsed()
82 }
83}
84
85/// Multi-client session manager
86///
87/// Manages multiple concurrent OpenIGTLink client connections with automatic
88/// message routing and broadcasting capabilities.
89pub struct SessionManager {
90 /// TCP listener for accepting new clients
91 listener: TcpListener,
92 /// Active client sessions (ClientId -> ClientSession)
93 clients: Arc<RwLock<HashMap<ClientId, ClientSession>>>,
94 /// Client ID counter
95 next_client_id: AtomicU64,
96 /// Message handlers (optional)
97 handlers: Arc<RwLock<Vec<Box<dyn MessageHandler>>>>,
98}
99
100/// Trait for handling incoming messages
101///
102/// Implement this trait to process messages from clients.
103pub trait MessageHandler: Send + Sync {
104 /// Handle a message from a specific client
105 ///
106 /// # Arguments
107 /// * `client_id` - ID of the client that sent the message
108 /// * `type_name` - Message type name (e.g., "TRANSFORM")
109 /// * `data` - Raw message data (header + body)
110 fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]);
111}
112
113impl SessionManager {
114 /// Create a new session manager bound to the specified address
115 ///
116 /// # Arguments
117 /// * `addr` - Address to bind (e.g., "127.0.0.1:18944")
118 ///
119 /// # Examples
120 ///
121 /// ```no_run
122 /// use openigtlink_rust::io::SessionManager;
123 ///
124 /// #[tokio::main]
125 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
126 /// let manager = SessionManager::new("0.0.0.0:18944").await?;
127 /// Ok(())
128 /// }
129 /// ```
130 pub async fn new(addr: &str) -> Result<Self> {
131 info!(addr = %addr, "Creating SessionManager");
132 let listener = TcpListener::bind(addr).await?;
133 let local_addr = listener.local_addr()?;
134 info!(
135 local_addr = %local_addr,
136 "SessionManager listening for clients"
137 );
138 Ok(SessionManager {
139 listener,
140 clients: Arc::new(RwLock::new(HashMap::new())),
141 next_client_id: AtomicU64::new(1),
142 handlers: Arc::new(RwLock::new(Vec::new())),
143 })
144 }
145
146 /// Get the local address this manager is bound to
147 pub fn local_addr(&self) -> Result<SocketAddr> {
148 Ok(self.listener.local_addr()?)
149 }
150
151 /// Get the number of active client connections
152 ///
153 /// # Examples
154 ///
155 /// ```no_run
156 /// # use openigtlink_rust::io::SessionManager;
157 /// # #[tokio::main]
158 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
159 /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
160 /// println!("Active clients: {}", manager.client_count().await);
161 /// # Ok(())
162 /// # }
163 /// ```
164 pub async fn client_count(&self) -> usize {
165 self.clients.read().await.len()
166 }
167
168 /// Get a list of all active client IDs
169 pub async fn client_ids(&self) -> Vec<ClientId> {
170 self.clients.read().await.keys().copied().collect()
171 }
172
173 /// Get information about a specific client
174 pub async fn client_info(&self, client_id: ClientId) -> Option<ClientInfo> {
175 let clients = self.clients.read().await;
176 clients.get(&client_id).map(|session| ClientInfo {
177 id: session.id,
178 addr: session.addr,
179 uptime: session.uptime(),
180 })
181 }
182
183 /// Register a message handler
184 ///
185 /// Handlers are called in the order they were registered.
186 ///
187 /// # Examples
188 ///
189 /// ```no_run
190 /// use openigtlink_rust::io::{SessionManager, MessageHandler, ClientId};
191 ///
192 /// struct MyHandler;
193 ///
194 /// impl MessageHandler for MyHandler {
195 /// fn handle_message(&self, client_id: ClientId, type_name: &str, data: &[u8]) {
196 /// println!("Client {} sent {}", client_id, type_name);
197 /// }
198 /// }
199 ///
200 /// # #[tokio::main]
201 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
202 /// # let mut manager = SessionManager::new("127.0.0.1:18944").await?;
203 /// manager.add_handler(Box::new(MyHandler)).await;
204 /// # Ok(())
205 /// # }
206 /// ```
207 pub async fn add_handler(&self, handler: Box<dyn MessageHandler>) {
208 debug!("Registering new message handler");
209 self.handlers.write().await.push(handler);
210 let count = self.handlers.read().await.len();
211 info!(handler_count = count, "Message handler registered");
212 }
213
214 /// Accept new client connections in a loop
215 ///
216 /// This method runs forever, accepting new clients and spawning handler tasks.
217 /// It should be run in a separate task.
218 ///
219 /// # Examples
220 ///
221 /// ```no_run
222 /// use openigtlink_rust::io::SessionManager;
223 /// use std::sync::Arc;
224 ///
225 /// #[tokio::main]
226 /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
227 /// let manager = Arc::new(SessionManager::new("127.0.0.1:18944").await?);
228 ///
229 /// // Spawn acceptor in background
230 /// let mgr = manager.clone();
231 /// tokio::spawn(async move {
232 /// mgr.accept_clients().await;
233 /// });
234 ///
235 /// // Do other work...
236 /// Ok(())
237 /// }
238 /// ```
239 pub async fn accept_clients(&self) {
240 info!("Starting client accept loop");
241 loop {
242 match self.listener.accept().await {
243 Ok((socket, addr)) => {
244 let client_id = self.next_client_id.fetch_add(1, Ordering::SeqCst);
245 info!(
246 client_id = client_id,
247 addr = %addr,
248 "Client connected"
249 );
250
251 if let Err(e) = self.handle_client(client_id, socket, addr).await {
252 warn!(
253 client_id = client_id,
254 error = %e,
255 "Failed to setup client session"
256 );
257 }
258 }
259 Err(e) => {
260 warn!(error = %e, "Failed to accept client connection");
261 }
262 }
263 }
264 }
265
266 /// Handle a single client connection
267 async fn handle_client(
268 &self,
269 client_id: ClientId,
270 socket: TcpStream,
271 addr: SocketAddr,
272 ) -> Result<()> {
273 debug!(client_id = client_id, "Setting up client session");
274 let (tx, mut rx) = mpsc::unbounded_channel::<Vec<u8>>();
275
276 // Register client session
277 {
278 let session = ClientSession {
279 id: client_id,
280 addr,
281 tx,
282 connected_at: std::time::Instant::now(),
283 };
284 self.clients.write().await.insert(client_id, session);
285 let count = self.clients.read().await.len();
286 info!(
287 client_id = client_id,
288 total_clients = count,
289 "Client session registered"
290 );
291 }
292
293 // Split socket into read/write halves (consuming ownership)
294 let (mut reader, mut writer) = socket.into_split();
295
296 // Spawn sender task (sends messages to client)
297 let sender_task = tokio::spawn(async move {
298 while let Some(data) = rx.recv().await {
299 if writer.write_all(&data).await.is_err() {
300 break;
301 }
302 if writer.flush().await.is_err() {
303 break;
304 }
305 }
306 });
307
308 // Receiver task (reads messages from client)
309 let handlers = self.handlers.clone();
310
311 let receiver_task = tokio::spawn(async move {
312 trace!(client_id = client_id, "Client receiver task started");
313 loop {
314 // Read header
315 let mut header_buf = vec![0u8; Header::SIZE];
316 if reader.read_exact(&mut header_buf).await.is_err() {
317 trace!(
318 client_id = client_id,
319 "Client disconnected (header read failed)"
320 );
321 break;
322 }
323
324 let header = match Header::decode(&header_buf) {
325 Ok(h) => h,
326 Err(e) => {
327 warn!(
328 client_id = client_id,
329 error = %e,
330 "Failed to decode header from client"
331 );
332 break;
333 }
334 };
335
336 let msg_type = header.type_name.as_str().unwrap_or("UNKNOWN");
337 let device_name = header.device_name.as_str().unwrap_or("UNKNOWN");
338
339 debug!(
340 client_id = client_id,
341 msg_type = msg_type,
342 device_name = device_name,
343 body_size = header.body_size,
344 "Received message from client"
345 );
346
347 // Read body
348 let mut body_buf = vec![0u8; header.body_size as usize];
349 if reader.read_exact(&mut body_buf).await.is_err() {
350 warn!(
351 client_id = client_id,
352 msg_type = msg_type,
353 "Client disconnected while reading body"
354 );
355 break;
356 }
357
358 // Reconstruct full message
359 let mut full_msg = header_buf.clone();
360 full_msg.extend_from_slice(&body_buf);
361
362 // Call message handlers
363 let type_name = header.type_name.as_str().unwrap_or("UNKNOWN");
364 let handlers_guard = handlers.read().await;
365 trace!(
366 client_id = client_id,
367 msg_type = type_name,
368 handler_count = handlers_guard.len(),
369 "Dispatching message to handlers"
370 );
371 for handler in handlers_guard.iter() {
372 handler.handle_message(client_id, type_name, &full_msg);
373 }
374 }
375 });
376
377 // Wait for either task to finish (indicates disconnection)
378 tokio::select! {
379 _ = sender_task => {
380 trace!(client_id = client_id, "Sender task finished");
381 },
382 _ = receiver_task => {
383 trace!(client_id = client_id, "Receiver task finished");
384 },
385 }
386
387 // Cleanup: remove client from registry
388 self.clients.write().await.remove(&client_id);
389 let remaining = self.clients.read().await.len();
390 info!(
391 client_id = client_id,
392 remaining_clients = remaining,
393 "Client disconnected"
394 );
395
396 Ok(())
397 }
398
399 /// Broadcast a message to all connected clients
400 ///
401 /// # Examples
402 ///
403 /// ```no_run
404 /// use openigtlink_rust::io::SessionManager;
405 /// use openigtlink_rust::protocol::types::StatusMessage;
406 ///
407 /// # #[tokio::main]
408 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
409 /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
410 /// let status = StatusMessage::ok("System ready");
411 /// manager.broadcast(&status).await?;
412 /// # Ok(())
413 /// # }
414 /// ```
415 pub async fn broadcast<T: Message + Clone>(&self, message: &T) -> Result<()> {
416 let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
417 let data = igtl_msg.encode()?;
418
419 let clients_guard = self.clients.read().await;
420 let client_count = clients_guard.len();
421
422 debug!(
423 msg_type = std::any::type_name::<T>(),
424 client_count = client_count,
425 size = data.len(),
426 "Broadcasting message to all clients"
427 );
428
429 for session in clients_guard.values() {
430 let _ = session.send_raw(data.clone()).await;
431 }
432
433 trace!(client_count = client_count, "Broadcast completed");
434
435 Ok(())
436 }
437
438 /// Send a message to a specific client
439 ///
440 /// # Examples
441 ///
442 /// ```no_run
443 /// use openigtlink_rust::io::SessionManager;
444 /// use openigtlink_rust::protocol::types::StatusMessage;
445 ///
446 /// # #[tokio::main]
447 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
448 /// # let manager = SessionManager::new("127.0.0.1:18944").await?;
449 /// let status = StatusMessage::ok("Personal message");
450 /// manager.send_to(42, &status).await?;
451 /// # Ok(())
452 /// # }
453 /// ```
454 pub async fn send_to<T: Message + Clone>(
455 &self,
456 client_id: ClientId,
457 message: &T,
458 ) -> Result<()> {
459 let igtl_msg = IgtlMessage::new(message.clone(), "Server")?;
460 let data = igtl_msg.encode()?;
461
462 debug!(
463 client_id = client_id,
464 msg_type = std::any::type_name::<T>(),
465 size = data.len(),
466 "Sending message to client"
467 );
468
469 let clients_guard = self.clients.read().await;
470 if let Some(session) = clients_guard.get(&client_id) {
471 session.send_raw(data).await?;
472 trace!(client_id = client_id, "Message sent successfully");
473 Ok(())
474 } else {
475 warn!(client_id = client_id, "Client not found");
476 Err(IgtlError::Io(std::io::Error::new(
477 std::io::ErrorKind::NotFound,
478 format!("Client {} not found", client_id),
479 )))
480 }
481 }
482
483 /// Disconnect a specific client
484 pub async fn disconnect(&self, client_id: ClientId) -> Result<()> {
485 let mut clients = self.clients.write().await;
486 if clients.remove(&client_id).is_some() {
487 info!(client_id = client_id, "Forcibly disconnected client");
488 Ok(())
489 } else {
490 warn!(client_id = client_id, "Cannot disconnect: client not found");
491 Err(IgtlError::Io(std::io::Error::new(
492 std::io::ErrorKind::NotFound,
493 format!("Client {} not found", client_id),
494 )))
495 }
496 }
497
498 /// Disconnect all clients and shut down
499 pub async fn shutdown(&self) {
500 let mut clients = self.clients.write().await;
501 let count = clients.len();
502 clients.clear();
503 info!(
504 disconnected_clients = count,
505 "SessionManager shutdown complete"
506 );
507 }
508}
509
510/// Client information snapshot
511#[derive(Debug, Clone)]
512pub struct ClientInfo {
513 pub id: ClientId,
514 pub addr: SocketAddr,
515 pub uptime: std::time::Duration,
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use crate::protocol::types::StatusMessage;
522
523 #[tokio::test]
524 async fn test_session_manager_create() {
525 let manager = SessionManager::new("127.0.0.1:0").await;
526 assert!(manager.is_ok());
527 }
528
529 #[tokio::test]
530 async fn test_client_count() {
531 let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
532 assert_eq!(manager.client_count().await, 0);
533 }
534
535 #[tokio::test]
536 async fn test_broadcast_no_clients() {
537 let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
538 let status = StatusMessage::ok("test");
539 let result = manager.broadcast(&status).await;
540 assert!(result.is_ok());
541 }
542
543 #[tokio::test]
544 async fn test_send_to_nonexistent_client() {
545 let manager = SessionManager::new("127.0.0.1:0").await.unwrap();
546 let status = StatusMessage::ok("test");
547 let result = manager.send_to(999, &status).await;
548 assert!(result.is_err());
549 }
550}