1use super::handler;
4use crate::stateful::db::AttachableResolver;
5use commonware_actor::mailbox::{Overflow, Policy, Sender};
6use commonware_cryptography::{Digest, Hasher};
7use commonware_storage::{merkle::Family, qmdb::sync::compact};
8use commonware_utils::{channel::oneshot, sync::AsyncRwLock};
9use std::{collections::VecDeque, future::Future, sync::Arc};
10
11struct CancelGuard<DB, F: Family, Op, D: Digest> {
12 sender: Sender<Message<DB, F, Op, D>>,
13 request: Option<handler::Request<F, D>>,
14}
15
16impl<DB, F: Family, Op, D: Digest> CancelGuard<DB, F, Op, D> {
17 const fn new(sender: Sender<Message<DB, F, Op, D>>, request: handler::Request<F, D>) -> Self {
18 Self {
19 sender,
20 request: Some(request),
21 }
22 }
23
24 const fn disarm(&mut self) {
25 self.request = None;
26 }
27}
28
29impl<DB, F: Family, Op, D: Digest> Drop for CancelGuard<DB, F, Op, D> {
30 fn drop(&mut self) {
31 let Some(request) = self.request.take() else {
32 return;
33 };
34 let _ = self.sender.enqueue(Message::CancelState { request });
35 }
36}
37
38#[derive(Debug, thiserror::Error)]
40#[error("response dropped before completion")]
41pub struct ResponseDropped;
42
43pub(super) enum Message<DB, F: Family, Op, D: Digest> {
44 AttachDatabase(Arc<AsyncRwLock<DB>>),
45 GetState {
46 request: handler::Request<F, D>,
47 response: oneshot::Sender<Result<compact::FetchResult<F, Op, D>, ResponseDropped>>,
48 },
49 CancelState {
50 request: handler::Request<F, D>,
51 },
52}
53
54impl<DB, F: Family, Op, D: Digest> Message<DB, F, Op, D> {
55 fn response_closed(&self) -> bool {
56 match self {
57 Self::AttachDatabase(_) | Self::CancelState { .. } => false,
58 Self::GetState { response, .. } => response.is_closed(),
59 }
60 }
61}
62
63pub(super) struct Pending<DB, F: Family, Op, D: Digest> {
64 database: Option<Arc<AsyncRwLock<DB>>>,
65 messages: VecDeque<Message<DB, F, Op, D>>,
66}
67
68impl<DB, F: Family, Op, D: Digest> Default for Pending<DB, F, Op, D> {
69 fn default() -> Self {
70 Self {
71 database: None,
72 messages: VecDeque::new(),
73 }
74 }
75}
76
77impl<DB, F: Family, Op, D: Digest> Overflow<Message<DB, F, Op, D>> for Pending<DB, F, Op, D> {
78 fn is_empty(&self) -> bool {
79 self.database.is_none() && self.messages.is_empty()
80 }
81
82 fn drain<P>(&mut self, mut push: P)
83 where
84 P: FnMut(Message<DB, F, Op, D>) -> Option<Message<DB, F, Op, D>>,
85 {
86 if let Some(database) = self.database.take() {
87 if let Some(Message::AttachDatabase(database)) = push(Message::AttachDatabase(database))
88 {
89 self.database = Some(database);
90 return;
91 }
92 }
93
94 while let Some(message) = self.messages.pop_front() {
95 if message.response_closed() {
96 continue;
97 }
98
99 if let Some(message) = push(message) {
100 self.messages.push_front(message);
101 break;
102 }
103 }
104 }
105}
106
107impl<DB, F: Family, Op, D: Digest> Policy for Message<DB, F, Op, D> {
108 type Overflow = Pending<DB, F, Op, D>;
109
110 fn handle(overflow: &mut Self::Overflow, message: Self) {
111 if message.response_closed() {
112 return;
113 }
114
115 match message {
116 Self::AttachDatabase(database) => {
117 overflow.database = Some(database);
118 }
119 message => overflow.messages.push_back(message),
120 }
121 }
122}
123
124pub struct Mailbox<DB, F: Family, Op, H: Hasher> {
126 sender: Sender<Message<DB, F, Op, H::Digest>>,
127}
128
129impl<DB, F: Family, Op, H: Hasher> Clone for Mailbox<DB, F, Op, H> {
130 fn clone(&self) -> Self {
131 Self {
132 sender: self.sender.clone(),
133 }
134 }
135}
136
137impl<DB, F: Family, Op, H: Hasher> Mailbox<DB, F, Op, H> {
138 pub(super) const fn new(sender: Sender<Message<DB, F, Op, H::Digest>>) -> Self {
139 Self { sender }
140 }
141}
142
143impl<DB: Send + Sync, F: Family, Op: Send, H: Hasher> Mailbox<DB, F, Op, H> {
144 pub fn attach_database(&self, db: Arc<AsyncRwLock<DB>>) {
145 let _ = self.sender.enqueue(Message::AttachDatabase(db));
146 }
147}
148
149impl<DB, F, Op, H> compact::Resolver for Mailbox<DB, F, Op, H>
150where
151 DB: Send + Sync + 'static,
152 F: Family,
153 Op: Send + Sync + Clone + 'static,
154 H: Hasher,
155{
156 type Digest = H::Digest;
157 type Error = ResponseDropped;
158 type Family = F;
159 type Op = Op;
160
161 async fn get_compact_state(
162 &self,
163 target: compact::Target<Self::Family, Self::Digest>,
164 ) -> Result<compact::FetchResult<Self::Family, Self::Op, Self::Digest>, Self::Error> {
165 let request = handler::Request::from_target(target);
166 let (response, receiver) = oneshot::channel();
167 let _ = self.sender.enqueue(Message::GetState {
168 request: request.clone(),
169 response,
170 });
171 let mut cancel = CancelGuard::new(self.sender.clone(), request);
172 let result = receiver.await;
173 cancel.disarm();
174 result.map_err(|_| ResponseDropped)?
175 }
176}
177
178impl<DB, F, Op, H> AttachableResolver<DB> for Mailbox<DB, F, Op, H>
179where
180 DB: Send + Sync + 'static,
181 F: Family,
182 Op: Send + Sync + Clone + 'static,
183 H: Hasher,
184{
185 fn attach_database(&self, db: Arc<AsyncRwLock<DB>>) -> impl Future<Output = ()> + Send {
186 Self::attach_database(self, db);
187 std::future::ready(())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use commonware_cryptography::sha256::Sha256;
195 use commonware_runtime::{deterministic, Runner as _};
196 use commonware_storage::{mmr, qmdb::sync::compact::Resolver as _};
197 use commonware_utils::NZUsize;
198 use futures::future::poll_fn;
199 use std::task::Poll;
200
201 #[test]
202 fn get_compact_state_sends_request() {
203 deterministic::Runner::default().start(|context| async move {
204 let (sender, mut receiver) = commonware_actor::mailbox::new(context, NZUsize!(4));
205 let mailbox = Mailbox::<(), mmr::Family, u64, Sha256>::new(sender);
206 let target = compact::Target {
207 root: [1u8; 32].into(),
208 leaf_count: mmr::Location::new(7),
209 };
210
211 let get = mailbox.get_compact_state(target.clone());
212 let observe = async move {
213 let message = receiver.recv().await.expect("request should be queued");
214 let Message::GetState { request, response } = message else {
215 panic!("unexpected attach message");
216 };
217 assert_eq!(request.to_target(), target);
218 drop(response);
219 };
220
221 let (result, _) = futures::join!(get, observe);
222 assert!(matches!(result, Err(ResponseDropped)));
223 });
224 }
225
226 #[test]
227 fn dropped_request_sends_cancel_message() {
228 deterministic::Runner::default().start(|context| async move {
229 let (sender, mut receiver) = commonware_actor::mailbox::new(context, NZUsize!(4));
230 let mailbox = Mailbox::<(), mmr::Family, u64, Sha256>::new(sender);
231 let target = compact::Target {
232 root: [2u8; 32].into(),
233 leaf_count: mmr::Location::new(9),
234 };
235
236 let mut get = Box::pin(mailbox.get_compact_state(target.clone()));
237 poll_fn(|cx| {
238 assert!(matches!(get.as_mut().poll(cx), Poll::Pending));
239 Poll::Ready(())
240 })
241 .await;
242 drop(get);
243
244 let message = receiver.recv().await.expect("request should be queued");
245 let Message::GetState { request, response } = message else {
246 panic!("unexpected attach message");
247 };
248 assert_eq!(request.to_target(), target);
249 drop(response);
250
251 match receiver.recv().await.expect("cancel should be queued") {
252 Message::CancelState { request } => {
253 assert_eq!(request.to_target(), target);
254 }
255 Message::AttachDatabase(_) => panic!("unexpected attach message"),
256 Message::GetState { .. } => panic!("unexpected duplicate request"),
257 }
258 });
259 }
260}