async_resol_vbus/
tcp_server_handshake.rs1use std::future::Future;
2
3use async_std::{net::TcpStream, prelude::*};
4
5use resol_vbus::BlobBuffer;
6
7use crate::error::Result;
8
9pub type FutureResult<T> = std::result::Result<T, &'static str>;
10
11#[derive(Debug)]
36pub struct TcpServerHandshake {
37 stream: TcpStream,
38 buf: BlobBuffer,
39}
40
41impl TcpServerHandshake {
42 pub async fn start(stream: TcpStream) -> Result<TcpServerHandshake> {
44 let mut hs = TcpServerHandshake {
45 stream,
46 buf: BlobBuffer::new(),
47 };
48
49 hs.send_reply("+HELLO\r\n").await?;
50
51 Ok(hs)
52 }
53
54 pub fn into_inner(self) -> TcpStream {
56 self.stream
57 }
58
59 async fn send_reply(&mut self, reply: &str) -> Result<()> {
60 self.stream.write_all(reply.as_bytes()).await?;
61 Ok(())
62 }
63
64 async fn receive_line(&mut self) -> Result<String> {
65 let line = loop {
66 if let Some(idx) = self.buf.iter().position(|b| *b == 10) {
67 let string = std::str::from_utf8(&self.buf[0..idx])?.to_string();
68
69 self.buf.consume(idx + 1);
70
71 break string;
72 }
73
74 let mut buf = [0u8; 256];
75 let len = self.stream.read(&mut buf).await?;
76 if len == 0 {
77 return Err("Reached EOF".into());
78 }
79
80 self.buf.extend_from_slice(&buf[0..len]);
81 };
82
83 Ok(line)
84 }
85
86 pub async fn receive_command<V, R, T>(&mut self, validator: V) -> Result<T>
103 where
104 V: Fn(String, Option<String>) -> R,
105 R: Future<Output = FutureResult<T>>,
106 {
107 loop {
108 let line = self.receive_line().await?;
109 let line = line.trim();
110
111 let (command, args) = if let Some(idx) = line.chars().position(|c| c.is_whitespace()) {
112 let command = (&line[0..idx]).to_uppercase();
113 let args = (&line[idx..]).trim().to_string();
114 (command, Some(args))
115 } else {
116 (line.to_uppercase(), None)
117 };
118
119 let (reply, result) = if command == "QUIT" {
120 ("+OK\r\n", Some(Err("Received QUIT command".into())))
121 } else {
122 match validator(command, args).await {
123 Ok(result) => ("+OK\r\n", Some(Ok(result))),
124 Err(reply) => (reply, None),
125 }
126 };
127
128 self.send_reply(reply).await?;
129
130 if let Some(result) = result {
131 break result;
132 }
133 }
134 }
135
136 pub async fn receive_connect_command(&mut self) -> Result<String> {
138 self.receive_connect_command_and_verify_via_tag(|via_tag| async move { Ok(via_tag) })
139 .await
140 }
141
142 pub async fn receive_connect_command_and_verify_via_tag<V, R>(
144 &mut self,
145 validator: V,
146 ) -> Result<String>
147 where
148 V: Fn(String) -> R,
149 R: Future<Output = FutureResult<String>>,
150 {
151 self.receive_command(|command, args| {
152 let result = if command != "CONNECT" {
153 Err("-ERROR Expected CONNECT command\r\n")
154 } else if let Some(via_tag) = args {
155 Ok(validator(via_tag))
156 } else {
157 Err("-ERROR Expected argument\r\n")
158 };
159
160 async move {
161 match result {
162 Ok(future) => future.await,
163 Err(err) => Err(err),
164 }
165 }
166 })
167 .await
168 }
169
170 pub async fn receive_pass_command(&mut self) -> Result<String> {
172 self.receive_pass_command_and_verify_password(|password| async move { Ok(password) })
173 .await
174 }
175
176 pub async fn receive_pass_command_and_verify_password<V, R>(
178 &mut self,
179 validator: V,
180 ) -> Result<String>
181 where
182 V: Fn(String) -> R,
183 R: Future<Output = FutureResult<String>>,
184 {
185 self.receive_command(|command, args| {
186 let result = if command != "PASS" {
187 Err("-ERROR Expected PASS command\r\n")
188 } else if let Some(password) = args {
189 Ok(validator(password))
190 } else {
191 Err("-ERROR Expected argument\r\n")
192 };
193
194 async move {
195 match result {
196 Ok(future) => future.await,
197 Err(err) => Err(err),
198 }
199 }
200 })
201 .await
202 }
203
204 pub async fn receive_channel_command(&mut self) -> Result<u8> {
206 self.receive_channel_command_and_verify_channel(|channel| async move { Ok(channel) })
207 .await
208 }
209
210 pub async fn receive_channel_command_and_verify_channel<V, R>(
212 &mut self,
213 validator: V,
214 ) -> Result<u8>
215 where
216 V: Fn(u8) -> R,
217 R: Future<Output = FutureResult<u8>>,
218 {
219 self.receive_command(|command, args| {
220 let result = if command != "CHANNEL" {
221 Err("-ERROR Expected CHANNEL command\r\n")
222 } else if let Some(channel) = args {
223 if let Ok(channel) = channel.parse() {
224 Ok(validator(channel))
225 } else {
226 Err("-ERROR Expected 8 bit number argument\r\n")
227 }
228 } else {
229 Err("-ERROR Expected argument\r\n")
230 };
231
232 async {
233 match result {
234 Ok(future) => future.await,
235 Err(err) => Err(err),
236 }
237 }
238 })
239 .await
240 }
241
242 pub async fn receive_data_command(mut self) -> Result<TcpStream> {
247 self.receive_command(|command, args| {
248 let result = if command != "DATA" {
249 Err("-ERROR Expected DATA command\r\n")
250 } else if args.is_some() {
251 Err("-ERROR Unexpected argument\r\n")
252 } else {
253 Ok(())
254 };
255
256 async move { result }
257 })
258 .await?;
259
260 Ok(self.stream)
261 }
262}