Skip to main content

corro_client/
lib.rs

1pub mod sub;
2
3use corro_api_types::{
4    ChangeId, ExecResponse, ExecResult, SqliteValue, Statement, QUERY_HASH_HEADER, QUERY_ID_HEADER,
5};
6use hickory_resolver::net::NetError as ResolveError;
7use serde::de::DeserializeOwned;
8use std::{
9    fmt::Write as _,
10    net::SocketAddr,
11    ops::Deref,
12    path::Path,
13    sync::Arc,
14    time::{self, Duration, Instant},
15};
16use sub::{QueryStream, SubscriptionStream, UpdatesStream};
17use tokio::{
18    sync::{RwLock, RwLockReadGuard},
19    time::timeout,
20};
21use tracing::{debug, info};
22use uuid::Uuid;
23
24const HTTP2_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
25const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(10);
26const DNS_RESOLVE_TIMEOUT: Duration = Duration::from_secs(3);
27
28type Resolver = hickory_resolver::Resolver<hickory_resolver::net::runtime::TokioRuntimeProvider>;
29
30/// Single-address Corrosion HTTP API client.
31///
32#[derive(Clone)]
33pub struct CorrosionApiClient {
34    api_addr: SocketAddr,
35    api_client: reqwest::Client,
36}
37
38impl CorrosionApiClient {
39    pub fn new(api_addr: SocketAddr) -> Result<Self, reqwest::Error> {
40        Ok(Self {
41            api_addr,
42            api_client: reqwest::ClientBuilder::new()
43                .http2_prior_knowledge()
44                .connect_timeout(HTTP2_CONNECT_TIMEOUT)
45                .http2_keep_alive_interval(Some(HTTP2_KEEP_ALIVE_INTERVAL))
46                .http2_keep_alive_timeout(HTTP2_KEEP_ALIVE_INTERVAL / 2)
47                .build()?,
48        })
49    }
50
51    /// Execute a single query against a Corrosion node, deserializing each row into `T`.
52    /// Optionally accepts a timeout for the request.
53    ///
54    /// Calls the `/v1/queries` endpoint (<https://superfly.github.io/corrosion/api/queries.html>).
55    pub async fn query_typed<T: DeserializeOwned + Unpin>(
56        &self,
57        statement: &Statement,
58        timeout: Option<u64>,
59    ) -> Result<QueryStream<T>, Error> {
60        let mut uri = format!("http://{}/v1/queries", self.api_addr);
61
62        if let Some(t) = timeout {
63            write!(&mut uri, "?timeout={t}").unwrap();
64        }
65
66        let res = self
67            .api_client
68            .post(uri)
69            .header(http::header::CONTENT_TYPE, "application/json")
70            .header(http::header::ACCEPT, "application/json")
71            .body(serde_json::to_vec(statement)?)
72            .send()
73            .await?;
74
75        if !res.status().is_success() {
76            let status = res.status();
77            match res.bytes().await {
78                Ok(b) => match serde_json::from_slice(&b) {
79                    Ok(res) => match res {
80                        ExecResult::Error { error } => return Err(Error::ResponseError(error)),
81                        res => return Err(Error::UnexpectedResult(res)),
82                    },
83                    Err(error) => {
84                        debug!(
85                            %error,
86                            "could not deserialize response body, sending generic error..."
87                        );
88                        return Err(Error::UnexpectedStatusCode(status));
89                    }
90                },
91                Err(error) => {
92                    debug!(
93                        %error,
94                        "could not aggregate response body bytes, sending generic error..."
95                    );
96                    return Err(Error::UnexpectedStatusCode(status));
97                }
98            }
99        }
100
101        Ok(QueryStream::new(res.into()))
102    }
103
104    /// Same as [`Self::query_typed`], but returns each row as a
105    /// `Vec<SqliteValue>`.
106    pub async fn query(
107        &self,
108        statement: &Statement,
109        timeout: Option<u64>,
110    ) -> Result<QueryStream<Vec<SqliteValue>>, Error> {
111        self.query_typed(statement, timeout).await
112    }
113
114    /// Create a new subscription and stream query updates, deserializing rows into T.
115    /// * `skip_rows` — when `true`, the initial rows are skipped and only changes are streamed.
116    /// * `from` — when set, resume the subscription past the given `ChangeId` instead of producing a fresh snapshot.
117    ///
118    /// Calls the `/v1/subscriptions` endpoint (<https://superfly.github.io/corrosion/api/subscriptions.html>).
119    pub async fn subscribe_typed<T: DeserializeOwned + Unpin>(
120        &self,
121        statement: &Statement,
122        skip_rows: bool,
123        from: Option<ChangeId>,
124    ) -> Result<SubscriptionStream<T>, Error> {
125        let mut uri = format!(
126            "http://{}/v1/subscriptions?skip_rows={skip_rows}",
127            self.api_addr
128        );
129
130        if let Some(change_id) = from {
131            write!(&mut uri, "&from={change_id}").unwrap();
132        }
133
134        let res = self
135            .api_client
136            .post(uri)
137            .header(http::header::CONTENT_TYPE, "application/json")
138            .header(http::header::ACCEPT, "application/json")
139            .body(serde_json::to_vec(statement)?)
140            .send()
141            .await?;
142
143        if !res.status().is_success() {
144            return Err(Error::UnexpectedStatusCode(res.status()));
145        }
146
147        let id = res
148            .headers()
149            .get(QUERY_ID_HEADER)
150            .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
151            .ok_or(Error::ExpectedQueryId)?;
152        let hash = res
153            .headers()
154            .get(QUERY_HASH_HEADER)
155            .and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
156
157        Ok(SubscriptionStream::new(
158            id,
159            hash,
160            self.api_client.clone(),
161            self.api_addr,
162            res.into(),
163            from,
164        ))
165    }
166
167    /// Same as [`Self::subscribe_typed`], but returns each row as a
168    /// `Vec<SqliteValue>`.
169    pub async fn subscribe(
170        &self,
171        statement: &Statement,
172        skip_rows: bool,
173        from: Option<ChangeId>,
174    ) -> Result<SubscriptionStream<Vec<SqliteValue>>, Error> {
175        self.subscribe_typed(statement, skip_rows, from).await
176    }
177
178    /// Reconnect to an existing subscription identified by its `Uuid`.
179    pub async fn subscription_typed<T: DeserializeOwned + Unpin>(
180        &self,
181        id: Uuid,
182        skip_rows: bool,
183        from: Option<ChangeId>,
184    ) -> Result<SubscriptionStream<T>, Error> {
185        let mut uri = format!(
186            "http://{}/v1/subscriptions/{id}?skip_rows={skip_rows}",
187            self.api_addr
188        );
189
190        if let Some(change_id) = from {
191            write!(&mut uri, "&from={change_id}").unwrap();
192        }
193
194        let res = self
195            .api_client
196            .get(uri)
197            .header(http::header::ACCEPT, "application/json")
198            .send()
199            .await?;
200
201        if !res.status().is_success() {
202            return Err(Error::UnexpectedStatusCode(res.status()));
203        }
204
205        let hash = res
206            .headers()
207            .get(QUERY_HASH_HEADER)
208            .and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
209
210        Ok(SubscriptionStream::new(
211            id,
212            hash,
213            self.api_client.clone(),
214            self.api_addr,
215            res.into(),
216            from,
217        ))
218    }
219
220    /// Same as [`Self::subscription_typed`], but returns each row as a
221    /// `Vec<SqliteValue>`.
222    pub async fn subscription(
223        &self,
224        id: Uuid,
225        skip_rows: bool,
226        from: Option<ChangeId>,
227    ) -> Result<SubscriptionStream<Vec<SqliteValue>>, Error> {
228        self.subscription_typed(id, skip_rows, from).await
229    }
230
231    /// Subscribe to row-level changes on a single table.
232    ///
233    /// Calls the `/v1/updates/{table}` endpoint (<https://superfly.github.io/corrosion/api/updates.html>).
234    pub async fn updates_typed<T: DeserializeOwned + Unpin>(
235        &self,
236        table: &str,
237    ) -> Result<UpdatesStream<T>, Error> {
238        let res = self
239            .api_client
240            .post(format!("http://{}/v1/updates/{table}", self.api_addr))
241            .header(http::header::CONTENT_TYPE, "application/json")
242            .header(http::header::ACCEPT, "application/json")
243            .send()
244            .await?;
245
246        if !res.status().is_success() {
247            return Err(Error::UnexpectedStatusCode(res.status()));
248        }
249
250        let id = res
251            .headers()
252            .get(QUERY_ID_HEADER)
253            .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
254            .ok_or(Error::ExpectedQueryId)?;
255
256        Ok(UpdatesStream::new(id, res.into()))
257    }
258
259    /// Same as [`Self::updates_typed`], but returns each row as a
260    /// `Vec<SqliteValue>`.
261    pub async fn updates(&self, table: &str) -> Result<UpdatesStream<Vec<SqliteValue>>, Error> {
262        self.updates_typed(table).await
263    }
264
265    /// Execute one or more SQL statements in a single transaction.
266    ///
267    /// Calls the `/v1/transactions` endpoint (<https://superfly.github.io/corrosion/api/transactions.html>).
268    pub async fn execute(
269        &self,
270        statements: &[Statement],
271        timeout: Option<u64>,
272    ) -> Result<ExecResponse, Error> {
273        let uri = if let Some(timeout) = timeout {
274            format!("http://{}/v1/transactions?timeout={timeout}", self.api_addr)
275        } else {
276            format!("http://{}/v1/transactions", self.api_addr)
277        };
278        // println!("uri: {:?}", uri);
279        let res = self
280            .api_client
281            .post(uri)
282            .header(http::header::CONTENT_TYPE, "application/json")
283            .header(http::header::ACCEPT, "application/json")
284            .body(serde_json::to_vec(statements)?)
285            .send()
286            .await?;
287
288        let status = res.status();
289        if !status.is_success() {
290            match res.bytes().await {
291                Ok(b) => match serde_json::from_slice(&b) {
292                    Ok(ExecResponse { results, .. }) => {
293                        if let Some(ExecResult::Error { error }) = results
294                            .into_iter()
295                            .find(|r| matches!(r, ExecResult::Error { .. }))
296                        {
297                            return Err(Error::ResponseError(error));
298                        }
299                        return Err(Error::UnexpectedStatusCode(status));
300                    }
301                    Err(error) => {
302                        debug!(
303                            %error,
304                            "could not deserialize response body, sending generic error..."
305                        );
306                        return Err(Error::UnexpectedStatusCode(status));
307                    }
308                },
309                Err(error) => {
310                    debug!(
311                        %error,
312                        "could not aggregate response body bytes, sending generic error..."
313                    );
314                    return Err(Error::UnexpectedStatusCode(status));
315                }
316            }
317        }
318
319        Ok(serde_json::from_slice(&res.bytes().await?)?)
320    }
321}
322
323/// Convenience client that combines a [`CorrosionApiClient`] with a local
324/// SQLite connection pool.
325#[derive(Clone)]
326pub struct CorrosionClient {
327    api_client: CorrosionApiClient,
328    pool: sqlite_pool::RusqlitePool,
329}
330
331impl CorrosionClient {
332    pub fn new<P: AsRef<Path>>(api_addr: SocketAddr, db_path: P) -> Result<Self, reqwest::Error> {
333        Ok(Self {
334            api_client: CorrosionApiClient::new(api_addr)?,
335            pool: sqlite_pool::Config::new(db_path.as_ref())
336                .max_size(5)
337                .create_pool()
338                .expect("could not build pool, this can't fail because we specified a runtime"),
339        })
340    }
341
342    pub fn with_sqlite_pool(
343        api_addr: SocketAddr,
344        pool: sqlite_pool::RusqlitePool,
345    ) -> Result<Self, reqwest::Error> {
346        Ok(Self {
347            api_client: CorrosionApiClient::new(api_addr)?,
348            pool,
349        })
350    }
351
352    /// Borrow the SQLite connection pool used for direct reads.
353    pub fn pool(&self) -> &sqlite_pool::RusqlitePool {
354        &self.pool
355    }
356}
357
358impl Deref for CorrosionClient {
359    type Target = CorrosionApiClient;
360
361    fn deref(&self) -> &Self::Target {
362        &self.api_client
363    }
364}
365
366/// Client to connect to a pool of Corrosion nodes.
367///
368/// Selects the first address from the list and tries to connect to it.
369/// On I/O errors the client falls back to the next address; once a request succeeds the client
370/// "sticks" to that peer until it has been failing continuously for the
371/// configured `stickiness_timeout`.
372#[derive(Clone)]
373pub struct CorrosionPooledClient {
374    inner: Arc<RwLock<PooledClientInner>>,
375}
376
377struct PooledClientInner {
378    picker: AddrPicker,
379
380    // For how long to stick with a chosen server
381    stickiness_timeout: time::Duration,
382    // Currently chosen client
383    client: Option<CorrosionApiClient>,
384    // Whether or not the chosen client has made a successful request
385    had_success: bool,
386    // Time when the first fail occurred after at least one successful call
387    first_fail_at: Option<Instant>,
388    // Current client generation, incremented after each client change
389    generation: u64,
390}
391
392impl CorrosionPooledClient {
393    /// Build a new pooled client.
394    ///
395    /// * `addrs` — ordered list of agent addresses. Each entry can be either
396    ///   a `host:port` string (resolved through `resolver`) or an already
397    ///   resolved `SocketAddr` formatted as a string. Entries are tried in
398    ///   the supplied order; once a peer has been successful the client
399    ///   prefers it for as long as it stays healthy.
400    /// * `stickiness_timeout` — how long the client keeps retrying a
401    ///   previously successful peer after it starts failing before rotating
402    ///   to the next address.
403    /// * `resolver` — DNS resolver used to translate hostnames to addresses.
404    pub fn new(addrs: Vec<String>, stickiness_timeout: time::Duration, resolver: Resolver) -> Self {
405        Self {
406            inner: Arc::new(RwLock::new(PooledClientInner {
407                picker: AddrPicker::new(addrs, resolver),
408
409                stickiness_timeout,
410                client: None,
411                had_success: false,
412                first_fail_at: None,
413                generation: 0,
414            })),
415        }
416    }
417
418    /// Run a one-shot query against the currently selected peer.
419    ///
420    /// Equivalent to [`CorrosionApiClient::query_typed`]
421    pub async fn query_typed<T: DeserializeOwned + Unpin>(
422        &self,
423        statement: &Statement,
424        timeout: Option<u64>,
425    ) -> Result<QueryStream<T>, Error> {
426        let (response, generation) = {
427            let (client, generation) = self.get_client().await?;
428            let response = client.query_typed(statement, timeout).await;
429
430            (response, generation)
431        };
432
433        if matches!(response, Err(Error::Reqwest(_))) {
434            // We only care about I/O related errors
435            self.handle_error(generation).await;
436        } else {
437            // The rest are considered a success
438            self.handle_success(generation).await;
439        }
440
441        response
442    }
443
444    /// Open a new subscription against the currently selected peer.
445    ///
446    pub async fn subscribe_typed<T: DeserializeOwned + Unpin>(
447        &self,
448        statement: &Statement,
449        skip_rows: bool,
450        from: Option<ChangeId>,
451    ) -> Result<SubscriptionStream<T>, Error> {
452        let (response, generation) = {
453            let (client, generation) = self.get_client().await?;
454            let response = client.subscribe_typed(statement, skip_rows, from).await;
455
456            (response, generation)
457        };
458
459        if matches!(response, Err(Error::Reqwest(_))) {
460            // We only care about I/O related errors
461            self.handle_error(generation).await;
462        } else {
463            // The rest are considered a success
464            self.handle_success(generation).await;
465        }
466
467        response
468    }
469
470    /// Reconnect to an existing subscription by id against the currently
471    /// selected peer. See [`CorrosionApiClient::subscription_typed`].
472    pub async fn subscription_typed<T: DeserializeOwned + Unpin>(
473        &self,
474        id: Uuid,
475        skip_rows: bool,
476        from: Option<ChangeId>,
477    ) -> Result<SubscriptionStream<T>, Error> {
478        let (response, generation) = {
479            let (client, generation) = self.get_client().await?;
480            let response = client.subscription_typed(id, skip_rows, from).await;
481
482            (response, generation)
483        };
484
485        if matches!(response, Err(Error::Reqwest(_))) {
486            // We only care about I/O related errors
487            self.handle_error(generation).await;
488        } else {
489            // The rest are considered a success
490            self.handle_success(generation).await;
491        }
492
493        response
494    }
495
496    async fn get_client(&self) -> Result<(RwLockReadGuard<'_, CorrosionApiClient>, u64), Error> {
497        let mut inner = self.inner.write().await;
498        let generation = inner.generation;
499
500        if inner.client.is_none() {
501            let addr = inner.picker.next().await?;
502            info!(
503                "next Corrosion server to attempt: {}, generation: {}",
504                addr, generation
505            );
506            inner.client = Some(CorrosionApiClient::new(addr)?)
507        }
508
509        Ok((
510            RwLockReadGuard::map(inner.downgrade(), |inner| inner.client.as_ref().unwrap()),
511            generation,
512        ))
513    }
514
515    async fn handle_success(&self, generation: u64) {
516        let mut inner = self.inner.write().await;
517
518        // Even though the call was successul, another failed call has advanced the client already.
519        if inner.generation != generation {
520            return;
521        }
522
523        // Mark that this client was able to perform a successful call to Corrosion
524        // and we should stick with it.
525        inner.had_success = true;
526        // And reset the time of the first fail.
527        inner.first_fail_at = None;
528    }
529
530    async fn handle_error(&self, generation: u64) {
531        let mut inner = self.inner.write().await;
532
533        // Somebody else has already handled an error with this client
534        if generation != inner.generation {
535            return;
536        }
537
538        match inner.first_fail_at {
539            // First fail after success
540            None if inner.had_success => {
541                inner.first_fail_at = Some(Instant::now());
542            }
543
544            // Still within stickiness timeout, try the same server again
545            Some(first) if Instant::now().duration_since(first) < inner.stickiness_timeout => {}
546
547            // Otherwise, pick a new server for the next attempt
548            _ => {
549                // If we had a successful call before, try to fallback to the first server, so the first
550                // one is always preferred by all the clients.
551                // Otherwise, continue iterating over the rest of the servers.
552                if inner.had_success {
553                    inner.picker.reset()
554                }
555
556                inner.client = None;
557                inner.first_fail_at = None;
558                inner.had_success = false;
559                inner.generation += 1;
560            }
561        }
562    }
563}
564
565struct AddrPicker {
566    // Resolver used to resolve the addresses
567    resolver: Resolver,
568    // List of addresses/hostname to try in order
569    addrs: Vec<String>,
570    // Next address/hostname to try
571    next_addr: usize,
572
573    // List of addresses returned by the last resolve attempt
574    last_resolved_addrs: Option<Vec<SocketAddr>>,
575    // Next address to return
576    next_resolved_addr: usize,
577}
578
579impl AddrPicker {
580    fn new(addrs: Vec<String>, resolver: Resolver) -> AddrPicker {
581        Self {
582            resolver,
583            addrs,
584            next_addr: 0,
585
586            last_resolved_addrs: None,
587            next_resolved_addr: 0,
588        }
589    }
590
591    async fn next(&mut self) -> Result<SocketAddr, Error> {
592        // Either we don't have any address or we tried them all, resolve again.
593        if self.next_resolved_addr
594            >= self
595                .last_resolved_addrs
596                .as_ref()
597                .map(|v| v.len())
598                .unwrap_or_default()
599        {
600            let host_port = self
601                .addrs
602                .get(self.next_addr)
603                .ok_or(ResolveError::from("No addresses available"))?;
604            self.next_addr = (self.next_addr + 1) % self.addrs.len();
605
606            let mut addrs = if let Ok(addr) = host_port.parse() {
607                vec![addr]
608            } else {
609                // split host port
610                let (host, port) = host_port
611                    .rsplit_once(':')
612                    .and_then(|(host, port)| Some((host, port.parse().ok()?)))
613                    .ok_or(ResolveError::from("Invalid Corrosion server address"))?;
614
615                timeout(DNS_RESOLVE_TIMEOUT, self.resolver.lookup_ip(host))
616                    .await
617                    .map_err(|_| ResolveError::Timeout)??
618                    .iter()
619                    .map(|addr| (addr, port).into())
620                    .collect::<Vec<_>>()
621            };
622            // Sort so all the nodes try the addresses in the same order
623            addrs.sort();
624
625            debug!("got the following Corrosion servers: {:?}", addrs);
626
627            self.last_resolved_addrs = Some(addrs);
628            self.next_resolved_addr = 0;
629        }
630
631        if let Some(addr) = self
632            .last_resolved_addrs
633            .as_ref()
634            .and_then(|a| a.get(self.next_resolved_addr).copied())
635        {
636            self.next_resolved_addr += 1;
637
638            Ok(addr)
639        } else {
640            Err(ResolveError::from("DNS didn't return any addresses").into())
641        }
642    }
643
644    fn reset(&mut self) {
645        self.next_addr = 0;
646        self.last_resolved_addrs = None;
647        self.next_resolved_addr = 0;
648    }
649}
650
651/// Errors returned by [`CorrosionApiClient`] and [`CorrosionPooledClient`].
652#[derive(Debug, thiserror::Error)]
653pub enum Error {
654    #[error(transparent)]
655    Dns(#[from] ResolveError),
656    #[error(transparent)]
657    Reqwest(#[from] reqwest::Error),
658    #[error(transparent)]
659    InvalidUri(#[from] http::uri::InvalidUri),
660    #[error(transparent)]
661    Http(#[from] http::Error),
662    #[error(transparent)]
663    Serde(#[from] serde_json::Error),
664
665    #[error("received unexpected response code: {0}")]
666    UnexpectedStatusCode(http::StatusCode),
667
668    #[error("{0}")]
669    ResponseError(String),
670
671    #[error("unexpected result: {0:?}")]
672    UnexpectedResult(ExecResult),
673
674    #[error("could not retrieve subscription id from headers")]
675    ExpectedQueryId,
676}
677
678#[cfg(test)]
679mod tests {
680    use crate::{CorrosionPooledClient, Error};
681    use corro_api_types::{SqliteValue, QUERY_ID_HEADER};
682    use hickory_resolver::Resolver;
683    use hyper::{header::HeaderValue, service::service_fn, Request, Response};
684    use std::{
685        convert::Infallible,
686        net::SocketAddr,
687        sync::{
688            atomic::{AtomicBool, Ordering},
689            Arc,
690        },
691        time::Duration,
692    };
693    use tokio::{net::TcpListener, pin, sync::broadcast};
694    use uuid::Uuid;
695
696    struct Empty<D>(std::marker::PhantomData<D>);
697
698    impl Empty<bytes::Bytes> {
699        fn new() -> Self {
700            Self(std::marker::PhantomData)
701        }
702    }
703
704    impl<D: bytes::Buf> http_body::Body for Empty<D> {
705        type Data = D;
706        type Error = std::convert::Infallible;
707
708        fn poll_frame(
709            self: std::pin::Pin<&mut Self>,
710            _cx: &mut std::task::Context<'_>,
711        ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
712            std::task::Poll::Ready(None)
713        }
714        fn is_end_stream(&self) -> bool {
715            true
716        }
717
718        fn size_hint(&self) -> http_body::SizeHint {
719            http_body::SizeHint::with_exact(0)
720        }
721    }
722
723    struct Server {
724        id: Uuid,
725        addr: SocketAddr,
726        refuse: Arc<AtomicBool>,
727        drop_conn_tx: broadcast::Sender<()>,
728    }
729
730    impl Server {
731        async fn new(id: Uuid) -> Self {
732            let refuse = Arc::new(AtomicBool::new(false));
733            let (drop_conn_tx, drop_conn_rx) = broadcast::channel::<()>(1);
734            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
735            let addr = listener.local_addr().unwrap();
736
737            tokio::spawn({
738                let refuse = refuse.clone();
739
740                async move {
741                    loop {
742                        let (stream, _) = listener.accept().await.unwrap();
743                        if refuse.load(Ordering::Relaxed) {
744                            drop(stream);
745                            continue;
746                        }
747
748                        let io = hyper_util::rt::TokioIo::new(stream);
749
750                        let mut drop_conn_rx = drop_conn_rx.resubscribe();
751                        tokio::spawn(async move {
752                            let conn = hyper::server::conn::http2::Builder::new(
753                                hyper_util::rt::TokioExecutor::new(),
754                            )
755                            .serve_connection(
756                                io,
757                                service_fn(move |_: Request<hyper::body::Incoming>| async move {
758                                    let mut res = Response::new(Empty::new());
759                                    res.headers_mut().insert(
760                                        QUERY_ID_HEADER,
761                                        HeaderValue::from_str(&id.to_string()).unwrap(),
762                                    );
763                                    Ok::<_, Infallible>(res)
764                                }),
765                            );
766                            pin!(conn);
767
768                            tokio::select! {
769                                _ = conn.as_mut() => (),
770                                _ = drop_conn_rx.recv() => {
771                                    conn.as_mut().graceful_shutdown()
772                                },
773                            }
774                        });
775                    }
776                }
777            });
778
779            Server {
780                id,
781                addr,
782                refuse,
783                drop_conn_tx,
784            }
785        }
786
787        fn refuse_new_conns(&self, refuse: bool) {
788            self.refuse.store(refuse, Ordering::Relaxed)
789        }
790
791        fn kill_existing_conns(&self) {
792            _ = self.drop_conn_tx.send(())
793        }
794    }
795
796    async fn gen_servers(num: usize) -> (Vec<Server>, Vec<String>) {
797        let mut servers = Vec::new();
798
799        for _ in 0..num {
800            servers.push(Server::new(Uuid::new_v4()).await);
801        }
802
803        // sort the way the client is supposed to try them
804        servers.sort_by(|a, b| a.addr.partial_cmp(&b.addr).unwrap());
805        let addrs = servers.iter().map(|s| s.addr.to_string()).collect();
806
807        (servers, addrs)
808    }
809
810    #[tokio::test]
811    async fn test_single_address() {
812        let statement = "".into();
813        let (servers, addresses) = gen_servers(1).await;
814
815        let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
816        let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
817        let sub = client
818            .subscribe_typed::<SqliteValue>(&statement, false, None)
819            .await
820            .unwrap();
821        assert_eq!(sub.id(), servers[0].id);
822
823        // Drop the connection, next attempt should error
824        servers[0].kill_existing_conns();
825
826        let res = client
827            .subscribe_typed::<SqliteValue>(&statement, false, None)
828            .await;
829        assert!(matches!(res, Result::Err(Error::Reqwest(_))));
830
831        // But the new one should succeed
832        let sub = client
833            .subscribe_typed::<SqliteValue>(&statement, false, None)
834            .await
835            .unwrap();
836        assert_eq!(sub.id(), servers[0].id);
837    }
838
839    #[tokio::test]
840    async fn test_multiple_addresses() {
841        let statement = "".into();
842        let (servers, addresses) = gen_servers(3).await;
843
844        let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
845        let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
846
847        // Refuse connections on the first server
848        servers[0].refuse_new_conns(true);
849
850        // First one should error
851        let res = client
852            .subscribe_typed::<SqliteValue>(&statement, false, None)
853            .await;
854        assert!(matches!(res, Result::Err(Error::Reqwest(_))));
855
856        // Second one should succeed
857        let sub = client
858            .subscribe_typed::<SqliteValue>(&statement, false, None)
859            .await
860            .unwrap();
861        assert_eq!(sub.id(), servers[1].id);
862
863        // Abort the second server, the client should fallback back to the first one after the first two attempts
864        servers[1].kill_existing_conns();
865        servers[1].refuse_new_conns(true);
866        servers[0].refuse_new_conns(false);
867
868        // First and second one should error
869        for _ in 0..2 {
870            let res = client
871                .subscribe_typed::<SqliteValue>(&statement, false, None)
872                .await;
873            assert!(matches!(res, Result::Err(Error::Reqwest(_))));
874        }
875
876        // The next one should succeed
877        let sub = client
878            .subscribe_typed::<SqliteValue>(&statement, false, None)
879            .await
880            .unwrap();
881        assert_eq!(sub.id(), servers[0].id);
882    }
883
884    #[tokio::test]
885    async fn test_multiple_addresses_sticky() {
886        let statement = "".into();
887        let (servers, addresses) = gen_servers(3).await;
888
889        let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
890        let client = CorrosionPooledClient::new(addresses, Duration::from_millis(50), resolver);
891
892        // Refuse connections on the first server
893        servers[0].refuse_new_conns(true);
894
895        // First one should error
896        let res = client
897            .subscribe_typed::<SqliteValue>(&statement, false, None)
898            .await;
899        assert!(matches!(res, Result::Err(Error::Reqwest(_))));
900
901        // Second one should succeed
902        let sub = client
903            .subscribe_typed::<SqliteValue>(&statement, false, None)
904            .await
905            .unwrap();
906        assert_eq!(sub.id(), servers[1].id);
907
908        // Abort the second server, the client should continue trying it until the timeout expires
909        servers[1].kill_existing_conns();
910        servers[1].refuse_new_conns(true);
911        servers[0].refuse_new_conns(false);
912
913        let mut attempts = 0;
914        loop {
915            let res = client
916                .subscribe_typed::<SqliteValue>(&statement, false, None)
917                .await;
918
919            match res {
920                Ok(sub) => {
921                    assert_eq!(sub.id(), servers[0].id);
922                    break;
923                }
924                Err(_) => attempts += 1,
925            };
926        }
927        assert!(attempts > 2);
928    }
929
930    #[tokio::test]
931    async fn test_more_servers() {
932        let statement = "".into();
933        let (pool1_servers, pool1_addresses) = gen_servers(2).await;
934        let (pool2_servers, pool2_addresses) = gen_servers(2).await;
935
936        let mut addresses = pool1_addresses;
937        addresses.extend_from_slice(&pool2_addresses);
938
939        let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
940        let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
941
942        // Refuse connections on all servers
943        for i in 0..2 {
944            pool1_servers[i].refuse_new_conns(true);
945            pool2_servers[i].refuse_new_conns(true);
946        }
947
948        // Try to connect multiple times, all should fail
949        for _ in 0..15 {
950            let res = client
951                .subscribe_typed::<SqliteValue>(&statement, false, None)
952                .await;
953            assert!(matches!(res, Result::Err(Error::Reqwest(_))));
954        }
955
956        // Accept connections on first server in the backup pool
957        pool2_servers[0].refuse_new_conns(false);
958        for i in 0..4 {
959            let res = client
960                .subscribe_typed::<SqliteValue>(&statement, false, None)
961                .await;
962            match res {
963                Result::Err(_) => (),
964                Ok(sub) => {
965                    assert_eq!(sub.id(), pool2_servers[0].id);
966                    break;
967                }
968            }
969            assert!(i != 3);
970        }
971
972        // Kill the connection, it should fallback to the first pool
973        pool2_servers[0].kill_existing_conns();
974        pool2_servers[0].refuse_new_conns(true);
975        pool1_servers[0].refuse_new_conns(false);
976        pool1_servers[1].refuse_new_conns(false);
977
978        // First and second one should error
979        for _ in 0..2 {
980            let res = client
981                .subscribe_typed::<SqliteValue>(&statement, false, None)
982                .await;
983            assert!(matches!(res, Result::Err(Error::Reqwest(_))));
984        }
985
986        // Thirst one should succeed
987        let sub = client
988            .subscribe_typed::<SqliteValue>(&statement, false, None)
989            .await
990            .unwrap();
991        assert_eq!(sub.id(), pool1_servers[0].id);
992    }
993}