rasi_ext/net/quic/
pool.rs

1use std::{
2    collections::HashMap,
3    fmt::Debug,
4    io,
5    net::{SocketAddr, ToSocketAddrs},
6    sync::Arc,
7};
8
9use quiche::ConnectionId;
10use rand::{seq::SliceRandom, thread_rng};
11use rasi::syscall::{global_network, Network};
12
13use crate::utils::AsyncSpinMutex;
14
15use super::{Config, QuicConn, QuicConnector, QuicStream};
16
17enum OpenStream {
18    Stream(QuicStream),
19    Connector(QuicConnector),
20}
21
22struct RawQuicConnPool {
23    config: Config,
24    conns: HashMap<ConnectionId<'static>, QuicConn>,
25}
26
27impl RawQuicConnPool {
28    async fn open_stream(
29        &mut self,
30        server_name: Option<&str>,
31        max_conns: usize,
32        raddrs: &[SocketAddr],
33        syscall: &'static dyn Network,
34    ) -> io::Result<OpenStream> {
35        let mut conns = self.conns.values().collect::<Vec<_>>();
36
37        conns.shuffle(&mut thread_rng());
38
39        let mut closed = vec![];
40
41        for conn in conns {
42            match conn.stream_open(true).await {
43                Ok(stream) => {
44                    return Ok(OpenStream::Stream(stream));
45                }
46                Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
47                    continue;
48                }
49                Err(err) => {
50                    log::error!("{}, open stream with error: {}", conn, err);
51                    closed.push(conn.source_id().clone());
52                    continue;
53                }
54            }
55        }
56
57        for id in closed {
58            self.conns.remove(&id);
59        }
60
61        if !(self.conns.len() < max_conns) {
62            return Err(io::Error::new(
63                io::ErrorKind::WouldBlock,
64                format!(
65                    "Quic conn pool, max connections limits is reached({}).",
66                    max_conns
67                ),
68            ));
69        }
70
71        let connector = QuicConnector::new_with(
72            server_name,
73            ["[::]:0".parse().unwrap(), "0.0.0.0:0".parse().unwrap()].as_slice(),
74            raddrs,
75            &mut self.config,
76            syscall,
77        )
78        .await?;
79
80        Ok(OpenStream::Connector(connector))
81    }
82}
83
84/// The quic connection pool implementation.
85#[derive(Clone)]
86pub struct QuicConnPool {
87    server_name: Option<String>,
88    /// Peer addresses.
89    raddrs: Vec<SocketAddr>,
90    /// The maximum number of connections this pool can hold.
91    max_conns: usize,
92    /// mixin [`RawQuicConnPool`]
93    inner: Arc<AsyncSpinMutex<RawQuicConnPool>>,
94    /// syscall instance.
95    syscall: &'static dyn Network,
96}
97
98impl Debug for QuicConnPool {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        write!(
101            f,
102            "QuicConnPool, max_conns={}, raddrs={:?}",
103            self.max_conns, self.raddrs
104        )
105    }
106}
107
108impl QuicConnPool {
109    /// Create new [`QuicConn`] pool with custom [`syscall`](Network) interface.
110    ///
111    /// See [`new`](Self::new) for more information.
112    pub fn new_with<A: ToSocketAddrs>(
113        server_name: Option<&str>,
114        raddrs: A,
115        config: Config,
116        syscall: &'static dyn Network,
117    ) -> io::Result<Self> {
118        Ok(Self {
119            raddrs: raddrs.to_socket_addrs()?.collect::<Vec<_>>(),
120            max_conns: 100,
121            syscall,
122            server_name: server_name.map(str::to_string),
123            inner: Arc::new(AsyncSpinMutex::new(RawQuicConnPool {
124                config,
125                conns: Default::default(),
126            })),
127        })
128    }
129    /// Create new [`QuicConn`] pool with global [`syscall`](Network) interface.
130    ///
131    pub fn new<A: ToSocketAddrs>(
132        server_name: Option<&str>,
133        raddrs: A,
134        config: Config,
135    ) -> io::Result<Self> {
136        Self::new_with(server_name, raddrs, config, global_network())
137    }
138
139    /// Open new outbound stream.
140    ///
141    /// This function will randomly select a connection from the pool
142    /// and open a new outbound stream.
143    ///
144    /// If necessary, a new Quic connection will be created.
145    /// If the `max_conns` is reached, returns the [`WouldBlock`](io::ErrorKind::WouldBlock) error.
146    pub async fn stream_open(&self) -> io::Result<QuicStream> {
147        use crate::utils::AsyncLockable;
148
149        let connector = {
150            let mut inner = self.inner.lock().await;
151
152            match inner
153                .open_stream(
154                    self.server_name.as_ref().map(String::as_str),
155                    self.max_conns,
156                    &self.raddrs,
157                    self.syscall,
158                )
159                .await?
160            {
161                OpenStream::Stream(stream) => return Ok(stream),
162                OpenStream::Connector(connector) => connector,
163            }
164        };
165
166        // performs real connecting process.
167
168        let connection = connector.connect().await?;
169
170        let stream = connection.stream_open(true).await?;
171
172        // relock inner.
173        let mut inner = self.inner.lock().await;
174
175        inner
176            .conns
177            .insert(connection.source_id().clone(), connection);
178
179        Ok(stream)
180    }
181
182    /// Set `max_conns` parameter.
183    ///
184    /// The default value is `100`
185    pub fn set_max_conns(&mut self, value: usize) -> &mut Self {
186        self.max_conns = value;
187        self
188    }
189}