mm1_ask/
ask.rs

1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::impl_error_kind;
9use mm1_common::log::warn;
10use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
11use mm1_core::envelope::{Envelope, EnvelopeHeader};
12use mm1_proto::Message;
13use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
14
15static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
18pub enum AskErrorKind {
19    Send(SendErrorKind),
20    Recv(RecvErrorKind),
21    Fork(ForkErrorKind),
22    Timeout,
23    Cast,
24}
25
26pub trait Ask: Messaging + Sized {
27    fn ask<Rq, Rs>(
28        &mut self,
29        server: Address,
30        request: Rq,
31        timeout: Duration,
32    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
33    where
34        Self: Fork,
35        Rq: Send,
36        Request<Rq>: Message,
37        Rs: Message;
38
39    #[doc(hidden)]
40    fn ask_nofork<Rq, Rs>(
41        &mut self,
42        server: Address,
43        request: Rq,
44        timeout: Duration,
45    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
46    where
47        Rq: Send,
48        Request<Rq>: Message,
49        Rs: Message;
50}
51
52pub trait Reply: Messaging + Send {
53    fn reply<Rs>(
54        &mut self,
55        to: RequestHeader,
56        response: Rs,
57    ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
58    where
59        Rs: Send,
60        Response<Rs>: Message;
61}
62
63impl<Ctx> Ask for Ctx
64where
65    Ctx: Messaging + Sized + Send,
66{
67    async fn ask_nofork<Rq, Rs>(
68        &mut self,
69        server: Address,
70        request: Rq,
71        timeout: Duration,
72    ) -> Result<Rs, ErrorOf<AskErrorKind>>
73    where
74        Request<Rq>: Message,
75        Response<Rs>: Message,
76    {
77        let reply_to = self.address();
78        let request_header = RequestHeader {
79            id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
80            reply_to,
81        };
82        let request_message = Request {
83            header:  request_header,
84            payload: request,
85        };
86        let request_header = EnvelopeHeader::to_address(server);
87        let request_envelope = Envelope::new(request_header, request_message);
88        let () = self
89            .send(request_envelope.into_erased())
90            .await
91            .map_err(into_ask_error)?;
92        let response_envelope: Envelope<Response<Rs>> = self
93            .recv()
94            .timeout(timeout)
95            .await
96            .map_err(|_elapsed| {
97                ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
98            })?
99            .map_err(into_ask_error)?
100            .cast()
101            .map_err(|envelope| {
102                warn!(
103                    "invalid cast [expected: {}; actual: {}]",
104                    std::any::type_name::<Response<Rs>>(),
105                    envelope.message_name()
106                );
107                ErrorOf::new(AskErrorKind::Cast, "unexpected response type")
108            })?;
109        let (response_message, _empty_envelope) = response_envelope.take();
110        let Response {
111            header: _,
112            payload: response,
113        } = response_message;
114
115        Ok(response)
116    }
117
118    async fn ask<Rq, Rs>(
119        &mut self,
120        server: Address,
121        request: Rq,
122        timeout: Duration,
123    ) -> Result<Rs, ErrorOf<AskErrorKind>>
124    where
125        Self: Fork,
126        Rq: Send,
127        Request<Rq>: Message,
128        Rs: Message,
129    {
130        self.fork()
131            .await
132            .map_err(into_ask_error)?
133            .ask_nofork(server, request, timeout)
134            .await
135    }
136}
137
138impl<Ctx> Reply for Ctx
139where
140    Ctx: Messaging + Send,
141{
142    async fn reply<Rs>(
143        &mut self,
144        to: RequestHeader,
145        response: Rs,
146    ) -> Result<(), ErrorOf<SendErrorKind>>
147    where
148        Response<Rs>: Message,
149    {
150        let RequestHeader { id, reply_to } = to;
151        let response_header = ResponseHeader { id };
152        let response_message = Response {
153            header:  response_header,
154            payload: response,
155        };
156        let response_envelope_header = EnvelopeHeader::to_address(reply_to).with_priority(true);
157        let response_envelope = Envelope::new(response_envelope_header, response_message);
158        self.send(response_envelope.into_erased()).await?;
159
160        Ok(())
161    }
162}
163
164impl_error_kind!(AskErrorKind);
165
166fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
167where
168    K: ErrorKind + Into<AskErrorKind>,
169{
170    e.map_kind(Into::into)
171}