abcperf_generic_client/cs/
quic_quinn.rs1use std::{
2 future::Future,
3 net::{Ipv6Addr, SocketAddr},
4 pin::Pin,
5 sync::Arc,
6 time::Duration,
7};
8
9use anyhow::{anyhow, Result};
10use async_trait::async_trait;
11use futures::FutureExt;
12use quinn::{
13 ClientConfig, Connecting, Connection, Endpoint, RecvStream, SendStream, ServerConfig,
14 TransportConfig, VarInt,
15};
16use rustls::{Certificate, PrivateKey, RootCertStore};
17use serde::{Deserialize, Serialize};
18use tokio::{select, sync::oneshot};
19use tracing::{error, info, info_span, Instrument};
20
21use crate::cs::{CSConfig, CSTrait, CSTraitClient, CSTraitClientConnection, CSTraitServer};
22
23const READ_TO_END_LIMIT: usize = 128 * 1024 * 1024;
24
25#[derive(Debug, Clone, Copy, Default)]
26pub struct QuicQuinn;
27
28impl CSTrait for QuicQuinn {
29 type Config = QuicQuinnConfig;
30
31 fn configure(self, ca: Vec<u8>, priv_key: Vec<u8>, cert: Vec<u8>) -> Self::Config {
32 let mut roots = RootCertStore::empty();
33 roots.add(&Certificate(ca)).unwrap();
34 let mut client_config = ClientConfig::with_root_certificates(roots);
35
36 let mut transport_config = TransportConfig::default();
37 transport_config.max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
38 transport_config.keep_alive_interval(Some(Duration::from_secs(10)));
39 client_config.transport_config(Arc::new(transport_config));
40
41 let priv_key = PrivateKey(priv_key);
42 let cert_chain = vec![Certificate(cert)];
43 let mut server_config = ServerConfig::with_single_cert(cert_chain, priv_key).unwrap();
44 Arc::get_mut(&mut server_config.transport)
45 .expect("config was just created")
46 .max_idle_timeout(Some(VarInt::from_u32(30_000).into()));
47
48 QuicQuinnConfig {
49 client_config,
50 server_config,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
56pub struct QuicQuinnConfig {
57 client_config: ClientConfig,
58 server_config: ServerConfig,
59}
60
61impl CSConfig for QuicQuinnConfig {
62 type Client = QuicQuinnClient;
63
64 fn client(&self) -> Self::Client {
65 let mut client = Endpoint::client((Ipv6Addr::UNSPECIFIED, 0).into()).unwrap();
66 client.set_default_client_config(self.client_config.clone());
67
68 QuicQuinnClient { client }
69 }
70
71 type Server = QuicQuinnServer;
72
73 fn server(&self, local_socket: SocketAddr) -> Self::Server {
74 QuicQuinnServer {
75 local_socket,
76 client_config: self.client_config.clone(),
77 server_config: self.server_config.clone(),
78 }
79 }
80}
81
82#[derive(Debug)]
83pub struct QuicQuinnClient {
84 client: Endpoint,
85}
86
87#[async_trait]
88impl CSTraitClient for QuicQuinnClient {
89 type Connection = QuicQuinnClientConnection;
90
91 async fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Self::Connection> {
92 let connection = self
93 .client
94 .connect(addr, server_name)?
95 .await
96 .map_err(|e| anyhow!("failed to connect: {}", e))?;
97
98 Ok(QuicQuinnClientConnection { connection })
99 }
100
101 fn local_addr(&self) -> SocketAddr {
102 self.client.local_addr().unwrap()
103 }
104}
105
106#[derive(Debug)]
107pub struct QuicQuinnClientConnection {
108 connection: Connection,
109}
110
111#[async_trait]
112impl CSTraitClientConnection for QuicQuinnClientConnection {
113 async fn request<Request: Serialize + Send + Sync, Response: for<'a> Deserialize<'a>>(
114 &mut self,
115 request: Request,
116 ) -> Result<Response> {
117 let (mut send, mut recv) = self
118 .connection
119 .open_bi()
120 .await
121 .map_err(|e| anyhow!("failed to open stream: {}", e))?;
122 send.write_all(&bincode::serialize::<Request>(&request)?)
123 .await
124 .map_err(|e| anyhow!("failed to send request: {}", e))?;
125 send.finish().await.map_err(|e| match e {
126 quinn::WriteError::ConnectionLost(e) => {
127 anyhow!("failed to shutdown stream: connection lost: {}", e)
128 }
129 quinn::WriteError::UnknownStream => todo!(),
130 _ => anyhow!("failed to shutdown stream: {}", e),
131 })?;
132
133 let resp = recv
134 .read_to_end(READ_TO_END_LIMIT)
135 .await
136 .map_err(|e| anyhow!("failed to read response: {}", e))?;
137
138 Ok(bincode::deserialize::<Response>(&resp)?)
139 }
140}
141
142#[derive(Debug)]
143pub struct QuicQuinnServer {
144 local_socket: SocketAddr,
145 client_config: ClientConfig,
146 server_config: ServerConfig,
147}
148
149impl CSTraitServer for QuicQuinnServer {
150 fn start<
151 S: Clone + Send + Sync + 'static,
152 Request: for<'a> Deserialize<'a> + Send + 'static,
153 Response: Serialize + Send + 'static,
154 Fut: Future<Output = Option<Response>> + Send + 'static,
155 F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
156 >(
157 self,
158 shared: S,
159 on_request: F,
160 exit: oneshot::Receiver<()>,
161 ) -> (SocketAddr, Pin<Box<dyn Future<Output = ()> + Send>>) {
162 let mut endpoint = Endpoint::server(self.server_config, self.local_socket).unwrap();
163 endpoint.set_default_client_config(self.client_config);
164
165 let local_addr = endpoint.local_addr().unwrap();
166
167 info!("listening on {}", local_addr);
168
169 let join_handle = tokio::spawn(Self::run::<S, Request, Response, Fut, F>(
170 endpoint, shared, on_request, exit,
171 ));
172
173 (local_addr, Box::pin(join_handle.map(|_| ())))
174 }
175}
176
177impl QuicQuinnServer {
178 async fn run<
179 S: Clone + Send + Sync + 'static,
180 Request: for<'a> Deserialize<'a> + Send + 'static,
181 Response: Serialize + Send + 'static,
182 Fut: Future<Output = Option<Response>> + Send + 'static,
183 F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
184 >(
185 endpoint: Endpoint,
186 shared: S,
187 on_request: F,
188 mut exit: oneshot::Receiver<()>,
189 ) {
190 loop {
191 select! {
192 biased;
193 _ = &mut exit => break,
194 Some(conn) = endpoint.accept() => {
195 info!("connection incoming");
196 let fut = Self::handle_connection::<S, Request, Response, Fut, F>(
197 conn,
198 shared.clone(),
199 on_request.clone(),
200 );
201 tokio::spawn(async move {
202 if let Err(e) = fut.await {
203 error!("connection failed: {}", e.to_string())
204 }
205 });
206 },
207 else => break
208 }
209 }
210 }
211
212 async fn handle_connection<
213 S: Clone + Send + 'static,
214 Request: for<'a> Deserialize<'a> + Send + 'static,
215 Response: Serialize + Send + 'static,
216 Fut: Future<Output = Option<Response>> + Send + 'static,
217 F: Fn(S, Request) -> Fut + Send + Sync + Clone + 'static,
218 >(
219 connection: Connecting,
220
221 shared: S,
222 on_request: F,
223 ) -> Result<()> {
224 let connection = connection.await?;
225 let span = info_span!(
226 "connection",
227 remote = %connection.remote_address(),
228 );
229 async {
230 info!("established");
231
232 loop {
234 let stream = connection.accept_bi().await;
235 let stream = match stream {
236 Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
237 info!("connection closed");
238 return Ok(());
239 }
240
241 Err(e) => {
242 return Err(e);
243 }
244 Ok(s) => s,
245 };
246 let fut = Self::handle_request::<S, Request, Response, Fut, F>(
247 stream,
248 shared.clone(),
249 on_request.clone(),
250 );
251 tokio::spawn(
252 async move {
253 if let Err(e) = fut.await {
254 error!("failed: {reason}", reason = e.to_string());
255 }
256 }
257 .instrument(info_span!("request")),
258 );
259 }
260 }
261 .instrument(span)
262 .await?;
263 Ok(())
264 }
265
266 async fn handle_request<
267 S,
268 Request: for<'a> Deserialize<'a>,
269 Response: Serialize,
270 Fut: Future<Output = Option<Response>>,
271 F: Fn(S, Request) -> Fut,
272 >(
273 (mut send_quinn, mut recv_quinn): (SendStream, RecvStream),
274 shared: S,
275 on_request: F,
276 ) -> Result<()> {
277 let req = recv_quinn
278 .read_to_end(READ_TO_END_LIMIT)
279 .await
280 .map_err(|e| anyhow!("failed reading request: {}", e))?;
281
282 let req =
283 bincode::deserialize::<Request>(&req).expect("request was not of expected format");
284
285 let response = on_request(shared, req).await;
286
287 if let Some(response) = response {
288 let response = bincode::serialize::<Response>(&response)?;
289
290 send_quinn
291 .write_all(&response)
292 .await
293 .map_err(|e| anyhow!("failed to send response: {}", e))?;
294 }
295
296 send_quinn.finish().await.map_err(|e| match e {
297 quinn::WriteError::ConnectionLost(e) => {
298 anyhow!("failed to shutdown stream: connection lost: {}", e)
299 }
300 quinn::WriteError::UnknownStream => todo!(),
301 _ => anyhow!("failed to shutdown stream: {}", e),
302 })?;
303
304 info!("complete");
305
306 Ok(())
307 }
308}