Skip to main content

elfo_core/
request_table.rs

1use std::{fmt, marker::PhantomData, sync::Arc};
2
3use idr_ebr::EbrGuard;
4use parking_lot::Mutex;
5use slotmap::{new_key_type, Key, SlotMap};
6use smallvec::SmallVec;
7use tokio::sync::Notify;
8
9use crate::{
10    address_book::AddressBook, envelope::Envelope, errors::RequestError, message::AnyMessage,
11    tracing::TraceId, Addr,
12};
13
14// === RequestId ===
15
16new_key_type! {
17    pub struct RequestId;
18}
19
20impl RequestId {
21    #[doc(hidden)]
22    #[inline]
23    pub fn to_ffi(self) -> u64 {
24        self.data().as_ffi()
25    }
26
27    #[doc(hidden)]
28    #[inline]
29    pub fn from_ffi(id: u64) -> Self {
30        slotmap::KeyData::from_ffi(id).into()
31    }
32
33    #[doc(hidden)]
34    #[inline]
35    pub fn is_null(&self) -> bool {
36        Key::is_null(self)
37    }
38}
39
40// === RequestTable ===
41
42pub(crate) struct RequestTable {
43    owner: Addr,
44    notifier: Notify,
45    requests: Mutex<SlotMap<RequestId, RequestData>>,
46}
47
48assert_impl_all!(RequestTable: Sync);
49
50type Responses = SmallVec<[Result<Envelope, RequestError>; 1]>;
51
52#[derive(Default)]
53struct RequestData {
54    remainder: usize,
55    responses: Responses,
56    collect_all: bool,
57}
58
59impl RequestData {
60    /// Returns `true` if the request is done.
61    fn push(&mut self, response: Result<Envelope, RequestError>) -> bool {
62        // Extra responses (in `any` case).
63        if self.remainder == 0 {
64            // TODO: move to `ResponseToken` to avoid sending extra responses over network.
65            debug_assert!(!self.collect_all);
66            return false;
67        }
68
69        self.remainder -= 1;
70
71        if self.collect_all {
72            self.responses.push(response);
73            return self.remainder == 0;
74        }
75
76        // `Any` request contains at most one related response.
77        debug_assert!(self.responses.len() <= 1);
78
79        let is_ok = response.is_ok();
80
81        if self.responses.is_empty() {
82            self.responses.push(response);
83        }
84        // Priority: `Ok(_)` > `Err(Ignored)` > `Err(Failed)`
85        else if response.is_ok() {
86            debug_assert!(self.responses[0].is_err());
87            self.responses[0] = response;
88        } else if let Err(RequestError::Ignored) = response {
89            debug_assert!(self.responses[0].is_err());
90            self.responses[0] = response;
91        }
92
93        // Received `Ok`, so prevent further responses.
94        if is_ok {
95            self.remainder = 0;
96        }
97
98        self.remainder == 0
99    }
100}
101
102impl RequestTable {
103    pub(crate) fn new(owner: Addr) -> Self {
104        Self {
105            owner,
106            notifier: Notify::new(),
107            requests: Mutex::new(SlotMap::default()),
108        }
109    }
110
111    pub(crate) fn new_request(
112        &self,
113        book: AddressBook,
114        trace_id: TraceId,
115        collect_all: bool,
116    ) -> ResponseToken {
117        let mut requests = self.requests.lock();
118        let request_id = requests.insert(RequestData {
119            remainder: 1,
120            responses: Responses::new(),
121            collect_all,
122        });
123        ResponseToken::new(self.owner, request_id, trace_id, book)
124    }
125
126    pub(crate) fn cancel_request(&self, request_id: RequestId) {
127        let mut requests = self.requests.lock();
128        requests.remove(request_id);
129    }
130
131    pub(crate) async fn wait(&self, request_id: RequestId) -> Responses {
132        loop {
133            let waiting = self.notifier.notified();
134
135            {
136                let mut requests = self.requests.lock();
137                let request = requests.get(request_id).expect("unknown request");
138
139                if request.remainder == 0 {
140                    break requests.remove(request_id).expect("under lock").responses;
141                }
142            }
143
144            waiting.await;
145        }
146    }
147
148    pub(crate) fn resolve(
149        &self,
150        mut token: ResponseToken,
151        response: Result<Envelope, RequestError>,
152    ) {
153        // Do nothing for forgotten tokens.
154        let data = ward!(token.data.take());
155        let mut requests = self.requests.lock();
156
157        // `None` here means the request was with `collect_all = false` and
158        // the response has been recieved already.
159        let request = ward!(requests.get_mut(data.request_id));
160
161        if request.push(response) {
162            // Actors can perform multiple requests in parallel using different
163            // wakers, so we should wake all possible wakers up.
164            self.notifier.notify_waiters();
165        }
166    }
167}
168
169// === ResponseToken ===
170
171#[must_use]
172pub struct ResponseToken<T = AnyMessage> {
173    /// `None` if forgotten.
174    data: Option<Arc<ResponseTokenData>>,
175    received: bool,
176    marker: PhantomData<T>,
177}
178
179struct ResponseTokenData {
180    sender: Addr,
181    request_id: RequestId,
182    trace_id: TraceId,
183    book: AddressBook,
184}
185
186impl ResponseToken {
187    #[doc(hidden)]
188    #[inline]
189    pub fn new(sender: Addr, request_id: RequestId, trace_id: TraceId, book: AddressBook) -> Self {
190        debug_assert!(!sender.is_null());
191        debug_assert!(!request_id.is_null());
192
193        Self {
194            data: Some(Arc::new(ResponseTokenData {
195                sender,
196                request_id,
197                trace_id,
198                book,
199            })),
200            received: false,
201            marker: PhantomData,
202        }
203    }
204
205    /// # Panics
206    /// If the token is forgotten.
207    #[doc(hidden)]
208    #[inline]
209    pub fn trace_id(&self) -> TraceId {
210        self.data.as_ref().map(|data| data.trace_id).unwrap()
211    }
212
213    /// # Panics
214    /// If the token is forgotten.
215    #[doc(hidden)]
216    #[inline]
217    pub fn sender(&self) -> Addr {
218        self.data.as_ref().map(|data| data.sender).unwrap()
219    }
220
221    /// # Panics
222    /// If the token is forgotten.
223    #[doc(hidden)]
224    #[inline]
225    pub fn request_id(&self) -> RequestId {
226        self.data.as_ref().map(|data| data.request_id).unwrap()
227    }
228
229    /// # Panics
230    /// If the token is forgotten.
231    #[doc(hidden)]
232    #[inline]
233    pub fn is_last(&self) -> bool {
234        self.data.as_ref().map(Arc::strong_count).unwrap() <= 1
235    }
236
237    #[doc(hidden)]
238    #[inline]
239    pub fn into_received<T>(mut self) -> ResponseToken<T> {
240        ResponseToken {
241            data: self.data.take(),
242            received: true,
243            marker: PhantomData,
244        }
245    }
246
247    #[doc(hidden)]
248    #[inline]
249    pub fn duplicate(&self) -> Self {
250        Self {
251            data: self.do_duplicate(),
252            received: self.received,
253            marker: PhantomData,
254        }
255    }
256
257    #[doc(hidden)]
258    #[inline]
259    pub fn forget(mut self) {
260        self.data = None;
261    }
262
263    fn do_duplicate(&self) -> Option<Arc<ResponseTokenData>> {
264        let data = self.data.as_ref()?;
265
266        if data.sender.is_local() {
267            let guard = EbrGuard::new();
268            let object = data.book.get(data.sender, &guard)?;
269            let actor = object.as_actor()?;
270            let mut requests = actor.request_table().requests.lock();
271            requests.get_mut(data.request_id)?.remainder += 1;
272        }
273
274        Some(data.clone())
275    }
276}
277
278impl<R> ResponseToken<R> {
279    #[doc(hidden)]
280    #[inline]
281    pub fn forgotten() -> Self {
282        Self {
283            data: None,
284            received: false,
285            marker: PhantomData,
286        }
287    }
288
289    pub(crate) fn into_untyped(mut self) -> ResponseToken {
290        ResponseToken {
291            data: self.data.take(),
292            received: self.received,
293            marker: PhantomData,
294        }
295    }
296
297    #[doc(hidden)]
298    #[inline]
299    pub fn is_forgotten(&self) -> bool {
300        self.data.is_none()
301    }
302}
303
304impl<T> Drop for ResponseToken<T> {
305    #[inline]
306    fn drop(&mut self) {
307        // Do nothing for forgotten tokens.
308        let data = ward!(self.data.take());
309        let book = data.book.clone();
310        let guard = EbrGuard::new();
311        let object = ward!(book.get(data.sender, &guard));
312        let this = ResponseToken {
313            data: Some(data),
314            received: self.received,
315            marker: PhantomData,
316        };
317        let err = if self.received {
318            RequestError::Ignored
319        } else {
320            RequestError::Failed
321        };
322
323        object.respond(this, Err(err));
324    }
325}
326
327impl<T> fmt::Debug for ResponseToken<T> {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        f.debug_struct("ResponseToken").finish()
330    }
331}
332
333#[cfg(test)]
334#[cfg(TODO)]
335mod tests {
336    use super::*;
337
338    use std::sync::Arc;
339
340    use crate::{actor::ActorMeta, assert_msg_eq, envelope::MessageKind, message, scope::Scope};
341
342    #[message]
343    #[derive(PartialEq)]
344    struct Num(u32);
345
346    fn envelope(addr: Addr, request_id: RequestId, num: Num) -> Envelope {
347        Scope::test(
348            addr,
349            Arc::new(ActorMeta {
350                group: "test".into(),
351                key: String::new(),
352            }),
353        )
354        .sync_within(|| {
355            Envelope::new(
356                num,
357                MessageKind::Response {
358                    sender: addr,
359                    request_id,
360                },
361            )
362            .upcast()
363        })
364    }
365
366    #[tokio::test]
367    async fn one_request_one_response() {
368        let addr = Addr::from_bits(1);
369        let table = Arc::new(RequestTable::new(addr));
370        let book = AddressBook::new();
371
372        for _ in 0..3 {
373            let token = table.new_request(book.clone(), true);
374            let request_id = token.request_id();
375
376            let table1 = table.clone();
377            tokio::spawn(async move {
378                table1.resolve(token, Ok(envelope(addr, request_id, Num(42))));
379            });
380
381            let mut data = table.wait(request_id).await;
382
383            assert_eq!(data.len(), 1);
384            assert_msg_eq!(data.pop().unwrap().unwrap(), Num(42));
385        }
386    }
387
388    async fn one_request_many_response(collect_all: bool, ignore: bool) {
389        let addr = Addr::from_bits(1);
390        let table = Arc::new(RequestTable::new(addr));
391        let token = table.new_request(AddressBook::new(), collect_all);
392        let request_id = token.request_id();
393
394        let n = 5;
395        for i in 1..n {
396            let table1 = table.clone();
397            let token = table.clone_token(&token).unwrap();
398            assert_eq!(token.request_id, request_id);
399            tokio::spawn(async move {
400                if !ignore {
401                    table1.resolve(request_id, Ok(envelope(addr, request_id, Num(i))));
402                } else {
403                    // TODO: test a real `Drop`.
404                    table1.resolve(request_id, Err(RequestError::Ignored));
405                }
406            });
407        }
408
409        if !ignore {
410            table.resolve(request_id, Ok(envelope(addr, request_id, Num(0))));
411        } else {
412            // TODO: test a real `Drop`.
413            table.resolve(request_id, Err(RequestError::Ignored));
414        }
415
416        let mut data = table.wait(request_id).await;
417
418        let expected_len = if ignore {
419            0
420        } else if collect_all {
421            n as usize
422        } else {
423            1
424        };
425        assert_eq!(data.len(), expected_len);
426
427        for (i, response) in data.drain(..).enumerate() {
428            if ignore {
429                assert!(response.is_err());
430            } else {
431                assert_msg_eq!(response.unwrap(), Num(i as u32));
432            }
433        }
434    }
435
436    #[tokio::test]
437    async fn one_request_many_response_all() {
438        one_request_many_response(true, false).await;
439    }
440
441    #[tokio::test]
442    async fn one_request_many_response_all_ignored() {
443        one_request_many_response(false, true).await;
444    }
445
446    #[tokio::test]
447    async fn one_request_many_response_any() {
448        one_request_many_response(false, false).await;
449    }
450
451    #[tokio::test]
452    async fn one_request_many_response_any_ignored() {
453        one_request_many_response(false, true).await;
454    }
455
456    // TODO: check many requests.
457    // TODO: check `Drop`.
458
459    #[tokio::test]
460    async fn late_resolve() {
461        let addr = Addr::from_bits(1);
462        let table = Arc::new(RequestTable::new(addr));
463        let book = AddressBook::new();
464
465        let token = table.new_request(book.clone(), false);
466        let _token1 = table.clone_token(&token).unwrap();
467        let request_id = token.request_id;
468
469        let table1 = table.clone();
470        tokio::spawn(async move {
471            table1.resolve(request_id, Ok(envelope(addr, request_id, Num(42))));
472        });
473
474        let _data = table.wait(request_id).await;
475        table.resolve(request_id, Ok(envelope(addr, request_id, Num(43))));
476    }
477}