1pub mod sub;
2
3use corro_api_types::{
4 ChangeId, ExecResponse, ExecResult, SqliteValue, Statement, QUERY_HASH_HEADER, QUERY_ID_HEADER,
5};
6use hickory_resolver::net::NetError as ResolveError;
7use serde::de::DeserializeOwned;
8use std::{
9 fmt::Write as _,
10 net::SocketAddr,
11 ops::Deref,
12 path::Path,
13 sync::Arc,
14 time::{self, Duration, Instant},
15};
16use sub::{QueryStream, SubscriptionStream, UpdatesStream};
17use tokio::{
18 sync::{RwLock, RwLockReadGuard},
19 time::timeout,
20};
21use tracing::{debug, info};
22use uuid::Uuid;
23
24const HTTP2_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
25const HTTP2_KEEP_ALIVE_INTERVAL: Duration = Duration::from_secs(10);
26const DNS_RESOLVE_TIMEOUT: Duration = Duration::from_secs(3);
27
28type Resolver = hickory_resolver::Resolver<hickory_resolver::net::runtime::TokioRuntimeProvider>;
29
30#[derive(Clone)]
33pub struct CorrosionApiClient {
34 api_addr: SocketAddr,
35 api_client: reqwest::Client,
36}
37
38impl CorrosionApiClient {
39 pub fn new(api_addr: SocketAddr) -> Result<Self, reqwest::Error> {
40 Ok(Self {
41 api_addr,
42 api_client: reqwest::ClientBuilder::new()
43 .http2_prior_knowledge()
44 .connect_timeout(HTTP2_CONNECT_TIMEOUT)
45 .http2_keep_alive_interval(Some(HTTP2_KEEP_ALIVE_INTERVAL))
46 .http2_keep_alive_timeout(HTTP2_KEEP_ALIVE_INTERVAL / 2)
47 .build()?,
48 })
49 }
50
51 pub async fn query_typed<T: DeserializeOwned + Unpin>(
56 &self,
57 statement: &Statement,
58 timeout: Option<u64>,
59 ) -> Result<QueryStream<T>, Error> {
60 let mut uri = format!("http://{}/v1/queries", self.api_addr);
61
62 if let Some(t) = timeout {
63 write!(&mut uri, "?timeout={t}").unwrap();
64 }
65
66 let res = self
67 .api_client
68 .post(uri)
69 .header(http::header::CONTENT_TYPE, "application/json")
70 .header(http::header::ACCEPT, "application/json")
71 .body(serde_json::to_vec(statement)?)
72 .send()
73 .await?;
74
75 if !res.status().is_success() {
76 let status = res.status();
77 match res.bytes().await {
78 Ok(b) => match serde_json::from_slice(&b) {
79 Ok(res) => match res {
80 ExecResult::Error { error } => return Err(Error::ResponseError(error)),
81 res => return Err(Error::UnexpectedResult(res)),
82 },
83 Err(error) => {
84 debug!(
85 %error,
86 "could not deserialize response body, sending generic error..."
87 );
88 return Err(Error::UnexpectedStatusCode(status));
89 }
90 },
91 Err(error) => {
92 debug!(
93 %error,
94 "could not aggregate response body bytes, sending generic error..."
95 );
96 return Err(Error::UnexpectedStatusCode(status));
97 }
98 }
99 }
100
101 Ok(QueryStream::new(res.into()))
102 }
103
104 pub async fn query(
107 &self,
108 statement: &Statement,
109 timeout: Option<u64>,
110 ) -> Result<QueryStream<Vec<SqliteValue>>, Error> {
111 self.query_typed(statement, timeout).await
112 }
113
114 pub async fn subscribe_typed<T: DeserializeOwned + Unpin>(
120 &self,
121 statement: &Statement,
122 skip_rows: bool,
123 from: Option<ChangeId>,
124 ) -> Result<SubscriptionStream<T>, Error> {
125 let mut uri = format!(
126 "http://{}/v1/subscriptions?skip_rows={skip_rows}",
127 self.api_addr
128 );
129
130 if let Some(change_id) = from {
131 write!(&mut uri, "&from={change_id}").unwrap();
132 }
133
134 let res = self
135 .api_client
136 .post(uri)
137 .header(http::header::CONTENT_TYPE, "application/json")
138 .header(http::header::ACCEPT, "application/json")
139 .body(serde_json::to_vec(statement)?)
140 .send()
141 .await?;
142
143 if !res.status().is_success() {
144 return Err(Error::UnexpectedStatusCode(res.status()));
145 }
146
147 let id = res
148 .headers()
149 .get(QUERY_ID_HEADER)
150 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
151 .ok_or(Error::ExpectedQueryId)?;
152 let hash = res
153 .headers()
154 .get(QUERY_HASH_HEADER)
155 .and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
156
157 Ok(SubscriptionStream::new(
158 id,
159 hash,
160 self.api_client.clone(),
161 self.api_addr,
162 res.into(),
163 from,
164 ))
165 }
166
167 pub async fn subscribe(
170 &self,
171 statement: &Statement,
172 skip_rows: bool,
173 from: Option<ChangeId>,
174 ) -> Result<SubscriptionStream<Vec<SqliteValue>>, Error> {
175 self.subscribe_typed(statement, skip_rows, from).await
176 }
177
178 pub async fn subscription_typed<T: DeserializeOwned + Unpin>(
180 &self,
181 id: Uuid,
182 skip_rows: bool,
183 from: Option<ChangeId>,
184 ) -> Result<SubscriptionStream<T>, Error> {
185 let mut uri = format!(
186 "http://{}/v1/subscriptions/{id}?skip_rows={skip_rows}",
187 self.api_addr
188 );
189
190 if let Some(change_id) = from {
191 write!(&mut uri, "&from={change_id}").unwrap();
192 }
193
194 let res = self
195 .api_client
196 .get(uri)
197 .header(http::header::ACCEPT, "application/json")
198 .send()
199 .await?;
200
201 if !res.status().is_success() {
202 return Err(Error::UnexpectedStatusCode(res.status()));
203 }
204
205 let hash = res
206 .headers()
207 .get(QUERY_HASH_HEADER)
208 .and_then(|v| v.to_str().map(ToOwned::to_owned).ok());
209
210 Ok(SubscriptionStream::new(
211 id,
212 hash,
213 self.api_client.clone(),
214 self.api_addr,
215 res.into(),
216 from,
217 ))
218 }
219
220 pub async fn subscription(
223 &self,
224 id: Uuid,
225 skip_rows: bool,
226 from: Option<ChangeId>,
227 ) -> Result<SubscriptionStream<Vec<SqliteValue>>, Error> {
228 self.subscription_typed(id, skip_rows, from).await
229 }
230
231 pub async fn updates_typed<T: DeserializeOwned + Unpin>(
235 &self,
236 table: &str,
237 ) -> Result<UpdatesStream<T>, Error> {
238 let res = self
239 .api_client
240 .post(format!("http://{}/v1/updates/{table}", self.api_addr))
241 .header(http::header::CONTENT_TYPE, "application/json")
242 .header(http::header::ACCEPT, "application/json")
243 .send()
244 .await?;
245
246 if !res.status().is_success() {
247 return Err(Error::UnexpectedStatusCode(res.status()));
248 }
249
250 let id = res
251 .headers()
252 .get(QUERY_ID_HEADER)
253 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok()))
254 .ok_or(Error::ExpectedQueryId)?;
255
256 Ok(UpdatesStream::new(id, res.into()))
257 }
258
259 pub async fn updates(&self, table: &str) -> Result<UpdatesStream<Vec<SqliteValue>>, Error> {
262 self.updates_typed(table).await
263 }
264
265 pub async fn execute(
269 &self,
270 statements: &[Statement],
271 timeout: Option<u64>,
272 ) -> Result<ExecResponse, Error> {
273 let uri = if let Some(timeout) = timeout {
274 format!("http://{}/v1/transactions?timeout={timeout}", self.api_addr)
275 } else {
276 format!("http://{}/v1/transactions", self.api_addr)
277 };
278 let res = self
280 .api_client
281 .post(uri)
282 .header(http::header::CONTENT_TYPE, "application/json")
283 .header(http::header::ACCEPT, "application/json")
284 .body(serde_json::to_vec(statements)?)
285 .send()
286 .await?;
287
288 let status = res.status();
289 if !status.is_success() {
290 match res.bytes().await {
291 Ok(b) => match serde_json::from_slice(&b) {
292 Ok(ExecResponse { results, .. }) => {
293 if let Some(ExecResult::Error { error }) = results
294 .into_iter()
295 .find(|r| matches!(r, ExecResult::Error { .. }))
296 {
297 return Err(Error::ResponseError(error));
298 }
299 return Err(Error::UnexpectedStatusCode(status));
300 }
301 Err(error) => {
302 debug!(
303 %error,
304 "could not deserialize response body, sending generic error..."
305 );
306 return Err(Error::UnexpectedStatusCode(status));
307 }
308 },
309 Err(error) => {
310 debug!(
311 %error,
312 "could not aggregate response body bytes, sending generic error..."
313 );
314 return Err(Error::UnexpectedStatusCode(status));
315 }
316 }
317 }
318
319 Ok(serde_json::from_slice(&res.bytes().await?)?)
320 }
321}
322
323#[derive(Clone)]
326pub struct CorrosionClient {
327 api_client: CorrosionApiClient,
328 pool: sqlite_pool::RusqlitePool,
329}
330
331impl CorrosionClient {
332 pub fn new<P: AsRef<Path>>(api_addr: SocketAddr, db_path: P) -> Result<Self, reqwest::Error> {
333 Ok(Self {
334 api_client: CorrosionApiClient::new(api_addr)?,
335 pool: sqlite_pool::Config::new(db_path.as_ref())
336 .max_size(5)
337 .create_pool()
338 .expect("could not build pool, this can't fail because we specified a runtime"),
339 })
340 }
341
342 pub fn with_sqlite_pool(
343 api_addr: SocketAddr,
344 pool: sqlite_pool::RusqlitePool,
345 ) -> Result<Self, reqwest::Error> {
346 Ok(Self {
347 api_client: CorrosionApiClient::new(api_addr)?,
348 pool,
349 })
350 }
351
352 pub fn pool(&self) -> &sqlite_pool::RusqlitePool {
354 &self.pool
355 }
356}
357
358impl Deref for CorrosionClient {
359 type Target = CorrosionApiClient;
360
361 fn deref(&self) -> &Self::Target {
362 &self.api_client
363 }
364}
365
366#[derive(Clone)]
373pub struct CorrosionPooledClient {
374 inner: Arc<RwLock<PooledClientInner>>,
375}
376
377struct PooledClientInner {
378 picker: AddrPicker,
379
380 stickiness_timeout: time::Duration,
382 client: Option<CorrosionApiClient>,
384 had_success: bool,
386 first_fail_at: Option<Instant>,
388 generation: u64,
390}
391
392impl CorrosionPooledClient {
393 pub fn new(addrs: Vec<String>, stickiness_timeout: time::Duration, resolver: Resolver) -> Self {
405 Self {
406 inner: Arc::new(RwLock::new(PooledClientInner {
407 picker: AddrPicker::new(addrs, resolver),
408
409 stickiness_timeout,
410 client: None,
411 had_success: false,
412 first_fail_at: None,
413 generation: 0,
414 })),
415 }
416 }
417
418 pub async fn query_typed<T: DeserializeOwned + Unpin>(
422 &self,
423 statement: &Statement,
424 timeout: Option<u64>,
425 ) -> Result<QueryStream<T>, Error> {
426 let (response, generation) = {
427 let (client, generation) = self.get_client().await?;
428 let response = client.query_typed(statement, timeout).await;
429
430 (response, generation)
431 };
432
433 if matches!(response, Err(Error::Reqwest(_))) {
434 self.handle_error(generation).await;
436 } else {
437 self.handle_success(generation).await;
439 }
440
441 response
442 }
443
444 pub async fn subscribe_typed<T: DeserializeOwned + Unpin>(
447 &self,
448 statement: &Statement,
449 skip_rows: bool,
450 from: Option<ChangeId>,
451 ) -> Result<SubscriptionStream<T>, Error> {
452 let (response, generation) = {
453 let (client, generation) = self.get_client().await?;
454 let response = client.subscribe_typed(statement, skip_rows, from).await;
455
456 (response, generation)
457 };
458
459 if matches!(response, Err(Error::Reqwest(_))) {
460 self.handle_error(generation).await;
462 } else {
463 self.handle_success(generation).await;
465 }
466
467 response
468 }
469
470 pub async fn subscription_typed<T: DeserializeOwned + Unpin>(
473 &self,
474 id: Uuid,
475 skip_rows: bool,
476 from: Option<ChangeId>,
477 ) -> Result<SubscriptionStream<T>, Error> {
478 let (response, generation) = {
479 let (client, generation) = self.get_client().await?;
480 let response = client.subscription_typed(id, skip_rows, from).await;
481
482 (response, generation)
483 };
484
485 if matches!(response, Err(Error::Reqwest(_))) {
486 self.handle_error(generation).await;
488 } else {
489 self.handle_success(generation).await;
491 }
492
493 response
494 }
495
496 async fn get_client(&self) -> Result<(RwLockReadGuard<'_, CorrosionApiClient>, u64), Error> {
497 let mut inner = self.inner.write().await;
498 let generation = inner.generation;
499
500 if inner.client.is_none() {
501 let addr = inner.picker.next().await?;
502 info!(
503 "next Corrosion server to attempt: {}, generation: {}",
504 addr, generation
505 );
506 inner.client = Some(CorrosionApiClient::new(addr)?)
507 }
508
509 Ok((
510 RwLockReadGuard::map(inner.downgrade(), |inner| inner.client.as_ref().unwrap()),
511 generation,
512 ))
513 }
514
515 async fn handle_success(&self, generation: u64) {
516 let mut inner = self.inner.write().await;
517
518 if inner.generation != generation {
520 return;
521 }
522
523 inner.had_success = true;
526 inner.first_fail_at = None;
528 }
529
530 async fn handle_error(&self, generation: u64) {
531 let mut inner = self.inner.write().await;
532
533 if generation != inner.generation {
535 return;
536 }
537
538 match inner.first_fail_at {
539 None if inner.had_success => {
541 inner.first_fail_at = Some(Instant::now());
542 }
543
544 Some(first) if Instant::now().duration_since(first) < inner.stickiness_timeout => {}
546
547 _ => {
549 if inner.had_success {
553 inner.picker.reset()
554 }
555
556 inner.client = None;
557 inner.first_fail_at = None;
558 inner.had_success = false;
559 inner.generation += 1;
560 }
561 }
562 }
563}
564
565struct AddrPicker {
566 resolver: Resolver,
568 addrs: Vec<String>,
570 next_addr: usize,
572
573 last_resolved_addrs: Option<Vec<SocketAddr>>,
575 next_resolved_addr: usize,
577}
578
579impl AddrPicker {
580 fn new(addrs: Vec<String>, resolver: Resolver) -> AddrPicker {
581 Self {
582 resolver,
583 addrs,
584 next_addr: 0,
585
586 last_resolved_addrs: None,
587 next_resolved_addr: 0,
588 }
589 }
590
591 async fn next(&mut self) -> Result<SocketAddr, Error> {
592 if self.next_resolved_addr
594 >= self
595 .last_resolved_addrs
596 .as_ref()
597 .map(|v| v.len())
598 .unwrap_or_default()
599 {
600 let host_port = self
601 .addrs
602 .get(self.next_addr)
603 .ok_or(ResolveError::from("No addresses available"))?;
604 self.next_addr = (self.next_addr + 1) % self.addrs.len();
605
606 let mut addrs = if let Ok(addr) = host_port.parse() {
607 vec![addr]
608 } else {
609 let (host, port) = host_port
611 .rsplit_once(':')
612 .and_then(|(host, port)| Some((host, port.parse().ok()?)))
613 .ok_or(ResolveError::from("Invalid Corrosion server address"))?;
614
615 timeout(DNS_RESOLVE_TIMEOUT, self.resolver.lookup_ip(host))
616 .await
617 .map_err(|_| ResolveError::Timeout)??
618 .iter()
619 .map(|addr| (addr, port).into())
620 .collect::<Vec<_>>()
621 };
622 addrs.sort();
624
625 debug!("got the following Corrosion servers: {:?}", addrs);
626
627 self.last_resolved_addrs = Some(addrs);
628 self.next_resolved_addr = 0;
629 }
630
631 if let Some(addr) = self
632 .last_resolved_addrs
633 .as_ref()
634 .and_then(|a| a.get(self.next_resolved_addr).copied())
635 {
636 self.next_resolved_addr += 1;
637
638 Ok(addr)
639 } else {
640 Err(ResolveError::from("DNS didn't return any addresses").into())
641 }
642 }
643
644 fn reset(&mut self) {
645 self.next_addr = 0;
646 self.last_resolved_addrs = None;
647 self.next_resolved_addr = 0;
648 }
649}
650
651#[derive(Debug, thiserror::Error)]
653pub enum Error {
654 #[error(transparent)]
655 Dns(#[from] ResolveError),
656 #[error(transparent)]
657 Reqwest(#[from] reqwest::Error),
658 #[error(transparent)]
659 InvalidUri(#[from] http::uri::InvalidUri),
660 #[error(transparent)]
661 Http(#[from] http::Error),
662 #[error(transparent)]
663 Serde(#[from] serde_json::Error),
664
665 #[error("received unexpected response code: {0}")]
666 UnexpectedStatusCode(http::StatusCode),
667
668 #[error("{0}")]
669 ResponseError(String),
670
671 #[error("unexpected result: {0:?}")]
672 UnexpectedResult(ExecResult),
673
674 #[error("could not retrieve subscription id from headers")]
675 ExpectedQueryId,
676}
677
678#[cfg(test)]
679mod tests {
680 use crate::{CorrosionPooledClient, Error};
681 use corro_api_types::{SqliteValue, QUERY_ID_HEADER};
682 use hickory_resolver::Resolver;
683 use hyper::{header::HeaderValue, service::service_fn, Request, Response};
684 use std::{
685 convert::Infallible,
686 net::SocketAddr,
687 sync::{
688 atomic::{AtomicBool, Ordering},
689 Arc,
690 },
691 time::Duration,
692 };
693 use tokio::{net::TcpListener, pin, sync::broadcast};
694 use uuid::Uuid;
695
696 struct Empty<D>(std::marker::PhantomData<D>);
697
698 impl Empty<bytes::Bytes> {
699 fn new() -> Self {
700 Self(std::marker::PhantomData)
701 }
702 }
703
704 impl<D: bytes::Buf> http_body::Body for Empty<D> {
705 type Data = D;
706 type Error = std::convert::Infallible;
707
708 fn poll_frame(
709 self: std::pin::Pin<&mut Self>,
710 _cx: &mut std::task::Context<'_>,
711 ) -> std::task::Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
712 std::task::Poll::Ready(None)
713 }
714 fn is_end_stream(&self) -> bool {
715 true
716 }
717
718 fn size_hint(&self) -> http_body::SizeHint {
719 http_body::SizeHint::with_exact(0)
720 }
721 }
722
723 struct Server {
724 id: Uuid,
725 addr: SocketAddr,
726 refuse: Arc<AtomicBool>,
727 drop_conn_tx: broadcast::Sender<()>,
728 }
729
730 impl Server {
731 async fn new(id: Uuid) -> Self {
732 let refuse = Arc::new(AtomicBool::new(false));
733 let (drop_conn_tx, drop_conn_rx) = broadcast::channel::<()>(1);
734 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
735 let addr = listener.local_addr().unwrap();
736
737 tokio::spawn({
738 let refuse = refuse.clone();
739
740 async move {
741 loop {
742 let (stream, _) = listener.accept().await.unwrap();
743 if refuse.load(Ordering::Relaxed) {
744 drop(stream);
745 continue;
746 }
747
748 let io = hyper_util::rt::TokioIo::new(stream);
749
750 let mut drop_conn_rx = drop_conn_rx.resubscribe();
751 tokio::spawn(async move {
752 let conn = hyper::server::conn::http2::Builder::new(
753 hyper_util::rt::TokioExecutor::new(),
754 )
755 .serve_connection(
756 io,
757 service_fn(move |_: Request<hyper::body::Incoming>| async move {
758 let mut res = Response::new(Empty::new());
759 res.headers_mut().insert(
760 QUERY_ID_HEADER,
761 HeaderValue::from_str(&id.to_string()).unwrap(),
762 );
763 Ok::<_, Infallible>(res)
764 }),
765 );
766 pin!(conn);
767
768 tokio::select! {
769 _ = conn.as_mut() => (),
770 _ = drop_conn_rx.recv() => {
771 conn.as_mut().graceful_shutdown()
772 },
773 }
774 });
775 }
776 }
777 });
778
779 Server {
780 id,
781 addr,
782 refuse,
783 drop_conn_tx,
784 }
785 }
786
787 fn refuse_new_conns(&self, refuse: bool) {
788 self.refuse.store(refuse, Ordering::Relaxed)
789 }
790
791 fn kill_existing_conns(&self) {
792 _ = self.drop_conn_tx.send(())
793 }
794 }
795
796 async fn gen_servers(num: usize) -> (Vec<Server>, Vec<String>) {
797 let mut servers = Vec::new();
798
799 for _ in 0..num {
800 servers.push(Server::new(Uuid::new_v4()).await);
801 }
802
803 servers.sort_by(|a, b| a.addr.partial_cmp(&b.addr).unwrap());
805 let addrs = servers.iter().map(|s| s.addr.to_string()).collect();
806
807 (servers, addrs)
808 }
809
810 #[tokio::test]
811 async fn test_single_address() {
812 let statement = "".into();
813 let (servers, addresses) = gen_servers(1).await;
814
815 let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
816 let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
817 let sub = client
818 .subscribe_typed::<SqliteValue>(&statement, false, None)
819 .await
820 .unwrap();
821 assert_eq!(sub.id(), servers[0].id);
822
823 servers[0].kill_existing_conns();
825
826 let res = client
827 .subscribe_typed::<SqliteValue>(&statement, false, None)
828 .await;
829 assert!(matches!(res, Result::Err(Error::Reqwest(_))));
830
831 let sub = client
833 .subscribe_typed::<SqliteValue>(&statement, false, None)
834 .await
835 .unwrap();
836 assert_eq!(sub.id(), servers[0].id);
837 }
838
839 #[tokio::test]
840 async fn test_multiple_addresses() {
841 let statement = "".into();
842 let (servers, addresses) = gen_servers(3).await;
843
844 let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
845 let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
846
847 servers[0].refuse_new_conns(true);
849
850 let res = client
852 .subscribe_typed::<SqliteValue>(&statement, false, None)
853 .await;
854 assert!(matches!(res, Result::Err(Error::Reqwest(_))));
855
856 let sub = client
858 .subscribe_typed::<SqliteValue>(&statement, false, None)
859 .await
860 .unwrap();
861 assert_eq!(sub.id(), servers[1].id);
862
863 servers[1].kill_existing_conns();
865 servers[1].refuse_new_conns(true);
866 servers[0].refuse_new_conns(false);
867
868 for _ in 0..2 {
870 let res = client
871 .subscribe_typed::<SqliteValue>(&statement, false, None)
872 .await;
873 assert!(matches!(res, Result::Err(Error::Reqwest(_))));
874 }
875
876 let sub = client
878 .subscribe_typed::<SqliteValue>(&statement, false, None)
879 .await
880 .unwrap();
881 assert_eq!(sub.id(), servers[0].id);
882 }
883
884 #[tokio::test]
885 async fn test_multiple_addresses_sticky() {
886 let statement = "".into();
887 let (servers, addresses) = gen_servers(3).await;
888
889 let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
890 let client = CorrosionPooledClient::new(addresses, Duration::from_millis(50), resolver);
891
892 servers[0].refuse_new_conns(true);
894
895 let res = client
897 .subscribe_typed::<SqliteValue>(&statement, false, None)
898 .await;
899 assert!(matches!(res, Result::Err(Error::Reqwest(_))));
900
901 let sub = client
903 .subscribe_typed::<SqliteValue>(&statement, false, None)
904 .await
905 .unwrap();
906 assert_eq!(sub.id(), servers[1].id);
907
908 servers[1].kill_existing_conns();
910 servers[1].refuse_new_conns(true);
911 servers[0].refuse_new_conns(false);
912
913 let mut attempts = 0;
914 loop {
915 let res = client
916 .subscribe_typed::<SqliteValue>(&statement, false, None)
917 .await;
918
919 match res {
920 Ok(sub) => {
921 assert_eq!(sub.id(), servers[0].id);
922 break;
923 }
924 Err(_) => attempts += 1,
925 };
926 }
927 assert!(attempts > 2);
928 }
929
930 #[tokio::test]
931 async fn test_more_servers() {
932 let statement = "".into();
933 let (pool1_servers, pool1_addresses) = gen_servers(2).await;
934 let (pool2_servers, pool2_addresses) = gen_servers(2).await;
935
936 let mut addresses = pool1_addresses;
937 addresses.extend_from_slice(&pool2_addresses);
938
939 let resolver = Resolver::builder_tokio().unwrap().build().unwrap();
940 let client = CorrosionPooledClient::new(addresses, Duration::from_nanos(1), resolver);
941
942 for i in 0..2 {
944 pool1_servers[i].refuse_new_conns(true);
945 pool2_servers[i].refuse_new_conns(true);
946 }
947
948 for _ in 0..15 {
950 let res = client
951 .subscribe_typed::<SqliteValue>(&statement, false, None)
952 .await;
953 assert!(matches!(res, Result::Err(Error::Reqwest(_))));
954 }
955
956 pool2_servers[0].refuse_new_conns(false);
958 for i in 0..4 {
959 let res = client
960 .subscribe_typed::<SqliteValue>(&statement, false, None)
961 .await;
962 match res {
963 Result::Err(_) => (),
964 Ok(sub) => {
965 assert_eq!(sub.id(), pool2_servers[0].id);
966 break;
967 }
968 }
969 assert!(i != 3);
970 }
971
972 pool2_servers[0].kill_existing_conns();
974 pool2_servers[0].refuse_new_conns(true);
975 pool1_servers[0].refuse_new_conns(false);
976 pool1_servers[1].refuse_new_conns(false);
977
978 for _ in 0..2 {
980 let res = client
981 .subscribe_typed::<SqliteValue>(&statement, false, None)
982 .await;
983 assert!(matches!(res, Result::Err(Error::Reqwest(_))));
984 }
985
986 let sub = client
988 .subscribe_typed::<SqliteValue>(&statement, false, None)
989 .await
990 .unwrap();
991 assert_eq!(sub.id(), pool1_servers[0].id);
992 }
993}