mm1_ask/
ask.rs

1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::impl_error_kind;
9use mm1_common::log::warn;
10use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
11use mm1_core::envelope::{Envelope, EnvelopeHeader};
12use mm1_proto::Message;
13use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
14
15static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
18pub enum AskErrorKind {
19    Send(SendErrorKind),
20    Recv(RecvErrorKind),
21    Fork(ForkErrorKind),
22    Timeout,
23    Cast,
24}
25
26pub trait Ask: Messaging + Sized {
27    fn ask<Rq, Rs>(
28        &mut self,
29        server: Address,
30        request: Rq,
31        timeout: Duration,
32    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
33    where
34        Rq: Send,
35        Request<Rq>: Message,
36        Rs: Message;
37
38    fn fork_ask<Rq, Rs>(
39        &mut self,
40        server: Address,
41        request: Rq,
42        timeout: Duration,
43    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
44    where
45        Self: Fork,
46        Rq: Send,
47        Request<Rq>: Message,
48        Rs: Message;
49}
50
51pub trait Reply: Messaging + Send {
52    fn reply<Rs>(
53        &mut self,
54        to: RequestHeader,
55        response: Rs,
56    ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
57    where
58        Rs: Send,
59        Response<Rs>: Message;
60}
61
62impl<Ctx> Ask for Ctx
63where
64    Ctx: Messaging + Sized + Send,
65{
66    async fn ask<Rq, Rs>(
67        &mut self,
68        server: Address,
69        request: Rq,
70        timeout: Duration,
71    ) -> Result<Rs, ErrorOf<AskErrorKind>>
72    where
73        Request<Rq>: Message,
74        Response<Rs>: Message,
75    {
76        let reply_to = self.address();
77        let request_header = RequestHeader {
78            id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
79            reply_to,
80        };
81        let request_message = Request {
82            header:  request_header,
83            payload: request,
84        };
85        let request_header = EnvelopeHeader::to_address(server);
86        let request_envelope = Envelope::new(request_header, request_message);
87        let () = self
88            .send(request_envelope.into_erased())
89            .await
90            .map_err(into_ask_error)?;
91        let response_envelope: Envelope<Response<Rs>> = self
92            .recv()
93            .timeout(timeout)
94            .await
95            .map_err(|_elapsed| {
96                ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
97            })?
98            .map_err(into_ask_error)?
99            .cast()
100            .map_err(|envelope| {
101                warn!(
102                    "invalid cast [expected: {}; actual: {}]",
103                    std::any::type_name::<Response<Rs>>(),
104                    envelope.message_name()
105                );
106                ErrorOf::new(AskErrorKind::Cast, "unexpected response type")
107            })?;
108        let (response_message, _empty_envelope) = response_envelope.take();
109        let Response {
110            header: _,
111            payload: response,
112        } = response_message;
113
114        Ok(response)
115    }
116
117    async fn fork_ask<Rq, Rs>(
118        &mut self,
119        server: Address,
120        request: Rq,
121        timeout: Duration,
122    ) -> Result<Rs, ErrorOf<AskErrorKind>>
123    where
124        Self: Fork,
125        Rq: Send,
126        Request<Rq>: Message,
127        Rs: Message,
128    {
129        self.fork()
130            .await
131            .map_err(into_ask_error)?
132            .ask(server, request, timeout)
133            .await
134    }
135}
136
137impl<Ctx> Reply for Ctx
138where
139    Ctx: Messaging + Send,
140{
141    async fn reply<Rs>(
142        &mut self,
143        to: RequestHeader,
144        response: Rs,
145    ) -> Result<(), ErrorOf<SendErrorKind>>
146    where
147        Response<Rs>: Message,
148    {
149        let RequestHeader { id, reply_to } = to;
150        let response_header = ResponseHeader { id };
151        let response_message = Response {
152            header:  response_header,
153            payload: response,
154        };
155        let response_envelope_header = EnvelopeHeader::to_address(reply_to);
156        let response_envelope = Envelope::new(response_envelope_header, response_message);
157        self.send(response_envelope.into_erased()).await?;
158
159        Ok(())
160    }
161}
162
163impl_error_kind!(AskErrorKind);
164
165fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
166where
167    K: ErrorKind + Into<AskErrorKind>,
168{
169    e.map_kind(Into::into)
170}