1use crate::TunnelId;
2use crate::errors::{CmdErrorCode, CmdResult, cmd_err, into_cmd_err};
3use crate::peer_id::PeerId;
4use bucky_raw_codec::{RawDecode, RawEncode};
5use callback_result::SingleCallbackWaiter;
6use futures_lite::ready;
7use num::{FromPrimitive, ToPrimitive};
8use sfo_split::RHalf;
9use std::collections::HashMap;
10use std::fmt::Debug;
11use std::hash::Hash;
12use std::ops::DerefMut;
13use std::pin::Pin;
14use std::sync::{Arc, Mutex};
15use std::task::{Context, Poll};
16use std::{fmt, io};
17use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
18
19#[derive(RawEncode, RawDecode)]
20pub struct CmdHeader<LEN, CMD> {
21 pkg_len: LEN,
22 version: u8,
23 cmd_code: CMD,
24 is_resp: bool,
25 seq: Option<u32>,
26}
27
28impl<
29 LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
30 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
31> CmdHeader<LEN, CMD>
32{
33 pub fn new(version: u8, is_resp: bool, seq: Option<u32>, cmd_code: CMD, pkg_len: LEN) -> Self {
34 Self {
35 pkg_len,
36 version,
37 seq,
38 cmd_code,
39 is_resp,
40 }
41 }
42
43 pub fn pkg_len(&self) -> LEN {
44 self.pkg_len
45 }
46
47 pub fn version(&self) -> u8 {
48 self.version
49 }
50
51 pub fn seq(&self) -> Option<u32> {
52 self.seq
53 }
54
55 pub fn is_resp(&self) -> bool {
56 self.is_resp
57 }
58
59 pub fn cmd_code(&self) -> CMD {
60 self.cmd_code
61 }
62
63 pub fn set_pkg_len(&mut self, pkg_len: LEN) {
64 self.pkg_len = pkg_len;
65 }
66}
67
68#[async_trait::async_trait]
69pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
70 async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
71}
72
73pub(crate) struct CmdBodyRead<
74 R: AsyncRead + Send + 'static + Unpin,
75 W: AsyncWrite + Send + 'static + Unpin,
76> {
77 recv: Option<RHalf<R, W>>,
78 len: usize,
79 offset: usize,
80 waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
81}
82
83impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
84 CmdBodyRead<R, W>
85{
86 pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
87 Self {
88 recv: Some(recv),
89 len,
90 offset: 0,
91 waiter: Arc::new(SingleCallbackWaiter::new()),
92 }
93 }
94
95 pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
96 self.waiter.clone()
97 }
98}
99
100#[async_trait::async_trait]
101impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll
102 for CmdBodyRead<R, W>
103{
104 async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
105 if self.offset == self.len {
106 return Ok(Vec::new());
107 }
108 let mut buf = vec![0u8; self.len - self.offset];
109 let ret = self
110 .recv
111 .as_mut()
112 .unwrap()
113 .read_exact(&mut buf)
114 .await
115 .map_err(into_cmd_err!(CmdErrorCode::IoError));
116 if ret.is_ok() {
117 self.offset = self.len;
118 self.waiter
119 .set_result_with_cache(Ok(self.recv.take().unwrap()));
120 Ok(buf)
121 } else {
122 self.recv.take();
123 self.waiter
124 .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
125 Err(ret.err().unwrap())
126 }
127 }
128}
129
130impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop
131 for CmdBodyRead<R, W>
132{
133 fn drop(&mut self) {
134 if self.recv.is_none() || (self.len == self.offset && self.len != 0) {
135 return;
136 }
137 let mut recv = self.recv.take().unwrap();
138 let len = self.len - self.offset;
139 let waiter = self.waiter.clone();
140 if len == 0 {
141 waiter.set_result_with_cache(Ok(recv));
142 return;
143 }
144
145 tokio::spawn(async move {
146 let mut buf = vec![0u8; len];
147 if let Err(e) = recv.read_exact(&mut buf).await {
148 waiter.set_result_with_cache(Err(cmd_err!(
149 CmdErrorCode::IoError,
150 "read body error {}",
151 e
152 )));
153 } else {
154 waiter.set_result_with_cache(Ok(recv));
155 }
156 });
157 }
158}
159
160impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>
161 tokio::io::AsyncRead for CmdBodyRead<R, W>
162{
163 fn poll_read(
164 self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &mut ReadBuf<'_>,
167 ) -> Poll<std::io::Result<()>> {
168 let this = Pin::into_inner(self);
169 let len = this.len - this.offset;
170 if len == 0 {
171 return Poll::Ready(Ok(()));
172 }
173 let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
174 let read_len = std::cmp::min(len, buf.remaining());
175 let mut read_buf = ReadBuf::new(buf.initialize_unfilled_to(read_len));
176 let fut = recv.poll_read(cx, &mut read_buf);
177 match fut {
178 Poll::Ready(Ok(())) => {
179 let len = read_buf.filled().len();
180 drop(read_buf);
181 this.offset += len;
182 buf.advance(len);
183 if this.offset == this.len {
184 this.waiter
185 .set_result_with_cache(Ok(this.recv.take().unwrap()));
186 }
187 Poll::Ready(Ok(()))
188 }
189 Poll::Ready(Err(e)) => {
190 this.recv.take();
191 this.waiter
192 .set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
193 Poll::Ready(Err(e))
194 }
195 Poll::Pending => Poll::Pending,
196 }
197 }
198}
199
200#[callback_trait::callback_trait]
201pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
202where
203 LEN: RawEncode
204 + for<'a> RawDecode<'a>
205 + Copy
206 + Send
207 + Sync
208 + 'static
209 + FromPrimitive
210 + ToPrimitive,
211 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static,
212{
213 async fn handle(
214 &self,
215 peer_id: PeerId,
216 tunnel_id: TunnelId,
217 header: CmdHeader<LEN, CMD>,
218 body: CmdBody,
219 ) -> CmdResult<Option<CmdBody>>;
220}
221
222pub(crate) struct CmdHandlerMap<LEN, CMD> {
223 map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
224}
225
226impl<LEN, CMD> CmdHandlerMap<LEN, CMD>
227where
228 LEN: RawEncode
229 + for<'a> RawDecode<'a>
230 + Copy
231 + Send
232 + Sync
233 + 'static
234 + FromPrimitive
235 + ToPrimitive,
236 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash,
237{
238 pub fn new() -> Self {
239 Self {
240 map: Mutex::new(HashMap::new()),
241 }
242 }
243
244 pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
245 self.map.lock().unwrap().insert(cmd, Arc::new(handler));
246 }
247
248 pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
249 self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
250 }
251}
252pin_project_lite::pin_project! {
253pub struct CmdBody {
254 #[pin]
255 reader: Box<dyn AsyncBufRead + Unpin + Send + 'static>,
256 length: u64,
257 bytes_read: u64,
258 }
259}
260
261impl CmdBody {
262 pub fn empty() -> Self {
263 Self {
264 reader: Box::new(tokio::io::empty()),
265 length: 0,
266 bytes_read: 0,
267 }
268 }
269
270 pub fn from_reader(reader: impl AsyncBufRead + Unpin + Send + 'static, length: u64) -> Self {
271 Self {
272 reader: Box::new(reader),
273 length,
274 bytes_read: 0,
275 }
276 }
277
278 pub fn into_reader(self) -> Box<dyn AsyncBufRead + Unpin + Send + 'static> {
279 self.reader
280 }
281
282 pub async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
283 let mut buf = Vec::with_capacity(1024);
284 self.read_to_end(&mut buf)
285 .await
286 .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
287 Ok(buf)
288 }
289
290 pub fn from_bytes(bytes: Vec<u8>) -> Self {
291 Self {
292 length: bytes.len() as u64,
293 reader: Box::new(io::Cursor::new(bytes)),
294 bytes_read: 0,
295 }
296 }
297
298 pub async fn into_bytes(mut self) -> CmdResult<Vec<u8>> {
299 let mut buf = Vec::with_capacity(1024);
300 self.read_to_end(&mut buf)
301 .await
302 .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to end failed"))?;
303 Ok(buf)
304 }
305
306 pub fn from_string(s: String) -> Self {
307 Self {
308 length: s.len() as u64,
309 reader: Box::new(io::Cursor::new(s.into_bytes())),
310 bytes_read: 0,
311 }
312 }
313
314 pub async fn into_string(mut self) -> CmdResult<String> {
315 let mut result = String::with_capacity(self.len() as usize);
316 self.read_to_string(&mut result)
317 .await
318 .map_err(into_cmd_err!(CmdErrorCode::Failed, "read to string failed"))?;
319 Ok(result)
320 }
321
322 pub async fn from_path<P>(path: P) -> io::Result<Self>
323 where
324 P: AsRef<std::path::Path>,
325 {
326 let path = path.as_ref();
327 let file = tokio::fs::File::open(path).await?;
328 Self::from_file(file).await
329 }
330
331 pub async fn from_file(file: tokio::fs::File) -> io::Result<Self> {
332 let len = file.metadata().await?.len();
333
334 Ok(Self {
335 length: len,
336 reader: Box::new(tokio::io::BufReader::new(file)),
337 bytes_read: 0,
338 })
339 }
340
341 pub fn len(&self) -> u64 {
342 self.length
343 }
344
345 pub fn is_empty(&self) -> bool {
347 self.length == 0
348 }
349
350 pub fn chain(self, other: CmdBody) -> Self {
351 let length = (self.length - self.bytes_read)
352 .checked_add(other.length - other.bytes_read)
353 .unwrap_or(0);
354 Self {
355 length,
356 reader: Box::new(tokio::io::AsyncReadExt::chain(self, other)),
357 bytes_read: 0,
358 }
359 }
360}
361
362impl Debug for CmdBody {
363 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
364 f.debug_struct("CmdResponse")
365 .field("reader", &"<hidden>")
366 .field("length", &self.length)
367 .field("bytes_read", &self.bytes_read)
368 .finish()
369 }
370}
371
372impl From<String> for CmdBody {
373 fn from(s: String) -> Self {
374 Self::from_string(s)
375 }
376}
377
378impl<'a> From<&'a str> for CmdBody {
379 fn from(s: &'a str) -> Self {
380 Self::from_string(s.to_owned())
381 }
382}
383
384impl From<Vec<u8>> for CmdBody {
385 fn from(b: Vec<u8>) -> Self {
386 Self::from_bytes(b)
387 }
388}
389
390impl<'a> From<&'a [u8]> for CmdBody {
391 fn from(b: &'a [u8]) -> Self {
392 Self::from_bytes(b.to_owned())
393 }
394}
395
396impl AsyncRead for CmdBody {
397 #[allow(rustdoc::missing_doc_code_examples)]
398 fn poll_read(
399 mut self: Pin<&mut Self>,
400 cx: &mut Context<'_>,
401 buf: &mut ReadBuf<'_>,
402 ) -> Poll<io::Result<()>> {
403 let buf = if self.length == self.bytes_read {
404 return Poll::Ready(Ok(()));
405 } else {
406 buf
407 };
408
409 ready!(Pin::new(&mut self.reader).poll_read(cx, buf))?;
410 self.bytes_read += buf.filled().len() as u64;
411 Poll::Ready(Ok(()))
412 }
413}
414
415impl AsyncBufRead for CmdBody {
416 #[allow(rustdoc::missing_doc_code_examples)]
417 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&'_ [u8]>> {
418 self.project().reader.poll_fill_buf(cx)
419 }
420
421 fn consume(mut self: Pin<&mut Self>, amt: usize) {
422 Pin::new(&mut self.reader).consume(amt)
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::{CmdBody, CmdBodyRead, CmdBodyReadAll, CmdHeader};
429 use crate::{CmdTunnel, CmdTunnelRead, CmdTunnelWrite, PeerId};
430 use std::pin::Pin;
431 use std::task::{Context, Poll};
432 use tokio::io::{
433 AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadBuf, split,
434 };
435
436 struct TestRead {
437 read: tokio::io::ReadHalf<DuplexStream>,
438 }
439
440 impl AsyncRead for TestRead {
441 fn poll_read(
442 mut self: Pin<&mut Self>,
443 cx: &mut Context<'_>,
444 buf: &mut ReadBuf<'_>,
445 ) -> Poll<std::io::Result<()>> {
446 Pin::new(&mut self.read).poll_read(cx, buf)
447 }
448 }
449
450 impl CmdTunnelRead<()> for TestRead {
451 fn get_remote_peer_id(&self) -> PeerId {
452 PeerId::from(vec![1; 32])
453 }
454 }
455
456 struct TestWrite {
457 write: tokio::io::WriteHalf<DuplexStream>,
458 }
459
460 impl AsyncWrite for TestWrite {
461 fn poll_write(
462 mut self: Pin<&mut Self>,
463 cx: &mut Context<'_>,
464 buf: &[u8],
465 ) -> Poll<std::io::Result<usize>> {
466 Pin::new(&mut self.write).poll_write(cx, buf)
467 }
468
469 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
470 Pin::new(&mut self.write).poll_flush(cx)
471 }
472
473 fn poll_shutdown(
474 mut self: Pin<&mut Self>,
475 cx: &mut Context<'_>,
476 ) -> Poll<std::io::Result<()>> {
477 Pin::new(&mut self.write).poll_shutdown(cx)
478 }
479 }
480
481 impl CmdTunnelWrite<()> for TestWrite {
482 fn get_remote_peer_id(&self) -> PeerId {
483 PeerId::from(vec![2; 32])
484 }
485 }
486
487 #[tokio::test]
488 async fn cmd_body_bytes_round_trip() {
489 let body = CmdBody::from_bytes(b"hello-body".to_vec());
490 let data = body.into_bytes().await.unwrap();
491 assert_eq!(data, b"hello-body");
492 }
493
494 #[tokio::test]
495 async fn cmd_body_string_round_trip() {
496 let body = CmdBody::from_string("hello-string".to_owned());
497 let s = body.into_string().await.unwrap();
498 assert_eq!(s, "hello-string");
499 }
500
501 #[tokio::test]
502 async fn cmd_body_chain_respects_consumed_prefix() {
503 let mut first = CmdBody::from_bytes(b"abc".to_vec());
504 let mut buf = [0u8; 1];
505 first.read_exact(&mut buf).await.unwrap();
506 assert_eq!(&buf, b"a");
507
508 let chained = first.chain(CmdBody::from_bytes(b"XYZ".to_vec()));
509 let s = chained.into_string().await.unwrap();
510 assert_eq!(s, "bcXYZ");
511 }
512
513 #[test]
514 fn cmd_body_empty_and_len() {
515 let empty = CmdBody::empty();
516 assert!(empty.is_empty());
517 assert_eq!(empty.len(), 0);
518
519 let body = CmdBody::from_bytes(vec![1, 2, 3, 4]);
520 assert!(!body.is_empty());
521 assert_eq!(body.len(), 4);
522 }
523
524 #[tokio::test]
525 async fn cmd_body_into_reader_and_read_all() {
526 let mut body = CmdBody::from_string("reader-body".to_owned());
527 let all = body.read_all().await.unwrap();
528 assert_eq!(all, b"reader-body");
529
530 let body = CmdBody::from_string("reader-body2".to_owned());
531 let mut reader = body.into_reader();
532 let mut out = Vec::new();
533 reader.read_to_end(&mut out).await.unwrap();
534 assert_eq!(out, b"reader-body2");
535 }
536
537 #[test]
538 fn cmd_header_set_pkg_len() {
539 let mut header = CmdHeader::<u16, u8>::new(1, false, Some(7), 0x11, 3);
540 assert_eq!(header.pkg_len(), 3);
541 header.set_pkg_len(9);
542 assert_eq!(header.pkg_len(), 9);
543 }
544
545 #[tokio::test]
546 async fn cmd_body_read_all_success_and_empty_after_read() {
547 let (side_a, side_b) = tokio::io::duplex(128);
548 let (a_read, a_write) = split(side_a);
549 let (_b_read, mut b_write) = split(side_b);
550 b_write.write_all(b"abcdef").await.unwrap();
551 b_write.flush().await.unwrap();
552
553 let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
554 let (reader, _writer) = tunnel.split();
555 let mut body_read = CmdBodyRead::new(reader, 6);
556
557 let first = body_read.read_all().await.unwrap();
558 assert_eq!(first, b"abcdef");
559 let second = body_read.read_all().await.unwrap();
560 assert!(second.is_empty());
561 }
562
563 #[tokio::test]
564 async fn cmd_body_read_all_error_when_source_short() {
565 let (side_a, side_b) = tokio::io::duplex(128);
566 let (a_read, a_write) = split(side_a);
567 let (_b_read, mut b_write) = split(side_b);
568 b_write.write_all(b"ab").await.unwrap();
569 b_write.shutdown().await.unwrap();
570
571 let tunnel = CmdTunnel::new(TestRead { read: a_read }, TestWrite { write: a_write });
572 let (reader, _writer) = tunnel.split();
573 let mut body_read = CmdBodyRead::new(reader, 5);
574 assert!(body_read.read_all().await.is_err());
575 }
576}