use std::net::{SocketAddr, UdpSocket};
use std::time::Duration;
use crate::error::{IgtlError, Result};
use crate::protocol::message::{IgtlMessage, Message};
pub const MAX_UDP_DATAGRAM_SIZE: usize = 65507;
pub struct UdpClient {
socket: UdpSocket,
}
impl UdpClient {
pub fn bind(local_addr: &str) -> Result<Self> {
let socket = UdpSocket::bind(local_addr)?;
Ok(UdpClient { socket })
}
pub fn send_to<T: Message>(&self, msg: &IgtlMessage<T>, target: &str) -> Result<()> {
let data = msg.encode()?;
if data.len() > MAX_UDP_DATAGRAM_SIZE {
return Err(IgtlError::BodyTooLarge {
size: data.len(),
max: MAX_UDP_DATAGRAM_SIZE,
});
}
self.socket.send_to(&data, target)?;
Ok(())
}
pub fn receive_from<T: Message>(&self) -> Result<(IgtlMessage<T>, SocketAddr)> {
let mut buf = vec![0u8; MAX_UDP_DATAGRAM_SIZE];
let (size, src) = self.socket.recv_from(&mut buf)?;
let msg = IgtlMessage::decode(&buf[..size])?;
Ok((msg, src))
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> Result<()> {
self.socket.set_read_timeout(timeout)?;
Ok(())
}
pub fn set_write_timeout(&self, timeout: Option<Duration>) -> Result<()> {
self.socket.set_write_timeout(timeout)?;
Ok(())
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.socket.local_addr()?)
}
}
pub struct UdpServer {
socket: UdpSocket,
}
impl UdpServer {
pub fn bind(addr: &str) -> Result<Self> {
let socket = UdpSocket::bind(addr)?;
Ok(UdpServer { socket })
}
pub fn receive<T: Message>(&self) -> Result<(IgtlMessage<T>, SocketAddr)> {
let mut buf = vec![0u8; MAX_UDP_DATAGRAM_SIZE];
let (size, src) = self.socket.recv_from(&mut buf)?;
let msg = IgtlMessage::decode(&buf[..size])?;
Ok((msg, src))
}
pub fn send_to<T: Message>(&self, msg: &IgtlMessage<T>, target: SocketAddr) -> Result<()> {
let data = msg.encode()?;
if data.len() > MAX_UDP_DATAGRAM_SIZE {
return Err(IgtlError::BodyTooLarge {
size: data.len(),
max: MAX_UDP_DATAGRAM_SIZE,
});
}
self.socket.send_to(&data, target)?;
Ok(())
}
pub fn set_read_timeout(&self, timeout: Option<Duration>) -> Result<()> {
self.socket.set_read_timeout(timeout)?;
Ok(())
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.socket.local_addr()?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::types::TransformMessage;
#[test]
fn test_max_datagram_size() {
assert_eq!(MAX_UDP_DATAGRAM_SIZE, 65507);
}
#[test]
fn test_client_bind() {
let client = UdpClient::bind("127.0.0.1:0");
assert!(client.is_ok());
}
#[test]
fn test_server_bind() {
let server = UdpServer::bind("127.0.0.1:0");
assert!(server.is_ok());
}
#[test]
fn test_local_addr() {
let client = UdpClient::bind("127.0.0.1:0").unwrap();
let addr = client.local_addr().unwrap();
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert!(addr.port() > 0);
}
#[test]
fn test_send_receive() {
let server = UdpServer::bind("127.0.0.1:0").unwrap();
let server_addr = server.local_addr().unwrap();
let client = UdpClient::bind("127.0.0.1:0").unwrap();
let transform = TransformMessage::identity();
let msg = IgtlMessage::new(transform, "TestDevice").unwrap();
client.send_to(&msg, &server_addr.to_string()).unwrap();
let (received_msg, sender) = server.receive::<TransformMessage>().unwrap();
assert_eq!(
received_msg.header.device_name.as_str().unwrap(),
"TestDevice"
);
assert_eq!(sender, client.local_addr().unwrap());
}
#[test]
fn test_timeout() {
let client = UdpClient::bind("127.0.0.1:0").unwrap();
client
.set_read_timeout(Some(Duration::from_millis(100)))
.unwrap();
let result = client.receive_from::<TransformMessage>();
assert!(result.is_err());
}
#[test]
fn test_message_too_large() {
let _client = UdpClient::bind("127.0.0.1:0").unwrap();
const _: () = assert!(MAX_UDP_DATAGRAM_SIZE < 65536);
}
}