abcperf_generic_client/cs/
quic_quinn.rs

1use std::{
2    future::Future,
3    net::{Ipv6Addr, SocketAddr},
4    pin::Pin,
5    sync::Arc,
6    time::Duration,
7};
8
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use futures::FutureExt;
12use quinn::{
13    ClientConfig, Connecting, Connection, Endpoint, RecvStream, SendStream, ServerConfig,
14    TransportConfig, VarInt,
15};
16use rustls::{Certificate, PrivateKey, RootCertStore};
17use serde::{Deserialize, Serialize};
18use tokio::{select, sync::oneshot};
19use tracing::{error, info, info_span, Instrument};
20
21use crate::cs::{CSConfig, CSTrait, CSTraitClient, CSTraitClientConnection, CSTraitServer};
22
23const READ_TO_END_LIMIT: usize = 128 * 1024 * 1024;
24
25#[derive(Debug, Clone, Copy, Default)]
26pub struct QuicQuinn;
27
28impl CSTrait for QuicQuinn {
29    type Config = QuicQuinnConfig;
30
31    fn configure(self, ca: Vec<u8>, priv_key: Vec<u8>, cert: Vec<u8>) -> Self::Config {
32        let mut roots = RootCertStore::empty();
33        roots.add(&Certificate(ca)).unwrap();
34        let mut client_config = ClientConfig::with_root_certificates(roots);
35
36        let mut transport_config = TransportConfig::default();
37        transport_config.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
38        transport_config.keep_alive_interval(Some(Duration::from_secs(10)));
39        client_config.transport_config(Arc::new(transport_config));
40
41        let priv_key = PrivateKey(priv_key);
42        let cert_chain = vec![Certificate(cert)];
43        let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key).unwrap();
44        Arc::get_mut(&mut server_config.transport)
45            .expect("config was just created")
46            .max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
47
48        QuicQuinnConfig {
49            client_config,
50            server_config,
51        }
52    }
53}
54
55#[derive(Debug, Clone)]
56pub struct QuicQuinnConfig {
57    client_config: ClientConfig,
58    server_config: ServerConfig,
59}
60
61impl CSConfig for QuicQuinnConfig {
62    type Client = QuicQuinnClient;
63
64    fn client(&self) -> Self::Client {
65        let mut client = Endpoint::client((Ipv6Addr::UNSPECIFIED, 0).into()).unwrap();
66        client.set_default_client_config(self.client_config.clone());
67
68        QuicQuinnClient { client }
69    }
70
71    type Server = QuicQuinnServer;
72
73    fn server(&self, local_socket: SocketAddr) -> Self::Server {
74        QuicQuinnServer {
75            local_socket,
76            client_config: self.client_config.clone(),
77            server_config: self.server_config.clone(),
78        }
79    }
80}
81
82#[derive(Debug)]
83pub struct QuicQuinnClient {
84    client: Endpoint,
85}
86
87#[async_trait]
88impl CSTraitClient for QuicQuinnClient {
89    type Connection = QuicQuinnClientConnection;
90
91    async fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Self::Connection> {
92        let connection = self
93            .client
94            .connect(addr, server_name)?
95            .await
96            .map_err(|e| anyhow!("failed to connect: {}", e))?;
97
98        Ok(QuicQuinnClientConnection { connection })
99    }
100
101    fn local_addr(&self) -> SocketAddr {
102        self.client.local_addr().unwrap()
103    }
104}
105
106#[derive(Debug)]
107pub struct QuicQuinnClientConnection {
108    connection: Connection,
109}
110
111#[async_trait]
112impl CSTraitClientConnection for QuicQuinnClientConnection {
113    async fn request<Request: Serialize + Send + Sync, Response: for<'a> Deserialize<'a>>(
114        &mut self,
115        request: Request,
116    ) -> Result<Response> {
117        let (mut send, mut recv) = self
118            .connection
119            .open_bi()
120            .await
121            .map_err(|e| anyhow!("failed to open stream: {}", e))?;
122        send.write_all(&bincode::serialize::<Request>(&request)?)
123            .await
124            .map_err(|e| anyhow!("failed to send request: {}", e))?;
125        send.finish().await.map_err(|e| match e {
126            quinn::WriteError::ConnectionLost(e) => {
127                anyhow!("failed to shutdown stream: connection lost: {}", e)
128            }
129            quinn::WriteError::UnknownStream => todo!(),
130            _ => anyhow!("failed to shutdown stream: {}", e),
131        })?;
132
133        let resp = recv
134            .read_to_end(READ_TO_END_LIMIT)
135            .await
136            .map_err(|e| anyhow!("failed to read response: {}", e))?;
137
138        Ok(bincode::deserialize::<Response>(&resp)?)
139    }
140}
141
142#[derive(Debug)]
143pub struct QuicQuinnServer {
144    local_socket: SocketAddr,
145    client_config: ClientConfig,
146    server_config: ServerConfig,
147}
148
149impl CSTraitServer for QuicQuinnServer {
150    fn start<
151        S: Clone + Send + Sync + 'static,
152        Request: for<'a> Deserialize<'a> + Send + 'static,
153        Response: Serialize + Send + 'static,
154        Fut: Future<Output = Option<Response>> + Send + 'static,
155        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
156    >(
157        self,
158        shared: S,
159        on_request: F,
160        exit: oneshot::Receiver<()>,
161    ) -> (SocketAddr, Pin<Box<dyn Future<Output = ()> + Send>>) {
162        let mut endpoint = Endpoint::server(self.server_config, self.local_socket).unwrap();
163        endpoint.set_default_client_config(self.client_config);
164
165        let local_addr = endpoint.local_addr().unwrap();
166
167        info!("listening on {}", local_addr);
168
169        let join_handle = tokio::spawn(Self::run::<S, Request, Response, Fut, F>(
170            endpoint, shared, on_request, exit,
171        ));
172
173        (local_addr, Box::pin(join_handle.map(|_| ())))
174    }
175}
176
177impl QuicQuinnServer {
178    async fn run<
179        S: Clone + Send + Sync + 'static,
180        Request: for<'a> Deserialize<'a> + Send + 'static,
181        Response: Serialize + Send + 'static,
182        Fut: Future<Output = Option<Response>> + Send + 'static,
183        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
184    >(
185        endpoint: Endpoint,
186        shared: S,
187        on_request: F,
188        mut exit: oneshot::Receiver<()>,
189    ) {
190        loop {
191            select! {
192                biased;
193                _ = &mut exit => break,
194                Some(conn) = endpoint.accept() => {
195                    info!("connection incoming");
196                    let fut = Self::handle_connection::<S, Request, Response, Fut, F>(
197                        conn,
198                        shared.clone(),
199                        on_request.clone(),
200                    );
201                    tokio::spawn(async move {
202                        if let Err(e) = fut.await {
203                            error!("connection failed: {}", e.to_string())
204                        }
205                    });
206                },
207                else => break
208            }
209        }
210    }
211
212    async fn handle_connection<
213        S: Clone + Send + 'static,
214        Request: for<'a> Deserialize<'a> + Send + 'static,
215        Response: Serialize + Send + 'static,
216        Fut: Future<Output = Option<Response>> + Send + 'static,
217        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
218    >(
219        connection: Connecting,
220
221        shared: S,
222        on_request: F,
223    ) -> Result<()> {
224        let connection = connection.await?;
225        let span = info_span!(
226            "connection",
227            remote = %connection.remote_address(),
228        );
229        async {
230            info!("established");
231
232            // Each stream initiated by the client constitutes a new request.
233            loop {
234                let stream = connection.accept_bi().await;
235                let stream = match stream {
236                    Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
237                        info!("connection closed");
238                        return Ok(());
239                    }
240
241                    Err(e) => {
242                        return Err(e);
243                    }
244                    Ok(s) => s,
245                };
246                let fut = Self::handle_request::<S, Request, Response, Fut, F>(
247                    stream,
248                    shared.clone(),
249                    on_request.clone(),
250                );
251                tokio::spawn(
252                    async move {
253                        if let Err(e) = fut.await {
254                            error!("failed: {reason}", reason = e.to_string());
255                        }
256                    }
257                    .instrument(info_span!("request")),
258                );
259            }
260        }
261        .instrument(span)
262        .await?;
263        Ok(())
264    }
265
266    async fn handle_request<
267        S,
268        Request: for<'a> Deserialize<'a>,
269        Response: Serialize,
270        Fut: Future<Output = Option<Response>>,
271        F: Fn(S, Request) -> Fut,
272    >(
273        (mut send_quinn, mut recv_quinn): (SendStream, RecvStream),
274        shared: S,
275        on_request: F,
276    ) -> Result<()> {
277        let req = recv_quinn
278            .read_to_end(READ_TO_END_LIMIT)
279            .await
280            .map_err(|e| anyhow!("failed reading request: {}", e))?;
281
282        let req =
283            bincode::deserialize::<Request>(&req).expect("request was not of expected format");
284
285        let response = on_request(shared, req).await;
286
287        if let Some(response) = response {
288            let response = bincode::serialize::<Response>(&response)?;
289
290            send_quinn
291                .write_all(&response)
292                .await
293                .map_err(|e| anyhow!("failed to send response: {}", e))?;
294        }
295
296        send_quinn.finish().await.map_err(|e| match e {
297            quinn::WriteError::ConnectionLost(e) => {
298                anyhow!("failed to shutdown stream: connection lost: {}", e)
299            }
300            quinn::WriteError::UnknownStream => todo!(),
301            _ => anyhow!("failed to shutdown stream: {}", e),
302        })?;
303
304        info!("complete");
305
306        Ok(())
307    }
308}