1use std::sync::{Arc, RwLock};
2
3use crate::{
4 batch::BatchResult,
5 config::{self, RqliteClientConfig, RqliteClientConfigBuilder},
6 error::{ClientBuilderError, RequestError},
7 fallback::{FallbackCount, FallbackStrategy},
8 node::{Node, NodeResponse, RemoveNodeRequest},
9 query::{self, QueryArgs, RqliteQuery},
10 query_result::QueryResult,
11 request::{RequestOptions, RqliteQueryParam, RqliteQueryParams},
12 response::{RqliteResponseRaw, RqliteResult},
13 select::RqliteSelectResults,
14};
15use base64::{engine::general_purpose, Engine};
16use reqwest::header;
17use rqlite_rs_core::Row;
18
19pub struct RqliteClient {
21 client: reqwest::Client,
22 hosts: Arc<RwLock<Vec<String>>>,
23 config: RqliteClientConfig,
24}
25
26#[derive(Default)]
28pub struct RqliteClientBuilder {
29 hosts: Vec<String>,
31 config: RqliteClientConfigBuilder,
33 basic_auth: Option<String>,
35}
36
37impl RqliteClientBuilder {
38 #[must_use]
40 pub fn new() -> Self {
41 Self::default()
42 }
43
44 #[must_use]
46 pub fn auth(mut self, user: &str, password: &str) -> Self {
47 self.basic_auth = Some(general_purpose::STANDARD.encode(format!("{user}:{password}")));
48 self
49 }
50
51 #[must_use]
57 #[expect(
58 clippy::needless_pass_by_value,
59 reason = "impl ToString is idiomatic for builder patterns"
60 )]
61 pub fn known_host(mut self, host: impl ToString) -> Self {
62 let host_str = host.to_string();
63
64 if self.hosts.iter().any(|h| h == &host_str) {
65 tracing::warn!("Host {host_str} already exists");
66 } else {
67 self.hosts.push(host_str);
68 }
69 self
70 }
71
72 #[must_use]
75 pub fn default_query_params(mut self, params: Vec<RqliteQueryParam>) -> Self {
76 self.config = self.config.default_query_params(params);
77 self
78 }
79
80 #[must_use]
83 pub fn fallback_count(mut self, count: FallbackCount) -> Self {
84 self.config = self.config.fallback_count(count);
85 self
86 }
87
88 #[must_use]
92 pub fn fallback_strategy(mut self, strategy: impl FallbackStrategy + 'static) -> Self {
93 self.config = self.config.fallback_strategy(strategy);
94 self
95 }
96
97 #[must_use]
101 pub fn fallback_persistence(mut self, persist: bool) -> Self {
102 self.config = self.config.fallback_persistence(persist);
103 self
104 }
105
106 #[must_use]
108 pub fn scheme(mut self, scheme: config::Scheme) -> Self {
109 self.config = self.config.scheme(scheme);
110 self
111 }
112
113 pub fn build(self) -> Result<RqliteClient, ClientBuilderError> {
122 if self.hosts.is_empty() {
123 return Err(ClientBuilderError::NoHostsProvided);
124 }
125
126 let hosts = self.hosts.into_iter().collect::<Vec<String>>();
127
128 let mut headers = header::HeaderMap::new();
129 headers.insert(
130 header::CONTENT_TYPE,
131 header::HeaderValue::from_static("application/json"),
132 );
133
134 if let Some(credentials) = self.basic_auth {
135 let basic_auth_fmt = format!("Basic {credentials}");
136 headers.insert(
137 header::AUTHORIZATION,
138 header::HeaderValue::from_str(basic_auth_fmt.as_str())?,
139 );
140 }
141
142 let mut client = reqwest::ClientBuilder::new()
143 .timeout(std::time::Duration::from_secs(5))
144 .default_headers(headers);
145
146 if matches!(self.config.scheme, Some(config::Scheme::Https)) {
147 client = client.https_only(true);
148 }
149
150 Ok(RqliteClient {
151 client: client.build()?,
152 hosts: Arc::new(RwLock::new(hosts)),
153 config: self.config.build(),
154 })
155 }
156}
157
158impl RqliteClient {
159 async fn try_request(
160 &self,
161 mut options: RequestOptions,
162 ) -> Result<reqwest::Response, RequestError> {
163 let (mut host, host_count) = {
164 let hosts = self
165 .hosts
166 .read()
167 .map_err(|_poisoned| RequestError::LockPoisoned)?;
168 let first_host = hosts.first().ok_or(RequestError::NoAvailableHosts)?;
169 (first_host.clone(), hosts.len())
170 };
171
172 let retry_count = self.config.fallback_count.count(host_count);
173
174 if let Some(default_params) = &self.config.default_query_params {
175 options.merge_default_query_params(default_params);
176 }
177
178 for _ in 0..retry_count {
179 tracing::debug!("Trying host: {host}");
180 let req = options.to_reqwest_request(&self.client, host.as_str(), &self.config.scheme);
181
182 match req.send().await {
183 Ok(res) if res.status().is_success() => return Ok(res),
184 Ok(res) => match res.status() {
185 reqwest::StatusCode::UNAUTHORIZED => {
186 return Err(RequestError::Unauthorized);
187 }
188 status => {
189 return Err(RequestError::ReqwestError {
190 body: res.text().await?,
191 status,
192 });
193 }
194 },
195 Err(e) => self.handle_request_error(&e, &mut host)?,
196 }
197 }
198
199 Err(RequestError::NoAvailableHosts)
200 }
201
202 fn handle_request_error(
206 &self,
207 e: &reqwest::Error,
208 host: &mut String, ) -> Result<(), RequestError> {
210 if e.is_connect() || e.is_timeout() {
211 let previous_host = host.clone();
212 let mut writable_hosts = self
213 .hosts
214 .write()
215 .map_err(|_poisoned| RequestError::LockPoisoned)?;
216
217 let new_host = self
218 .config
219 .fallback_strategy
220 .write()
221 .map_err(|_poisoned| RequestError::LockPoisoned)?
222 .fallback(&mut writable_hosts, host, self.config.fallback_persistence)
223 .ok_or(RequestError::NoAvailableHosts)?;
224
225 host.clone_from(new_host);
226 tracing::info!("Connection to {} failed, trying {}", previous_host, *host);
227 Ok(())
228 } else {
229 Err(RequestError::SwitchoverWrongError(e.to_string()))
230 }
231 }
232
233 async fn exec_query<T>(&self, q: query::RqliteQuery) -> Result<RqliteResult<T>, RequestError>
234 where
235 T: serde::de::DeserializeOwned + Clone,
236 {
237 let res = self
238 .try_request(RequestOptions {
239 endpoint: q.endpoint(),
240 body: Some(
241 q.into_json()
242 .map_err(RequestError::FailedParseRequestBody)?,
243 ),
244 ..Default::default()
245 })
246 .await?;
247
248 let body = res.text().await?;
249
250 let response = serde_json::from_str::<RqliteResponseRaw<T>>(&body)
251 .map_err(RequestError::FailedParseResponseBody)?;
252
253 response
254 .results
255 .into_iter()
256 .next()
257 .ok_or(RequestError::NoRowsReturned)
258 }
259
260 pub async fn fetch<Q>(&self, q: Q) -> Result<Vec<Row>, RequestError>
292 where
293 Q: TryInto<RqliteQuery>,
294 RequestError: From<Q::Error>,
295 {
296 let result = self
297 .exec_query::<RqliteSelectResults>(q.try_into()?)
298 .await?;
299
300 match result {
301 RqliteResult::Success(qr) => Ok(qr.rows()),
302 RqliteResult::Error(qe) => Err(RequestError::DatabaseError(qe.error)),
303 }
304 }
305
306 pub async fn exec<Q>(&self, q: Q) -> Result<QueryResult, RequestError>
318 where
319 Q: TryInto<RqliteQuery>,
320 RequestError: From<Q::Error>,
321 {
322 let query_result = self.exec_query::<QueryResult>(q.try_into()?).await?;
323
324 match query_result {
325 RqliteResult::Success(qr) => Ok(qr),
326 RqliteResult::Error(qe) => Err(RequestError::DatabaseError(qe.error)),
327 }
328 }
329
330 pub async fn batch<Q>(&self, qs: Vec<Q>) -> Result<Vec<RqliteResult<BatchResult>>, RequestError>
347 where
348 Q: TryInto<RqliteQuery>,
349 RequestError: From<Q::Error>,
350 {
351 let queries = qs
352 .into_iter()
353 .map(std::convert::TryInto::try_into)
354 .collect::<Result<Vec<RqliteQuery>, _>>()?;
355
356 let batch = QueryArgs::from(queries);
357 let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
358
359 let res = self
360 .try_request(RequestOptions {
361 endpoint: "db/request".to_string(),
362 body: Some(body),
363 ..Default::default()
364 })
365 .await?;
366
367 let body = res.text().await?;
368
369 let results = serde_json::from_str::<RqliteResponseRaw<BatchResult>>(&body)
370 .map_err(RequestError::FailedParseResponseBody)?
371 .results;
372
373 Ok(results)
374 }
375
376 pub async fn transaction<Q>(
392 &self,
393 qs: Vec<Q>,
394 ) -> Result<Vec<RqliteResult<QueryResult>>, RequestError>
395 where
396 Q: TryInto<RqliteQuery>,
397 RequestError: From<Q::Error>,
398 {
399 let queries = qs
400 .into_iter()
401 .map(std::convert::TryInto::try_into)
402 .collect::<Result<Vec<RqliteQuery>, _>>()?;
403
404 let batch = QueryArgs::from(queries);
405 let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
406
407 let res = self
408 .try_request(RequestOptions {
409 endpoint: "db/execute".to_string(),
410 body: Some(body),
411 params: Some(
412 RqliteQueryParams::new()
413 .transaction()
414 .into_request_query_params(),
415 ),
416 ..Default::default()
417 })
418 .await?;
419
420 let body = res.text().await?;
421
422 let results = serde_json::from_str::<RqliteResponseRaw<QueryResult>>(&body)
423 .map_err(RequestError::FailedParseResponseBody)?
424 .results;
425
426 Ok(results)
427 }
428
429 pub async fn queue<Q>(&self, qs: Vec<Q>) -> Result<(), RequestError>
442 where
443 Q: TryInto<RqliteQuery>,
444 RequestError: From<Q::Error>,
445 {
446 let queries = qs
447 .into_iter()
448 .map(std::convert::TryInto::try_into)
449 .collect::<Result<Vec<RqliteQuery>, _>>()?;
450
451 let batch = QueryArgs::from(queries);
452 let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
453
454 self.try_request(RequestOptions {
455 endpoint: "db/execute".to_string(),
456 body: Some(body),
457 params: Some(RqliteQueryParams::new().queue().into_request_query_params()),
458 ..Default::default()
459 })
460 .await?;
461
462 Ok(())
463 }
464
465 pub async fn ready(&self) -> bool {
468 self.try_request(RequestOptions {
469 endpoint: "readyz".to_string(),
470 method: reqwest::Method::GET,
471 ..Default::default()
472 })
473 .await
474 .is_ok_and(|res| res.status() == reqwest::StatusCode::OK)
475 }
476
477 pub async fn nodes(&self) -> Result<Vec<Node>, RequestError> {
486 let res = self
487 .try_request(RequestOptions {
488 endpoint: "nodes".to_string(),
489 params: Some(
490 RqliteQueryParams::new()
491 .ver("2".to_string())
492 .into_request_query_params(),
493 ),
494 method: reqwest::Method::GET,
495 ..Default::default()
496 })
497 .await?;
498
499 let body = res.text().await?;
500
501 let response = serde_json::from_str::<NodeResponse>(&body)
502 .map_err(RequestError::FailedParseResponseBody)?;
503
504 Ok(response.nodes)
505 }
506
507 pub async fn leader(&self) -> Result<Option<Node>, RequestError> {
516 let nodes = self.nodes().await?;
517
518 Ok(nodes.into_iter().find(|n| n.leader))
519 }
520
521 pub async fn remove_node(&self, id: &str) -> Result<(), RequestError> {
531 let body = serde_json::to_string(&RemoveNodeRequest { id: id.to_string() })
532 .map_err(RequestError::FailedParseRequestBody)?;
533
534 let res = self
535 .try_request(RequestOptions {
536 endpoint: "remove".to_string(),
537 body: Some(body),
538 method: reqwest::Method::DELETE,
539 ..Default::default()
540 })
541 .await?;
542
543 if res.status().is_success() {
544 Ok(())
545 } else {
546 Err(RequestError::DatabaseError(format!(
547 "Failed to remove node: {}",
548 res.text()
549 .await
550 .map_err(RequestError::FailedReadingResponse)?
551 )))
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn unit_rqlite_client_builder_success() {
562 let client = RqliteClientBuilder::new()
563 .known_host("http://localhost:4001")
564 .scheme(config::Scheme::Http)
565 .build();
566
567 assert!(client.is_ok());
568 }
569
570 #[test]
571 fn unit_rqlite_client_builder_no_hosts() {
572 let client = RqliteClientBuilder::new().build();
573
574 assert!(matches!(client, Err(ClientBuilderError::NoHostsProvided)));
575 }
576
577 #[test]
578 fn unit_rqlite_client_builder_duplicate_host() {
579 let client = RqliteClientBuilder::new()
580 .known_host("http://localhost:4001")
581 .known_host("http://localhost:4001")
582 .build();
583
584 assert!(client.is_ok());
585 }
586
587 #[test]
588 fn unit_rqlite_client_builder_https() {
589 let client = RqliteClientBuilder::new()
590 .known_host("http://localhost:4001")
591 .scheme(config::Scheme::Https)
592 .build();
593
594 let config = client.unwrap().config;
595
596 assert!(matches!(config.scheme, config::Scheme::Https));
597 }
598
599 #[test]
600 fn unit_rqlite_client_builder_auth() {
601 let client = RqliteClientBuilder::new()
602 .known_host("http://localhost:4001")
603 .auth("user", "password")
604 .build();
605
606 assert!(client.is_ok());
607 }
608
609 #[test]
610 fn unit_rqlite_client_builder_default_query_params() {
611 let client = RqliteClientBuilder::new()
612 .known_host("http://localhost:4001")
613 .default_query_params(vec![RqliteQueryParam::Ver("3".to_string())])
614 .build();
615
616 let config = client.unwrap().config;
617
618 #[cfg(feature = "fast-blob")]
619 assert_eq!(config.default_query_params.unwrap().0.len(), 1);
620 #[cfg(not(feature = "fast-blob"))]
621 assert_eq!(config.default_query_params.unwrap().0.len(), 2);
622 }
623
624 #[test]
625 fn unit_rqlite_client_builder_default_scheme() {
626 let client = RqliteClientBuilder::new()
627 .known_host("http://localhost:4001")
628 .build();
629
630 let config = client.unwrap().config;
631
632 assert!(matches!(config.scheme, config::Scheme::Http));
633 }
634
635 #[test]
637 fn unit_rqlite_client_builder_fallback_strategy() {
638 let client = RqliteClientBuilder::new()
639 .known_host("http://localhost:4001")
640 .fallback_strategy(crate::fallback::Priority::new(vec![
641 "localhost:4005".to_string(),
642 "localhost:4003".to_string(),
643 "localhost:4001".to_string(),
644 ]))
645 .build()
646 .unwrap();
647
648 assert!(client
649 .config
650 .fallback_strategy
651 .write()
652 .unwrap()
653 .fallback(
654 &mut vec!["localhost:4001".to_string(), "localhost:4002".to_string()],
655 "localhost:4001",
656 false
657 )
658 .is_some());
659 }
660
661 #[test]
662 fn unit_rqllite_client_builder_fallback_count() {
663 let client = RqliteClientBuilder::new()
664 .known_host("http://localhost:4001")
665 .fallback_count(FallbackCount::Count(3))
666 .build()
667 .unwrap();
668
669 assert_eq!(client.config.fallback_count.count(4), 3);
670 }
671
672 #[test]
673 fn unit_rqllite_client_builder_fallback_persistence() {
674 let client = RqliteClientBuilder::new()
675 .known_host("http://localhost:4001")
676 .fallback_persistence(false)
677 .build()
678 .unwrap();
679
680 assert!(!client.config.fallback_persistence);
681 }
682}