1use std::cell::RefCell;
26use std::collections::VecDeque;
27use std::future::Future;
28use std::pin::Pin;
29use std::rc::Rc;
30use std::task::{Context, Poll, Waker};
31
32use serde::de::DeserializeOwned;
33
34use crate::{Endpoint, MessageCodec, NetworkAddress, UID};
35use moonpool_sim::sometimes_assert;
36
37use super::endpoint_map::MessageReceiver;
38
39pub struct NetNotifiedQueue<T, C: MessageCodec> {
56 inner: RefCell<NetNotifiedQueueInner<T>>,
58
59 endpoint: Endpoint,
61
62 codec: C,
64}
65
66struct NetNotifiedQueueInner<T> {
68 queue: VecDeque<T>,
70
71 wakers: Vec<Waker>,
73
74 closed: bool,
76
77 messages_received: u64,
79 messages_dropped: u64,
80}
81
82impl<T> Default for NetNotifiedQueueInner<T> {
83 fn default() -> Self {
84 Self {
85 queue: VecDeque::new(),
86 wakers: Vec::new(),
87 closed: false,
88 messages_received: 0,
89 messages_dropped: 0,
90 }
91 }
92}
93
94impl<T, C: MessageCodec> NetNotifiedQueue<T, C> {
95 pub fn new(endpoint: Endpoint, codec: C) -> Self {
97 Self {
98 inner: RefCell::new(NetNotifiedQueueInner::default()),
99 endpoint,
100 codec,
101 }
102 }
103
104 pub fn with_address(address: NetworkAddress, codec: C) -> Self {
108 let token = UID::new(0, rand_simple_id());
111 Self::new(Endpoint::new(address, token), codec)
112 }
113
114 pub fn endpoint(&self) -> &Endpoint {
118 &self.endpoint
119 }
120
121 pub fn try_recv(&self) -> Option<T> {
125 self.inner.borrow_mut().queue.pop_front()
126 }
127
128 pub fn is_empty(&self) -> bool {
130 self.inner.borrow().queue.is_empty()
131 }
132
133 pub fn len(&self) -> usize {
135 self.inner.borrow().queue.len()
136 }
137
138 pub fn messages_received(&self) -> u64 {
140 self.inner.borrow().messages_received
141 }
142
143 pub fn messages_dropped(&self) -> u64 {
145 self.inner.borrow().messages_dropped
146 }
147
148 pub fn close(&self) {
153 let mut inner = self.inner.borrow_mut();
154 inner.closed = true;
155 for waker in inner.wakers.drain(..) {
157 waker.wake();
158 }
159 }
160
161 pub fn is_closed(&self) -> bool {
163 self.inner.borrow().closed
164 }
165
166 #[cfg(test)]
168 fn push(&self, message: T) {
169 let mut inner = self.inner.borrow_mut();
170 inner.queue.push_back(message);
171 inner.messages_received += 1;
172 for waker in inner.wakers.drain(..) {
174 waker.wake();
175 }
176 }
177}
178
179impl<T: DeserializeOwned, C: MessageCodec> NetNotifiedQueue<T, C> {
180 pub fn recv(&self) -> RecvFuture<'_, T, C> {
184 RecvFuture { queue: self }
185 }
186}
187
188impl<T: DeserializeOwned + 'static, C: MessageCodec> MessageReceiver for NetNotifiedQueue<T, C> {
189 fn receive(&self, payload: &[u8]) {
190 match self.codec.decode::<T>(payload) {
192 Ok(message) => {
193 sometimes_assert!(
194 deserialization_success,
195 true,
196 "Message deserialized successfully"
197 );
198 let mut inner = self.inner.borrow_mut();
199 inner.queue.push_back(message);
200 inner.messages_received += 1;
201
202 let had_waiters = !inner.wakers.is_empty();
204 for waker in inner.wakers.drain(..) {
205 waker.wake();
206 }
207 if had_waiters {
208 sometimes_assert!(waker_notified, true, "Wakers notified on new message");
209 }
210 }
211 Err(e) => {
212 sometimes_assert!(
213 deserialization_failed,
214 true,
215 "Message deserialization failed"
216 );
217 tracing::warn!(
219 endpoint = %self.endpoint.token,
220 error = %e,
221 "failed to deserialize message"
222 );
223 self.inner.borrow_mut().messages_dropped += 1;
224 }
225 }
226 }
227}
228
229pub struct RecvFuture<'a, T, C: MessageCodec> {
231 queue: &'a NetNotifiedQueue<T, C>,
232}
233
234impl<T, C: MessageCodec> Future for RecvFuture<'_, T, C> {
235 type Output = Option<T>;
236
237 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
238 let mut inner = self.queue.inner.borrow_mut();
239
240 if let Some(message) = inner.queue.pop_front() {
242 sometimes_assert!(message_available, true, "Message available immediately");
243 return Poll::Ready(Some(message));
244 }
245
246 if inner.closed {
248 sometimes_assert!(queue_closed_empty, true, "Queue closed and empty");
249 return Poll::Ready(None);
250 }
251
252 sometimes_assert!(recv_pending, true, "Recv waiting for message");
254 inner.wakers.push(cx.waker().clone());
255 Poll::Pending
256 }
257}
258
259pub struct SharedNetNotifiedQueue<T: DeserializeOwned + 'static, C: MessageCodec>(
261 pub Rc<NetNotifiedQueue<T, C>>,
262);
263
264impl<T: DeserializeOwned + 'static, C: MessageCodec> MessageReceiver
265 for SharedNetNotifiedQueue<T, C>
266{
267 fn receive(&self, payload: &[u8]) {
268 self.0.receive(payload)
269 }
270}
271
272impl<T: DeserializeOwned + 'static, C: MessageCodec> SharedNetNotifiedQueue<T, C> {
273 pub fn new(endpoint: Endpoint, codec: C) -> Self {
275 Self(Rc::new(NetNotifiedQueue::new(endpoint, codec)))
276 }
277
278 pub fn inner(&self) -> &NetNotifiedQueue<T, C> {
280 &self.0
281 }
282
283 pub fn as_receiver(&self) -> Rc<NetNotifiedQueue<T, C>> {
285 Rc::clone(&self.0)
286 }
287}
288
289fn rand_simple_id() -> u64 {
292 use std::sync::atomic::{AtomicU64, Ordering};
293 static COUNTER: AtomicU64 = AtomicU64::new(1);
294 COUNTER.fetch_add(1, Ordering::Relaxed)
295}
296
297#[cfg(test)]
298mod tests {
299 use std::net::{IpAddr, Ipv4Addr};
300
301 use super::*;
302 use crate::JsonCodec;
303
304 fn test_endpoint() -> Endpoint {
305 let addr = NetworkAddress::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4500);
306 Endpoint::new(addr, UID::new(1, 1))
307 }
308
309 #[test]
310 fn test_new_queue_is_empty() {
311 let queue: NetNotifiedQueue<String, JsonCodec> =
312 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
313 assert!(queue.is_empty());
314 assert_eq!(queue.len(), 0);
315 assert_eq!(queue.messages_received(), 0);
316 }
317
318 #[test]
319 fn test_push_and_try_recv() {
320 let queue: NetNotifiedQueue<String, JsonCodec> =
321 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
322
323 queue.push("hello".to_string());
324 assert!(!queue.is_empty());
325 assert_eq!(queue.len(), 1);
326
327 let msg = queue.try_recv();
328 assert_eq!(msg, Some("hello".to_string()));
329 assert!(queue.is_empty());
330 }
331
332 #[test]
333 fn test_receive_deserializes() {
334 let queue: NetNotifiedQueue<String, JsonCodec> =
335 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
336
337 let payload = b"\"hello world\"";
339 queue.receive(payload);
340
341 assert_eq!(queue.len(), 1);
342 assert_eq!(queue.messages_received(), 1);
343 assert_eq!(queue.try_recv(), Some("hello world".to_string()));
344 }
345
346 #[test]
347 fn test_receive_invalid_json_drops() {
348 let queue: NetNotifiedQueue<String, JsonCodec> =
349 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
350
351 let payload = b"not valid json";
353 queue.receive(payload);
354
355 assert!(queue.is_empty());
356 assert_eq!(queue.messages_received(), 0);
357 assert_eq!(queue.messages_dropped(), 1);
358 }
359
360 #[test]
361 fn test_fifo_ordering() {
362 let queue: NetNotifiedQueue<i32, JsonCodec> =
363 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
364
365 queue.push(1);
366 queue.push(2);
367 queue.push(3);
368
369 assert_eq!(queue.try_recv(), Some(1));
370 assert_eq!(queue.try_recv(), Some(2));
371 assert_eq!(queue.try_recv(), Some(3));
372 assert_eq!(queue.try_recv(), None);
373 }
374
375 #[test]
376 fn test_close_queue() {
377 let queue: NetNotifiedQueue<String, JsonCodec> =
378 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
379
380 assert!(!queue.is_closed());
381 queue.close();
382 assert!(queue.is_closed());
383 }
384
385 #[test]
386 fn test_endpoint_accessor() {
387 let endpoint = test_endpoint();
388 let queue: NetNotifiedQueue<String, JsonCodec> =
389 NetNotifiedQueue::new(endpoint.clone(), JsonCodec);
390
391 assert_eq!(queue.endpoint().token, endpoint.token);
392 }
393
394 #[derive(Debug, PartialEq, serde::Deserialize)]
395 struct TestMessage {
396 id: u32,
397 content: String,
398 }
399
400 #[test]
401 fn test_receive_complex_type() {
402 let queue: NetNotifiedQueue<TestMessage, JsonCodec> =
403 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
404
405 let payload = br#"{"id": 42, "content": "hello"}"#;
406 queue.receive(payload);
407
408 let msg = queue.try_recv();
409 assert_eq!(
410 msg,
411 Some(TestMessage {
412 id: 42,
413 content: "hello".to_string()
414 })
415 );
416 }
417
418 #[test]
419 fn test_shared_queue() {
420 let shared: SharedNetNotifiedQueue<String, JsonCodec> =
421 SharedNetNotifiedQueue::new(test_endpoint(), JsonCodec);
422
423 shared.receive(b"\"shared message\"");
425
426 assert_eq!(
427 shared.inner().try_recv(),
428 Some("shared message".to_string())
429 );
430 }
431
432 #[tokio::test]
433 async fn test_recv_async() {
434 let queue: NetNotifiedQueue<String, JsonCodec> =
435 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
436
437 queue.push("async hello".to_string());
439
440 let result = queue.recv().await;
441 assert_eq!(result, Some("async hello".to_string()));
442 }
443
444 #[tokio::test]
445 async fn test_recv_closed_empty() {
446 let queue: NetNotifiedQueue<String, JsonCodec> =
447 NetNotifiedQueue::new(test_endpoint(), JsonCodec);
448
449 queue.close();
450
451 let result = queue.recv().await;
452 assert_eq!(result, None);
453 }
454}