1use std::{fmt, marker::PhantomData, sync::Arc};
2
3use idr_ebr::EbrGuard;
4use parking_lot::Mutex;
5use slotmap::{new_key_type, Key, SlotMap};
6use smallvec::SmallVec;
7use tokio::sync::Notify;
8
9use crate::{
10 address_book::AddressBook, envelope::Envelope, errors::RequestError, message::AnyMessage,
11 tracing::TraceId, Addr,
12};
13
14new_key_type! {
17 pub struct RequestId;
18}
19
20impl RequestId {
21 #[doc(hidden)]
22 #[inline]
23 pub fn to_ffi(self) -> u64 {
24 self.data().as_ffi()
25 }
26
27 #[doc(hidden)]
28 #[inline]
29 pub fn from_ffi(id: u64) -> Self {
30 slotmap::KeyData::from_ffi(id).into()
31 }
32
33 #[doc(hidden)]
34 #[inline]
35 pub fn is_null(&self) -> bool {
36 Key::is_null(self)
37 }
38}
39
40pub(crate) struct RequestTable {
43 owner: Addr,
44 notifier: Notify,
45 requests: Mutex<SlotMap<RequestId, RequestData>>,
46}
47
48assert_impl_all!(RequestTable: Sync);
49
50type Responses = SmallVec<[Result<Envelope, RequestError>; 1]>;
51
52#[derive(Default)]
53struct RequestData {
54 remainder: usize,
55 responses: Responses,
56 collect_all: bool,
57}
58
59impl RequestData {
60 fn push(&mut self, response: Result<Envelope, RequestError>) -> bool {
62 if self.remainder == 0 {
64 debug_assert!(!self.collect_all);
66 return false;
67 }
68
69 self.remainder -= 1;
70
71 if self.collect_all {
72 self.responses.push(response);
73 return self.remainder == 0;
74 }
75
76 debug_assert!(self.responses.len() <= 1);
78
79 let is_ok = response.is_ok();
80
81 if self.responses.is_empty() {
82 self.responses.push(response);
83 }
84 else if response.is_ok() {
86 debug_assert!(self.responses[0].is_err());
87 self.responses[0] = response;
88 } else if let Err(RequestError::Ignored) = response {
89 debug_assert!(self.responses[0].is_err());
90 self.responses[0] = response;
91 }
92
93 if is_ok {
95 self.remainder = 0;
96 }
97
98 self.remainder == 0
99 }
100}
101
102impl RequestTable {
103 pub(crate) fn new(owner: Addr) -> Self {
104 Self {
105 owner,
106 notifier: Notify::new(),
107 requests: Mutex::new(SlotMap::default()),
108 }
109 }
110
111 pub(crate) fn new_request(
112 &self,
113 book: AddressBook,
114 trace_id: TraceId,
115 collect_all: bool,
116 ) -> ResponseToken {
117 let mut requests = self.requests.lock();
118 let request_id = requests.insert(RequestData {
119 remainder: 1,
120 responses: Responses::new(),
121 collect_all,
122 });
123 ResponseToken::new(self.owner, request_id, trace_id, book)
124 }
125
126 pub(crate) fn cancel_request(&self, request_id: RequestId) {
127 let mut requests = self.requests.lock();
128 requests.remove(request_id);
129 }
130
131 pub(crate) async fn wait(&self, request_id: RequestId) -> Responses {
132 loop {
133 let waiting = self.notifier.notified();
134
135 {
136 let mut requests = self.requests.lock();
137 let request = requests.get(request_id).expect("unknown request");
138
139 if request.remainder == 0 {
140 break requests.remove(request_id).expect("under lock").responses;
141 }
142 }
143
144 waiting.await;
145 }
146 }
147
148 pub(crate) fn resolve(
149 &self,
150 mut token: ResponseToken,
151 response: Result<Envelope, RequestError>,
152 ) {
153 let data = ward!(token.data.take());
155 let mut requests = self.requests.lock();
156
157 let request = ward!(requests.get_mut(data.request_id));
160
161 if request.push(response) {
162 self.notifier.notify_waiters();
165 }
166 }
167}
168
169#[must_use]
172pub struct ResponseToken<T = AnyMessage> {
173 data: Option<Arc<ResponseTokenData>>,
175 received: bool,
176 marker: PhantomData<T>,
177}
178
179struct ResponseTokenData {
180 sender: Addr,
181 request_id: RequestId,
182 trace_id: TraceId,
183 book: AddressBook,
184}
185
186impl ResponseToken {
187 #[doc(hidden)]
188 #[inline]
189 pub fn new(sender: Addr, request_id: RequestId, trace_id: TraceId, book: AddressBook) -> Self {
190 debug_assert!(!sender.is_null());
191 debug_assert!(!request_id.is_null());
192
193 Self {
194 data: Some(Arc::new(ResponseTokenData {
195 sender,
196 request_id,
197 trace_id,
198 book,
199 })),
200 received: false,
201 marker: PhantomData,
202 }
203 }
204
205 #[doc(hidden)]
208 #[inline]
209 pub fn trace_id(&self) -> TraceId {
210 self.data.as_ref().map(|data| data.trace_id).unwrap()
211 }
212
213 #[doc(hidden)]
216 #[inline]
217 pub fn sender(&self) -> Addr {
218 self.data.as_ref().map(|data| data.sender).unwrap()
219 }
220
221 #[doc(hidden)]
224 #[inline]
225 pub fn request_id(&self) -> RequestId {
226 self.data.as_ref().map(|data| data.request_id).unwrap()
227 }
228
229 #[doc(hidden)]
232 #[inline]
233 pub fn is_last(&self) -> bool {
234 self.data.as_ref().map(Arc::strong_count).unwrap() <= 1
235 }
236
237 #[doc(hidden)]
238 #[inline]
239 pub fn into_received<T>(mut self) -> ResponseToken<T> {
240 ResponseToken {
241 data: self.data.take(),
242 received: true,
243 marker: PhantomData,
244 }
245 }
246
247 #[doc(hidden)]
248 #[inline]
249 pub fn duplicate(&self) -> Self {
250 Self {
251 data: self.do_duplicate(),
252 received: self.received,
253 marker: PhantomData,
254 }
255 }
256
257 #[doc(hidden)]
258 #[inline]
259 pub fn forget(mut self) {
260 self.data = None;
261 }
262
263 fn do_duplicate(&self) -> Option<Arc<ResponseTokenData>> {
264 let data = self.data.as_ref()?;
265
266 if data.sender.is_local() {
267 let guard = EbrGuard::new();
268 let object = data.book.get(data.sender, &guard)?;
269 let actor = object.as_actor()?;
270 let mut requests = actor.request_table().requests.lock();
271 requests.get_mut(data.request_id)?.remainder += 1;
272 }
273
274 Some(data.clone())
275 }
276}
277
278impl<R> ResponseToken<R> {
279 #[doc(hidden)]
280 #[inline]
281 pub fn forgotten() -> Self {
282 Self {
283 data: None,
284 received: false,
285 marker: PhantomData,
286 }
287 }
288
289 pub(crate) fn into_untyped(mut self) -> ResponseToken {
290 ResponseToken {
291 data: self.data.take(),
292 received: self.received,
293 marker: PhantomData,
294 }
295 }
296
297 #[doc(hidden)]
298 #[inline]
299 pub fn is_forgotten(&self) -> bool {
300 self.data.is_none()
301 }
302}
303
304impl<T> Drop for ResponseToken<T> {
305 #[inline]
306 fn drop(&mut self) {
307 let data = ward!(self.data.take());
309 let book = data.book.clone();
310 let guard = EbrGuard::new();
311 let object = ward!(book.get(data.sender, &guard));
312 let this = ResponseToken {
313 data: Some(data),
314 received: self.received,
315 marker: PhantomData,
316 };
317 let err = if self.received {
318 RequestError::Ignored
319 } else {
320 RequestError::Failed
321 };
322
323 object.respond(this, Err(err));
324 }
325}
326
327impl<T> fmt::Debug for ResponseToken<T> {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.debug_struct("ResponseToken").finish()
330 }
331}
332
333#[cfg(test)]
334#[cfg(TODO)]
335mod tests {
336 use super::*;
337
338 use std::sync::Arc;
339
340 use crate::{actor::ActorMeta, assert_msg_eq, envelope::MessageKind, message, scope::Scope};
341
342 #[message]
343 #[derive(PartialEq)]
344 struct Num(u32);
345
346 fn envelope(addr: Addr, request_id: RequestId, num: Num) -> Envelope {
347 Scope::test(
348 addr,
349 Arc::new(ActorMeta {
350 group: "test".into(),
351 key: String::new(),
352 }),
353 )
354 .sync_within(|| {
355 Envelope::new(
356 num,
357 MessageKind::Response {
358 sender: addr,
359 request_id,
360 },
361 )
362 .upcast()
363 })
364 }
365
366 #[tokio::test]
367 async fn one_request_one_response() {
368 let addr = Addr::from_bits(1);
369 let table = Arc::new(RequestTable::new(addr));
370 let book = AddressBook::new();
371
372 for _ in 0..3 {
373 let token = table.new_request(book.clone(), true);
374 let request_id = token.request_id();
375
376 let table1 = table.clone();
377 tokio::spawn(async move {
378 table1.resolve(token, Ok(envelope(addr, request_id, Num(42))));
379 });
380
381 let mut data = table.wait(request_id).await;
382
383 assert_eq!(data.len(), 1);
384 assert_msg_eq!(data.pop().unwrap().unwrap(), Num(42));
385 }
386 }
387
388 async fn one_request_many_response(collect_all: bool, ignore: bool) {
389 let addr = Addr::from_bits(1);
390 let table = Arc::new(RequestTable::new(addr));
391 let token = table.new_request(AddressBook::new(), collect_all);
392 let request_id = token.request_id();
393
394 let n = 5;
395 for i in 1..n {
396 let table1 = table.clone();
397 let token = table.clone_token(&token).unwrap();
398 assert_eq!(token.request_id, request_id);
399 tokio::spawn(async move {
400 if !ignore {
401 table1.resolve(request_id, Ok(envelope(addr, request_id, Num(i))));
402 } else {
403 table1.resolve(request_id, Err(RequestError::Ignored));
405 }
406 });
407 }
408
409 if !ignore {
410 table.resolve(request_id, Ok(envelope(addr, request_id, Num(0))));
411 } else {
412 table.resolve(request_id, Err(RequestError::Ignored));
414 }
415
416 let mut data = table.wait(request_id).await;
417
418 let expected_len = if ignore {
419 0
420 } else if collect_all {
421 n as usize
422 } else {
423 1
424 };
425 assert_eq!(data.len(), expected_len);
426
427 for (i, response) in data.drain(..).enumerate() {
428 if ignore {
429 assert!(response.is_err());
430 } else {
431 assert_msg_eq!(response.unwrap(), Num(i as u32));
432 }
433 }
434 }
435
436 #[tokio::test]
437 async fn one_request_many_response_all() {
438 one_request_many_response(true, false).await;
439 }
440
441 #[tokio::test]
442 async fn one_request_many_response_all_ignored() {
443 one_request_many_response(false, true).await;
444 }
445
446 #[tokio::test]
447 async fn one_request_many_response_any() {
448 one_request_many_response(false, false).await;
449 }
450
451 #[tokio::test]
452 async fn one_request_many_response_any_ignored() {
453 one_request_many_response(false, true).await;
454 }
455
456 #[tokio::test]
460 async fn late_resolve() {
461 let addr = Addr::from_bits(1);
462 let table = Arc::new(RequestTable::new(addr));
463 let book = AddressBook::new();
464
465 let token = table.new_request(book.clone(), false);
466 let _token1 = table.clone_token(&token).unwrap();
467 let request_id = token.request_id;
468
469 let table1 = table.clone();
470 tokio::spawn(async move {
471 table1.resolve(request_id, Ok(envelope(addr, request_id, Num(42))));
472 });
473
474 let _data = table.wait(request_id).await;
475 table.resolve(request_id, Ok(envelope(addr, request_id, Num(43))));
476 }
477}