concurrent_biproxy/
lib.rs

1use std::{
2    collections::HashMap,
3    io::Cursor,
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7};
8
9use async_trait::async_trait;
10use bytes::{Buf, BytesMut};
11use std::future::Future;
12use thiserror::Error;
13use tokio::{
14    io::{AsyncReadExt, AsyncWriteExt},
15    net::{
16        tcp::{OwnedReadHalf, OwnedWriteHalf},
17        TcpStream,
18    },
19    sync::{broadcast, oneshot, Mutex},
20};
21
22pub type Result<T> = std::result::Result<T, Error>;
23
24#[derive(Error, Debug)]
25pub enum Error {
26    #[error("{0}")]
27    Io(#[from] std::io::Error),
28    #[error("channel disconnected")]
29    Disconnect,
30}
31
32#[derive(Error, Debug)]
33pub enum PacketError {
34    #[error("")]
35    Incomplete,
36    #[error("{0}")]
37    Other(String),
38}
39
40#[async_trait]
41pub trait IdentifierPacket: Sized + Send + 'static {
42    fn id(&self) -> &str;
43    /// return Ok(None) if need more bytes
44    async fn from_slice(src: &mut Cursor<&[u8]>) -> Result<Option<Self>>;
45    fn into_bytes(self) -> Vec<u8>;
46}
47
48#[derive(Debug, Clone)]
49pub struct WriteConnection {
50    inner: Arc<Mutex<OwnedWriteHalf>>,
51}
52
53impl WriteConnection {
54    pub fn new(write: OwnedWriteHalf) -> Self {
55        Self {
56            inner: Arc::new(Mutex::new(write)),
57        }
58    }
59
60    pub async fn write<P: IdentifierPacket>(&self, packet: P) -> Result<()> {
61        let bytes: Vec<u8> = packet.into_bytes();
62        self.inner.lock().await.write_all(bytes.as_slice()).await?;
63        Ok(())
64    }
65}
66
67#[async_trait]
68pub trait PacketProcessor<P>: Sync + Send + Clone + 'static
69where
70    P: IdentifierPacket,
71{
72    async fn process(&self, packet: P);
73}
74
75pub struct ReadConnection<P, PP>
76where
77    P: IdentifierPacket,
78    PP: PacketProcessor<P>,
79{
80    identifer: String,
81    inner: OwnedReadHalf,
82    response_notify: Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>,
83    processor: PP,
84    buffer: BytesMut,
85}
86
87impl<P, PP> ReadConnection<P, PP>
88where
89    P: IdentifierPacket,
90    PP: PacketProcessor<P>,
91{
92    pub fn new(
93        identifer: &str,
94        connection: OwnedReadHalf,
95        processor: PP,
96        notify: Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>,
97        capacity: usize,
98    ) -> Self {
99        Self {
100            identifer: identifer.to_string(),
101            inner: connection,
102            response_notify: notify,
103            processor,
104            buffer: BytesMut::with_capacity(capacity),
105        }
106    }
107
108    pub async fn run(&mut self, mut shutdown: broadcast::Receiver<()>) -> Result<()> {
109        loop {
110            let maybe_packet = tokio::select! {
111                res = self.read() => res?,
112                _ = shutdown.recv() => return Ok(())
113            };
114
115            let packet = match maybe_packet {
116                Some(packet) => packet,
117                None => return Ok(()),
118            };
119
120            let id = packet.id();
121            let notify = self.response_notify.lock().await.remove(id);
122            match notify {
123                Some(notify) => {
124                    if notify.send(packet).is_err() {
125                        // TODO maybe we should log the message's content
126                        logs::warn!("notify channel has closed, can not notify sender");
127                    }
128                }
129                None => self.processor.process(packet).await,
130            }
131        }
132    }
133
134    async fn read(&mut self) -> Result<Option<P>> {
135        loop {
136            let mut buf = Cursor::new(&self.buffer[..]);
137            if let Some(packet) = <P>::from_slice(&mut buf).await? {
138                let len = buf.position() as usize;
139                self.buffer.advance(len);
140                return Ok(Some(packet));
141            }
142
143            let size = self.inner.read_buf(&mut self.buffer).await?;
144            if size == 0 {
145                if self.buffer.is_empty() {
146                    return Ok(None);
147                } else {
148                    return Err(Error::Disconnect);
149                }
150            }
151            println!("server[{}] receive bytes: {}", self.identifer, size);
152        }
153    }
154}
155
156type NotifyMap<P> = Arc<Mutex<HashMap<String, Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>>>>;
157#[derive(Clone)]
158pub struct BiProxy<P>
159where
160    P: IdentifierPacket,
161{
162    read_buffer_size: usize,
163    write_conns: Arc<Mutex<HashMap<String, WriteConnection>>>,
164    notify: NotifyMap<P>,
165    shutdown: broadcast::Sender<()>,
166    _a: std::marker::PhantomData<P>,
167}
168
169impl<P> BiProxy<P>
170where
171    P: IdentifierPacket,
172{
173    pub fn new(read_buffer_size: usize) -> Self {
174        Self {
175            read_buffer_size,
176            write_conns: Arc::new(Mutex::new(HashMap::new())),
177            notify: Arc::new(Mutex::new(HashMap::new())),
178            shutdown: broadcast::channel(1).0,
179            _a: std::marker::PhantomData::default(),
180        }
181    }
182
183    pub async fn accept_tcpstream<PP: PacketProcessor<P> + Clone>(
184        &self,
185        key: &str,
186        stream: TcpStream,
187        processor: PP,
188    ) {
189        let (read, write) = stream.into_split();
190        self.write_conns
191            .lock()
192            .await
193            .insert(key.to_string(), WriteConnection::new(write));
194
195        let notify_map = Arc::new(Mutex::new(HashMap::new()));
196        self.notify
197            .lock()
198            .await
199            .insert(key.to_string(), notify_map.clone());
200        self.spawn_read_task(key, notify_map, read, processor).await;
201    }
202
203    pub async fn send_oneway(&self, key: &str, data: P) -> Result<()> {
204        self.get_connection(key)
205            .await
206            .ok_or(Error::Disconnect)?
207            .write(data)
208            .await
209    }
210
211    pub async fn send(&self, key: &str, data: P) -> Result<P> {
212        let id = data.id().to_string();
213        let (sender, receiver) = oneshot::channel();
214        self.register_request_notify(key, id.as_str(), sender).await;
215
216        if let Err(error) = self.send_oneway(key, data).await {
217            self.deregister_request_notify(key, id.as_str()).await;
218            return Err(error);
219        }
220
221        let res = Response::new(receiver).await;
222        Ok(res)
223    }
224
225    async fn spawn_read_task<PP: PacketProcessor<P> + Clone>(
226        &self,
227        key: &str,
228        notify: Arc<Mutex<HashMap<String, oneshot::Sender<P>>>>,
229        conn: OwnedReadHalf,
230        processor: PP,
231    ) {
232        let mut read_conn =
233            ReadConnection::new(key, conn, processor, notify, self.read_buffer_size);
234        let shutdown = self.shutdown.subscribe();
235        tokio::spawn(async move {
236            if let Err(error) = read_conn.run(shutdown).await {
237                logs::error!("connection reset by peer, closed: {}", error);
238            }
239        });
240    }
241
242    async fn deregister_request_notify(&self, key: &str, packet_id: &str) {
243        let notify_map = self.notify.lock().await;
244        if let Some(notify_map) = notify_map.get(key) {
245            // ignore return
246            let _ = notify_map.lock().await.remove(packet_id);
247        }
248    }
249
250    async fn register_request_notify(
251        &self,
252        key: &str,
253        packet_id: &str,
254        sender: oneshot::Sender<P>,
255    ) {
256        let mut notify_map = self.notify.lock().await;
257
258        if !notify_map.contains_key(key) {
259            notify_map.insert(key.to_string(), Arc::new(Mutex::new(HashMap::new())));
260        }
261
262        let inner_notifys = match notify_map.get(key) {
263            Some(read_notify) => read_notify,
264            None => unreachable!(),
265        };
266
267        inner_notifys
268            .lock()
269            .await
270            .insert(packet_id.to_string(), sender);
271    }
272
273    async fn get_connection(&self, key: &str) -> Option<WriteConnection> {
274        self.write_conns.lock().await.get(key).cloned()
275    }
276}
277
278struct Response<P> {
279    notify: oneshot::Receiver<P>,
280}
281
282impl<P> Response<P> {
283    fn new(notify: oneshot::Receiver<P>) -> Self {
284        Self { notify }
285    }
286}
287
288impl<P> Future for Response<P> {
289    type Output = P;
290    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
291        match Pin::new(&mut self.notify).poll(cx) {
292            Poll::Ready(Ok(value)) => Poll::Ready(value),
293            Poll::Pending => Poll::Pending,
294            Poll::Ready(Err(_)) => unreachable!(),
295        }
296    }
297}