1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::impl_error_kind;
9use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
10use mm1_core::envelope::{Envelope, EnvelopeHeader};
11use mm1_proto::Message;
12use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
13
14static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
17pub enum AskErrorKind {
18 Send(SendErrorKind),
19 Recv(RecvErrorKind),
20 Fork(ForkErrorKind),
21 Timeout,
22 Cast,
23}
24
25pub trait Ask: Messaging + Sized {
26 fn ask<Rq, Rs>(
27 &mut self,
28 server: Address,
29 request: Rq,
30 timeout: Duration,
31 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
32 where
33 Rq: Send,
34 Request<Rq>: Message,
35 Rs: Message;
36
37 fn fork_ask<Rq, Rs>(
38 &mut self,
39 server: Address,
40 request: Rq,
41 timeout: Duration,
42 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
43 where
44 Self: Fork,
45 Rq: Send,
46 Request<Rq>: Message,
47 Rs: Message;
48}
49
50pub trait Reply: Messaging + Send {
51 fn reply<Rs>(
52 &mut self,
53 to: RequestHeader,
54 response: Rs,
55 ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
56 where
57 Rs: Send,
58 Response<Rs>: Message;
59}
60
61impl<Ctx> Ask for Ctx
62where
63 Ctx: Messaging + Sized + Send,
64{
65 async fn ask<Rq, Rs>(
66 &mut self,
67 server: Address,
68 request: Rq,
69 timeout: Duration,
70 ) -> Result<Rs, ErrorOf<AskErrorKind>>
71 where
72 Request<Rq>: Message,
73 Response<Rs>: Message,
74 {
75 let reply_to = self.address();
76 let request_header = RequestHeader {
77 id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
78 reply_to,
79 };
80 let request_message = Request {
81 header: request_header,
82 payload: request,
83 };
84 let request_header = EnvelopeHeader::to_address(server);
85 let request_envelope = Envelope::new(request_header, request_message);
86 let () = self
87 .send(request_envelope.into_erased())
88 .await
89 .map_err(into_ask_error)?;
90 let response_envelope: Envelope<Response<Rs>> = self
91 .recv()
92 .timeout(timeout)
93 .await
94 .map_err(|_elapsed| {
95 ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
96 })?
97 .map_err(into_ask_error)?
98 .cast()
99 .map_err(|_| ErrorOf::new(AskErrorKind::Cast, "unexpected response type"))?;
100 let (response_message, _empty_envelope) = response_envelope.take();
101 let Response {
102 header: _,
103 payload: response,
104 } = response_message;
105
106 Ok(response)
107 }
108
109 async fn fork_ask<Rq, Rs>(
110 &mut self,
111 server: Address,
112 request: Rq,
113 timeout: Duration,
114 ) -> Result<Rs, ErrorOf<AskErrorKind>>
115 where
116 Self: Fork,
117 Rq: Send,
118 Request<Rq>: Message,
119 Rs: Message,
120 {
121 self.fork()
122 .await
123 .map_err(into_ask_error)?
124 .ask(server, request, timeout)
125 .await
126 }
127}
128
129impl<Ctx> Reply for Ctx
130where
131 Ctx: Messaging + Send,
132{
133 async fn reply<Rs>(
134 &mut self,
135 to: RequestHeader,
136 response: Rs,
137 ) -> Result<(), ErrorOf<SendErrorKind>>
138 where
139 Response<Rs>: Message,
140 {
141 let RequestHeader { id, reply_to } = to;
142 let response_header = ResponseHeader { id };
143 let response_message = Response {
144 header: response_header,
145 payload: response,
146 };
147 let response_envelope_header = EnvelopeHeader::to_address(reply_to);
148 let response_envelope = Envelope::new(response_envelope_header, response_message);
149 self.send(response_envelope.into_erased()).await?;
150
151 Ok(())
152 }
153}
154
155impl_error_kind!(AskErrorKind);
156
157fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
158where
159 K: ErrorKind + Into<AskErrorKind>,
160{
161 e.map_kind(Into::into)
162}