amp_async/
server.rs

1use std::collections::HashMap;
2use std::convert::TryInto;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use serde::Serialize;
7
8use futures::sink::SinkExt;
9use futures::stream::StreamExt;
10
11use tokio::prelude::*;
12use tokio::sync::{mpsc, oneshot, Barrier};
13use tokio::task::JoinHandle;
14use tokio_util::codec::{BytesCodec, FramedRead, FramedWrite};
15
16use amp_serde::{ErrorResponse, OkResponse, Request};
17
18use crate::frame::Response;
19use crate::{Decoder, Error, Frame, RawFrame};
20
21#[derive(Debug)]
22pub struct DispatchRequest(pub Bytes, pub RawFrame, pub Option<ReplyTicket>);
23
24struct ExpectReply {
25    tag: u64,
26    reply: oneshot::Sender<Response>,
27    barrier: Arc<tokio::sync::Barrier>,
28}
29
30type _FrameMaker = Box<dyn FnOnce(Option<Bytes>) -> Result<Vec<u8>, amp_serde::Error> + Send>;
31
32struct FrameMaker(_FrameMaker);
33
34impl std::fmt::Debug for FrameMaker {
35    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
36        write!(fmt, "callback")
37    }
38}
39
40#[derive(Debug)]
41enum WriteCmd {
42    Reply(Bytes),
43    Request(FrameMaker, Option<oneshot::Sender<Response>>),
44}
45
46#[derive(Debug)]
47pub struct ReplyTicket {
48    tag: Option<Bytes>,
49    write_handle: mpsc::Sender<WriteCmd>,
50}
51
52impl ReplyTicket {
53    pub async fn ok<R: Serialize>(mut self, reply: R) -> Result<(), Error> {
54        let tag = self.tag.take().expect("Tag taken out of sequence");
55
56        let reply = amp_serde::to_bytes(OkResponse { tag, fields: reply })?;
57
58        self.write_handle
59            .send(WriteCmd::Reply(reply.into()))
60            .await?;
61
62        Ok(())
63    }
64
65    pub async fn error(
66        mut self,
67        code: Option<String>,
68        description: Option<String>,
69    ) -> Result<(), Error> {
70        let tag = self.tag.take().expect("Tag taken out of sequence");
71
72        let reply = amp_serde::to_bytes(ErrorResponse {
73            tag,
74            code: code.unwrap_or_else(|| "UNKNOWN".into()),
75            description: description.unwrap_or_else(|| "".into()),
76        })?;
77
78        self.write_handle
79            .send(WriteCmd::Reply(reply.into()))
80            .await?;
81
82        Ok(())
83    }
84}
85
86impl Drop for ReplyTicket {
87    fn drop(&mut self) {
88        if let Some(tag) = self.tag.take() {
89            let mut write_handle = self.write_handle.clone();
90            let reply = amp_serde::to_bytes(ErrorResponse {
91                tag,
92                code: "UNKNOWN".into(),
93                description: "Request dropped without reply".into(),
94            })
95            .unwrap();
96
97            // Can't wait for poll_drop
98            tokio::spawn(async move {
99                write_handle
100                    .send(WriteCmd::Reply(reply.into()))
101                    .await
102                    .expect("error on drop")
103            });
104        }
105    }
106}
107
108#[derive(Clone)]
109pub struct RequestSender(mpsc::Sender<WriteCmd>);
110
111impl RequestSender {
112    pub async fn call_remote<Q: Serialize + Send + 'static>(
113        &mut self,
114        command: String,
115        request: Q,
116    ) -> Result<RawFrame, Error> {
117        let (tx, rx) = oneshot::channel();
118
119        let frame = FrameMaker(Box::new(move |tag| {
120            amp_serde::to_bytes(Request {
121                tag,
122                command,
123                fields: request,
124            })
125        }));
126
127        self.0.send(WriteCmd::Request(frame, Some(tx))).await?;
128
129        rx.await?.map_err(|err| Error::Remote {
130            code: err.code,
131            description: err.description,
132        })
133    }
134
135    pub async fn call_remote_noreply<Q: Serialize + Send + 'static>(
136        &mut self,
137        command: String,
138        request: Q,
139    ) -> Result<(), Error> {
140        let frame = FrameMaker(Box::new(move |tag| {
141            amp_serde::to_bytes(Request {
142                tag,
143                command,
144                fields: request,
145            })
146        }));
147
148        self.0.send(WriteCmd::Request(frame, None)).await?;
149
150        Ok(())
151    }
152}
153
154pub struct Handle {
155    write_res: JoinHandle<Result<(), Error>>,
156    read_res: JoinHandle<Result<(), Error>>,
157    write_loop_handle: Option<mpsc::Sender<WriteCmd>>,
158    shutdown: Option<oneshot::Sender<()>>,
159}
160
161impl Handle {
162    pub fn shutdown(&mut self) {
163        self.write_loop_handle = None;
164        if let Some(s) = self.shutdown.take() {
165            let _ = s.send(());
166        }
167    }
168
169    pub async fn join(mut self) -> Result<(), Error> {
170        self.write_loop_handle = None;
171        self.write_res.await.unwrap()?;
172        if let Some(s) = self.shutdown.take() {
173            let _ = s.send(());
174        }
175        self.read_res.await.unwrap()?;
176
177        Ok(())
178    }
179
180    pub fn request_sender(&self) -> Option<RequestSender> {
181        self.write_loop_handle.as_ref().cloned().map(RequestSender)
182    }
183}
184
185pub fn serve<R, W>(input: R, output: W) -> (Handle, mpsc::Receiver<DispatchRequest>)
186where
187    R: AsyncRead + Unpin + Send + 'static,
188    W: AsyncWrite + Unpin + Send + 'static,
189{
190    let (write_tx, write_rx) = mpsc::channel::<WriteCmd>(32);
191    let (dispatch_tx, dispatch_rx) = mpsc::channel::<DispatchRequest>(32);
192    let (expect_tx, expect_rx) = mpsc::channel::<ExpectReply>(32);
193    let (shutdown_tx, shutdown_rx) = oneshot::channel();
194
195    let read_res = tokio::spawn(read_loop(
196        input,
197        shutdown_rx,
198        write_tx.clone(),
199        dispatch_tx,
200        expect_rx,
201    ));
202    let write_res = tokio::spawn(write_loop(output, write_rx, expect_tx));
203
204    (
205        Handle {
206            write_res,
207            read_res,
208            write_loop_handle: Some(write_tx),
209            shutdown: Some(shutdown_tx),
210        },
211        dispatch_rx,
212    )
213}
214
215type ReplyMap = HashMap<u64, oneshot::Sender<Response>>;
216
217async fn read_loop<R>(
218    input: R,
219    mut shutdown: oneshot::Receiver<()>,
220    mut write_tx: mpsc::Sender<WriteCmd>,
221    mut dispatch_tx: mpsc::Sender<DispatchRequest>,
222    mut expect_rx: mpsc::Receiver<ExpectReply>,
223) -> Result<(), Error>
224where
225    R: AsyncRead + Unpin,
226{
227    let codec_in: Decoder<RawFrame> = Decoder::new();
228    let mut input = FramedRead::new(input, codec_in);
229    let mut reply_map = ReplyMap::new();
230
231    loop {
232        tokio::select! {
233            frame = input.next() => {
234                if let Some(frame) = frame {
235                    dispatch_frame(frame?, &mut reply_map, &mut write_tx, &mut dispatch_tx).await?;
236                } else {
237                    break;
238                }
239            }
240            expect = expect_rx.recv() => {
241                if let Some(expect) = expect {
242                    reply_map.insert(expect.tag, expect.reply);
243                    expect.barrier.wait().await;
244                }
245            }
246            _ = &mut shutdown => break
247        }
248    }
249
250    Ok(())
251}
252
253async fn dispatch_frame(
254    frame: RawFrame,
255    reply_map: &mut ReplyMap,
256    write_tx: &mut mpsc::Sender<WriteCmd>,
257    dispatch_tx: &mut mpsc::Sender<DispatchRequest>,
258) -> Result<(), Error> {
259    match frame.try_into()? {
260        Frame::Request {
261            tag,
262            command,
263            fields,
264        } => {
265            let ticket = tag.map(|tag| ReplyTicket {
266                tag: Some(tag),
267                write_handle: write_tx.clone(),
268            });
269
270            // The application may close its dispatch channel. All
271            // incoming requests will generate a "Request dropped
272            // without reply" error.
273            let _ = dispatch_tx
274                .send(DispatchRequest(command, fields, ticket))
275                .await;
276        }
277
278        Frame::Response { tag, response } => {
279            let reply_tx = std::str::from_utf8(&tag)
280                .ok()
281                .and_then(|tag_str| u64::from_str_radix(tag_str, 16).ok())
282                .and_then(|tag_u64| reply_map.remove(&tag_u64))
283                .ok_or(Error::UnmatchedReply)?;
284
285            reply_tx.send(response).map_err(|_| Error::SendError)?;
286        }
287    }
288
289    Ok(())
290}
291
292async fn write_loop<W>(
293    output: W,
294    mut input: mpsc::Receiver<WriteCmd>,
295    mut expect_tx: mpsc::Sender<ExpectReply>,
296) -> Result<(), Error>
297where
298    W: AsyncWrite + Unpin,
299{
300    let mut output = FramedWrite::new(output, BytesCodec::new());
301    let mut seqno: u64 = 0;
302
303    while let Some(msg) = input.next().await {
304        match msg {
305            WriteCmd::Reply(frame) => {
306                output.send(frame).await?;
307            }
308            WriteCmd::Request(request, reply) => {
309                if let Some(reply) = reply {
310                    seqno += 1;
311
312                    let barrier = Arc::new(Barrier::new(2));
313
314                    let expect = ExpectReply {
315                        tag: seqno,
316                        reply,
317                        barrier: barrier.clone(),
318                    };
319
320                    expect_tx.send(expect).await?;
321                    barrier.wait().await;
322
323                    output
324                        .send(request.0(Some(format!("{:x}", seqno).into()))?.into())
325                        .await?;
326                } else {
327                    output.send(request.0(None)?.into()).await?;
328                }
329            }
330        }
331    }
332
333    Ok(())
334}