use crate::cmd::data::{CommandHeader, CommandPacket, Payload};
use crate::cmd::generic::GenericPayload;
use snafu::{ensure, Report, ResultExt, Snafu};
use std::process::Termination;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::broadcast;
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
pub mod data;
pub mod generic;
pub struct CommandHandler {
seq_id: AtomicU16,
socket: Arc<UdpSocket>,
broadcast: broadcast::Sender<Arc<[u8]>>,
}
impl CommandHandler {
const TIMEOUT: Duration = Duration::from_millis(1000);
const RETRIES: usize = 10;
pub async fn new() -> Result<Self, Error> {
let socket = UdpSocket::bind("192.168.1.10:50023")
.await
.context(ConnectingSnafu)?;
socket
.connect("192.168.1.11:50123")
.await
.context(ConnectingSnafu)?;
let socket = Arc::new(socket);
let (broadcast, _) = broadcast::channel(16);
let sock = socket.clone();
let bc = broadcast.clone().downgrade();
tokio::spawn(async move {
let result: Report<Error> = (async {
loop {
let mut buff = vec![0; 1800]; let bytes = sock.recv(&mut buff).await.context(ReceiveSnafu)?;
buff.resize(bytes, 0);
let Some(broadcast) = bc.upgrade() else {
return Ok(());
};
let _ = broadcast.send(Arc::from(buff));
}
})
.await
.into();
eprintln!("closed command handler");
result.report();
});
Ok(CommandHandler {
seq_id: AtomicU16::new(0),
socket,
broadcast,
})
}
fn next_seq_id(&self) -> u16 {
self.seq_id.fetch_add(1, Ordering::Relaxed)
}
async fn send_packet<T: Payload>(
&self,
seq_id: u16,
payload: &T,
) -> Result<(), std::io::Error> {
let mut buffer = vec![0u8; payload.packet_size()];
payload.write_packet(seq_id, &mut buffer);
self.socket.send(&buffer).await?;
Ok(())
}
async fn send_ack<T: Payload>(&self, seq_id: u16, payload: &T) -> Result<(), std::io::Error> {
let command = CommandHeader {
packet_type: 3.into(),
query_type: T::QUERY_TYPE.into(),
payload_size: 0.into(),
seq_id: seq_id.into()
};
self.socket.send(command.as_bytes()).await?;
Ok(())
}
async fn recv_packet(
&self,
rcv: &mut broadcast::Receiver<Arc<[u8]>>,
seq_id: u16,
) -> Result<Arc<[u8]>, Error> {
loop {
let packet = rcv.recv().await.expect("broken channel");
let packet_data =
CommandPacket::ref_from_bytes(&packet).map_err(|x| Error::Incomplete {
reason: x.to_string(),
})?;
if packet_data.header.seq_id != seq_id {
continue;
}
ensure!(
packet_data.header.payload_size.get() as usize == packet_data.payload.len(),
PayloadLengthSnafu
);
return Ok(packet);
}
}
pub async fn command<T: Payload>(&self, data: &T) -> Result<T::Response, Error> {
let seq_id = self.next_seq_id();
let mut rcv = self.broadcast.subscribe();
let mut retries = 0;
loop {
self.send_packet(seq_id, data)
.await
.context(SendSnafu)?;
let ack = select! {
res = self.recv_packet(&mut rcv, seq_id) => res?,
_ = tokio::time::sleep(Self::TIMEOUT) => {
retries += 1;
ensure!(retries < Self::RETRIES, TimeoutSnafu);
continue;
}
};
let ack = CommandPacket::ref_from_bytes(&ack).expect("already unpacked");
ensure!(ack.header.packet_type == 1, AckExpectedSnafu);
ensure!(ack.payload.len() == 0, PayloadLengthSnafu);
break;
}
let response = self.recv_packet(&mut rcv, seq_id).await?;
let response = CommandPacket::ref_from_bytes(&response).expect("already unpacked");
ensure!(response.header.packet_type == 2, ResponseExpectedSnafu);
self.send_ack(seq_id, data)
.await
.context(SendSnafu)?;
T::Response::read_from_bytes(&response.payload).map_err(|x| Error::Incomplete {
reason: x.to_string(),
})
}
}
#[derive(Debug, Snafu)]
pub enum Error {
Connecting { source: std::io::Error },
Send { source: std::io::Error },
Receive { source: std::io::Error },
#[snafu(display("incomplete packet: {reason}"))]
Incomplete { reason: String },
AckExpected,
ResponseExpected,
PayloadLength,
Timeout,
}