rasi_ext/net/quic/
pool.rs1use std::{
2 collections::HashMap,
3 fmt::Debug,
4 io,
5 net::{SocketAddr, ToSocketAddrs},
6 sync::Arc,
7};
8
9use quiche::ConnectionId;
10use rand::{seq::SliceRandom, thread_rng};
11use rasi::syscall::{global_network, Network};
12
13use crate::utils::AsyncSpinMutex;
14
15use super::{Config, QuicConn, QuicConnector, QuicStream};
16
17enum OpenStream {
18 Stream(QuicStream),
19 Connector(QuicConnector),
20}
21
22struct RawQuicConnPool {
23 config: Config,
24 conns: HashMap<ConnectionId<'static>, QuicConn>,
25}
26
27impl RawQuicConnPool {
28 async fn open_stream(
29 &mut self,
30 server_name: Option<&str>,
31 max_conns: usize,
32 raddrs: &[SocketAddr],
33 syscall: &'static dyn Network,
34 ) -> io::Result<OpenStream> {
35 let mut conns = self.conns.values().collect::<Vec<_>>();
36
37 conns.shuffle(&mut thread_rng());
38
39 let mut closed = vec![];
40
41 for conn in conns {
42 match conn.stream_open(true).await {
43 Ok(stream) => {
44 return Ok(OpenStream::Stream(stream));
45 }
46 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
47 continue;
48 }
49 Err(err) => {
50 log::error!("{}, open stream with error: {}", conn, err);
51 closed.push(conn.source_id().clone());
52 continue;
53 }
54 }
55 }
56
57 for id in closed {
58 self.conns.remove(&id);
59 }
60
61 if !(self.conns.len() < max_conns) {
62 return Err(io::Error::new(
63 io::ErrorKind::WouldBlock,
64 format!(
65 "Quic conn pool, max connections limits is reached({}).",
66 max_conns
67 ),
68 ));
69 }
70
71 let connector = QuicConnector::new_with(
72 server_name,
73 ["[::]:0".parse().unwrap(), "0.0.0.0:0".parse().unwrap()].as_slice(),
74 raddrs,
75 &mut self.config,
76 syscall,
77 )
78 .await?;
79
80 Ok(OpenStream::Connector(connector))
81 }
82}
83
84#[derive(Clone)]
86pub struct QuicConnPool {
87 server_name: Option<String>,
88 raddrs: Vec<SocketAddr>,
90 max_conns: usize,
92 inner: Arc<AsyncSpinMutex<RawQuicConnPool>>,
94 syscall: &'static dyn Network,
96}
97
98impl Debug for QuicConnPool {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 write!(
101 f,
102 "QuicConnPool, max_conns={}, raddrs={:?}",
103 self.max_conns, self.raddrs
104 )
105 }
106}
107
108impl QuicConnPool {
109 pub fn new_with<A: ToSocketAddrs>(
113 server_name: Option<&str>,
114 raddrs: A,
115 config: Config,
116 syscall: &'static dyn Network,
117 ) -> io::Result<Self> {
118 Ok(Self {
119 raddrs: raddrs.to_socket_addrs()?.collect::<Vec<_>>(),
120 max_conns: 100,
121 syscall,
122 server_name: server_name.map(str::to_string),
123 inner: Arc::new(AsyncSpinMutex::new(RawQuicConnPool {
124 config,
125 conns: Default::default(),
126 })),
127 })
128 }
129 pub fn new<A: ToSocketAddrs>(
132 server_name: Option<&str>,
133 raddrs: A,
134 config: Config,
135 ) -> io::Result<Self> {
136 Self::new_with(server_name, raddrs, config, global_network())
137 }
138
139 pub async fn stream_open(&self) -> io::Result<QuicStream> {
147 use crate::utils::AsyncLockable;
148
149 let connector = {
150 let mut inner = self.inner.lock().await;
151
152 match inner
153 .open_stream(
154 self.server_name.as_ref().map(String::as_str),
155 self.max_conns,
156 &self.raddrs,
157 self.syscall,
158 )
159 .await?
160 {
161 OpenStream::Stream(stream) => return Ok(stream),
162 OpenStream::Connector(connector) => connector,
163 }
164 };
165
166 let connection = connector.connect().await?;
169
170 let stream = connection.stream_open(true).await?;
171
172 let mut inner = self.inner.lock().await;
174
175 inner
176 .conns
177 .insert(connection.source_id().clone(), connection);
178
179 Ok(stream)
180 }
181
182 pub fn set_max_conns(&mut self, value: usize) -> &mut Self {
186 self.max_conns = value;
187 self
188 }
189}