melodium_distribution/
protocol.rs1use crate::messages::Message;
2use async_std::io::{timeout, BufReader, BufWriter, Read, Write};
3use async_std::sync::Mutex;
4use core::fmt::Display;
5use core::sync::atomic::AtomicBool;
6use core::time::Duration;
7use futures::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use futures::AsyncWriteExt;
9
10type Result<T> = std::result::Result<T, Error>;
11
12const TIMEOUT: u64 = 20;
13
14#[derive(Debug)]
15pub enum Error {
16 Io(async_std::io::Error),
17 Deserialization(ciborium::de::Error<std::io::Error>),
18 Serialization(ciborium::ser::Error<std::io::Error>),
19}
20
21impl Display for Error {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Error::Io(err) => write!(f, "{err}"),
25 Error::Deserialization(err) => write!(f, "{err}"),
26 Error::Serialization(err) => write!(f, "{err}"),
27 }
28 }
29}
30
31impl std::error::Error for Error {
32 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
33 match self {
34 Error::Io(err) => Some(err),
35 Error::Deserialization(err) => Some(err),
36 Error::Serialization(err) => Some(err),
37 }
38 }
39}
40
41impl From<async_std::io::Error> for Error {
42 fn from(value: async_std::io::Error) -> Self {
43 Error::Io(value)
44 }
45}
46
47#[derive(Debug)]
48pub struct Protocol<R: Read + Write + Unpin + Send> {
49 closed: AtomicBool,
50 reader: Mutex<BufReader<ReadHalf<R>>>,
51 writer: Mutex<BufWriter<WriteHalf<R>>>,
52}
53
54impl<R: Read + Write + Unpin + Send> Protocol<R> {
55 pub fn new(rw: R) -> Self {
56 let (read, write) = rw.split();
57 Self {
58 closed: AtomicBool::new(false),
59 reader: Mutex::new(BufReader::new(read)),
60 writer: Mutex::new(BufWriter::new(write)),
61 }
62 }
63
64 pub async fn close(&self) {
65 if !self.closed.load(core::sync::atomic::Ordering::Relaxed) {
66 let _ = self.send_message(Message::Ended).await;
67 let mut writer = self.writer.lock().await;
69 let _ = writer.close().await;
70 self.closed
71 .store(true, core::sync::atomic::Ordering::Relaxed);
72 }
73 }
74
75 pub async fn recv_message(&self) -> Result<Message> {
76 let mut reader = self.reader.lock().await;
77 let mut expected_size: [u8; 4] = [0; 4];
78 timeout(
79 Duration::from_secs(TIMEOUT),
80 reader.read_exact(&mut expected_size),
81 )
82 .await?;
83 let expected_size = u32::from_be_bytes(expected_size) as usize;
84
85 let mut data = vec![0u8; expected_size];
86 timeout(Duration::from_secs(TIMEOUT), reader.read_exact(&mut data)).await?;
87
88 match ciborium::de::from_reader(data.as_slice()) {
89 Ok(message) => Ok(message),
90 Err(err) => Err(Error::Deserialization(err)),
91 }
92 }
93
94 pub async fn send_message(&self, message: Message) -> Result<()> {
95 if self.closed.load(core::sync::atomic::Ordering::Relaxed) {
96 return Err(Error::Io(std::io::Error::new(
97 std::io::ErrorKind::Other,
98 "closed",
99 )));
100 }
101 let mut writer = self.writer.lock().await;
102
103 let mut data = Vec::new();
104 match ciborium::into_writer(&message, &mut data) {
105 Ok(()) => {
106 timeout(
107 Duration::from_secs(TIMEOUT),
108 writer.write_all(&(data.len() as u32).to_be_bytes()),
109 )
110 .await?;
111 timeout(Duration::from_secs(TIMEOUT), writer.write_all(&data)).await?;
112 timeout(Duration::from_secs(TIMEOUT), writer.flush()).await?;
113 Ok(())
114 }
115 Err(err) => Err(Error::Serialization(err)),
116 }
117 }
118}