1use std::{
2 collections::HashMap,
3 sync::{atomic::AtomicUsize, Arc},
4 task::{Context, Poll},
5};
6
7use futures_util::{FutureExt, Sink, SinkExt, Stream, StreamExt};
8use narrowlink_types::{
9 agent::{EventInBound as AgentEventInBound, EventOutBound as AgentEventOutBound},
10 client::{EventInBound as ClientEventInBound, EventOutBound as ClientEventOutBound},
11};
12use tokio::sync::{mpsc, oneshot};
13
14use crate::{error::NetworkError, UniversalStream};
15
16pub struct NarrowEvent<T, U> {
17 req_sender: mpsc::UnboundedSender<Option<(usize, T, oneshot::Sender<U>)>>,
18 req_receiver: mpsc::UnboundedReceiver<Option<(usize, T, oneshot::Sender<U>)>>,
19 sender: mpsc::UnboundedSender<T>,
20 receiver: mpsc::UnboundedReceiver<T>,
21 inner_stream: Box<dyn UniversalStream<String, NetworkError>>,
22 last_req_id: Arc<AtomicUsize>,
23 requests: HashMap<usize, oneshot::Sender<U>>,
24}
25
26#[derive(Debug, Clone)]
27pub struct NarrowEventRequest<T, U> {
28 req_id: Arc<AtomicUsize>,
29 sender: mpsc::UnboundedSender<Option<(usize, T, oneshot::Sender<U>)>>,
30}
31
32impl<'a, T, U> NarrowEventRequest<T, U>
33where
34 T: RequestManager,
35{
36 pub async fn request(&self, mut req: T) -> Result<U, NetworkError> {
37 let req_id = self.req_id.load(std::sync::atomic::Ordering::SeqCst);
38 self.req_id
39 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
40 req.set_id(req_id);
41 let (cur_req_sender, cur_req_receiver) = oneshot::channel();
42
43 let _ = self.sender.send(Some((req_id, req, cur_req_sender))); cur_req_receiver
46 .await
47 .map_err(|_| NetworkError::RequestCanceled)
48 }
49 pub async fn shutdown(&self) {
50 let _ = self.sender.send(None);
51 }
52}
53
54impl<T, U> NarrowEvent<T, U> {
55 pub fn new(stream: impl UniversalStream<String, NetworkError>) -> Self {
56 let (req_sender, req_receiver) = mpsc::unbounded_channel();
57 let (sender, receiver) = mpsc::unbounded_channel();
58 Self {
59 req_sender,
60 req_receiver,
61 sender,
62 receiver,
63 inner_stream: Box::new(stream),
64 last_req_id: Arc::new(AtomicUsize::new(0)),
65 requests: HashMap::new(),
66 }
67 }
68 pub fn get_sender(&self) -> mpsc::UnboundedSender<T> {
69 self.sender.clone()
70 }
71 pub fn get_request(&self) -> NarrowEventRequest<T, U> {
72 NarrowEventRequest {
73 req_id: self.last_req_id.clone(),
74 sender: self.req_sender.clone(),
75 }
76 }
77}
78impl<T, U> Unpin for NarrowEvent<T, U> {}
79
80impl<T, U> Stream for NarrowEvent<T, U>
81where
82 U: ResponseManager + for<'a> serde::de::Deserialize<'a>,
83 T: serde::ser::Serialize,
84{
85 type Item = Result<U, NetworkError>;
86
87 fn poll_next(
88 mut self: std::pin::Pin<&mut Self>,
89 cx: &mut Context<'_>,
90 ) -> Poll<Option<Self::Item>> {
91 loop {
92 match self.inner_stream.poll_next_unpin(cx)? {
93 Poll::Ready(Some(item)) => {
94 let item = serde_json::from_str::<U>(&item)?;
95 if let Some(msg_response) =
96 item.get_id().and_then(|id| self.requests.remove(&id))
97 {
98 let _ = msg_response.send(item);
99 continue;
100 }
101 return Poll::Ready(Some(Ok(item)));
102 }
103 Poll::Ready(None) => return Poll::Ready(None),
104 Poll::Pending => match self.req_receiver.poll_recv(cx) {
105 Poll::Ready(Some(Some((req_id, msg, msg_response)))) => {
106 match self
107 .inner_stream
108 .send(serde_json::to_string(&msg)?)
109 .poll_unpin(cx)
110 {
111 Poll::Ready(Ok(())) => {
112 self.requests.insert(req_id, msg_response);
113 continue;
114 }
115 Poll::Ready(Err(_)) => return Poll::Ready(None),
116 Poll::Pending => return Poll::Pending,
117 }
118 }
119 Poll::Ready(None) | Poll::Ready(Some(None)) => return Poll::Ready(None),
120 Poll::Pending => match self.receiver.poll_recv(cx) {
121 Poll::Ready(Some(msg)) => match self
122 .inner_stream
123 .send(serde_json::to_string(&msg)?)
124 .poll_unpin(cx)
125 {
126 Poll::Ready(Ok(())) => {
127 continue;
128 }
129 Poll::Ready(Err(_)) => return Poll::Ready(None),
130 Poll::Pending => return Poll::Pending,
131 },
132 Poll::Ready(None) => return Poll::Ready(None),
133 Poll::Pending => return Poll::Pending,
134 },
135 },
136 }
137 }
138 }
139}
140
141impl<T, U> Sink<T> for NarrowEvent<T, U>
142where
143 T: serde::ser::Serialize,
144{
145 type Error = NetworkError;
146
147 fn poll_ready(
148 mut self: std::pin::Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 ) -> Poll<Result<(), Self::Error>> {
151 self.inner_stream.poll_ready_unpin(cx)
152 }
153
154 fn start_send(mut self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
155 self.inner_stream
156 .start_send_unpin(serde_json::to_string(&item)?)
157 }
158
159 fn poll_flush(
160 mut self: std::pin::Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 ) -> Poll<Result<(), Self::Error>> {
163 self.inner_stream.poll_flush_unpin(cx)
164 }
165
166 fn poll_close(
167 mut self: std::pin::Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 ) -> Poll<Result<(), Self::Error>> {
170 self.inner_stream.poll_close_unpin(cx)
171 }
172}
173
174impl<S, E> From<Box<dyn UniversalStream<String, NetworkError>>> for NarrowEvent<S, E> {
175 fn from(stream: Box<dyn UniversalStream<String, NetworkError>>) -> Self {
176 let (req_sender, req_receiver) = mpsc::unbounded_channel();
177 let (sender, receiver) = mpsc::unbounded_channel();
178 Self {
179 req_sender,
180 req_receiver,
181 sender,
182 receiver,
183 inner_stream: stream,
184 last_req_id: Arc::new(AtomicUsize::new(0)),
185 requests: HashMap::new(),
186 }
187 }
188}
189
190pub trait RequestManager {
191 fn set_id(&mut self, id: usize) -> Option<usize>;
192}
193
194pub trait ResponseManager {
195 fn get_id(&self) -> Option<usize>;
196}
197
198impl RequestManager for AgentEventOutBound {
199 fn set_id(&mut self, id: usize) -> Option<usize> {
200 if let AgentEventOutBound::Request(_id, _) = self {
201 let old_id = *_id;
202 *_id = id;
203 Some(old_id)
204 } else {
205 None
206 }
207 }
208}
209
210impl RequestManager for AgentEventInBound {
211 fn set_id(&mut self, _id: usize) -> Option<usize> {
212 None
213 }
214}
215
216impl ResponseManager for AgentEventOutBound {
217 fn get_id(&self) -> Option<usize> {
218 None
219 }
220}
221
222impl ResponseManager for AgentEventInBound {
223 fn get_id(&self) -> Option<usize> {
224 if let AgentEventInBound::Response(id, _) = self {
225 Some(*id)
226 } else {
227 None
228 }
229 }
230}
231
232impl RequestManager for ClientEventOutBound {
233 fn set_id(&mut self, id: usize) -> Option<usize> {
234 #[allow(irrefutable_let_patterns)]
235 if let ClientEventOutBound::Request(_id, _) = self {
236 let old_id = *_id;
237 *_id = id;
238 Some(old_id)
239 } else {
240 None
241 }
242 }
243}
244
245impl RequestManager for ClientEventInBound {
246 fn set_id(&mut self, _id: usize) -> Option<usize> {
247 None
248 }
249}
250impl ResponseManager for ClientEventInBound {
251 fn get_id(&self) -> Option<usize> {
252 #[allow(irrefutable_let_patterns)]
253 if let ClientEventInBound::Response(id, _) = self {
254 Some(*id)
255 } else {
256 None
257 }
258 }
259}
260
261impl ResponseManager for ClientEventOutBound {
262 fn get_id(&self) -> Option<usize> {
263 None
264 }
265}