narrowlink_network/
event.rs

1use std::{
2    collections::HashMap,
3    sync::{atomic::AtomicUsize, Arc},
4    task::{Context, Poll},
5};
6
7use futures_util::{FutureExt, Sink, SinkExt, Stream, StreamExt};
8use narrowlink_types::{
9    agent::{EventInBound as AgentEventInBound, EventOutBound as AgentEventOutBound},
10    client::{EventInBound as ClientEventInBound, EventOutBound as ClientEventOutBound},
11};
12use tokio::sync::{mpsc, oneshot};
13
14use crate::{error::NetworkError, UniversalStream};
15
16pub struct NarrowEvent<T, U> {
17    req_sender: mpsc::UnboundedSender<Option<(usize, T, oneshot::Sender<U>)>>,
18    req_receiver: mpsc::UnboundedReceiver<Option<(usize, T, oneshot::Sender<U>)>>,
19    sender: mpsc::UnboundedSender<T>,
20    receiver: mpsc::UnboundedReceiver<T>,
21    inner_stream: Box<dyn UniversalStream<String, NetworkError>>,
22    last_req_id: Arc<AtomicUsize>,
23    requests: HashMap<usize, oneshot::Sender<U>>,
24}
25
26#[derive(Debug, Clone)]
27pub struct NarrowEventRequest<T, U> {
28    req_id: Arc<AtomicUsize>,
29    sender: mpsc::UnboundedSender<Option<(usize, T, oneshot::Sender<U>)>>,
30}
31
32impl<'a, T, U> NarrowEventRequest<T, U>
33where
34    T: RequestManager,
35{
36    pub async fn request(&self, mut req: T) -> Result<U, NetworkError> {
37        let req_id = self.req_id.load(std::sync::atomic::Ordering::SeqCst);
38        self.req_id
39            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
40        req.set_id(req_id);
41        let (cur_req_sender, cur_req_receiver) = oneshot::channel();
42
43        let _ = self.sender.send(Some((req_id, req, cur_req_sender))); //todo error
44
45        cur_req_receiver
46            .await
47            .map_err(|_| NetworkError::RequestCanceled)
48    }
49    pub async fn shutdown(&self) {
50        let _ = self.sender.send(None);
51    }
52}
53
54impl<T, U> NarrowEvent<T, U> {
55    pub fn new(stream: impl UniversalStream<String, NetworkError>) -> Self {
56        let (req_sender, req_receiver) = mpsc::unbounded_channel();
57        let (sender, receiver) = mpsc::unbounded_channel();
58        Self {
59            req_sender,
60            req_receiver,
61            sender,
62            receiver,
63            inner_stream: Box::new(stream),
64            last_req_id: Arc::new(AtomicUsize::new(0)),
65            requests: HashMap::new(),
66        }
67    }
68    pub fn get_sender(&self) -> mpsc::UnboundedSender<T> {
69        self.sender.clone()
70    }
71    pub fn get_request(&self) -> NarrowEventRequest<T, U> {
72        NarrowEventRequest {
73            req_id: self.last_req_id.clone(),
74            sender: self.req_sender.clone(),
75        }
76    }
77}
78impl<T, U> Unpin for NarrowEvent<T, U> {}
79
80impl<T, U> Stream for NarrowEvent<T, U>
81where
82    U: ResponseManager + for<'a> serde::de::Deserialize<'a>,
83    T: serde::ser::Serialize,
84{
85    type Item = Result<U, NetworkError>;
86
87    fn poll_next(
88        mut self: std::pin::Pin<&mut Self>,
89        cx: &mut Context<'_>,
90    ) -> Poll<Option<Self::Item>> {
91        loop {
92            match self.inner_stream.poll_next_unpin(cx)? {
93                Poll::Ready(Some(item)) => {
94                    let item = serde_json::from_str::<U>(&item)?;
95                    if let Some(msg_response) =
96                        item.get_id().and_then(|id| self.requests.remove(&id))
97                    {
98                        let _ = msg_response.send(item);
99                        continue;
100                    }
101                    return Poll::Ready(Some(Ok(item)));
102                }
103                Poll::Ready(None) => return Poll::Ready(None),
104                Poll::Pending => match self.req_receiver.poll_recv(cx) {
105                    Poll::Ready(Some(Some((req_id, msg, msg_response)))) => {
106                        match self
107                            .inner_stream
108                            .send(serde_json::to_string(&msg)?)
109                            .poll_unpin(cx)
110                        {
111                            Poll::Ready(Ok(())) => {
112                                self.requests.insert(req_id, msg_response);
113                                continue;
114                            }
115                            Poll::Ready(Err(_)) => return Poll::Ready(None),
116                            Poll::Pending => return Poll::Pending,
117                        }
118                    }
119                    Poll::Ready(None) | Poll::Ready(Some(None)) => return Poll::Ready(None),
120                    Poll::Pending => match self.receiver.poll_recv(cx) {
121                        Poll::Ready(Some(msg)) => match self
122                            .inner_stream
123                            .send(serde_json::to_string(&msg)?)
124                            .poll_unpin(cx)
125                        {
126                            Poll::Ready(Ok(())) => {
127                                continue;
128                            }
129                            Poll::Ready(Err(_)) => return Poll::Ready(None),
130                            Poll::Pending => return Poll::Pending,
131                        },
132                        Poll::Ready(None) => return Poll::Ready(None),
133                        Poll::Pending => return Poll::Pending,
134                    },
135                },
136            }
137        }
138    }
139}
140
141impl<T, U> Sink<T> for NarrowEvent<T, U>
142where
143    T: serde::ser::Serialize,
144{
145    type Error = NetworkError;
146
147    fn poll_ready(
148        mut self: std::pin::Pin<&mut Self>,
149        cx: &mut Context<'_>,
150    ) -> Poll<Result<(), Self::Error>> {
151        self.inner_stream.poll_ready_unpin(cx)
152    }
153
154    fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
155        self.inner_stream
156            .start_send_unpin(serde_json::to_string(&item)?)
157    }
158
159    fn poll_flush(
160        mut self: std::pin::Pin<&mut Self>,
161        cx: &mut Context<'_>,
162    ) -> Poll<Result<(), Self::Error>> {
163        self.inner_stream.poll_flush_unpin(cx)
164    }
165
166    fn poll_close(
167        mut self: std::pin::Pin<&mut Self>,
168        cx: &mut Context<'_>,
169    ) -> Poll<Result<(), Self::Error>> {
170        self.inner_stream.poll_close_unpin(cx)
171    }
172}
173
174impl<S, E> From<Box<dyn UniversalStream<String, NetworkError>>> for NarrowEvent<S, E> {
175    fn from(stream: Box<dyn UniversalStream<String, NetworkError>>) -> Self {
176        let (req_sender, req_receiver) = mpsc::unbounded_channel();
177        let (sender, receiver) = mpsc::unbounded_channel();
178        Self {
179            req_sender,
180            req_receiver,
181            sender,
182            receiver,
183            inner_stream: stream,
184            last_req_id: Arc::new(AtomicUsize::new(0)),
185            requests: HashMap::new(),
186        }
187    }
188}
189
190pub trait RequestManager {
191    fn set_id(&mut self, id: usize) -> Option<usize>;
192}
193
194pub trait ResponseManager {
195    fn get_id(&self) -> Option<usize>;
196}
197
198impl RequestManager for AgentEventOutBound {
199    fn set_id(&mut self, id: usize) -> Option<usize> {
200        if let AgentEventOutBound::Request(_id, _) = self {
201            let old_id = *_id;
202            *_id = id;
203            Some(old_id)
204        } else {
205            None
206        }
207    }
208}
209
210impl RequestManager for AgentEventInBound {
211    fn set_id(&mut self, _id: usize) -> Option<usize> {
212        None
213    }
214}
215
216impl ResponseManager for AgentEventOutBound {
217    fn get_id(&self) -> Option<usize> {
218        None
219    }
220}
221
222impl ResponseManager for AgentEventInBound {
223    fn get_id(&self) -> Option<usize> {
224        if let AgentEventInBound::Response(id, _) = self {
225            Some(*id)
226        } else {
227            None
228        }
229    }
230}
231
232impl RequestManager for ClientEventOutBound {
233    fn set_id(&mut self, id: usize) -> Option<usize> {
234        #[allow(irrefutable_let_patterns)]
235        if let ClientEventOutBound::Request(_id, _) = self {
236            let old_id = *_id;
237            *_id = id;
238            Some(old_id)
239        } else {
240            None
241        }
242    }
243}
244
245impl RequestManager for ClientEventInBound {
246    fn set_id(&mut self, _id: usize) -> Option<usize> {
247        None
248    }
249}
250impl ResponseManager for ClientEventInBound {
251    fn get_id(&self) -> Option<usize> {
252        #[allow(irrefutable_let_patterns)]
253        if let ClientEventInBound::Response(id, _) = self {
254            Some(*id)
255        } else {
256            None
257        }
258    }
259}
260
261impl ResponseManager for ClientEventOutBound {
262    fn get_id(&self) -> Option<usize> {
263        None
264    }
265}