async_resol_vbus/
tcp_server_handshake.rs

1use std::future::Future;
2
3use async_std::{net::TcpStream, prelude::*};
4
5use resol_vbus::BlobBuffer;
6
7use crate::error::Result;
8
9pub type FutureResult<T> = std::result::Result<T, &'static str>;
10
11/// Handles the server-side of the [VBus-over-TCP][1] handshake.
12///
13/// [1]: http://danielwippermann.github.io/resol-vbus/vbus-over-tcp.html
14///
15/// # Examples
16///
17/// ```no_run
18/// # fn main() -> async_resol_vbus::Result<()> { async_std::task::block_on(async {
19/// #
20/// use async_std::net::{SocketAddr, TcpListener, TcpStream};
21///
22/// use async_resol_vbus::TcpServerHandshake;
23///
24/// let address = "0.0.0.0:7053".parse::<SocketAddr>()?;
25/// let listener = TcpListener::bind(address).await?;
26/// let (stream, _) = listener.accept().await?;
27/// let mut hs = TcpServerHandshake::start(stream).await?;
28/// let password = hs.receive_pass_command().await?;
29/// // ...
30/// let stream = hs.receive_data_command().await?;
31/// // ...
32/// #
33/// # Ok(()) }) }
34/// ```
35#[derive(Debug)]
36pub struct TcpServerHandshake {
37    stream: TcpStream,
38    buf: BlobBuffer,
39}
40
41impl TcpServerHandshake {
42    /// Start the VBus-over-TCP handshake as the server side.
43    pub async fn start(stream: TcpStream) -> Result<TcpServerHandshake> {
44        let mut hs = TcpServerHandshake {
45            stream,
46            buf: BlobBuffer::new(),
47        };
48
49        hs.send_reply("+HELLO\r\n").await?;
50
51        Ok(hs)
52    }
53
54    /// Consume `self` and return the underlying `TcpStream`.
55    pub fn into_inner(self) -> TcpStream {
56        self.stream
57    }
58
59    async fn send_reply(&mut self, reply: &str) -> Result<()> {
60        self.stream.write_all(reply.as_bytes()).await?;
61        Ok(())
62    }
63
64    async fn receive_line(&mut self) -> Result<String> {
65        let line = loop {
66            if let Some(idx) = self.buf.iter().position(|b| *b == 10) {
67                let string = std::str::from_utf8(&self.buf[0..idx])?.to_string();
68
69                self.buf.consume(idx + 1);
70
71                break string;
72            }
73
74            let mut buf = [0u8; 256];
75            let len = self.stream.read(&mut buf).await?;
76            if len == 0 {
77                return Err("Reached EOF".into());
78            }
79
80            self.buf.extend_from_slice(&buf[0..len]);
81        };
82
83        Ok(line)
84    }
85
86    /// Receive a command and verify it and its provided arguments. The
87    /// command reception is repeated as long as the verification fails.
88    ///
89    /// The preferred way to receive commands documented in the VBus-over-TCP
90    /// specification is through the `receive_xxx_command` and
91    /// `receive_xxx_command_and_verify_yyy` methods which use the
92    /// `receive_command` method internally.
93    ///
94    /// This method takes a validator function that is called with the
95    /// received command and its optional arguments. The validator
96    /// returns a `Future` that can resolve into an
97    /// `std::result::Result<T, &'static str>`. It can either be:
98    /// - `Ok(value)` if the validation succeeded. The `value` is used
99    ///   to resolve the `receive_command` `Future`.
100    /// - `Err(reply)` if the validation failed. The `reply` is send
101    ///   back to the client and the command reception is repeated.
102    pub async fn receive_command<V, R, T>(&mut self, validator: V) -> Result<T>
103    where
104        V: Fn(String, Option<String>) -> R,
105        R: Future<Output = FutureResult<T>>,
106    {
107        loop {
108            let line = self.receive_line().await?;
109            let line = line.trim();
110
111            let (command, args) = if let Some(idx) = line.chars().position(|c| c.is_whitespace()) {
112                let command = (&line[0..idx]).to_uppercase();
113                let args = (&line[idx..]).trim().to_string();
114                (command, Some(args))
115            } else {
116                (line.to_uppercase(), None)
117            };
118
119            let (reply, result) = if command == "QUIT" {
120                ("+OK\r\n", Some(Err("Received QUIT command".into())))
121            } else {
122                match validator(command, args).await {
123                    Ok(result) => ("+OK\r\n", Some(Ok(result))),
124                    Err(reply) => (reply, None),
125                }
126            };
127
128            self.send_reply(reply).await?;
129
130            if let Some(result) = result {
131                break result;
132            }
133        }
134    }
135
136    /// Wait for a `CONNECT <via_tag>` command. The via tag argument is returned.
137    pub async fn receive_connect_command(&mut self) -> Result<String> {
138        self.receive_connect_command_and_verify_via_tag(|via_tag| async move { Ok(via_tag) })
139            .await
140    }
141
142    /// Wait for a `CONNECT <via_tag>` command.
143    pub async fn receive_connect_command_and_verify_via_tag<V, R>(
144        &mut self,
145        validator: V,
146    ) -> Result<String>
147    where
148        V: Fn(String) -> R,
149        R: Future<Output = FutureResult<String>>,
150    {
151        self.receive_command(|command, args| {
152            let result = if command != "CONNECT" {
153                Err("-ERROR Expected CONNECT command\r\n")
154            } else if let Some(via_tag) = args {
155                Ok(validator(via_tag))
156            } else {
157                Err("-ERROR Expected argument\r\n")
158            };
159
160            async move {
161                match result {
162                    Ok(future) => future.await,
163                    Err(err) => Err(err),
164                }
165            }
166        })
167        .await
168    }
169
170    /// Wait for a `PASS <password>` command.
171    pub async fn receive_pass_command(&mut self) -> Result<String> {
172        self.receive_pass_command_and_verify_password(|password| async move { Ok(password) })
173            .await
174    }
175
176    /// Wait for a `PASS <password>` command and validate the provided password.
177    pub async fn receive_pass_command_and_verify_password<V, R>(
178        &mut self,
179        validator: V,
180    ) -> Result<String>
181    where
182        V: Fn(String) -> R,
183        R: Future<Output = FutureResult<String>>,
184    {
185        self.receive_command(|command, args| {
186            let result = if command != "PASS" {
187                Err("-ERROR Expected PASS command\r\n")
188            } else if let Some(password) = args {
189                Ok(validator(password))
190            } else {
191                Err("-ERROR Expected argument\r\n")
192            };
193
194            async move {
195                match result {
196                    Ok(future) => future.await,
197                    Err(err) => Err(err),
198                }
199            }
200        })
201        .await
202    }
203
204    /// Wait for a `CHANNEL <channel>` command.
205    pub async fn receive_channel_command(&mut self) -> Result<u8> {
206        self.receive_channel_command_and_verify_channel(|channel| async move { Ok(channel) })
207            .await
208    }
209
210    /// Wait for `CHANNEL <channel>` command and validate the provided channel
211    pub async fn receive_channel_command_and_verify_channel<V, R>(
212        &mut self,
213        validator: V,
214    ) -> Result<u8>
215    where
216        V: Fn(u8) -> R,
217        R: Future<Output = FutureResult<u8>>,
218    {
219        self.receive_command(|command, args| {
220            let result = if command != "CHANNEL" {
221                Err("-ERROR Expected CHANNEL command\r\n")
222            } else if let Some(channel) = args {
223                if let Ok(channel) = channel.parse() {
224                    Ok(validator(channel))
225                } else {
226                    Err("-ERROR Expected 8 bit number argument\r\n")
227                }
228            } else {
229                Err("-ERROR Expected argument\r\n")
230            };
231
232            async {
233                match result {
234                    Ok(future) => future.await,
235                    Err(err) => Err(err),
236                }
237            }
238        })
239        .await
240    }
241
242    /// Wait for a `DATA` command.
243    ///
244    /// This function returns the underlying `TcpStream` since the handshake is complete
245    /// after sending this command.
246    pub async fn receive_data_command(mut self) -> Result<TcpStream> {
247        self.receive_command(|command, args| {
248            let result = if command != "DATA" {
249                Err("-ERROR Expected DATA command\r\n")
250            } else if args.is_some() {
251                Err("-ERROR Unexpected argument\r\n")
252            } else {
253                Ok(())
254            };
255
256            async move { result }
257        })
258        .await?;
259
260        Ok(self.stream)
261    }
262}