mm1_ask/
ask.rs

1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::impl_error_kind;
9use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
10use mm1_core::envelope::{Envelope, EnvelopeHeader};
11use mm1_proto::Message;
12use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
13
14static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
17pub enum AskErrorKind {
18    Send(SendErrorKind),
19    Recv(RecvErrorKind),
20    Fork(ForkErrorKind),
21    Timeout,
22    Cast,
23}
24
25pub trait Ask: Messaging + Sized {
26    fn ask<Rq, Rs>(
27        &mut self,
28        server: Address,
29        request: Rq,
30        timeout: Duration,
31    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
32    where
33        Rq: Send,
34        Request<Rq>: Message,
35        Rs: Message;
36
37    fn fork_ask<Rq, Rs>(
38        &mut self,
39        server: Address,
40        request: Rq,
41        timeout: Duration,
42    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
43    where
44        Self: Fork,
45        Rq: Send,
46        Request<Rq>: Message,
47        Rs: Message;
48}
49
50pub trait Reply: Messaging + Send {
51    fn reply<Rs>(
52        &mut self,
53        to: RequestHeader,
54        response: Rs,
55    ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
56    where
57        Rs: Send,
58        Response<Rs>: Message;
59}
60
61impl<Ctx> Ask for Ctx
62where
63    Ctx: Messaging + Sized + Send,
64{
65    async fn ask<Rq, Rs>(
66        &mut self,
67        server: Address,
68        request: Rq,
69        timeout: Duration,
70    ) -> Result<Rs, ErrorOf<AskErrorKind>>
71    where
72        Request<Rq>: Message,
73        Response<Rs>: Message,
74    {
75        let reply_to = self.address();
76        let request_header = RequestHeader {
77            id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
78            reply_to,
79        };
80        let request_message = Request {
81            header:  request_header,
82            payload: request,
83        };
84        let request_header = EnvelopeHeader::to_address(server);
85        let request_envelope = Envelope::new(request_header, request_message);
86        let () = self
87            .send(request_envelope.into_erased())
88            .await
89            .map_err(into_ask_error)?;
90        let response_envelope: Envelope<Response<Rs>> = self
91            .recv()
92            .timeout(timeout)
93            .await
94            .map_err(|_elapsed| {
95                ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
96            })?
97            .map_err(into_ask_error)?
98            .cast()
99            .map_err(|_| ErrorOf::new(AskErrorKind::Cast, "unexpected response type"))?;
100        let (response_message, _empty_envelope) = response_envelope.take();
101        let Response {
102            header: _,
103            payload: response,
104        } = response_message;
105
106        Ok(response)
107    }
108
109    async fn fork_ask<Rq, Rs>(
110        &mut self,
111        server: Address,
112        request: Rq,
113        timeout: Duration,
114    ) -> Result<Rs, ErrorOf<AskErrorKind>>
115    where
116        Self: Fork,
117        Rq: Send,
118        Request<Rq>: Message,
119        Rs: Message,
120    {
121        self.fork()
122            .await
123            .map_err(into_ask_error)?
124            .ask(server, request, timeout)
125            .await
126    }
127}
128
129impl<Ctx> Reply for Ctx
130where
131    Ctx: Messaging + Send,
132{
133    async fn reply<Rs>(
134        &mut self,
135        to: RequestHeader,
136        response: Rs,
137    ) -> Result<(), ErrorOf<SendErrorKind>>
138    where
139        Response<Rs>: Message,
140    {
141        let RequestHeader { id, reply_to } = to;
142        let response_header = ResponseHeader { id };
143        let response_message = Response {
144            header:  response_header,
145            payload: response,
146        };
147        let response_envelope_header = EnvelopeHeader::to_address(reply_to);
148        let response_envelope = Envelope::new(response_envelope_header, response_message);
149        self.send(response_envelope.into_erased()).await?;
150
151        Ok(())
152    }
153}
154
155impl_error_kind!(AskErrorKind);
156
157fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
158where
159    K: ErrorKind + Into<AskErrorKind>,
160{
161    e.map_kind(Into::into)
162}