1use std::collections::HashMap;
2use std::hash::Hash;
3use std::ops::DerefMut;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use bucky_raw_codec::{RawDecode, RawEncode, RawFixedBytes};
8use callback_result::{SingleCallbackWaiter};
9use num::{FromPrimitive, ToPrimitive};
10use sfo_split::RHalf;
11use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
12use crate::errors::{cmd_err, into_cmd_err, CmdErrorCode, CmdResult};
13use crate::peer_id::PeerId;
14use crate::{TunnelId};
15
16#[derive(RawEncode, RawDecode)]
17pub struct CmdHeader<LEN, CMD> {
18 pkg_len: LEN,
19 version: u8,
20 cmd_code: CMD,
21}
22
23impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
24 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static> CmdHeader<LEN, CMD> {
25 pub fn new(version: u8, cmd_code: CMD, pkg_len: LEN) -> Self {
26 Self {
27 pkg_len,
28 version,
29 cmd_code,
30 }
31 }
32
33 pub fn pkg_len(&self) -> LEN {
34 self.pkg_len
35 }
36
37 pub fn version(&self) -> u8 {
38 self.version
39 }
40
41 pub fn cmd_code(&self) -> CMD {
42 self.cmd_code
43 }
44
45 pub fn set_pkg_len(&mut self, pkg_len: LEN) {
46 self.pkg_len = pkg_len;
47 }
48}
49
50impl<LEN: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes,
51 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + RawFixedBytes> RawFixedBytes for CmdHeader<LEN, CMD> {
52 fn raw_bytes() -> Option<usize> {
53 Some(LEN::raw_bytes().unwrap() + u8::raw_bytes().unwrap() + CMD::raw_bytes().unwrap())
54 }
55}
56
57#[async_trait::async_trait]
58pub trait CmdBodyReadAll: tokio::io::AsyncRead + Send + 'static {
59 async fn read_all(&mut self) -> CmdResult<Vec<u8>>;
60}
61pub type CmdBodyRead = Box<dyn CmdBodyReadAll>;
62
63pub(crate) struct CmdBodyReadImpl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> {
64 recv: Option<RHalf<R, W>>,
65 len: usize,
66 offset: usize,
67 waiter: Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>>,
68}
69
70impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadImpl<R, W> {
71 pub fn new(recv: RHalf<R, W>, len: usize) -> Self {
72 Self {
73 recv: Some(recv),
74 len,
75 offset: 0,
76 waiter: Arc::new(SingleCallbackWaiter::new()),
77 }
78 }
79
80
81 pub(crate) fn get_waiter(&self) -> Arc<SingleCallbackWaiter<CmdResult<RHalf<R, W>>>> {
82 self.waiter.clone()
83 }
84}
85
86#[async_trait::async_trait]
87impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> CmdBodyReadAll for CmdBodyReadImpl<R, W> {
88 async fn read_all(&mut self) -> CmdResult<Vec<u8>> {
89 if self.offset == self.len {
90 return Ok(Vec::new());
91 }
92 let mut buf = vec![0u8; self.len - self.offset];
93 let ret = self.recv.as_mut().unwrap().read_exact(&mut buf).await.map_err(into_cmd_err!(CmdErrorCode::IoError));
94 if ret.is_ok() {
95 self.offset = self.len;
96 self.waiter.set_result_with_cache(Ok(self.recv.take().unwrap()));
97 Ok(buf)
98 } else {
99 self.recv.take();
100 self.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
101 Err(ret.err().unwrap())
102 }
103 }
104}
105
106impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> Drop for CmdBodyReadImpl<R, W> {
107 fn drop(&mut self) {
108 if self.recv.is_none() || self.len == self.offset {
109 return;
110 }
111 let mut recv = self.recv.take().unwrap();
112 let len = self.len - self.offset;
113 let waiter = self.waiter.clone();
114 if len == 0 {
115 waiter.set_result_with_cache(Ok(recv));
116 return;
117 }
118
119 tokio::spawn(async move {
120 let mut buf = vec![0u8; len];
121 if let Err(e) = recv.read_exact(&mut buf).await {
122 waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error {}", e)));
123 } else {
124 waiter.set_result_with_cache(Ok(recv));
125 }
126 });
127 }
128}
129
130impl<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin> tokio::io::AsyncRead for CmdBodyReadImpl<R, W> {
131 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
132 let this = Pin::into_inner(self);
133 let len = this.len - this.offset;
134 if len == 0 {
135 return Poll::Ready(Ok(()));
136 }
137 let buf = buf.initialize_unfilled();
138 let mut buf = ReadBuf::new(&mut buf[..len]);
139 let recv = Pin::new(this.recv.as_mut().unwrap().deref_mut());
140 let fut = recv.poll_read(cx, &mut buf);
141 match fut {
142 Poll::Ready(Ok(())) => {
143 this.offset += buf.filled().len();
144 if this.offset == this.len {
145 this.waiter.set_result_with_cache(Ok(this.recv.take().unwrap()));
146 }
147 Poll::Ready(Ok(()))
148 }
149 Poll::Ready(Err(e)) => {
150 this.recv.take();
151 this.waiter.set_result_with_cache(Err(cmd_err!(CmdErrorCode::IoError, "read body error")));
152 Poll::Ready(Err(e))
153 },
154 Poll::Pending => Poll::Pending,
155 }
156 }
157}
158
159
160#[callback_trait::callback_trait]
161pub trait CmdHandler<LEN, CMD>: Send + Sync + 'static
162where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
163 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static {
164 async fn handle(&self, peer_id: PeerId, tunnel_id: TunnelId, header: CmdHeader<LEN, CMD>, body: CmdBodyRead) -> CmdResult<()>;
165}
166
167pub(crate) struct CmdHandlerMap<LEN, CMD> {
168 map: Mutex<HashMap<CMD, Arc<dyn CmdHandler<LEN, CMD>>>>,
169}
170
171impl <LEN, CMD> CmdHandlerMap<LEN, CMD>
172where LEN: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + FromPrimitive + ToPrimitive,
173 CMD: RawEncode + for<'a> RawDecode<'a> + Copy + Send + Sync + 'static + Eq + Hash {
174 pub fn new() -> Self {
175 Self {
176 map: Mutex::new(HashMap::new()),
177 }
178 }
179
180 pub fn insert(&self, cmd: CMD, handler: impl CmdHandler<LEN, CMD>) {
181 self.map.lock().unwrap().insert(cmd, Arc::new(handler));
182 }
183
184 pub fn get(&self, cmd: CMD) -> Option<Arc<dyn CmdHandler<LEN, CMD>>> {
185 self.map.lock().unwrap().get(&cmd).map(|v| v.clone())
186 }
187}