1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::impl_error_kind;
9use mm1_common::log::warn;
10use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
11use mm1_core::envelope::{Envelope, EnvelopeHeader};
12use mm1_proto::Message;
13use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
14
15static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
18pub enum AskErrorKind {
19 Send(SendErrorKind),
20 Recv(RecvErrorKind),
21 Fork(ForkErrorKind),
22 Timeout,
23 Cast,
24}
25
26pub trait Ask: Messaging + Sized {
27 fn ask<Rq, Rs>(
28 &mut self,
29 server: Address,
30 request: Rq,
31 timeout: Duration,
32 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
33 where
34 Self: Fork,
35 Rq: Send,
36 Request<Rq>: Message,
37 Rs: Message;
38
39 #[doc(hidden)]
40 fn ask_nofork<Rq, Rs>(
41 &mut self,
42 server: Address,
43 request: Rq,
44 timeout: Duration,
45 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
46 where
47 Rq: Send,
48 Request<Rq>: Message,
49 Rs: Message;
50}
51
52pub trait Reply: Messaging + Send {
53 fn reply<Rs>(
54 &mut self,
55 to: RequestHeader,
56 response: Rs,
57 ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
58 where
59 Rs: Send,
60 Response<Rs>: Message;
61}
62
63impl<Ctx> Ask for Ctx
64where
65 Ctx: Messaging + Sized + Send,
66{
67 async fn ask_nofork<Rq, Rs>(
68 &mut self,
69 server: Address,
70 request: Rq,
71 timeout: Duration,
72 ) -> Result<Rs, ErrorOf<AskErrorKind>>
73 where
74 Request<Rq>: Message,
75 Response<Rs>: Message,
76 {
77 let reply_to = self.address();
78 let request_header = RequestHeader {
79 id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
80 reply_to,
81 };
82 let request_message = Request {
83 header: request_header,
84 payload: request,
85 };
86 let request_header = EnvelopeHeader::to_address(server);
87 let request_envelope = Envelope::new(request_header, request_message);
88 let () = self
89 .send(request_envelope.into_erased())
90 .await
91 .map_err(into_ask_error)?;
92 let response_envelope: Envelope<Response<Rs>> = self
93 .recv()
94 .timeout(timeout)
95 .await
96 .map_err(|_elapsed| {
97 ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
98 })?
99 .map_err(into_ask_error)?
100 .cast()
101 .map_err(|envelope| {
102 warn!(
103 "invalid cast [expected: {}; actual: {}]",
104 std::any::type_name::<Response<Rs>>(),
105 envelope.message_name()
106 );
107 ErrorOf::new(AskErrorKind::Cast, "unexpected response type")
108 })?;
109 let (response_message, _empty_envelope) = response_envelope.take();
110 let Response {
111 header: _,
112 payload: response,
113 } = response_message;
114
115 Ok(response)
116 }
117
118 async fn ask<Rq, Rs>(
119 &mut self,
120 server: Address,
121 request: Rq,
122 timeout: Duration,
123 ) -> Result<Rs, ErrorOf<AskErrorKind>>
124 where
125 Self: Fork,
126 Rq: Send,
127 Request<Rq>: Message,
128 Rs: Message,
129 {
130 self.fork()
131 .await
132 .map_err(into_ask_error)?
133 .ask_nofork(server, request, timeout)
134 .await
135 }
136}
137
138impl<Ctx> Reply for Ctx
139where
140 Ctx: Messaging + Send,
141{
142 async fn reply<Rs>(
143 &mut self,
144 to: RequestHeader,
145 response: Rs,
146 ) -> Result<(), ErrorOf<SendErrorKind>>
147 where
148 Response<Rs>: Message,
149 {
150 let RequestHeader { id, reply_to } = to;
151 let response_header = ResponseHeader { id };
152 let response_message = Response {
153 header: response_header,
154 payload: response,
155 };
156 let response_envelope_header = EnvelopeHeader::to_address(reply_to).with_priority(true);
157 let response_envelope = Envelope::new(response_envelope_header, response_message);
158 self.send(response_envelope.into_erased()).await?;
159
160 Ok(())
161 }
162}
163
164impl_error_kind!(AskErrorKind);
165
166fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
167where
168 K: ErrorKind + Into<AskErrorKind>,
169{
170 e.map_kind(Into::into)
171}