1use crate::{muxer, next_or_pending, preface, ReplyCodec, SpecCodec};
2use drain::Watch as Drain;
3use futures::{prelude::*, stream::FuturesUnordered};
4use ort_core::{Error, Ort};
5use std::net::SocketAddr;
6use tokio_util::codec::{FramedRead, FramedWrite};
7use tracing::{debug, debug_span, error, trace, Instrument};
8
9pub struct Server<O> {
10 inner: O,
11 buffer_capacity: usize,
12}
13
14impl<O: Ort> Server<O> {
15 pub fn new(inner: O) -> Self {
16 Self {
17 inner,
18 buffer_capacity: 100_000,
19 }
20 }
21
22 pub async fn serve(self, addr: SocketAddr, drain: Drain) -> Result<(), Error> {
23 let mut serving = FuturesUnordered::new();
24 let lis = tokio::net::TcpListener::bind(addr).await?;
25
26 tokio::pin! {
27 let closed = drain.clone().signaled();
28 }
29
30 loop {
31 tokio::select! {
32 shutdown = (&mut closed) => {
33 debug!("Letting all connections complete before shutdown");
34 while serving.next().await.is_some() {}
35 drop(shutdown);
36 return Ok(());
37 }
38
39 _ = next_or_pending(&mut serving) => {}
40
41 acc = lis.accept() => {
42 let ((rio, wio), peer) = match acc {
43 Ok((sock, peer)) => {
44 debug!(%peer, "Client connected");
45 (sock.into_split(), peer)
46 }
47 Err(error) => {
48 error!(%error, "Failed to accept connection");
49 continue;
50 }
51 };
52
53 let span = debug_span!("conn", %peer);
54
55 let decode = preface::Codec::from(muxer::FramedDecode::from(SpecCodec::default()));
56 let encode = muxer::FramedEncode::from(ReplyCodec::default());
57 let (mut rx, muxer) = span.in_scope(|| muxer::spawn_server(
58 FramedRead::new(rio, decode),
59 FramedWrite::new(wio, encode),
60 drain.clone(),
61 self.buffer_capacity,
62 ));
63
64 let srv = self.inner.clone();
65 let drain = drain.clone();
66
67 let server = tokio::spawn(async move {
68 tokio::pin! {
69 let closed = drain.signaled();
70 }
71
72 let mut in_flight = FuturesUnordered::new();
73 loop {
74 tokio::select! {
75 shutdown = (&mut closed) => {
76 debug!("Draining inflight requests before shutdown");
77 drop(rx);
78 while let Some(()) = in_flight.next().await {};
79 drop(shutdown);
80 return;
81 }
82
83 _ = next_or_pending(&mut in_flight) => {
84 trace!("Response completed");
85 }
86
87 next = rx.recv() => match next {
88 None => {
89 debug!("Client closed; draining in-flight requests");
90 while let Some(()) = in_flight.next().await {};
91 return;
92 }
93 Some((spec, tx)) => {
94 let mut srv = srv.clone();
95 let h = tokio::spawn(async move {
96 let reply = srv.ort(spec).await?;
97 let _ = tx.send(reply);
98 Ok::<(), Error>(())
99 }.instrument(debug_span!("req")));
100 in_flight.push(h.map(|res| match res {
101 Ok(Ok(())) => {},
102 Ok(Err(error)) => error!(%error, "Service failed"),
103 Err(error) => error!(%error, "Task failed"),
104 }));
105 }
106 }
107 }
108 }
109 }.instrument(span));
110
111 serving.push(async move {
112 let (m, r) = tokio::join!(muxer, server);
113 debug!(?m, ?r, "Connection complete");
114 let () = r?;
115 m
116 })
117 }
118 }
119 }
120 }
121}