async_resol_vbus/
tcp_client_handshake.rs

1use async_std::{net::TcpStream, prelude::*};
2
3use resol_vbus::BlobBuffer;
4
5use crate::error::Result;
6
7/// Handles the client-side of the [VBus-over-TCP][1] handshake.
8///
9/// [1]: http://danielwippermann.github.io/resol-vbus/vbus-over-tcp.html
10///
11/// # Examples
12///
13/// ```no_run
14/// # fn main() -> async_resol_vbus::Result<()> { async_std::task::block_on(async {
15/// #
16/// use async_std::net::{SocketAddr, TcpStream};
17///
18/// use async_resol_vbus::TcpClientHandshake;
19///
20/// let address = "192.168.5.217:7053".parse::<SocketAddr>()?;
21/// let stream = TcpStream::connect(address).await?;
22/// let mut hs = TcpClientHandshake::start(stream).await?;
23/// hs.send_pass_command("vbus").await?;
24/// let stream = hs.send_data_command().await?;
25/// // ...
26/// #
27/// # Ok(()) }) }
28/// ```
29#[derive(Debug)]
30pub struct TcpClientHandshake {
31    stream: TcpStream,
32    buf: BlobBuffer,
33}
34
35impl TcpClientHandshake {
36    /// Start the handshake by waiting for the initial greeting reply from the service.
37    pub async fn start(stream: TcpStream) -> Result<TcpClientHandshake> {
38        let mut hs = TcpClientHandshake {
39            stream,
40            buf: BlobBuffer::new(),
41        };
42
43        hs.read_reply().await?;
44
45        Ok(hs)
46    }
47
48    /// Consume `self` and return the underlying `TcpStream`.
49    pub fn into_inner(self) -> TcpStream {
50        self.stream
51    }
52
53    async fn read_reply(&mut self) -> Result<()> {
54        let first_byte = loop {
55            if let Some(idx) = self.buf.iter().position(|b| *b == 10) {
56                let first_byte = self.buf[0];
57                self.buf.consume(idx + 1);
58
59                break first_byte;
60            }
61
62            let mut buf = [0u8; 256];
63            let len = self.stream.read(&mut buf).await?;
64            if len == 0 {
65                return Err("Reached EOF".into());
66            }
67
68            self.buf.extend_from_slice(&buf[0..len]);
69        };
70
71        if first_byte == b'+' {
72            Ok(())
73        } else if first_byte == b'-' {
74            Err("Negative reply".into())
75        } else {
76            Err("Unexpected reply".into())
77        }
78    }
79
80    async fn send_command(&mut self, cmd: &str, args: Option<&str>) -> Result<()> {
81        let cmd = match args {
82            Some(args) => format!("{} {}\r\n", cmd, args),
83            None => format!("{}\r\n", cmd),
84        };
85
86        self.stream.write_all(cmd.as_bytes()).await?;
87
88        self.read_reply().await
89    }
90
91    /// Send the `CONNECT` command and wait for the reply.
92    pub async fn send_connect_command(&mut self, via_tag: &str) -> Result<()> {
93        self.send_command("CONNECT", Some(via_tag)).await
94    }
95
96    /// Send the `PASS` command and wait for the reply.
97    pub async fn send_pass_command(&mut self, password: &str) -> Result<()> {
98        self.send_command("PASS", Some(password)).await
99    }
100
101    /// Send the `CHANNEL` command and wait for the reply.
102    pub async fn send_channel_command(&mut self, channel: u8) -> Result<()> {
103        self.send_command("CHANNEL", Some(&format!("{}", channel)))
104            .await
105    }
106
107    /// Send the `DATA` command and wait for the reply.
108    ///
109    /// This function returns the underlying `TcpStream` since the handshake is complete
110    /// after sending this command.
111    pub async fn send_data_command(mut self) -> Result<TcpStream> {
112        self.send_command("DATA", None).await?;
113        Ok(self.stream)
114    }
115
116    /// Send the `QUIT` command and wait for the reply.
117    pub async fn send_quit_command(mut self) -> Result<()> {
118        self.send_command("QUIT", None).await?;
119        Ok(())
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use async_std::net::{SocketAddr, TcpListener, TcpStream};
126
127    use crate::tcp_server_handshake::TcpServerHandshake;
128
129    use super::*;
130
131    #[test]
132    fn test() -> Result<()> {
133        async_std::task::block_on(async {
134            let addr = "127.0.0.1:0".parse::<SocketAddr>()?;
135            let listener = TcpListener::bind(&addr).await?;
136            let addr = listener.local_addr()?;
137
138            let server_future = async_std::task::spawn::<_, Result<()>>(async move {
139                let (stream, _) = listener.accept().await?;
140
141                let mut hs = TcpServerHandshake::start(stream).await?;
142                hs.receive_connect_command().await?;
143                hs.receive_pass_command().await?;
144                hs.receive_channel_command().await?;
145                let stream = hs.receive_data_command().await?;
146
147                drop(stream);
148
149                Ok(())
150            });
151
152            let client_future = async_std::task::spawn::<_, Result<()>>(async move {
153                let stream = TcpStream::connect(addr).await?;
154
155                let mut hs = TcpClientHandshake::start(stream).await?;
156                hs.send_connect_command("via_tag").await?;
157                hs.send_pass_command("password").await?;
158                hs.send_channel_command(1).await?;
159                let stream = hs.send_data_command().await?;
160
161                drop(stream);
162
163                Ok(())
164            });
165
166            server_future.await?;
167            client_future.await?;
168
169            Ok(())
170        })
171    }
172}