1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#![deny(warnings, rust_2018_idioms)]

pub mod client;
pub mod muxer;
pub mod preface;
pub mod server;

use bytes::{Buf, BufMut, BytesMut};
use ort_core::{Reply, Spec};
use tokio::{io, time};
use tokio_util::codec::{Decoder, Encoder, LengthDelimitedCodec};

#[derive(Default)]
struct SpecCodec(());

struct ReplyCodec(LengthDelimitedCodec);

// === impl SpecCodec ===

impl Decoder for SpecCodec {
    type Item = Spec;
    type Error = io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<Spec>> {
        if src.len() < 4 + 4 {
            return Ok(None);
        }
        let ms = src.get_u32();
        let sz = src.get_u32();
        Ok(Some(Spec {
            latency: time::Duration::from_millis(ms as u64),
            response_size: sz as usize,
        }))
    }
}

impl Encoder<Spec> for SpecCodec {
    type Error = io::Error;

    fn encode(&mut self, spec: Spec, dst: &mut BytesMut) -> io::Result<()> {
        dst.reserve(4 + 4);
        dst.put_u32(spec.latency.as_millis() as u32);
        dst.put_u32(spec.response_size as u32);
        Ok(())
    }
}

// === impl ReplyCodec ===

impl Default for ReplyCodec {
    fn default() -> Self {
        let frames = LengthDelimitedCodec::builder()
            .max_frame_length(std::u32::MAX as usize)
            .length_field_length(4)
            .new_codec();
        Self(frames)
    }
}

impl Decoder for ReplyCodec {
    type Item = Reply;
    type Error = io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Reply>, io::Error> {
        match self.0.decode(src)? {
            None => Ok(None),
            Some(buf) => Ok(Some(Reply { data: buf.freeze() })),
        }
    }
}

impl Encoder<Reply> for ReplyCodec {
    type Error = io::Error;

    fn encode(&mut self, Reply { data }: Reply, dst: &mut BytesMut) -> io::Result<()> {
        self.0.encode(data, dst)?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use bytes::Bytes;

    #[tokio::test]
    async fn roundtrip_spec() {
        let spec0 = Spec {
            latency: time::Duration::from_millis(1),
            response_size: 3,
        };
        let spec1 = Spec {
            latency: time::Duration::from_millis(2),
            response_size: 4,
        };

        let mut buf = BytesMut::with_capacity(100);

        let mut enc = SpecCodec::default();
        enc.encode(spec0, &mut buf).expect("must encode");
        enc.encode(spec1, &mut buf).expect("must encode");

        let mut dec = SpecCodec::default();
        assert_eq!(
            dec.decode(&mut buf)
                .expect("must decode")
                .expect("must decode"),
            spec0
        );
        assert_eq!(
            dec.decode(&mut buf)
                .expect("must decode")
                .expect("must decode"),
            spec1
        );
    }

    #[tokio::test]
    async fn roundtrip_reply() {
        let reply0 = Reply {
            data: Bytes::from_static(b"abcdef"),
        };
        let reply1 = Reply {
            data: Bytes::from_static(b"ghijkl"),
        };

        let mut buf = BytesMut::with_capacity(100);

        let mut enc = ReplyCodec::default();
        enc.encode(reply0.clone(), &mut buf).expect("must encode");
        enc.encode(reply1.clone(), &mut buf).expect("must encode");

        let mut dec = ReplyCodec::default();
        assert_eq!(
            dec.decode(&mut buf)
                .expect("must decode")
                .expect("must decode"),
            reply0
        );
        assert_eq!(
            dec.decode(&mut buf)
                .expect("must decode")
                .expect("must decode"),
            reply1
        );
    }
}

async fn next_or_pending<T, S: futures::Stream<Item = T> + Unpin>(p: &mut S) -> T {
    use futures::StreamExt;
    match p.next().await {
        Some(p) => p,
        None => futures::future::pending().await,
    }
}