1use crate::{
2 backend::Backend,
3 codec::NoopCodec,
4 mq::MessageQueue,
5 poller::{controller::Controller, stream::BackendStream, Poller},
6 request::{Request, RequestStream},
7 worker::{self, Worker},
8};
9use futures::{
10 channel::mpsc::{channel, Receiver, Sender},
11 Stream, StreamExt,
12};
13use std::{
14 pin::Pin,
15 sync::Arc,
16 task::{Context, Poll},
17};
18use tower::layer::util::Identity;
19
20#[derive(Debug)]
21pub struct MemoryStorage<T> {
23 controller: Controller,
25 inner: MemoryWrapper<T>,
27}
28impl<T> MemoryStorage<T> {
29 pub fn new() -> Self {
31 Self {
32 controller: Controller::new(),
33 inner: MemoryWrapper::new(),
34 }
35 }
36}
37
38impl<T> Default for MemoryStorage<T> {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl<T> Clone for MemoryStorage<T> {
45 fn clone(&self) -> Self {
46 Self {
47 controller: self.controller.clone(),
48 inner: self.inner.clone(),
49 }
50 }
51}
52
53#[derive(Debug)]
55pub struct MemoryWrapper<T> {
56 sender: Sender<Request<T, ()>>,
57 receiver: Arc<futures::lock::Mutex<Receiver<Request<T, ()>>>>,
58}
59
60impl<T> Clone for MemoryWrapper<T> {
61 fn clone(&self) -> Self {
62 Self {
63 receiver: self.receiver.clone(),
64 sender: self.sender.clone(),
65 }
66 }
67}
68
69impl<T> MemoryWrapper<T> {
70 pub fn new() -> Self {
72 let (sender, receiver) = channel(100);
73
74 Self {
75 sender,
76 receiver: Arc::new(futures::lock::Mutex::new(receiver)),
77 }
78 }
79}
80
81impl<T> Default for MemoryWrapper<T> {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<T> Stream for MemoryWrapper<T> {
88 type Item = Request<T, ()>;
89
90 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91 if let Some(mut receiver) = self.receiver.try_lock() {
92 receiver.poll_next_unpin(cx)
93 } else {
94 Poll::Pending
95 }
96 }
97}
98
99impl<T: Send + 'static + Sync> Backend<Request<T, ()>> for MemoryStorage<T> {
101 type Stream = BackendStream<RequestStream<Request<T, ()>>>;
102
103 type Layer = Identity;
104
105 type Codec = NoopCodec<Request<T, ()>>;
106
107 fn poll(self, _worker: &Worker<worker::Context>) -> Poller<Self::Stream> {
108 let stream = self.inner.map(|r| Ok(Some(r))).boxed();
109 Poller {
110 stream: BackendStream::new(stream, self.controller),
111 heartbeat: Box::pin(futures::future::pending()),
112 layer: Identity::new(),
113 _priv: (),
114 }
115 }
116}
117
118impl<Message: Send + 'static + Sync> MessageQueue<Message> for MemoryStorage<Message> {
119 type Context = ();
120 type Error = ();
121 type Compact = Message;
122
123 async fn enqueue_request(
124 &mut self,
125 req: Request<Message, Self::Context>,
126 ) -> Result<(), Self::Error> {
127 self.inner.sender.try_send(req).map_err(|_| ())?;
128 Ok(())
129 }
130
131 async fn enqueue_raw_request(
132 &mut self,
133 _req: Request<Self::Compact, Self::Context>,
134 ) -> Result<(), Self::Error> {
135 unreachable!("Cannot push a generic message")
136 }
137
138 async fn dequeue_request(&mut self) -> Result<Option<Request<Message, Self::Context>>, ()> {
139 Ok(self.inner.receiver.lock().await.next().await)
140 }
141
142 async fn size(&mut self) -> Result<usize, ()> {
143 Ok(self.inner.receiver.lock().await.size_hint().0)
144 }
145}