Skip to main content

mm1_ask/
ask.rs

1use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
2use std::time::Duration;
3
4use mm1_address::address::Address;
5use mm1_common::errors::error_kind::ErrorKind;
6use mm1_common::errors::error_of::ErrorOf;
7use mm1_common::futures::timeout::FutureTimeoutExt;
8use mm1_common::log::{Instrument, warn};
9use mm1_common::metrics::MeasuredFutureExt;
10use mm1_common::{impl_error_kind, make_metrics};
11use mm1_core::context::{Fork, ForkErrorKind, Messaging, RecvErrorKind, SendErrorKind};
12use mm1_core::envelope::{Envelope, EnvelopeHeader};
13use mm1_proto::Message;
14use mm1_proto_ask::{Request, RequestHeader, Response, ResponseHeader};
15use tracing::Level;
16
17static REQUEST_ID: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, derive_more::From)]
20pub enum AskErrorKind {
21    Send(SendErrorKind),
22    Recv(RecvErrorKind),
23    Fork(ForkErrorKind),
24    Timeout,
25    Cast,
26}
27
28pub trait Ask: Messaging + Sized {
29    fn ask<Rq, Rs>(
30        &mut self,
31        server: Address,
32        request: Rq,
33        timeout: Duration,
34    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
35    where
36        Self: Fork,
37        Rq: Send,
38        Request<Rq>: Message,
39        Rs: Message;
40
41    #[doc(hidden)]
42    fn ask_nofork<Rq, Rs>(
43        &mut self,
44        server: Address,
45        request: Rq,
46        timeout: Duration,
47    ) -> impl Future<Output = Result<Rs, ErrorOf<AskErrorKind>>> + Send
48    where
49        Rq: Send,
50        Request<Rq>: Message,
51        Rs: Message;
52}
53
54pub trait Reply: Messaging + Send {
55    fn reply<Rs>(
56        &mut self,
57        to: RequestHeader,
58        response: Rs,
59    ) -> impl Future<Output = Result<(), ErrorOf<SendErrorKind>>> + Send
60    where
61        Rs: Send,
62        Response<Rs>: Message;
63}
64
65impl<Ctx> Ask for Ctx
66where
67    Ctx: Messaging + Sized + Send,
68{
69    async fn ask_nofork<Rq, Rs>(
70        &mut self,
71        server: Address,
72        request: Rq,
73        timeout: Duration,
74    ) -> Result<Rs, ErrorOf<AskErrorKind>>
75    where
76        Request<Rq>: Message,
77        Response<Rs>: Message,
78    {
79        ask_impl(self, server, request, timeout)
80            .measured(make_metrics!("mm1_ask",
81                "req" => std::any::type_name::<Rq>(),
82                "resp" => std::any::type_name::<Rs>(),
83            ))
84            .instrument(tracing::span!(
85                Level::TRACE,
86                "mm1_ask",
87                req = std::any::type_name::<Rq>(),
88                resp = std::any::type_name::<Rs>(),
89            ))
90            .await
91    }
92
93    async fn ask<Rq, Rs>(
94        &mut self,
95        server: Address,
96        request: Rq,
97        timeout: Duration,
98    ) -> Result<Rs, ErrorOf<AskErrorKind>>
99    where
100        Self: Fork,
101        Rq: Send,
102        Request<Rq>: Message,
103        Rs: Message,
104    {
105        self.fork()
106            .await
107            .map_err(into_ask_error)?
108            .ask_nofork(server, request, timeout)
109            .await
110    }
111}
112
113impl<Ctx> Reply for Ctx
114where
115    Ctx: Messaging + Send,
116{
117    async fn reply<Rs>(
118        &mut self,
119        to: RequestHeader,
120        response: Rs,
121    ) -> Result<(), ErrorOf<SendErrorKind>>
122    where
123        Response<Rs>: Message,
124    {
125        let RequestHeader { id, reply_to } = to;
126        let response_header = ResponseHeader { id };
127        let response_message = Response {
128            header:  response_header,
129            payload: response,
130        };
131        let response_envelope_header = EnvelopeHeader::to_address(reply_to).with_priority(true);
132        let response_envelope = Envelope::new(response_envelope_header, response_message);
133        self.send(response_envelope.into_erased()).await?;
134
135        Ok(())
136    }
137}
138
139impl_error_kind!(AskErrorKind);
140
141fn into_ask_error<K>(e: ErrorOf<K>) -> ErrorOf<AskErrorKind>
142where
143    K: ErrorKind + Into<AskErrorKind>,
144{
145    e.map_kind(Into::into)
146}
147
148async fn ask_impl<Ctx, Rq, Rs>(
149    ctx: &mut Ctx,
150    server: Address,
151    request: Rq,
152    timeout: Duration,
153) -> Result<Rs, ErrorOf<AskErrorKind>>
154where
155    Ctx: Messaging,
156    Request<Rq>: Message,
157    Response<Rs>: Message,
158{
159    let reply_to = ctx.address();
160    let request_header = RequestHeader {
161        id: REQUEST_ID.fetch_add(1, AtomicOrdering::Relaxed),
162        reply_to,
163    };
164    let request_message = Request {
165        header:  request_header,
166        payload: request,
167    };
168    let request_header = EnvelopeHeader::to_address(server);
169    let request_envelope = Envelope::new(request_header, request_message);
170    let () = ctx
171        .send(request_envelope.into_erased())
172        .await
173        .map_err(into_ask_error)?;
174    let response_envelope: Envelope<Response<Rs>> = ctx
175        .recv()
176        .timeout(timeout)
177        .await
178        .map_err(|_elapsed| ErrorOf::new(AskErrorKind::Timeout, "timed out waiting for response"))?
179        .map_err(into_ask_error)?
180        .cast()
181        .map_err(|envelope| {
182            warn!(
183                expected = %std::any::type_name::<Response<Rs>>(),
184                actual = %envelope.message_name(),
185                "invalid cast"
186            );
187            ErrorOf::new(AskErrorKind::Cast, "unexpected response type")
188        })?;
189    let (response_message, _empty_envelope) = response_envelope.take();
190    let Response {
191        header: _,
192        payload: response,
193    } = response_message;
194
195    Ok(response)
196}