abcperf_generic_client/cs/
quic_s2n.rs

1use std::{
2    future::Future,
3    net::{Ipv6Addr, SocketAddr},
4    pin::Pin,
5};
6
7use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use futures::FutureExt;
10use s2n_quic::{
11    client::Connect, provider::tls, stream::BidirectionalStream, Client, Connection, Server,
12};
13use serde::{Deserialize, Serialize};
14use tokio::{io::AsyncReadExt, select, sync::oneshot};
15use tracing::{error, info, info_span, Instrument};
16
17use crate::cs::{CSConfig, CSTrait, CSTraitClient, CSTraitClientConnection, CSTraitServer};
18
19#[derive(Debug, Clone, Copy, Default)]
20pub struct QuicS2N;
21
22impl CSTrait for QuicS2N {
23    type Config = QuicS2NConfig;
24
25    fn configure(self, ca: Vec<u8>, priv_key: Vec<u8>, cert: Vec<u8>) -> Self::Config {
26        QuicS2NConfig {
27            tls_client_config: tls::default::Client::builder()
28                .with_certificate(ca)
29                .unwrap()
30                .build()
31                .unwrap(),
32            tls_server_config: tls::default::Server::builder()
33                .with_certificate(cert, priv_key)
34                .unwrap()
35                .build()
36                .unwrap(),
37        }
38    }
39}
40
41#[derive(Clone)]
42pub struct QuicS2NConfig {
43    tls_client_config: tls::default::Client,
44    tls_server_config: tls::default::Server,
45}
46
47impl CSConfig for QuicS2NConfig {
48    type Client = QuicS2NClient;
49
50    fn client(&self) -> Self::Client {
51        let client = Client::builder()
52            .with_tls(self.tls_client_config.clone())
53            .unwrap()
54            .with_io((Ipv6Addr::UNSPECIFIED, 0))
55            .unwrap()
56            .start()
57            .unwrap();
58
59        QuicS2NClient { client }
60    }
61
62    type Server = QuicS2NServer;
63
64    fn server(&self, local_socket: SocketAddr) -> Self::Server {
65        QuicS2NServer {
66            local_socket,
67            tls_server_config: self.tls_server_config.clone(),
68        }
69    }
70}
71
72#[derive(Debug)]
73pub struct QuicS2NClient {
74    client: Client,
75}
76
77#[async_trait]
78impl CSTraitClient for QuicS2NClient {
79    type Connection = QuicS2NClientConnection;
80
81    async fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Self::Connection> {
82        let connection = self
83            .client
84            .connect(Connect::new(addr).with_server_name(server_name))
85            .await
86            .map_err(|e| anyhow!("failed to connect: {}", e))?;
87
88        Ok(QuicS2NClientConnection { connection })
89    }
90
91    fn local_addr(&self) -> SocketAddr {
92        self.client.local_addr().unwrap()
93    }
94}
95
96#[derive(Debug)]
97pub struct QuicS2NClientConnection {
98    connection: Connection,
99}
100
101#[async_trait]
102impl CSTraitClientConnection for QuicS2NClientConnection {
103    async fn request<Request: Serialize + Send + Sync, Response: for<'a> Deserialize<'a>>(
104        &mut self,
105        request: Request,
106    ) -> Result<Response> {
107        let mut stream = self
108            .connection
109            .open_bidirectional_stream()
110            .await
111            .map_err(|e| anyhow!("failed to open stream: {}", e))?;
112        stream
113            .send(bincode::serialize::<Request>(&request)?.into())
114            .await?;
115        stream.finish()?;
116
117        let mut resp = Vec::new();
118        stream
119            .read_to_end(&mut resp)
120            .await
121            .map_err(|e| anyhow!("failed to read response: {}", e))?;
122
123        Ok(bincode::deserialize::<Response>(&resp)?)
124    }
125}
126
127pub struct QuicS2NServer {
128    local_socket: SocketAddr,
129    tls_server_config: tls::default::Server,
130}
131
132impl CSTraitServer for QuicS2NServer {
133    fn start<
134        S: Clone + Send + Sync + 'static,
135        Request: for<'a> Deserialize<'a> + Send + 'static,
136        Response: Serialize + Send + 'static,
137        Fut: Future<Output = Option<Response>> + Send + 'static,
138        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
139    >(
140        self,
141        shared: S,
142        on_request: F,
143        exit: oneshot::Receiver<()>,
144    ) -> (SocketAddr, Pin<Box<dyn Future<Output = ()> + Send>>) {
145        let endpoint = Server::builder()
146            .with_tls(self.tls_server_config)
147            .unwrap()
148            .with_io(self.local_socket)
149            .unwrap()
150            .start()
151            .unwrap();
152
153        let local_addr = endpoint.local_addr().unwrap();
154
155        info!("listening on {}", local_addr);
156
157        let join_handle = tokio::spawn(Self::run::<S, Request, Response, Fut, F>(
158            endpoint, shared, on_request, exit,
159        ));
160
161        (local_addr, Box::pin(join_handle.map(|_| ())))
162    }
163}
164
165impl QuicS2NServer {
166    async fn run<
167        S: Clone + Send + Sync + 'static,
168        Request: for<'a> Deserialize<'a> + Send + 'static,
169        Response: Serialize + Send + 'static,
170        Fut: Future<Output = Option<Response>> + Send + 'static,
171        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
172    >(
173        mut endpoint: Server,
174        shared: S,
175        on_request: F,
176        mut exit: oneshot::Receiver<()>,
177    ) {
178        loop {
179            select! {
180                biased;
181                _ = &mut exit => break,
182                Some(conn) = endpoint.accept() => {
183                    info!("connection incoming");
184                    let fut = Self::handle_connection::<S, Request, Response, Fut, F>(
185                        conn,
186                        shared.clone(),
187                        on_request.clone(),
188                    );
189                    tokio::spawn(async move {
190                        if let Err(e) = fut.await {
191                            error!("connection failed: {reason}", reason = e.to_string())
192                        }
193                    });
194                },
195                else => break
196            }
197        }
198    }
199
200    async fn handle_connection<
201        S: Clone + Send + 'static,
202        Request: for<'a> Deserialize<'a> + Send + 'static,
203        Response: Serialize + Send + 'static,
204        Fut: Future<Output = Option<Response>> + Send + 'static,
205        F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
206    >(
207        mut connection: Connection,
208        shared: S,
209        on_request: F,
210    ) -> Result<()> {
211        let span = info_span!(
212            "connection",
213            remote = %connection.remote_addr().unwrap(),
214        );
215        async {
216            info!("established");
217
218            // Each stream initiated by the client constitutes a new request.
219            loop {
220                let stream = connection.accept_bidirectional_stream().await;
221                let stream = match stream {
222                    Err(e) => {
223                        return Err(e);
224                    }
225                    Ok(Some(s)) => s,
226                    Ok(None) => todo!(),
227                };
228                let fut = Self::handle_request::<S, Request, Response, Fut, F>(
229                    stream,
230                    shared.clone(),
231                    on_request.clone(),
232                );
233                tokio::spawn(
234                    async move {
235                        if let Err(e) = fut.await {
236                            error!("failed: {reason}", reason = e.to_string());
237                        }
238                    }
239                    .instrument(info_span!("request")),
240                );
241            }
242        }
243        .instrument(span)
244        .await?;
245        Ok(())
246    }
247
248    async fn handle_request<
249        S,
250        Request: for<'a> Deserialize<'a>,
251        Response: Serialize,
252        Fut: Future<Output = Option<Response>>,
253        F: Fn(S, Request) -> Fut,
254    >(
255        mut stream: BidirectionalStream,
256        shared: S,
257        on_request: F,
258    ) -> Result<()> {
259        let mut req = Vec::new();
260        stream
261            .read_to_end(&mut req)
262            .await
263            .map_err(|e| anyhow!("failed reading request: {}", e))?;
264        let req = req;
265
266        let req =
267            bincode::deserialize::<Request>(&req).expect("request was not of expected format");
268
269        let response = on_request(shared, req).await;
270
271        if let Some(response) = response {
272            let response = bincode::serialize::<Response>(&response)?;
273
274            stream
275                .send(response.into())
276                .await
277                .map_err(|e| anyhow!("failed to send response: {}", e))?;
278        }
279
280        stream.finish()?;
281
282        info!("complete");
283
284        Ok(())
285    }
286}