1use std::time::Duration;
2
3use mm1_address::address::Address;
4use mm1_common::errors::error_kind::ErrorKind;
5use mm1_common::errors::error_of::ErrorOf;
6use mm1_common::futures::timeout::FutureTimeoutExt;
7use mm1_common::impl_error_kind;
8use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
9use mm1_core::envelope::{Envelope, EnvelopeHeader};
10use mm1_core::prim::Message;
11use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
14pub enum AskErrorKind {
15 Send(SendErrorKind),
16 Recv(RecvErrorKind),
17 Fork(ForkErrorKind),
18 Timeout,
19 Cast,
20}
21
22pub trait Ask: Messaging + Sized {
23 fn ask<Rq, Rs>(
24 &mut self,
25 server: Address,
26 request: Rq,
27 timeout: Duration,
28 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
29 where
30 Rq: Send,
31 Request<Rq>: Message,
32 Rs: Message;
33
34 fn fork_ask<Rq, Rs>(
35 &mut self,
36 server: Address,
37 request: Rq,
38 timeout: Duration,
39 ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
40 where
41 Self: Fork,
42 Rq: Send,
43 Request<Rq>: Message,
44 Rs: Message;
45}
46
47pub trait Reply: Messaging + Send {
48 fn reply<Id, Rs>(
49 &mut self,
50 to: RequestHeader<Id>,
51 response: Rs,
52 ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
53 where
54 Id: Send,
55 Rs: Send,
56 Response<Rs, Id>: Message;
57}
58
59impl<Ctx> Ask for Ctx
60where
61 Ctx: Messaging + Sized + Send,
62{
63 async fn ask<Rq, Rs>(
64 &mut self,
65 server: Address,
66 request: Rq,
67 timeout: Duration,
68 ) -> Result<Rs, ErrorOf<AskErrorKind>>
69 where
70 Request<Rq>: Message,
71 Rs: Message,
72 {
73 let reply_to = self.address();
74 let request_header = RequestHeader { id: (), reply_to };
75 let request_message = Request {
76 header: request_header,
77 payload: request,
78 };
79 let request_header = EnvelopeHeader::to_address(server);
80 let request_envelope = Envelope::new(request_header, request_message);
81 let () = self
82 .send(request_envelope.into_erased())
83 .await
84 .map_err(into_ask_error)?;
85 let response_envelope = self
86 .recv()
87 .timeout(timeout)
88 .await
89 .map_err(|_elapsed| {
90 ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response")
91 })?
92 .map_err(into_ask_error)?
93 .cast::<Response<Rs>>()
94 .map_err(|_| ErrorOf::new(AskErrorKind::Cast, "unexpected response type"))?;
95 let (response_message, _empty_envelope) = response_envelope.take();
96 let Response {
97 header: _,
98 payload: response,
99 } = response_message;
100
101 Ok(response)
102 }
103
104 async fn fork_ask<Rq, Rs>(
105 &mut self,
106 server: Address,
107 request: Rq,
108 timeout: Duration,
109 ) -> Result<Rs, ErrorOf<AskErrorKind>>
110 where
111 Self: Fork,
112 Rq: Send,
113 Request<Rq>: Message,
114 Rs: Message,
115 {
116 self.fork()
117 .await
118 .map_err(into_ask_error)?
119 .ask(server, request, timeout)
120 .await
121 }
122}
123
124impl<Ctx> Reply for Ctx
125where
126 Ctx: Messaging + Send,
127{
128 async fn reply<Id, Rs>(
129 &mut self,
130 to: RequestHeader<Id>,
131 response: Rs,
132 ) -> Result<(), ErrorOf<SendErrorKind>>
133 where
134 Response<Rs, Id>: Message,
135 {
136 let RequestHeader { id, reply_to } = to;
137 let response_header = ResponseHeader { id };
138 let response_message = Response {
139 header: response_header,
140 payload: response,
141 };
142 let response_envelope_header = EnvelopeHeader::to_address(reply_to);
143 let response_envelope = Envelope::new(response_envelope_header, response_message);
144 self.send(response_envelope.into_erased()).await?;
145
146 Ok(())
147 }
148}
149
150impl_error_kind!(AskErrorKind);
151
152fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
153where
154 K: ErrorKind + Into<AskErrorKind>,
155{
156 e.map_kind(Into::into)
157}