1use crate::{
2 backend::Backend,
3 mq::MessageQueue,
4 poller::Poller,
5 poller::{controller::Controller, stream::BackendStream},
6 request::{Request, RequestStream},
7 worker::{self, Worker},
8};
9use futures::{
10 channel::mpsc::{channel, Receiver, Sender},
11 Stream, StreamExt,
12};
13use std::{
14 pin::Pin,
15 sync::Arc,
16 task::{Context, Poll},
17};
18use tower::layer::util::Identity;
19
20#[derive(Debug)]
21pub struct MemoryStorage<T> {
23 controller: Controller,
25 inner: MemoryWrapper<T>,
27}
28impl<T> MemoryStorage<T> {
29 pub fn new() -> Self {
31 Self {
32 controller: Controller::new(),
33 inner: MemoryWrapper::new(),
34 }
35 }
36}
37
38impl<T> Default for MemoryStorage<T> {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl<T> Clone for MemoryStorage<T> {
45 fn clone(&self) -> Self {
46 Self {
47 controller: self.controller.clone(),
48 inner: self.inner.clone(),
49 }
50 }
51}
52
53#[derive(Debug)]
55pub struct MemoryWrapper<T> {
56 sender: Sender<Request<T, ()>>,
57 receiver: Arc<futures::lock::Mutex<Receiver<Request<T, ()>>>>,
58}
59
60impl<T> Clone for MemoryWrapper<T> {
61 fn clone(&self) -> Self {
62 Self {
63 receiver: self.receiver.clone(),
64 sender: self.sender.clone(),
65 }
66 }
67}
68
69impl<T> MemoryWrapper<T> {
70 pub fn new() -> Self {
72 let (sender, receiver) = channel(100);
73
74 Self {
75 sender,
76 receiver: Arc::new(futures::lock::Mutex::new(receiver)),
77 }
78 }
79}
80
81impl<T> Default for MemoryWrapper<T> {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl<T> Stream for MemoryWrapper<T> {
88 type Item = Request<T, ()>;
89
90 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
91 if let Some(mut receiver) = self.receiver.try_lock() {
92 receiver.poll_next_unpin(cx)
93 } else {
94 Poll::Pending
95 }
96 }
97}
98
99impl<T: Send + 'static + Sync, Res> Backend<Request<T, ()>, Res> for MemoryStorage<T> {
101 type Stream = BackendStream<RequestStream<Request<T, ()>>>;
102
103 type Layer = Identity;
104
105 fn poll<Svc>(self, _worker: &Worker<worker::Context>) -> Poller<Self::Stream> {
106 let stream = self.inner.map(|r| Ok(Some(r))).boxed();
107 Poller {
108 stream: BackendStream::new(stream, self.controller),
109 heartbeat: Box::pin(futures::future::pending()),
110 layer: Identity::new(),
111 _priv: (),
112 }
113 }
114}
115
116impl<Message: Send + 'static + Sync> MessageQueue<Message> for MemoryStorage<Message> {
117 type Error = ();
118 async fn enqueue(&mut self, message: Message) -> Result<(), Self::Error> {
119 self.inner
120 .sender
121 .try_send(Request::new(message))
122 .map_err(|_| ())?;
123 Ok(())
124 }
125
126 async fn dequeue(&mut self) -> Result<Option<Message>, ()> {
127 Ok(self
128 .inner
129 .receiver
130 .lock()
131 .await
132 .next()
133 .await
134 .map(|r| r.args))
135 }
136
137 async fn size(&mut self) -> Result<usize, ()> {
138 Ok(self.inner.receiver.lock().await.size_hint().0)
139 }
140}