Skip to main content

melodium_distribution/
protocol.rs

1use crate::messages::Message;
2use async_std::io::{timeout, BufReader, BufWriter, Read, Write};
3use async_std::sync::Mutex;
4use core::fmt::Display;
5use core::sync::atomic::AtomicBool;
6use core::time::Duration;
7use futures::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use futures::AsyncWriteExt;
9
10type Result<T> = std::result::Result<T, Error>;
11
12const TIMEOUT: u64 = 20;
13
14#[derive(Debug)]
15pub enum Error {
16    Io(async_std::io::Error),
17    Deserialization(ciborium::de::Error<std::io::Error>),
18    Serialization(ciborium::ser::Error<std::io::Error>),
19}
20
21impl Display for Error {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Error::Io(err) => write!(f, "{err}"),
25            Error::Deserialization(err) => write!(f, "{err}"),
26            Error::Serialization(err) => write!(f, "{err}"),
27        }
28    }
29}
30
31impl std::error::Error for Error {
32    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
33        match self {
34            Error::Io(err) => Some(err),
35            Error::Deserialization(err) => Some(err),
36            Error::Serialization(err) => Some(err),
37        }
38    }
39}
40
41impl From<async_std::io::Error> for Error {
42    fn from(value: async_std::io::Error) -> Self {
43        Error::Io(value)
44    }
45}
46
47#[derive(Debug)]
48pub struct Protocol<R: Read + Write + Unpin + Send> {
49    closed: AtomicBool,
50    reader: Mutex<BufReader<ReadHalf<R>>>,
51    writer: Mutex<BufWriter<WriteHalf<R>>>,
52}
53
54impl<R: Read + Write + Unpin + Send> Protocol<R> {
55    pub fn new(rw: R) -> Self {
56        let (read, write) = rw.split();
57        Self {
58            closed: AtomicBool::new(false),
59            reader: Mutex::new(BufReader::new(read)),
60            writer: Mutex::new(BufWriter::new(write)),
61        }
62    }
63
64    pub async fn close(&self) {
65        if !self.closed.load(core::sync::atomic::Ordering::Relaxed) {
66            let _ = self.send_message(Message::Ended).await;
67            // Currently only writer can be closed (reader rely on timeout)
68            let mut writer = self.writer.lock().await;
69            let _ = writer.close().await;
70            self.closed
71                .store(true, core::sync::atomic::Ordering::Relaxed);
72        }
73    }
74
75    pub async fn recv_message(&self) -> Result<Message> {
76        let mut reader = self.reader.lock().await;
77        let mut expected_size: [u8; 4] = [0; 4];
78        timeout(
79            Duration::from_secs(TIMEOUT),
80            reader.read_exact(&mut expected_size),
81        )
82        .await?;
83        let expected_size = u32::from_be_bytes(expected_size) as usize;
84
85        let mut data = vec![0u8; expected_size];
86        timeout(Duration::from_secs(TIMEOUT), reader.read_exact(&mut data)).await?;
87
88        match ciborium::de::from_reader(data.as_slice()) {
89            Ok(message) => Ok(message),
90            Err(err) => Err(Error::Deserialization(err)),
91        }
92    }
93
94    pub async fn send_message(&self, message: Message) -> Result<()> {
95        if self.closed.load(core::sync::atomic::Ordering::Relaxed) {
96            return Err(Error::Io(std::io::Error::new(
97                std::io::ErrorKind::Other,
98                "closed",
99            )));
100        }
101        let mut writer = self.writer.lock().await;
102
103        let mut data = Vec::new();
104        match ciborium::into_writer(&message, &mut data) {
105            Ok(()) => {
106                timeout(
107                    Duration::from_secs(TIMEOUT),
108                    writer.write_all(&(data.len() as u32).to_be_bytes()),
109                )
110                .await?;
111                timeout(Duration::from_secs(TIMEOUT), writer.write_all(&data)).await?;
112                timeout(Duration::from_secs(TIMEOUT), writer.flush()).await?;
113                Ok(())
114            }
115            Err(err) => Err(Error::Serialization(err)),
116        }
117    }
118}