1use super::handler;
4use crate::stateful::db::AttachableResolver;
5use commonware_actor::mailbox::{Overflow, Policy, Sender};
6use commonware_codec::Read;
7use commonware_cryptography::Digest;
8use commonware_macros::select;
9use commonware_storage::{
10 merkle::{Family, Location},
11 qmdb::sync::resolver::{FetchResult, Resolver as SyncResolver},
12};
13use commonware_utils::{channel::oneshot, sync::AsyncRwLock};
14use futures::FutureExt as _;
15use std::{collections::VecDeque, future::Future, num::NonZeroU64, sync::Arc};
16
17#[derive(Debug, thiserror::Error)]
19#[error("response dropped before completion")]
20pub struct ResponseDropped;
21
22pub(super) enum Message<DB, F: Family, Op, D: Digest> {
24 AttachDatabase(Arc<AsyncRwLock<DB>>),
26 GetOperations {
28 request: handler::Request<F>,
29 response: oneshot::Sender<Result<FetchResult<F, Op, D>, ResponseDropped>>,
30 },
31 CancelOperations { request: handler::Request<F> },
33}
34
35impl<DB, F: Family, Op, D: Digest> Message<DB, F, Op, D> {
36 fn response_closed(&self) -> bool {
37 match self {
38 Self::AttachDatabase(_) | Self::CancelOperations { .. } => false,
39 Self::GetOperations { response, .. } => response.is_closed(),
40 }
41 }
42}
43
44pub(super) struct Pending<DB, F: Family, Op, D: Digest> {
45 database: Option<Arc<AsyncRwLock<DB>>>,
46 messages: VecDeque<Message<DB, F, Op, D>>,
47}
48
49impl<DB, F: Family, Op, D: Digest> Default for Pending<DB, F, Op, D> {
50 fn default() -> Self {
51 Self {
52 database: None,
53 messages: VecDeque::new(),
54 }
55 }
56}
57
58impl<DB, F: Family, Op, D: Digest> Overflow<Message<DB, F, Op, D>> for Pending<DB, F, Op, D> {
59 fn is_empty(&self) -> bool {
60 self.database.is_none() && self.messages.is_empty()
61 }
62
63 fn drain<P>(&mut self, mut push: P)
64 where
65 P: FnMut(Message<DB, F, Op, D>) -> Option<Message<DB, F, Op, D>>,
66 {
67 if let Some(database) = self.database.take() {
68 if let Some(Message::AttachDatabase(database)) = push(Message::AttachDatabase(database))
69 {
70 self.database = Some(database);
71 return;
72 }
73 }
74
75 while let Some(message) = self.messages.pop_front() {
76 if message.response_closed() {
77 continue;
78 }
79
80 if let Some(message) = push(message) {
81 self.messages.push_front(message);
82 break;
83 }
84 }
85 }
86}
87
88impl<DB, F: Family, Op, D: Digest> Policy for Message<DB, F, Op, D> {
89 type Overflow = Pending<DB, F, Op, D>;
90
91 fn handle(overflow: &mut Self::Overflow, message: Self) {
92 if message.response_closed() {
93 return;
94 }
95
96 match message {
97 Self::AttachDatabase(database) => {
98 overflow.database = Some(database);
99 }
100 message => overflow.messages.push_back(message),
101 }
102 }
103}
104
105pub struct Mailbox<DB, F: Family, Op, D: Digest> {
107 sender: Sender<Message<DB, F, Op, D>>,
108}
109
110impl<DB, F: Family, Op, D: Digest> Clone for Mailbox<DB, F, Op, D> {
111 fn clone(&self) -> Self {
112 Self {
113 sender: self.sender.clone(),
114 }
115 }
116}
117
118impl<DB, F: Family, Op, D: Digest> Mailbox<DB, F, Op, D> {
119 pub(super) const fn new(sender: Sender<Message<DB, F, Op, D>>) -> Self {
120 Self { sender }
121 }
122}
123
124impl<DB: Send + Sync, F: Family, Op: Send, D: Digest> Mailbox<DB, F, Op, D> {
125 pub fn attach_database(&self, db: Arc<AsyncRwLock<DB>>) {
126 let _ = self.sender.enqueue(Message::AttachDatabase(db));
127 }
128}
129
130impl<DB, F, Op, D> SyncResolver for Mailbox<DB, F, Op, D>
131where
132 F: Family,
133 Op: Read<Cfg = ()> + Send + Sync + Clone + 'static,
134 D: Digest,
135 DB: Send + Sync + 'static,
136{
137 type Family = F;
138 type Digest = D;
139 type Op = Op;
140 type Error = ResponseDropped;
141
142 async fn get_operations(
143 &self,
144 op_count: Location<F>,
145 start_loc: Location<F>,
146 max_ops: NonZeroU64,
147 include_pinned_nodes: bool,
148 cancel_rx: oneshot::Receiver<()>,
149 ) -> Result<FetchResult<Self::Family, Self::Op, Self::Digest>, Self::Error> {
150 let request = handler::Request {
151 op_count,
152 start_loc,
153 max_ops,
154 include_pinned_nodes,
155 };
156
157 futures::pin_mut!(cancel_rx);
158 let (response_tx, response_rx) = oneshot::channel();
159 let _ = self.sender.enqueue(Message::GetOperations {
160 request: request.clone(),
161 response: response_tx,
162 });
163 futures::pin_mut!(response_rx);
164
165 select! {
166 response = response_rx.as_mut() => response.map_err(|_| ResponseDropped)?,
167 _ = cancel_rx.as_mut() => {
168 if let Some(response) = response_rx.as_mut().now_or_never() {
169 return response.map_err(|_| ResponseDropped)?;
170 }
171 let _ = self.sender.enqueue(Message::CancelOperations { request });
172 Err(ResponseDropped)
173 },
174 }
175 }
176}
177
178impl<DB, F, Op, D> AttachableResolver<DB> for Mailbox<DB, F, Op, D>
179where
180 F: Family,
181 Op: Read<Cfg = ()> + Send + Sync + Clone + 'static,
182 D: Digest,
183 DB: Send + Sync + 'static,
184{
185 fn attach_database(&self, db: Arc<AsyncRwLock<DB>>) -> impl Future<Output = ()> + Send {
186 Self::attach_database(self, db);
187 std::future::ready(())
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use commonware_cryptography::sha256;
195 use commonware_runtime::{deterministic, Runner as _};
196 use commonware_storage::mmr;
197 use commonware_utils::{NZUsize, NZU64};
198
199 #[test]
200 fn get_operations_cancellation_sends_cancel_message() {
201 deterministic::Runner::default().start(|context| async move {
202 let (sender, mut receiver) = commonware_actor::mailbox::new(context, NZUsize!(4));
203 let mailbox = Mailbox::<(), mmr::Family, u64, sha256::Digest>::new(sender);
204 let op_count = mmr::Location::new(10);
205 let start_loc = mmr::Location::new(3);
206 let max_ops = NZU64!(2);
207
208 let (cancel_tx, cancel_rx) = oneshot::channel();
209 let get = mailbox.get_operations(op_count, start_loc, max_ops, false, cancel_rx);
210 let observe = async move {
211 let response = match receiver.recv().await.expect("request should be queued") {
212 Message::GetOperations { request, response } => {
213 assert_eq!(request.op_count, op_count);
214 assert_eq!(request.start_loc, start_loc);
215 assert_eq!(request.max_ops, max_ops);
216 assert!(!request.include_pinned_nodes);
217 response
218 }
219 Message::AttachDatabase(_) => panic!("unexpected attach message"),
220 Message::CancelOperations { .. } => panic!("cancel should come after request"),
221 };
222
223 drop(cancel_tx);
224
225 match receiver.recv().await.expect("cancel should be queued") {
226 Message::CancelOperations { request } => {
227 assert_eq!(request.op_count, op_count);
228 assert_eq!(request.start_loc, start_loc);
229 assert_eq!(request.max_ops, max_ops);
230 assert!(!request.include_pinned_nodes);
231 }
232 Message::AttachDatabase(_) => panic!("unexpected attach message"),
233 Message::GetOperations { .. } => panic!("unexpected duplicate request"),
234 }
235
236 drop(response);
237 };
238
239 let (result, _) = futures::join!(get, observe);
240 assert!(matches!(result, Err(ResponseDropped)));
241 });
242 }
243}