1use std::collections::HashMap;
2use std::convert::TryInto;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use serde::Serialize;
7
8use futures::sink::SinkExt;
9use futures::stream::StreamExt;
10
11use tokio::prelude::*;
12use tokio::sync::{mpsc, oneshot, Barrier};
13use tokio::task::JoinHandle;
14use tokio_util::codec::{BytesCodec, FramedRead, FramedWrite};
15
16use amp_serde::{ErrorResponse, OkResponse, Request};
17
18use crate::frame::Response;
19use crate::{Decoder, Error, Frame, RawFrame};
20
21#[derive(Debug)]
22pub struct DispatchRequest(pub Bytes, pub RawFrame, pub Option<ReplyTicket>);
23
24struct ExpectReply {
25 tag: u64,
26 reply: oneshot::Sender<Response>,
27 barrier: Arc<tokio::sync::Barrier>,
28}
29
30type _FrameMaker = Box<dyn FnOnce(Option<Bytes>) -> Result<Vec<u8>, amp_serde::Error> + Send>;
31
32struct FrameMaker(_FrameMaker);
33
34impl std::fmt::Debug for FrameMaker {
35 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
36 write!(fmt, "callback")
37 }
38}
39
40#[derive(Debug)]
41enum WriteCmd {
42 Reply(Bytes),
43 Request(FrameMaker, Option<oneshot::Sender<Response>>),
44}
45
46#[derive(Debug)]
47pub struct ReplyTicket {
48 tag: Option<Bytes>,
49 write_handle: mpsc::Sender<WriteCmd>,
50}
51
52impl ReplyTicket {
53 pub async fn ok<R: Serialize>(mut self, reply: R) -> Result<(), Error> {
54 let tag = self.tag.take().expect("Tag taken out of sequence");
55
56 let reply = amp_serde::to_bytes(OkResponse { tag, fields: reply })?;
57
58 self.write_handle
59 .send(WriteCmd::Reply(reply.into()))
60 .await?;
61
62 Ok(())
63 }
64
65 pub async fn error(
66 mut self,
67 code: Option<String>,
68 description: Option<String>,
69 ) -> Result<(), Error> {
70 let tag = self.tag.take().expect("Tag taken out of sequence");
71
72 let reply = amp_serde::to_bytes(ErrorResponse {
73 tag,
74 code: code.unwrap_or_else(|| "UNKNOWN".into()),
75 description: description.unwrap_or_else(|| "".into()),
76 })?;
77
78 self.write_handle
79 .send(WriteCmd::Reply(reply.into()))
80 .await?;
81
82 Ok(())
83 }
84}
85
86impl Drop for ReplyTicket {
87 fn drop(&mut self) {
88 if let Some(tag) = self.tag.take() {
89 let mut write_handle = self.write_handle.clone();
90 let reply = amp_serde::to_bytes(ErrorResponse {
91 tag,
92 code: "UNKNOWN".into(),
93 description: "Request dropped without reply".into(),
94 })
95 .unwrap();
96
97 tokio::spawn(async move {
99 write_handle
100 .send(WriteCmd::Reply(reply.into()))
101 .await
102 .expect("error on drop")
103 });
104 }
105 }
106}
107
108#[derive(Clone)]
109pub struct RequestSender(mpsc::Sender<WriteCmd>);
110
111impl RequestSender {
112 pub async fn call_remote<Q: Serialize + Send + 'static>(
113 &mut self,
114 command: String,
115 request: Q,
116 ) -> Result<RawFrame, Error> {
117 let (tx, rx) = oneshot::channel();
118
119 let frame = FrameMaker(Box::new(move |tag| {
120 amp_serde::to_bytes(Request {
121 tag,
122 command,
123 fields: request,
124 })
125 }));
126
127 self.0.send(WriteCmd::Request(frame, Some(tx))).await?;
128
129 rx.await?.map_err(|err| Error::Remote {
130 code: err.code,
131 description: err.description,
132 })
133 }
134
135 pub async fn call_remote_noreply<Q: Serialize + Send + 'static>(
136 &mut self,
137 command: String,
138 request: Q,
139 ) -> Result<(), Error> {
140 let frame = FrameMaker(Box::new(move |tag| {
141 amp_serde::to_bytes(Request {
142 tag,
143 command,
144 fields: request,
145 })
146 }));
147
148 self.0.send(WriteCmd::Request(frame, None)).await?;
149
150 Ok(())
151 }
152}
153
154pub struct Handle {
155 write_res: JoinHandle<Result<(), Error>>,
156 read_res: JoinHandle<Result<(), Error>>,
157 write_loop_handle: Option<mpsc::Sender<WriteCmd>>,
158 shutdown: Option<oneshot::Sender<()>>,
159}
160
161impl Handle {
162 pub fn shutdown(&mut self) {
163 self.write_loop_handle = None;
164 if let Some(s) = self.shutdown.take() {
165 let _ = s.send(());
166 }
167 }
168
169 pub async fn join(mut self) -> Result<(), Error> {
170 self.write_loop_handle = None;
171 self.write_res.await.unwrap()?;
172 if let Some(s) = self.shutdown.take() {
173 let _ = s.send(());
174 }
175 self.read_res.await.unwrap()?;
176
177 Ok(())
178 }
179
180 pub fn request_sender(&self) -> Option<RequestSender> {
181 self.write_loop_handle.as_ref().cloned().map(RequestSender)
182 }
183}
184
185pub fn serve<R, W>(input: R, output: W) -> (Handle, mpsc::Receiver<DispatchRequest>)
186where
187 R: AsyncRead + Unpin + Send + 'static,
188 W: AsyncWrite + Unpin + Send + 'static,
189{
190 let (write_tx, write_rx) = mpsc::channel::<WriteCmd>(32);
191 let (dispatch_tx, dispatch_rx) = mpsc::channel::<DispatchRequest>(32);
192 let (expect_tx, expect_rx) = mpsc::channel::<ExpectReply>(32);
193 let (shutdown_tx, shutdown_rx) = oneshot::channel();
194
195 let read_res = tokio::spawn(read_loop(
196 input,
197 shutdown_rx,
198 write_tx.clone(),
199 dispatch_tx,
200 expect_rx,
201 ));
202 let write_res = tokio::spawn(write_loop(output, write_rx, expect_tx));
203
204 (
205 Handle {
206 write_res,
207 read_res,
208 write_loop_handle: Some(write_tx),
209 shutdown: Some(shutdown_tx),
210 },
211 dispatch_rx,
212 )
213}
214
215type ReplyMap = HashMap<u64, oneshot::Sender<Response>>;
216
217async fn read_loop<R>(
218 input: R,
219 mut shutdown: oneshot::Receiver<()>,
220 mut write_tx: mpsc::Sender<WriteCmd>,
221 mut dispatch_tx: mpsc::Sender<DispatchRequest>,
222 mut expect_rx: mpsc::Receiver<ExpectReply>,
223) -> Result<(), Error>
224where
225 R: AsyncRead + Unpin,
226{
227 let codec_in: Decoder<RawFrame> = Decoder::new();
228 let mut input = FramedRead::new(input, codec_in);
229 let mut reply_map = ReplyMap::new();
230
231 loop {
232 tokio::select! {
233 frame = input.next() => {
234 if let Some(frame) = frame {
235 dispatch_frame(frame?, &mut reply_map, &mut write_tx, &mut dispatch_tx).await?;
236 } else {
237 break;
238 }
239 }
240 expect = expect_rx.recv() => {
241 if let Some(expect) = expect {
242 reply_map.insert(expect.tag, expect.reply);
243 expect.barrier.wait().await;
244 }
245 }
246 _ = &mut shutdown => break
247 }
248 }
249
250 Ok(())
251}
252
253async fn dispatch_frame(
254 frame: RawFrame,
255 reply_map: &mut ReplyMap,
256 write_tx: &mut mpsc::Sender<WriteCmd>,
257 dispatch_tx: &mut mpsc::Sender<DispatchRequest>,
258) -> Result<(), Error> {
259 match frame.try_into()? {
260 Frame::Request {
261 tag,
262 command,
263 fields,
264 } => {
265 let ticket = tag.map(|tag| ReplyTicket {
266 tag: Some(tag),
267 write_handle: write_tx.clone(),
268 });
269
270 let _ = dispatch_tx
274 .send(DispatchRequest(command, fields, ticket))
275 .await;
276 }
277
278 Frame::Response { tag, response } => {
279 let reply_tx = std::str::from_utf8(&tag)
280 .ok()
281 .and_then(|tag_str| u64::from_str_radix(tag_str, 16).ok())
282 .and_then(|tag_u64| reply_map.remove(&tag_u64))
283 .ok_or(Error::UnmatchedReply)?;
284
285 reply_tx.send(response).map_err(|_| Error::SendError)?;
286 }
287 }
288
289 Ok(())
290}
291
292async fn write_loop<W>(
293 output: W,
294 mut input: mpsc::Receiver<WriteCmd>,
295 mut expect_tx: mpsc::Sender<ExpectReply>,
296) -> Result<(), Error>
297where
298 W: AsyncWrite + Unpin,
299{
300 let mut output = FramedWrite::new(output, BytesCodec::new());
301 let mut seqno: u64 = 0;
302
303 while let Some(msg) = input.next().await {
304 match msg {
305 WriteCmd::Reply(frame) => {
306 output.send(frame).await?;
307 }
308 WriteCmd::Request(request, reply) => {
309 if let Some(reply) = reply {
310 seqno += 1;
311
312 let barrier = Arc::new(Barrier::new(2));
313
314 let expect = ExpectReply {
315 tag: seqno,
316 reply,
317 barrier: barrier.clone(),
318 };
319
320 expect_tx.send(expect).await?;
321 barrier.wait().await;
322
323 output
324 .send(request.0(Some(format!("{:x}", seqno).into()))?.into())
325 .await?;
326 } else {
327 output.send(request.0(None)?.into()).await?;
328 }
329 }
330 }
331 }
332
333 Ok(())
334}