1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::log::{Instrument, warn};
9use mm1_common::metrics::MeasuredFutureExt;
10use mm1_common::{impl_error_kind, make_metrics};
11use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
12use mm1_core::envelope::{Envelope, EnvelopeHeader};
13use mm1_proto::Message;
14use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
15use tracing::Level;
16
17static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
20pub enum AskErrorKind {
21 Send(SendErrorKind),
22 Recv(RecvErrorKind),
23 Fork(ForkErrorKind),
24 Timeout,
25 Cast,
26}
27
28pub trait Ask: Messaging + Sized {
29 fn ask<Rq, Rs>(
30 &mut self,
31 server: Address,
32 request: Rq,
33 timeout: Duration,
34 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
35 where
36 Self: Fork,
37 Rq: Send,
38 Request<Rq>: Message,
39 Rs: Message;
40
41 #[doc(hidden)]
42 fn ask_nofork<Rq, Rs>(
43 &mut self,
44 server: Address,
45 request: Rq,
46 timeout: Duration,
47 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
48 where
49 Rq: Send,
50 Request<Rq>: Message,
51 Rs: Message;
52}
53
54pub trait Reply: Messaging + Send {
55 fn reply<Rs>(
56 &mut self,
57 to: RequestHeader,
58 response: Rs,
59 ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
60 where
61 Rs: Send,
62 Response<Rs>: Message;
63}
64
65impl<Ctx> Ask for Ctx
66where
67 Ctx: Messaging + Sized + Send,
68{
69 async fn ask_nofork<Rq, Rs>(
70 &mut self,
71 server: Address,
72 request: Rq,
73 timeout: Duration,
74 ) -> Result<Rs, ErrorOf<AskErrorKind>>
75 where
76 Request<Rq>: Message,
77 Response<Rs>: Message,
78 {
79 ask_impl(self, server, request, timeout)
80 .measured(make_metrics!("mm1_ask",
81 "req" => std::any::type_name::<Rq>(),
82 "resp" => std::any::type_name::<Rs>(),
83 ))
84 .instrument(tracing::span!(
85 Level::TRACE,
86 "mm1_ask",
87 req = std::any::type_name::<Rq>(),
88 resp = std::any::type_name::<Rs>(),
89 ))
90 .await
91 }
92
93 async fn ask<Rq, Rs>(
94 &mut self,
95 server: Address,
96 request: Rq,
97 timeout: Duration,
98 ) -> Result<Rs, ErrorOf<AskErrorKind>>
99 where
100 Self: Fork,
101 Rq: Send,
102 Request<Rq>: Message,
103 Rs: Message,
104 {
105 self.fork()
106 .await
107 .map_err(into_ask_error)?
108 .ask_nofork(server, request, timeout)
109 .await
110 }
111}
112
113impl<Ctx> Reply for Ctx
114where
115 Ctx: Messaging + Send,
116{
117 async fn reply<Rs>(
118 &mut self,
119 to: RequestHeader,
120 response: Rs,
121 ) -> Result<(), ErrorOf<SendErrorKind>>
122 where
123 Response<Rs>: Message,
124 {
125 let RequestHeader { id, reply_to } = to;
126 let response_header = ResponseHeader { id };
127 let response_message = Response {
128 header: response_header,
129 payload: response,
130 };
131 let response_envelope_header = EnvelopeHeader::to_address(reply_to).with_priority(true);
132 let response_envelope = Envelope::new(response_envelope_header, response_message);
133 self.send(response_envelope.into_erased()).await?;
134
135 Ok(())
136 }
137}
138
139impl_error_kind!(AskErrorKind);
140
141fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
142where
143 K: ErrorKind + Into<AskErrorKind>,
144{
145 e.map_kind(Into::into)
146}
147
148async fn ask_impl<Ctx, Rq, Rs>(
149 ctx: &mut Ctx,
150 server: Address,
151 request: Rq,
152 timeout: Duration,
153) -> Result<Rs, ErrorOf<AskErrorKind>>
154where
155 Ctx: Messaging,
156 Request<Rq>: Message,
157 Response<Rs>: Message,
158{
159 let reply_to = ctx.address();
160 let request_header = RequestHeader {
161 id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
162 reply_to,
163 };
164 let request_message = Request {
165 header: request_header,
166 payload: request,
167 };
168 let request_header = EnvelopeHeader::to_address(server);
169 let request_envelope = Envelope::new(request_header, request_message);
170 let () = ctx
171 .send(request_envelope.into_erased())
172 .await
173 .map_err(into_ask_error)?;
174 let response_envelope: Envelope<Response<Rs>> = ctx
175 .recv()
176 .timeout(timeout)
177 .await
178 .map_err(|_elapsed| ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response"))?
179 .map_err(into_ask_error)?
180 .cast()
181 .map_err(|envelope| {
182 warn!(
183 expected = %std::any::type_name::<Response<Rs>>(),
184 actual = %envelope.message_name(),
185 "invalid cast"
186 );
187 ErrorOf::new(AskErrorKind::Cast, "unexpected response type")
188 })?;
189 let (response_message, _empty_envelope) = response_envelope.take();
190 let Response {
191 header: _,
192 payload: response,
193 } = response_message;
194
195 Ok(response)
196}