ort_tcp/
lib.rs

1#![deny(warnings, rust_2018_idioms)]
2
3pub mod client;
4pub mod muxer;
5pub mod preface;
6pub mod server;
7
8use bytes::{Buf, BufMut, BytesMut};
9use ort_core::{Reply, Spec};
10use tokio::{io, time};
11use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};
12
13#[derive(Default)]
14struct SpecCodec(());
15
16struct ReplyCodec(LengthDelimitedCodec);
17
18// === impl SpecCodec ===
19
20impl Decoder for SpecCodec {
21    type Item = Spec;
22    type Error = io::Error;
23
24    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Spec>> {
25        if src.len() < 4 + 4 {
26            return Ok(None);
27        }
28        let ms = src.get_u32();
29        let sz = src.get_u32();
30        Ok(Some(Spec {
31            latency: time::Duration::from_millis(ms as u64),
32            response_size: sz as usize,
33        }))
34    }
35}
36
37impl Encoder<Spec> for SpecCodec {
38    type Error = io::Error;
39
40    fn encode(&mut self, spec: Spec, dst: &mut BytesMut) -> io::Result<()> {
41        dst.reserve(4 + 4);
42        dst.put_u32(spec.latency.as_millis() as u32);
43        dst.put_u32(spec.response_size as u32);
44        Ok(())
45    }
46}
47
48// === impl ReplyCodec ===
49
50impl Default for ReplyCodec {
51    fn default() -> Self {
52        let frames = LengthDelimitedCodec::builder()
53            .max_frame_length(std::u32::MAX as usize)
54            .length_field_length(4)
55            .new_codec();
56        Self(frames)
57    }
58}
59
60impl Decoder for ReplyCodec {
61    type Item = Reply;
62    type Error = io::Error;
63
64    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Reply>, io::Error> {
65        match self.0.decode(src)? {
66            None => Ok(None),
67            Some(buf) => Ok(Some(Reply { data: buf.freeze() })),
68        }
69    }
70}
71
72impl Encoder<Reply> for ReplyCodec {
73    type Error = io::Error;
74
75    fn encode(&mut self, Reply { data }: Reply, dst: &mut BytesMut) -> io::Result<()> {
76        self.0.encode(data, dst)?;
77        Ok(())
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use bytes::Bytes;
85
86    #[tokio::test]
87    async fn roundtrip_spec() {
88        let spec0 = Spec {
89            latency: time::Duration::from_millis(1),
90            response_size: 3,
91        };
92        let spec1 = Spec {
93            latency: time::Duration::from_millis(2),
94            response_size: 4,
95        };
96
97        let mut buf = BytesMut::with_capacity(100);
98
99        let mut enc = SpecCodec::default();
100        enc.encode(spec0, &mut buf).expect("must encode");
101        enc.encode(spec1, &mut buf).expect("must encode");
102
103        let mut dec = SpecCodec::default();
104        assert_eq!(
105            dec.decode(&mut buf)
106                .expect("must decode")
107                .expect("must decode"),
108            spec0
109        );
110        assert_eq!(
111            dec.decode(&mut buf)
112                .expect("must decode")
113                .expect("must decode"),
114            spec1
115        );
116    }
117
118    #[tokio::test]
119    async fn roundtrip_reply() {
120        let reply0 = Reply {
121            data: Bytes::from_static(b"abcdef"),
122        };
123        let reply1 = Reply {
124            data: Bytes::from_static(b"ghijkl"),
125        };
126
127        let mut buf = BytesMut::with_capacity(100);
128
129        let mut enc = ReplyCodec::default();
130        enc.encode(reply0.clone(), &mut buf).expect("must encode");
131        enc.encode(reply1.clone(), &mut buf).expect("must encode");
132
133        let mut dec = ReplyCodec::default();
134        assert_eq!(
135            dec.decode(&mut buf)
136                .expect("must decode")
137                .expect("must decode"),
138            reply0
139        );
140        assert_eq!(
141            dec.decode(&mut buf)
142                .expect("must decode")
143                .expect("must decode"),
144            reply1
145        );
146    }
147}
148
149async fn next_or_pending<T, S: futures::Stream<Item = T> + Unpin>(p: &mut S) -> T {
150    use futures::StreamExt;
151    match p.next().await {
152        Some(p) => p,
153        None => futures::future::pending().await,
154    }
155}