abcperf_generic_client/cs/
quic_s2n.rs1use std::{
2 future::Future,
3 net::{Ipv6Addr, SocketAddr},
4 pin::Pin,
5};
6
7use anyhow::{anyhow, Result};
8use async_trait::async_trait;
9use futures::FutureExt;
10use s2n_quic::{
11 client::Connect, provider::tls, stream::BidirectionalStream, Client, Connection, Server,
12};
13use serde::{Deserialize, Serialize};
14use tokio::{io::AsyncReadExt, select, sync::oneshot};
15use tracing::{error, info, info_span, Instrument};
16
17use crate::cs::{CSConfig, CSTrait, CSTraitClient, CSTraitClientConnection, CSTraitServer};
18
19#[derive(Debug, Clone, Copy, Default)]
20pub struct QuicS2N;
21
22impl CSTrait for QuicS2N {
23 type Config = QuicS2NConfig;
24
25 fn configure(self, ca: Vec<u8>, priv_key: Vec<u8>, cert: Vec<u8>) -> Self::Config {
26 QuicS2NConfig {
27 tls_client_config: tls::default::Client::builder()
28 .with_certificate(ca)
29 .unwrap()
30 .build()
31 .unwrap(),
32 tls_server_config: tls::default::Server::builder()
33 .with_certificate(cert, priv_key)
34 .unwrap()
35 .build()
36 .unwrap(),
37 }
38 }
39}
40
41#[derive(Clone)]
42pub struct QuicS2NConfig {
43 tls_client_config: tls::default::Client,
44 tls_server_config: tls::default::Server,
45}
46
47impl CSConfig for QuicS2NConfig {
48 type Client = QuicS2NClient;
49
50 fn client(&self) -> Self::Client {
51 let client = Client::builder()
52 .with_tls(self.tls_client_config.clone())
53 .unwrap()
54 .with_io((Ipv6Addr::UNSPECIFIED, 0))
55 .unwrap()
56 .start()
57 .unwrap();
58
59 QuicS2NClient { client }
60 }
61
62 type Server = QuicS2NServer;
63
64 fn server(&self, local_socket: SocketAddr) -> Self::Server {
65 QuicS2NServer {
66 local_socket,
67 tls_server_config: self.tls_server_config.clone(),
68 }
69 }
70}
71
72#[derive(Debug)]
73pub struct QuicS2NClient {
74 client: Client,
75}
76
77#[async_trait]
78impl CSTraitClient for QuicS2NClient {
79 type Connection = QuicS2NClientConnection;
80
81 async fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Self::Connection> {
82 let connection = self
83 .client
84 .connect(Connect::new(addr).with_server_name(server_name))
85 .await
86 .map_err(|e| anyhow!("failed to connect: {}", e))?;
87
88 Ok(QuicS2NClientConnection { connection })
89 }
90
91 fn local_addr(&self) -> SocketAddr {
92 self.client.local_addr().unwrap()
93 }
94}
95
96#[derive(Debug)]
97pub struct QuicS2NClientConnection {
98 connection: Connection,
99}
100
101#[async_trait]
102impl CSTraitClientConnection for QuicS2NClientConnection {
103 async fn request<Request: Serialize + Send + Sync, Response: for<'a> Deserialize<'a>>(
104 &mut self,
105 request: Request,
106 ) -> Result<Response> {
107 let mut stream = self
108 .connection
109 .open_bidirectional_stream()
110 .await
111 .map_err(|e| anyhow!("failed to open stream: {}", e))?;
112 stream
113 .send(bincode::serialize::<Request>(&request)?.into())
114 .await?;
115 stream.finish()?;
116
117 let mut resp = Vec::new();
118 stream
119 .read_to_end(&mut resp)
120 .await
121 .map_err(|e| anyhow!("failed to read response: {}", e))?;
122
123 Ok(bincode::deserialize::<Response>(&resp)?)
124 }
125}
126
127pub struct QuicS2NServer {
128 local_socket: SocketAddr,
129 tls_server_config: tls::default::Server,
130}
131
132impl CSTraitServer for QuicS2NServer {
133 fn start<
134 S: Clone + Send + Sync + 'static,
135 Request: for<'a> Deserialize<'a> + Send + 'static,
136 Response: Serialize + Send + 'static,
137 Fut: Future<Output = Option<Response>> + Send + 'static,
138 F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
139 >(
140 self,
141 shared: S,
142 on_request: F,
143 exit: oneshot::Receiver<()>,
144 ) -> (SocketAddr, Pin<Box<dyn Future<Output = ()> + Send>>) {
145 let endpoint = Server::builder()
146 .with_tls(self.tls_server_config)
147 .unwrap()
148 .with_io(self.local_socket)
149 .unwrap()
150 .start()
151 .unwrap();
152
153 let local_addr = endpoint.local_addr().unwrap();
154
155 info!("listening on {}", local_addr);
156
157 let join_handle = tokio::spawn(Self::run::<S, Request, Response, Fut, F>(
158 endpoint, shared, on_request, exit,
159 ));
160
161 (local_addr, Box::pin(join_handle.map(|_| ())))
162 }
163}
164
165impl QuicS2NServer {
166 async fn run<
167 S: Clone + Send + Sync + 'static,
168 Request: for<'a> Deserialize<'a> + Send + 'static,
169 Response: Serialize + Send + 'static,
170 Fut: Future<Output = Option<Response>> + Send + 'static,
171 F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
172 >(
173 mut endpoint: Server,
174 shared: S,
175 on_request: F,
176 mut exit: oneshot::Receiver<()>,
177 ) {
178 loop {
179 select! {
180 biased;
181 _ = &mut exit => break,
182 Some(conn) = endpoint.accept() => {
183 info!("connection incoming");
184 let fut = Self::handle_connection::<S, Request, Response, Fut, F>(
185 conn,
186 shared.clone(),
187 on_request.clone(),
188 );
189 tokio::spawn(async move {
190 if let Err(e) = fut.await {
191 error!("connection failed: {reason}", reason = e.to_string())
192 }
193 });
194 },
195 else => break
196 }
197 }
198 }
199
200 async fn handle_connection<
201 S: Clone + Send + 'static,
202 Request: for<'a> Deserialize<'a> + Send + 'static,
203 Response: Serialize + Send + 'static,
204 Fut: Future<Output = Option<Response>> + Send + 'static,
205 F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
206 >(
207 mut connection: Connection,
208 shared: S,
209 on_request: F,
210 ) -> Result<()> {
211 let span = info_span!(
212 "connection",
213 remote = %connection.remote_addr().unwrap(),
214 );
215 async {
216 info!("established");
217
218 loop {
220 let stream = connection.accept_bidirectional_stream().await;
221 let stream = match stream {
222 Err(e) => {
223 return Err(e);
224 }
225 Ok(Some(s)) => s,
226 Ok(None) => todo!(),
227 };
228 let fut = Self::handle_request::<S, Request, Response, Fut, F>(
229 stream,
230 shared.clone(),
231 on_request.clone(),
232 );
233 tokio::spawn(
234 async move {
235 if let Err(e) = fut.await {
236 error!("failed: {reason}", reason = e.to_string());
237 }
238 }
239 .instrument(info_span!("request")),
240 );
241 }
242 }
243 .instrument(span)
244 .await?;
245 Ok(())
246 }
247
248 async fn handle_request<
249 S,
250 Request: for<'a> Deserialize<'a>,
251 Response: Serialize,
252 Fut: Future<Output = Option<Response>>,
253 F: Fn(S, Request) -> Fut,
254 >(
255 mut stream: BidirectionalStream,
256 shared: S,
257 on_request: F,
258 ) -> Result<()> {
259 let mut req = Vec::new();
260 stream
261 .read_to_end(&mut req)
262 .await
263 .map_err(|e| anyhow!("failed reading request: {}", e))?;
264 let req = req;
265
266 let req =
267 bincode::deserialize::<Request>(&req).expect("request was not of expected format");
268
269 let response = on_request(shared, req).await;
270
271 if let Some(response) = response {
272 let response = bincode::serialize::<Response>(&response)?;
273
274 stream
275 .send(response.into())
276 .await
277 .map_err(|e| anyhow!("failed to send response: {}", e))?;
278 }
279
280 stream.finish()?;
281
282 info!("complete");
283
284 Ok(())
285 }
286}