ort_tcp/
server.rs

1use crate::{muxer, next_or_pending, preface, ReplyCodec, SpecCodec};
2use drain::Watch as Drain;
3use futures::{prelude::*, stream::FuturesUnordered};
4use ort_core::{Error, Ort};
5use std::net::SocketAddr;
6use tokio_util::codec::{FramedRead, FramedWrite};
7use tracing::{debug, debug_span, error, trace, Instrument};
8
9pub struct Server<O> {
10    inner: O,
11    buffer_capacity: usize,
12}
13
14impl<O: Ort> Server<O> {
15    pub fn new(inner: O) -> Self {
16        Self {
17            inner,
18            buffer_capacity: 100_000,
19        }
20    }
21
22    pub async fn serve(self, addr: SocketAddr, drain: Drain) -> Result<(), Error> {
23        let mut serving = FuturesUnordered::new();
24        let lis = tokio::net::TcpListener::bind(addr).await?;
25
26        tokio::pin! {
27            let closed = drain.clone().signaled();
28        }
29
30        loop {
31            tokio::select! {
32                shutdown = (&mut closed) => {
33                    debug!("Letting all connections complete before shutdown");
34                    while serving.next().await.is_some() {}
35                    drop(shutdown);
36                    return Ok(());
37                }
38
39                _ = next_or_pending(&mut serving) => {}
40
41                acc = lis.accept() => {
42                    let ((rio, wio), peer) = match acc {
43                        Ok((sock, peer)) => {
44                            debug!(%peer, "Client connected");
45                            (sock.into_split(), peer)
46                        }
47                        Err(error) => {
48                            error!(%error, "Failed to accept connection");
49                            continue;
50                        }
51                    };
52
53                    let span = debug_span!("conn", %peer);
54
55                    let decode = preface::Codec::from(muxer::FramedDecode::from(SpecCodec::default()));
56                    let encode = muxer::FramedEncode::from(ReplyCodec::default());
57                    let (mut rx, muxer) = span.in_scope(|| muxer::spawn_server(
58                        FramedRead::new(rio, decode),
59                        FramedWrite::new(wio, encode),
60                        drain.clone(),
61                        self.buffer_capacity,
62                    ));
63
64                    let srv = self.inner.clone();
65                    let drain = drain.clone();
66
67                    let server = tokio::spawn(async move {
68                        tokio::pin! {
69                            let closed = drain.signaled();
70                        }
71
72                        let mut in_flight = FuturesUnordered::new();
73                        loop {
74                            tokio::select! {
75                                shutdown = (&mut closed) => {
76                                    debug!("Draining inflight requests before shutdown");
77                                    drop(rx);
78                                    while let Some(()) = in_flight.next().await {};
79                                    drop(shutdown);
80                                    return;
81                                }
82
83                                _ = next_or_pending(&mut in_flight) => {
84                                    trace!("Response completed");
85                                }
86
87                                next = rx.recv() => match next {
88                                    None => {
89                                        debug!("Client closed; draining in-flight requests");
90                                        while let Some(()) = in_flight.next().await {};
91                                        return;
92                                    }
93                                    Some((spec, tx)) => {
94                                        let mut srv = srv.clone();
95                                        let h = tokio::spawn(async move {
96                                            let reply = srv.ort(spec).await?;
97                                            let _ = tx.send(reply);
98                                            Ok::<(), Error>(())
99                                        }.instrument(debug_span!("req")));
100                                        in_flight.push(h.map(|res| match res {
101                                            Ok(Ok(())) => {},
102                                            Ok(Err(error)) => error!(%error, "Service failed"),
103                                            Err(error) => error!(%error, "Task failed"),
104                                        }));
105                                    }
106                                }
107                            }
108                        }
109                    }.instrument(span));
110
111                    serving.push(async move {
112                        let (m, r) = tokio::join!(muxer, server);
113                        debug!(?m, ?r, "Connection complete");
114                        let () = r?;
115                        m
116                    })
117                }
118            }
119        }
120    }
121}