1use std::{
2 collections::HashMap,
3 io::Cursor,
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7};
8
9use async_trait::async_trait;
10use bytes::{Buf, BytesMut};
11use std::future::Future;
12use thiserror::Error;
13use tokio::{
14 io::{AsyncReadExt, AsyncWriteExt},
15 net::{
16 tcp::{OwnedReadHalf, OwnedWriteHalf},
17 TcpStream,
18 },
19 sync::{broadcast, oneshot, Mutex},
20};
21
22pub type Result<T> = std::result::Result<T, Error>;
23
24#[derive(Error, Debug)]
25pub enum Error {
26 #[error("{0}")]
27 Io(#[from] std::io::Error),
28 #[error("channel disconnected")]
29 Disconnect,
30}
31
32#[derive(Error, Debug)]
33pub enum PacketError {
34 #[error("")]
35 Incomplete,
36 #[error("{0}")]
37 Other(String),
38}
39
40#[async_trait]
41pub trait IdentifierPacket: Sized + Send + 'static {
42 fn id(&self) -> &str;
43 async fn from_slice(src: &mut Cursor<&[u8]>) -> Result<Option<Self>>;
45 fn into_bytes(self) -> Vec<u8>;
46}
47
48#[derive(Debug, Clone)]
49pub struct WriteConnection {
50 inner: Arc<Mutex<OwnedWriteHalf>>,
51}
52
53impl WriteConnection {
54 pub fn new(write: OwnedWriteHalf) -> Self {
55 Self {
56 inner: Arc::new(Mutex::new(write)),
57 }
58 }
59
60 pub async fn write<P: IdentifierPacket>(&self, packet: P) -> Result<()> {
61 let bytes: Vec<u8> = packet.into_bytes();
62 self.inner.lock().await.write_all(bytes.as_slice()).await?;
63 Ok(())
64 }
65}
66
67#[async_trait]
68pub trait PacketProcessor<P>: Sync + Send + Clone + 'static
69where
70 P: IdentifierPacket,
71{
72 async fn process(&self, packet: P);
73}
74
75pub struct ReadConnection<P, PP>
76where
77 P: IdentifierPacket,
78 PP: PacketProcessor<P>,
79{
80 identifer: String,
81 inner: OwnedReadHalf,
82 response_notify: Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>,
83 processor: PP,
84 buffer: BytesMut,
85}
86
87impl<P, PP> ReadConnection<P, PP>
88where
89 P: IdentifierPacket,
90 PP: PacketProcessor<P>,
91{
92 pub fn new(
93 identifer: &str,
94 connection: OwnedReadHalf,
95 processor: PP,
96 notify: Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>,
97 capacity: usize,
98 ) -> Self {
99 Self {
100 identifer: identifer.to_string(),
101 inner: connection,
102 response_notify: notify,
103 processor,
104 buffer: BytesMut::with_capacity(capacity),
105 }
106 }
107
108 pub async fn run(&mut self, mut shutdown: broadcast::Receiver<()>) -> Result<()> {
109 loop {
110 let maybe_packet = tokio::select! {
111 res = self.read() => res?,
112 _ = shutdown.recv() => return Ok(())
113 };
114
115 let packet = match maybe_packet {
116 Some(packet) => packet,
117 None => return Ok(()),
118 };
119
120 let id = packet.id();
121 let notify = self.response_notify.lock().await.remove(id);
122 match notify {
123 Some(notify) => {
124 if notify.send(packet).is_err() {
125 logs::warn!("notify channel has closed, can not notify sender");
127 }
128 }
129 None => self.processor.process(packet).await,
130 }
131 }
132 }
133
134 async fn read(&mut self) -> Result<Option<P>> {
135 loop {
136 let mut buf = Cursor::new(&self.buffer[..]);
137 if let Some(packet) = <P>::from_slice(&mut buf).await? {
138 let len = buf.position() as usize;
139 self.buffer.advance(len);
140 return Ok(Some(packet));
141 }
142
143 let size = self.inner.read_buf(&mut self.buffer).await?;
144 if size == 0 {
145 if self.buffer.is_empty() {
146 return Ok(None);
147 } else {
148 return Err(Error::Disconnect);
149 }
150 }
151 println!("server[{}] receive bytes: {}", self.identifer, size);
152 }
153 }
154}
155
156type NotifyMap<P> = Arc<Mutex<HashMap<String, Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>>>>;
157#[derive(Clone)]
158pub struct BiProxy<P>
159where
160 P: IdentifierPacket,
161{
162 read_buffer_size: usize,
163 write_conns: Arc<Mutex<HashMap<String, WriteConnection>>>,
164 notify: NotifyMap<P>,
165 shutdown: broadcast::Sender<()>,
166 _a: std::marker::PhantomData<P>,
167}
168
169impl<P> BiProxy<P>
170where
171 P: IdentifierPacket,
172{
173 pub fn new(read_buffer_size: usize) -> Self {
174 Self {
175 read_buffer_size,
176 write_conns: Arc::new(Mutex::new(HashMap::new())),
177 notify: Arc::new(Mutex::new(HashMap::new())),
178 shutdown: broadcast::channel(1).0,
179 _a: std::marker::PhantomData::default(),
180 }
181 }
182
183 pub async fn accept_tcpstream<PP: PacketProcessor<P> + Clone>(
184 &self,
185 key: &str,
186 stream: TcpStream,
187 processor: PP,
188 ) {
189 let (read, write) = stream.into_split();
190 self.write_conns
191 .lock()
192 .await
193 .insert(key.to_string(), WriteConnection::new(write));
194
195 let notify_map = Arc::new(Mutex::new(HashMap::new()));
196 self.notify
197 .lock()
198 .await
199 .insert(key.to_string(), notify_map.clone());
200 self.spawn_read_task(key, notify_map, read, processor).await;
201 }
202
203 pub async fn send_oneway(&self, key: &str, data: P) -> Result<()> {
204 self.get_connection(key)
205 .await
206 .ok_or(Error::Disconnect)?
207 .write(data)
208 .await
209 }
210
211 pub async fn send(&self, key: &str, data: P) -> Result<P> {
212 let id = data.id().to_string();
213 let (sender, receiver) = oneshot::channel();
214 self.register_request_notify(key, id.as_str(), sender).await;
215
216 if let Err(error) = self.send_oneway(key, data).await {
217 self.deregister_request_notify(key, id.as_str()).await;
218 return Err(error);
219 }
220
221 let res = Response::new(receiver).await;
222 Ok(res)
223 }
224
225 async fn spawn_read_task<PP: PacketProcessor<P> + Clone>(
226 &self,
227 key: &str,
228 notify: Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>,
229 conn: OwnedReadHalf,
230 processor: PP,
231 ) {
232 let mut read_conn =
233 ReadConnection::new(key, conn, processor, notify, self.read_buffer_size);
234 let shutdown = self.shutdown.subscribe();
235 tokio::spawn(async move {
236 if let Err(error) = read_conn.run(shutdown).await {
237 logs::error!("connection reset by peer, closed: {}", error);
238 }
239 });
240 }
241
242 async fn deregister_request_notify(&self, key: &str, packet_id: &str) {
243 let notify_map = self.notify.lock().await;
244 if let Some(notify_map) = notify_map.get(key) {
245 let _ = notify_map.lock().await.remove(packet_id);
247 }
248 }
249
250 async fn register_request_notify(
251 &self,
252 key: &str,
253 packet_id: &str,
254 sender: oneshot::Sender<P>,
255 ) {
256 let mut notify_map = self.notify.lock().await;
257
258 if !notify_map.contains_key(key) {
259 notify_map.insert(key.to_string(), Arc::new(Mutex::new(HashMap::new())));
260 }
261
262 let inner_notifys = match notify_map.get(key) {
263 Some(read_notify) => read_notify,
264 None => unreachable!(),
265 };
266
267 inner_notifys
268 .lock()
269 .await
270 .insert(packet_id.to_string(), sender);
271 }
272
273 async fn get_connection(&self, key: &str) -> Option<WriteConnection> {
274 self.write_conns.lock().await.get(key).cloned()
275 }
276}
277
278struct Response<P> {
279 notify: oneshot::Receiver<P>,
280}
281
282impl<P> Response<P> {
283 fn new(notify: oneshot::Receiver<P>) -> Self {
284 Self { notify }
285 }
286}
287
288impl<P> Future for Response<P> {
289 type Output = P;
290 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291 match Pin::new(&mut self.notify).poll(cx) {
292 Poll::Ready(Ok(value)) => Poll::Ready(value),
293 Poll::Pending => Poll::Pending,
294 Poll::Ready(Err(_)) => unreachable!(),
295 }
296 }
297}