kade_proto/pkg/
connection.rs

1use super::frame::{self, Frame};
2use bytes::{Buf, BytesMut};
3use std::io::{self, Cursor};
4use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
5use tokio::net::TcpStream;
6
7#[derive(Debug)]
8pub struct Connection {
9    stream: BufWriter<TcpStream>,
10    buffer: BytesMut,
11}
12
13impl Connection {
14    pub fn new(socket: TcpStream) -> Connection {
15        Connection {
16            stream: BufWriter::new(socket),
17            buffer: BytesMut::with_capacity(8 * 1024),
18        }
19    }
20
21    pub async fn read_frame(&mut self) -> crate::Result<Option<Frame>> {
22        loop {
23            if let Some(frame) = self.parse_frame()? {
24                return Ok(Some(frame));
25            }
26
27            if 0 == self.stream.read_buf(&mut self.buffer).await? {
28                if self.buffer.is_empty() {
29                    return Ok(None);
30                } else {
31                    return Err("connection reset by peer".into());
32                }
33            }
34        }
35    }
36
37    fn parse_frame(&mut self) -> crate::Result<Option<Frame>> {
38        use frame::Error::Incomplete;
39
40        let mut buf = Cursor::new(&self.buffer[..]);
41
42        match Frame::check(&mut buf) {
43            Ok(_) => {
44                let len = buf.position() as usize;
45                buf.set_position(0);
46
47                let frame = Frame::parse(&mut buf)?;
48                self.buffer.advance(len);
49
50                Ok(Some(frame))
51            }
52
53            Err(Incomplete) => Ok(None),
54            Err(e) => Err(e.into()),
55        }
56    }
57
58    pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> {
59        match frame {
60            Frame::Array(val) => {
61                self.stream.write_u8(b'*').await?;
62                self.write_decimal(val.len() as u64).await?;
63
64                for entry in &**val {
65                    self.write_value(entry).await?;
66                }
67            }
68            _ => self.write_value(frame).await?,
69        }
70
71        self.stream.flush().await
72    }
73
74    async fn write_value(&mut self, frame: &Frame) -> io::Result<()> {
75        match frame {
76            Frame::Simple(val) => {
77                self.stream.write_u8(b'+').await?;
78                self.stream.write_all(val.as_bytes()).await?;
79                self.stream.write_all(b"\r\n").await?;
80            }
81            Frame::Error(val) => {
82                self.stream.write_u8(b'-').await?;
83                self.stream.write_all(val.as_bytes()).await?;
84                self.stream.write_all(b"\r\n").await?;
85            }
86            Frame::Integer(val) => {
87                self.stream.write_u8(b':').await?;
88                self.write_decimal(*val).await?;
89            }
90            Frame::Null => {
91                self.stream.write_all(b"$-1\r\n").await?;
92            }
93            Frame::Bulk(val) => {
94                let len = val.len();
95
96                self.stream.write_u8(b'$').await?;
97                self.write_decimal(len as u64).await?;
98                self.stream.write_all(val).await?;
99                self.stream.write_all(b"\r\n").await?;
100            }
101            Frame::Array(_val) => unreachable!(),
102        }
103
104        Ok(())
105    }
106
107    async fn write_decimal(&mut self, val: u64) -> io::Result<()> {
108        use std::io::Write;
109
110        let mut buf = [0u8; 20];
111        let mut buf = Cursor::new(&mut buf[..]);
112        write!(&mut buf, "{}", val)?;
113
114        let pos = buf.position() as usize;
115        self.stream.write_all(&buf.get_ref()[..pos]).await?;
116        self.stream.write_all(b"\r\n").await?;
117
118        Ok(())
119    }
120}