use openigtlink_rust::error::Result;
use openigtlink_rust::protocol::header::Header;
use openigtlink_rust::protocol::message::IgtlMessage;
use openigtlink_rust::protocol::types::{StatusMessage, TransformMessage};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::{interval, Duration};
static CLIENT_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[tokio::main]
async fn main() {
if let Err(e) = run_server().await {
eprintln!("[ERROR] Server failed: {}", e);
std::process::exit(1);
}
}
async fn run_server() -> Result<()> {
let addr = "127.0.0.1:18944";
let listener = TcpListener::bind(addr).await?;
println!("=== Async OpenIGTLink Server ===\n");
println!("[INFO] Configuration:");
println!(" Address: {}", addr);
println!(" Runtime: Tokio (async/await)");
println!(" Concurrency: Unlimited clients");
println!(" I/O Mode: Non-blocking\n");
let active_clients = Arc::new(AtomicUsize::new(0));
let stats_clients = active_clients.clone();
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(5));
loop {
ticker.tick().await;
let count = stats_clients.load(Ordering::Relaxed);
if count > 0 {
println!("[STATS] Active clients: {}", count);
}
}
});
println!("[INFO] Server ready, waiting for connections...\n");
loop {
let (socket, addr) = listener.accept().await?;
let client_id = CLIENT_COUNTER.fetch_add(1, Ordering::Relaxed) + 1;
let clients = active_clients.clone();
println!("[CONNECT] Client #{} connected from {}", client_id, addr);
clients.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
let result = handle_client(socket, client_id).await;
clients.fetch_sub(1, Ordering::Relaxed);
match result {
Ok(_) => println!("[DISCONNECT] Client #{} disconnected gracefully", client_id),
Err(e) => eprintln!("[ERROR] Client #{} error: {}", client_id, e),
}
});
}
}
async fn handle_client(mut socket: TcpStream, client_id: usize) -> Result<()> {
println!(" [#{}] Handler started", client_id);
let mut message_count = 0;
loop {
let mut header_buf = vec![0u8; Header::SIZE];
match socket.read_exact(&mut header_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => return Err(e.into()),
}
let header = Header::decode(&header_buf)?;
let mut body_buf = vec![0u8; header.body_size as usize];
socket.read_exact(&mut body_buf).await?;
let mut full_msg = header_buf;
full_msg.extend_from_slice(&body_buf);
message_count += 1;
match header.type_name.as_str()? {
"TRANSFORM" => {
if let Ok(msg) = IgtlMessage::<TransformMessage>::decode(&full_msg) {
println!(
" [#{}] TRANSFORM received (device: {})",
client_id,
msg.header.device_name.as_str().unwrap_or("unknown")
);
send_ack(&mut socket, client_id, message_count).await?;
}
}
"STATUS" => {
if let Ok(msg) = IgtlMessage::<StatusMessage>::decode(&full_msg) {
println!(
" [#{}] STATUS received: {}",
client_id,
msg.content.status_string
);
}
}
msg_type => {
println!(" [#{}] {} message received", client_id, msg_type);
}
}
}
println!(
" [#{}] Processed {} messages total",
client_id, message_count
);
Ok(())
}
async fn send_ack(
socket: &mut TcpStream,
client_id: usize,
msg_num: usize,
) -> Result<()> {
let status = StatusMessage::ok(&format!(
"Message #{} processed by async handler #{}",
msg_num, client_id
));
let response = IgtlMessage::new(status, "AsyncServer")?;
let data = response.encode()?;
socket.write_all(&data).await?;
socket.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;
#[tokio::test]
async fn test_server_bind() {
let listener = TcpListener::bind("127.0.0.1:0").await;
assert!(listener.is_ok());
}
#[tokio::test]
async fn test_concurrent_clients() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((socket, _)) = listener.accept().await {
tokio::spawn(async move {
let _ = handle_client(socket, 999).await;
});
}
});
tokio::time::sleep(Duration::from_millis(10)).await;
let mut handles = vec![];
for i in 0..5 {
let handle = tokio::spawn(async move {
let result = TcpStream::connect(addr).await;
assert!(result.is_ok(), "Client {} failed to connect", i);
result.unwrap()
});
handles.push(handle);
}
let mut clients = vec![];
for handle in handles {
let client = handle.await.unwrap();
clients.push(client);
}
assert_eq!(clients.len(), 5);
for mut client in clients {
let _ = client.shutdown().await;
}
}
#[tokio::test]
async fn test_message_handling() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();
let _ = handle_client(socket, 1).await;
});
tokio::time::sleep(Duration::from_millis(10)).await;
let mut client = TcpStream::connect(addr).await.unwrap();
let transform = TransformMessage::identity();
let msg = IgtlMessage::new(transform, "TestDevice").unwrap();
let data = msg.encode().unwrap();
client.write_all(&data).await.unwrap();
client.flush().await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let mut header_buf = vec![0u8; Header::SIZE];
let read_result = client.read_exact(&mut header_buf).await;
assert!(read_result.is_ok());
}
#[test]
fn test_client_counter() {
let count1 = CLIENT_COUNTER.fetch_add(1, Ordering::Relaxed);
let count2 = CLIENT_COUNTER.fetch_add(1, Ordering::Relaxed);
assert!(count2 > count1);
}
}