1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::impl_error_kind;
9use mm1_common::log::warn;
10use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
11use mm1_core::envelope::{Envelope, EnvelopeHeader};
12use mm1_proto::Message;
13use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
14
15static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
18pub enum AskErrorKind {
19 Send(SendErrorKind),
20 Recv(RecvErrorKind),
21 Fork(ForkErrorKind),
22 Timeout,
23 Cast,
24}
25
26pub trait Ask: Messaging + Sized {
27 fn ask<Rq, Rs>(
28 &mut self,
29 server: Address,
30 request: Rq,
31 timeout: Duration,
32 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
33 where
34 Rq: Send,
35 Request<Rq>: Message,
36 Rs: Message;
37
38 fn fork_ask<Rq, Rs>(
39 &mut self,
40 server: Address,
41 request: Rq,
42 timeout: Duration,
43 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
44 where
45 Self: Fork,
46 Rq: Send,
47 Request<Rq>: Message,
48 Rs: Message;
49}
50
51pub trait Reply: Messaging + Send {
52 fn reply<Rs>(
53 &mut self,
54 to: RequestHeader,
55 response: Rs,
56 ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
57 where
58 Rs: Send,
59 Response<Rs>: Message;
60}
61
62impl<Ctx> Ask for Ctx
63where
64 Ctx: Messaging + Sized + Send,
65{
66 async fn ask<Rq, Rs>(
67 &mut self,
68 server: Address,
69 request: Rq,
70 timeout: Duration,
71 ) -> Result<Rs, ErrorOf<AskErrorKind>>
72 where
73 Request<Rq>: Message,
74 Response<Rs>: Message,
75 {
76 let reply_to = self.address();
77 let request_header = RequestHeader {
78 id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
79 reply_to,
80 };
81 let request_message = Request {
82 header: request_header,
83 payload: request,
84 };
85 let request_header = EnvelopeHeader::to_address(server);
86 let request_envelope = Envelope::new(request_header, request_message);
87 let () = self
88 .send(request_envelope.into_erased())
89 .await
90 .map_err(into_ask_error)?;
91 let response_envelope: Envelope<Response<Rs>> = self
92 .recv()
93 .timeout(timeout)
94 .await
95 .map_err(|_elapsed| {
96 ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
97 })?
98 .map_err(into_ask_error)?
99 .cast()
100 .map_err(|envelope| {
101 warn!(
102 "invalid cast [expected: {}; actual: {}]",
103 std::any::type_name::<Response<Rs>>(),
104 envelope.message_name()
105 );
106 ErrorOf::new(AskErrorKind::Cast, "unexpected response type")
107 })?;
108 let (response_message, _empty_envelope) = response_envelope.take();
109 let Response {
110 header: _,
111 payload: response,
112 } = response_message;
113
114 Ok(response)
115 }
116
117 async fn fork_ask<Rq, Rs>(
118 &mut self,
119 server: Address,
120 request: Rq,
121 timeout: Duration,
122 ) -> Result<Rs, ErrorOf<AskErrorKind>>
123 where
124 Self: Fork,
125 Rq: Send,
126 Request<Rq>: Message,
127 Rs: Message,
128 {
129 self.fork()
130 .await
131 .map_err(into_ask_error)?
132 .ask(server, request, timeout)
133 .await
134 }
135}
136
137impl<Ctx> Reply for Ctx
138where
139 Ctx: Messaging + Send,
140{
141 async fn reply<Rs>(
142 &mut self,
143 to: RequestHeader,
144 response: Rs,
145 ) -> Result<(), ErrorOf<SendErrorKind>>
146 where
147 Response<Rs>: Message,
148 {
149 let RequestHeader { id, reply_to } = to;
150 let response_header = ResponseHeader { id };
151 let response_message = Response {
152 header: response_header,
153 payload: response,
154 };
155 let response_envelope_header = EnvelopeHeader::to_address(reply_to);
156 let response_envelope = Envelope::new(response_envelope_header, response_message);
157 self.send(response_envelope.into_erased()).await?;
158
159 Ok(())
160 }
161}
162
163impl_error_kind!(AskErrorKind);
164
165fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
166where
167 K: ErrorKind + Into<AskErrorKind>,
168{
169 e.map_kind(Into::into)
170}