Skip to main content

trustless_protocol/
codec.rs

1/// Wrap an `AsyncRead` with length-delimited codec framing for reading protocol messages.
2pub fn framed_read<R: tokio::io::AsyncRead>(
3    reader: R,
4) -> tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec> {
5    tokio_util::codec::FramedRead::new(reader, tokio_util::codec::LengthDelimitedCodec::new())
6}
7
8/// Wrap an `AsyncWrite` with length-delimited codec framing for writing protocol messages.
9pub fn framed_write<W: tokio::io::AsyncWrite>(
10    writer: W,
11) -> tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec> {
12    tokio_util::codec::FramedWrite::new(writer, tokio_util::codec::LengthDelimitedCodec::new())
13}
14
15/// Serialize a message as JSON and send it over a framed writer.
16pub async fn send_message<W>(
17    writer: &mut tokio_util::codec::FramedWrite<W, tokio_util::codec::LengthDelimitedCodec>,
18    msg: &impl serde::Serialize,
19) -> Result<(), crate::error::Error>
20where
21    W: tokio::io::AsyncWrite + Unpin,
22{
23    use futures_util::SinkExt as _;
24
25    let json = serde_json::to_vec(msg)?;
26    writer.send(bytes::Bytes::from(json)).await?;
27    Ok(())
28}
29
30/// Read and deserialize a JSON message from a framed reader.
31///
32/// Returns [`Error::ProcessExited`](crate::error::Error::ProcessExited) when the stream reaches EOF.
33pub async fn recv_message<R, M>(
34    reader: &mut tokio_util::codec::FramedRead<R, tokio_util::codec::LengthDelimitedCodec>,
35) -> Result<M, crate::error::Error>
36where
37    R: tokio::io::AsyncRead + Unpin,
38    M: serde::de::DeserializeOwned,
39{
40    use futures_util::StreamExt as _;
41
42    let frame = reader
43        .next()
44        .await
45        .ok_or(crate::error::Error::ProcessExited)??;
46    let msg = serde_json::from_slice(&frame)?;
47    Ok(msg)
48}
49
50#[cfg(test)]
51mod tests {
52    #[tokio::test]
53    async fn round_trip_message() {
54        let (client, server) = tokio::io::duplex(4096);
55        let (read_half, write_half) = tokio::io::split(server);
56        let (client_read, client_write) = tokio::io::split(client);
57
58        let mut writer = super::framed_write(client_write);
59        let mut reader = super::framed_read(read_half);
60
61        let request = crate::message::Request::Initialize {
62            id: 1,
63            params: crate::message::InitializeParams {},
64        };
65        super::send_message(&mut writer, &request).await.unwrap();
66
67        let received: crate::message::Request = super::recv_message(&mut reader).await.unwrap();
68        assert_eq!(received.id(), 1);
69        assert!(matches!(
70            received,
71            crate::message::Request::Initialize { .. }
72        ));
73
74        // Send a response back
75        let mut server_writer = super::framed_write(write_half);
76        let mut client_reader = super::framed_read(client_read);
77
78        let response =
79            crate::message::Response::Success(crate::message::SuccessResponse::Initialize {
80                id: 1,
81                result: crate::message::InitializeResult {
82                    default: "cert1".to_owned(),
83                    certificates: vec![],
84                },
85            });
86        super::send_message(&mut server_writer, &response)
87            .await
88            .unwrap();
89
90        let received: crate::message::Response =
91            super::recv_message(&mut client_reader).await.unwrap();
92        assert_eq!(received.id(), 1);
93        match received {
94            crate::message::Response::Success(crate::message::SuccessResponse::Initialize {
95                result,
96                ..
97            }) => {
98                assert_eq!(result.default, "cert1");
99            }
100            _ => panic!("expected Initialize Result"),
101        }
102    }
103
104    #[tokio::test]
105    async fn eof_returns_process_exited() {
106        let (client, server) = tokio::io::duplex(4096);
107        drop(client);
108        let mut reader = super::framed_read(server);
109        let result: Result<crate::message::Request, _> = super::recv_message(&mut reader).await;
110        assert!(matches!(result, Err(crate::error::Error::ProcessExited)));
111    }
112
113    #[tokio::test]
114    async fn multiple_messages_in_sequence() {
115        let (client, server) = tokio::io::duplex(4096);
116        let (server_read, _server_write) = tokio::io::split(server);
117        let (client_read, client_write) = tokio::io::split(client);
118        let _ = client_read;
119
120        let mut writer = super::framed_write(client_write);
121        let mut reader = super::framed_read(server_read);
122
123        for i in 1..=5 {
124            let req = crate::message::Request::Sign {
125                id: i,
126                params: crate::message::SignParams {
127                    certificate_id: format!("cert{i}"),
128                    scheme: "ECDSA_NISTP256_SHA256".to_owned(),
129                    blob: vec![i as u8; 16],
130                },
131            };
132            super::send_message(&mut writer, &req).await.unwrap();
133        }
134
135        for i in 1..=5 {
136            let received: crate::message::Request = super::recv_message(&mut reader).await.unwrap();
137            assert_eq!(received.id(), i);
138            match &received {
139                crate::message::Request::Sign { params, .. } => {
140                    assert_eq!(params.certificate_id, format!("cert{i}"));
141                    assert_eq!(params.blob, vec![i as u8; 16]);
142                }
143                _ => panic!("expected Sign"),
144            }
145        }
146    }
147}