Skip to main content

rqlite_rs/
client.rs

1use std::sync::{Arc, RwLock};
2
3use crate::{
4    batch::BatchResult,
5    config::{self, RqliteClientConfig, RqliteClientConfigBuilder},
6    error::{ClientBuilderError, RequestError},
7    fallback::{FallbackCount, FallbackStrategy},
8    node::{Node, NodeResponse, RemoveNodeRequest},
9    query::{self, QueryArgs, RqliteQuery},
10    query_result::QueryResult,
11    request::{RequestOptions, RqliteQueryParam, RqliteQueryParams},
12    response::{RqliteResponseRaw, RqliteResult},
13    select::RqliteSelectResults,
14};
15use base64::{engine::general_purpose, Engine};
16use reqwest::header;
17use rqlite_rs_core::Row;
18
19/// A client for interacting with a rqlite cluster.
20pub struct RqliteClient {
21    client: reqwest::Client,
22    hosts: Arc<RwLock<Vec<String>>>,
23    config: RqliteClientConfig,
24}
25
26/// A builder for creating a [`RqliteClient`].
27#[derive(Default)]
28pub struct RqliteClientBuilder {
29    /// This uses a `HashSet` to ensure that no duplicate hosts are added.
30    hosts: Vec<String>,
31    /// The configration for the client.
32    config: RqliteClientConfigBuilder,
33    // The base64 encoded credentials used to make authorized requests to the Rqlite cluster
34    basic_auth: Option<String>,
35}
36
37impl RqliteClientBuilder {
38    /// Creates a new [`RqliteClientBuilder`].
39    #[must_use]
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Adds basic auth credentials
45    #[must_use]
46    pub fn auth(mut self, user: &str, password: &str) -> Self {
47        self.basic_auth = Some(general_purpose::STANDARD.encode(format!("{user}:{password}")));
48        self
49    }
50
51    /// Adds a known host to the builder.
52    /// It is important not to add the scheme to the host.
53    /// The scheme is set using the `scheme` method.
54    /// The host should be in the format `hostname:port`.
55    /// For example, `localhost:4001`.
56    #[must_use]
57    #[expect(
58        clippy::needless_pass_by_value,
59        reason = "impl ToString is idiomatic for builder patterns"
60    )]
61    pub fn known_host(mut self, host: impl ToString) -> Self {
62        let host_str = host.to_string();
63
64        if self.hosts.iter().any(|h| h == &host_str) {
65            tracing::warn!("Host {host_str} already exists");
66        } else {
67            self.hosts.push(host_str);
68        }
69        self
70    }
71
72    /// Adds a default query parameter to the builder.
73    /// The `blob_array` parameter is added by default if the `fast-blob` feature is not enabled. (see <https://rqlite.io/docs/api/api/#blob-data>)
74    #[must_use]
75    pub fn default_query_params(mut self, params: Vec<RqliteQueryParam>) -> Self {
76        self.config = self.config.default_query_params(params);
77        self
78    }
79
80    /// Sets the fallback count for the client.
81    /// The fallback count is the number of times the client will try to switch to another host if the current host fails.
82    #[must_use]
83    pub fn fallback_count(mut self, count: FallbackCount) -> Self {
84        self.config = self.config.fallback_count(count);
85        self
86    }
87
88    /// Sets the fallback strategy for the client.
89    /// The fallback strategy is the strategy used to switch to another host if the current host fails.
90    /// The default strategy is `RoundRobin`.
91    #[must_use]
92    pub fn fallback_strategy(mut self, strategy: impl FallbackStrategy + 'static) -> Self {
93        self.config = self.config.fallback_strategy(strategy);
94        self
95    }
96
97    /// Sets the fallback persistence for the client.
98    /// If set to `true`, which is the default, the client will keep using the last host that was successful.
99    /// If set to `false`, the client will always try the first host in the list.
100    #[must_use]
101    pub fn fallback_persistence(mut self, persist: bool) -> Self {
102        self.config = self.config.fallback_persistence(persist);
103        self
104    }
105
106    /// Sets the scheme for the client.
107    #[must_use]
108    pub fn scheme(mut self, scheme: config::Scheme) -> Self {
109        self.config = self.config.scheme(scheme);
110        self
111    }
112
113    /// Builds the [`RqliteClient`] with the provided hosts.
114    ///
115    /// # Errors
116    ///
117    /// This function will return an error if:
118    /// - No hosts were provided
119    /// - Failed to create HTTP client
120    /// - Invalid authorization header
121    pub fn build(self) -> Result<RqliteClient, ClientBuilderError> {
122        if self.hosts.is_empty() {
123            return Err(ClientBuilderError::NoHostsProvided);
124        }
125
126        let hosts = self.hosts.into_iter().collect::<Vec<String>>();
127
128        let mut headers = header::HeaderMap::new();
129        headers.insert(
130            header::CONTENT_TYPE,
131            header::HeaderValue::from_static("application/json"),
132        );
133
134        if let Some(credentials) = self.basic_auth {
135            let basic_auth_fmt = format!("Basic {credentials}");
136            headers.insert(
137                header::AUTHORIZATION,
138                header::HeaderValue::from_str(basic_auth_fmt.as_str())?,
139            );
140        }
141
142        let mut client = reqwest::ClientBuilder::new()
143            .timeout(std::time::Duration::from_secs(5))
144            .default_headers(headers);
145
146        if matches!(self.config.scheme, Some(config::Scheme::Https)) {
147            client = client.https_only(true);
148        }
149
150        Ok(RqliteClient {
151            client: client.build()?,
152            hosts: Arc::new(RwLock::new(hosts)),
153            config: self.config.build(),
154        })
155    }
156}
157
158impl RqliteClient {
159    async fn try_request(
160        &self,
161        mut options: RequestOptions,
162    ) -> Result<reqwest::Response, RequestError> {
163        let (mut host, host_count) = {
164            let hosts = self
165                .hosts
166                .read()
167                .map_err(|_poisoned| RequestError::LockPoisoned)?;
168            let first_host = hosts.first().ok_or(RequestError::NoAvailableHosts)?;
169            (first_host.clone(), hosts.len())
170        };
171
172        let retry_count = self.config.fallback_count.count(host_count);
173
174        if let Some(default_params) = &self.config.default_query_params {
175            options.merge_default_query_params(default_params);
176        }
177
178        for _ in 0..retry_count {
179            tracing::debug!("Trying host: {host}");
180            let req = options.to_reqwest_request(&self.client, host.as_str(), &self.config.scheme);
181
182            match req.send().await {
183                Ok(res) if res.status().is_success() => return Ok(res),
184                Ok(res) => match res.status() {
185                    reqwest::StatusCode::UNAUTHORIZED => {
186                        return Err(RequestError::Unauthorized);
187                    }
188                    status => {
189                        return Err(RequestError::ReqwestError {
190                            body: res.text().await?,
191                            status,
192                        });
193                    }
194                },
195                Err(e) => self.handle_request_error(&e, &mut host)?,
196            }
197        }
198
199        Err(RequestError::NoAvailableHosts)
200    }
201
202    /// Handles the error returned by the request.
203    /// If the error is a connection error or a timeout, it will try to switch to another host.
204    /// If the error is not a connection error or a timeout, it will return an error.
205    fn handle_request_error(
206        &self,
207        e: &reqwest::Error,
208        host: &mut String, // warum wird host gepasst? wird nicht alles über self.hosts gemacht?
209    ) -> Result<(), RequestError> {
210        if e.is_connect() || e.is_timeout() {
211            let previous_host = host.clone();
212            let mut writable_hosts = self
213                .hosts
214                .write()
215                .map_err(|_poisoned| RequestError::LockPoisoned)?;
216
217            let new_host = self
218                .config
219                .fallback_strategy
220                .write()
221                .map_err(|_poisoned| RequestError::LockPoisoned)?
222                .fallback(&mut writable_hosts, host, self.config.fallback_persistence)
223                .ok_or(RequestError::NoAvailableHosts)?;
224
225            host.clone_from(new_host);
226            tracing::info!("Connection to {} failed, trying {}", previous_host, *host);
227            Ok(())
228        } else {
229            Err(RequestError::SwitchoverWrongError(e.to_string()))
230        }
231    }
232
233    async fn exec_query<T>(&self, q: query::RqliteQuery) -> Result<RqliteResult<T>, RequestError>
234    where
235        T: serde::de::DeserializeOwned + Clone,
236    {
237        let res = self
238            .try_request(RequestOptions {
239                endpoint: q.endpoint(),
240                body: Some(
241                    q.into_json()
242                        .map_err(RequestError::FailedParseRequestBody)?,
243                ),
244                ..Default::default()
245            })
246            .await?;
247
248        let body = res.text().await?;
249
250        let response = serde_json::from_str::<RqliteResponseRaw<T>>(&body)
251            .map_err(RequestError::FailedParseResponseBody)?;
252
253        response
254            .results
255            .into_iter()
256            .next()
257            .ok_or(RequestError::NoRowsReturned)
258    }
259
260    // To be implemented for different types of queries such as batch or qeued queries
261    //async fn exec_many<T>(
262    //    &self,
263    //    qs: Vec<query::RqliteQuery>,
264    //    params: impl Into<Option<Vec<RequestQueryParam>>>,
265    //) -> anyhow::Result<Vec<RqliteResult<T>>>
266    //where
267    //    T: serde::de::DeserializeOwned + Clone,
268    //{
269    //    let args = QueryArgs::from(qs);
270    //    let body = serde_json::to_string(&args)?;
271    //
272    //    let res = self.try_request("request", body, None).await?;
273    //
274    //    let body = res.text().await?;
275    //
276    //    let response = serde_json::from_str::<RqliteResponseRaw<T>>(&body)?;
277    //
278    //    Ok(response.results)
279    //}
280
281    /// Executes a query that returns results.
282    /// Returns a vector of [`Row`]s if the query was successful, otherwise an error.
283    ///
284    /// # Errors
285    ///
286    /// This function will return an error if:
287    /// - The query could not be converted to a `RqliteQuery`
288    /// - The request to the rqlite server failed
289    /// - The response could not be parsed
290    /// - The database returned an error
291    pub async fn fetch<Q>(&self, q: Q) -> Result<Vec<Row>, RequestError>
292    where
293        Q: TryInto<RqliteQuery>,
294        RequestError: From<Q::Error>,
295    {
296        let result = self
297            .exec_query::<RqliteSelectResults>(q.try_into()?)
298            .await?;
299
300        match result {
301            RqliteResult::Success(qr) => Ok(qr.rows()),
302            RqliteResult::Error(qe) => Err(RequestError::DatabaseError(qe.error)),
303        }
304    }
305
306    /// Executes a query that does not return any results.
307    /// Returns the [`QueryResult`] if the query was successful, otherwise an error.
308    /// Is primarily used for `INSERT`, `UPDATE`, `DELETE` and `CREATE` queries.
309    ///
310    /// # Errors
311    ///
312    /// This function will return an error if:
313    /// - The query could not be converted to a `RqliteQuery`
314    /// - The request to the rqlite server failed
315    /// - The response could not be parsed
316    /// - The database returned an error
317    pub async fn exec<Q>(&self, q: Q) -> Result<QueryResult, RequestError>
318    where
319        Q: TryInto<RqliteQuery>,
320        RequestError: From<Q::Error>,
321    {
322        let query_result = self.exec_query::<QueryResult>(q.try_into()?).await?;
323
324        match query_result {
325            RqliteResult::Success(qr) => Ok(qr),
326            RqliteResult::Error(qe) => Err(RequestError::DatabaseError(qe.error)),
327        }
328    }
329
330    /// Executes a batch of queries.
331    /// It allows sending multiple queries in a single request.
332    /// This can be more efficient and reduces round-trips to the database.
333    /// Returns a vector of [`RqliteResult`]s.
334    /// Each result contains the result of the corresponding query in the batch.
335    /// If a query fails, the corresponding result will contain an error.
336    ///
337    /// For more information on batch queries, see the [rqlite documentation](https://rqlite.io/docs/api/bulk-api/).
338    ///
339    /// # Errors
340    ///
341    /// This function will return an error if:
342    /// - The query could not be converted to a `RqliteQuery`
343    /// - The request to the rqlite server failed
344    /// - The response could not be parsed
345    /// - The database returned an error
346    pub async fn batch<Q>(&self, qs: Vec<Q>) -> Result<Vec<RqliteResult<BatchResult>>, RequestError>
347    where
348        Q: TryInto<RqliteQuery>,
349        RequestError: From<Q::Error>,
350    {
351        let queries = qs
352            .into_iter()
353            .map(std::convert::TryInto::try_into)
354            .collect::<Result<Vec<RqliteQuery>, _>>()?;
355
356        let batch = QueryArgs::from(queries);
357        let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
358
359        let res = self
360            .try_request(RequestOptions {
361                endpoint: "db/request".to_string(),
362                body: Some(body),
363                ..Default::default()
364            })
365            .await?;
366
367        let body = res.text().await?;
368
369        let results = serde_json::from_str::<RqliteResponseRaw<BatchResult>>(&body)
370            .map_err(RequestError::FailedParseResponseBody)?
371            .results;
372
373        Ok(results)
374    }
375
376    /// Executes a transaction.
377    /// A transaction is a set of queries that are executed as a single unit.
378    /// If any of the queries fail, the entire transaction is rolled back.
379    /// Returns a vector of [`RqliteResult`]s.
380    ///
381    /// For more information on transactions, see the [rqlite documentation](https://rqlite.io/docs/api/api/#transactions).
382    ///
383    /// # Errors
384    ///
385    /// This function will return an error if:
386    /// - The query could not be converted to a `RqliteQuery`
387    /// - The request to the rqlite server failed
388    /// - The response could not be parsed
389    /// - The database returned an error
390    /// - The transaction could not be executed
391    pub async fn transaction<Q>(
392        &self,
393        qs: Vec<Q>,
394    ) -> Result<Vec<RqliteResult<QueryResult>>, RequestError>
395    where
396        Q: TryInto<RqliteQuery>,
397        RequestError: From<Q::Error>,
398    {
399        let queries = qs
400            .into_iter()
401            .map(std::convert::TryInto::try_into)
402            .collect::<Result<Vec<RqliteQuery>, _>>()?;
403
404        let batch = QueryArgs::from(queries);
405        let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
406
407        let res = self
408            .try_request(RequestOptions {
409                endpoint: "db/execute".to_string(),
410                body: Some(body),
411                params: Some(
412                    RqliteQueryParams::new()
413                        .transaction()
414                        .into_request_query_params(),
415                ),
416                ..Default::default()
417            })
418            .await?;
419
420        let body = res.text().await?;
421
422        let results = serde_json::from_str::<RqliteResponseRaw<QueryResult>>(&body)
423            .map_err(RequestError::FailedParseResponseBody)?
424            .results;
425
426        Ok(results)
427    }
428
429    /// Asynchronously executes multiple queries.
430    /// This results in much higher write performance.
431    ///
432    /// For more information on queued queries, see the [rqlite documentation](https://rqlite.io/docs/api/queued-writes/).
433    ///
434    /// # Errors
435    ///
436    /// This function will return an error if:
437    /// - The query could not be converted to a `RqliteQuery`
438    /// - The request to the rqlite server failed
439    /// - The response could not be parsed
440    /// - The database returned an error
441    pub async fn queue<Q>(&self, qs: Vec<Q>) -> Result<(), RequestError>
442    where
443        Q: TryInto<RqliteQuery>,
444        RequestError: From<Q::Error>,
445    {
446        let queries = qs
447            .into_iter()
448            .map(std::convert::TryInto::try_into)
449            .collect::<Result<Vec<RqliteQuery>, _>>()?;
450
451        let batch = QueryArgs::from(queries);
452        let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
453
454        self.try_request(RequestOptions {
455            endpoint: "db/execute".to_string(),
456            body: Some(body),
457            params: Some(RqliteQueryParams::new().queue().into_request_query_params()),
458            ..Default::default()
459        })
460        .await?;
461
462        Ok(())
463    }
464
465    /// Checks if the rqlite cluster is ready.
466    /// Returns `true` if the cluster is ready, otherwise `false`.
467    pub async fn ready(&self) -> bool {
468        self.try_request(RequestOptions {
469            endpoint: "readyz".to_string(),
470            method: reqwest::Method::GET,
471            ..Default::default()
472        })
473        .await
474        .is_ok_and(|res| res.status() == reqwest::StatusCode::OK)
475    }
476
477    /// Retrieves the nodes in the rqlite cluster.
478    /// Returns a vector of [`Node`]s.
479    ///
480    /// # Errors
481    ///
482    /// This function will return an error if:
483    /// - The request to the rqlite server failed
484    /// - The response could not be parsed
485    pub async fn nodes(&self) -> Result<Vec<Node>, RequestError> {
486        let res = self
487            .try_request(RequestOptions {
488                endpoint: "nodes".to_string(),
489                params: Some(
490                    RqliteQueryParams::new()
491                        .ver("2".to_string())
492                        .into_request_query_params(),
493                ),
494                method: reqwest::Method::GET,
495                ..Default::default()
496            })
497            .await?;
498
499        let body = res.text().await?;
500
501        let response = serde_json::from_str::<NodeResponse>(&body)
502            .map_err(RequestError::FailedParseResponseBody)?;
503
504        Ok(response.nodes)
505    }
506
507    /// Retrieves current the leader of the rqlite cluster.
508    /// Returns a [`Node`] if a leader is found, otherwise `None`.
509    ///
510    /// # Errors
511    ///
512    /// This function will return an error if:
513    /// - The request to the rqlite server failed
514    /// - The response could not be parsed
515    pub async fn leader(&self) -> Result<Option<Node>, RequestError> {
516        let nodes = self.nodes().await?;
517
518        Ok(nodes.into_iter().find(|n| n.leader))
519    }
520
521    /// Removes a node from the rqlite cluster.
522    ///
523    /// # Errors
524    ///
525    /// This function will return an error if:
526    /// - The request body cannot be serialized
527    /// - The request to the rqlite server failed
528    /// - The response indicates a failure
529    /// - The response body cannot be read
530    pub async fn remove_node(&self, id: &str) -> Result<(), RequestError> {
531        let body = serde_json::to_string(&RemoveNodeRequest { id: id.to_string() })
532            .map_err(RequestError::FailedParseRequestBody)?;
533
534        let res = self
535            .try_request(RequestOptions {
536                endpoint: "remove".to_string(),
537                body: Some(body),
538                method: reqwest::Method::DELETE,
539                ..Default::default()
540            })
541            .await?;
542
543        if res.status().is_success() {
544            Ok(())
545        } else {
546            Err(RequestError::DatabaseError(format!(
547                "Failed to remove node: {}",
548                res.text()
549                    .await
550                    .map_err(RequestError::FailedReadingResponse)?
551            )))
552        }
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn unit_rqlite_client_builder_success() {
562        let client = RqliteClientBuilder::new()
563            .known_host("http://localhost:4001")
564            .scheme(config::Scheme::Http)
565            .build();
566
567        assert!(client.is_ok());
568    }
569
570    #[test]
571    fn unit_rqlite_client_builder_no_hosts() {
572        let client = RqliteClientBuilder::new().build();
573
574        assert!(matches!(client, Err(ClientBuilderError::NoHostsProvided)));
575    }
576
577    #[test]
578    fn unit_rqlite_client_builder_duplicate_host() {
579        let client = RqliteClientBuilder::new()
580            .known_host("http://localhost:4001")
581            .known_host("http://localhost:4001")
582            .build();
583
584        assert!(client.is_ok());
585    }
586
587    #[test]
588    fn unit_rqlite_client_builder_https() {
589        let client = RqliteClientBuilder::new()
590            .known_host("http://localhost:4001")
591            .scheme(config::Scheme::Https)
592            .build();
593
594        let config = client.unwrap().config;
595
596        assert!(matches!(config.scheme, config::Scheme::Https));
597    }
598
599    #[test]
600    fn unit_rqlite_client_builder_auth() {
601        let client = RqliteClientBuilder::new()
602            .known_host("http://localhost:4001")
603            .auth("user", "password")
604            .build();
605
606        assert!(client.is_ok());
607    }
608
609    #[test]
610    fn unit_rqlite_client_builder_default_query_params() {
611        let client = RqliteClientBuilder::new()
612            .known_host("http://localhost:4001")
613            .default_query_params(vec![RqliteQueryParam::Ver("3".to_string())])
614            .build();
615
616        let config = client.unwrap().config;
617
618        #[cfg(feature = "fast-blob")]
619        assert_eq!(config.default_query_params.unwrap().0.len(), 1);
620        #[cfg(not(feature = "fast-blob"))]
621        assert_eq!(config.default_query_params.unwrap().0.len(), 2);
622    }
623
624    #[test]
625    fn unit_rqlite_client_builder_default_scheme() {
626        let client = RqliteClientBuilder::new()
627            .known_host("http://localhost:4001")
628            .build();
629
630        let config = client.unwrap().config;
631
632        assert!(matches!(config.scheme, config::Scheme::Http));
633    }
634
635    // Fallback related tests
636    #[test]
637    fn unit_rqlite_client_builder_fallback_strategy() {
638        let client = RqliteClientBuilder::new()
639            .known_host("http://localhost:4001")
640            .fallback_strategy(crate::fallback::Priority::new(vec![
641                "localhost:4005".to_string(),
642                "localhost:4003".to_string(),
643                "localhost:4001".to_string(),
644            ]))
645            .build()
646            .unwrap();
647
648        assert!(client
649            .config
650            .fallback_strategy
651            .write()
652            .unwrap()
653            .fallback(
654                &mut vec!["localhost:4001".to_string(), "localhost:4002".to_string()],
655                "localhost:4001",
656                false
657            )
658            .is_some());
659    }
660
661    #[test]
662    fn unit_rqllite_client_builder_fallback_count() {
663        let client = RqliteClientBuilder::new()
664            .known_host("http://localhost:4001")
665            .fallback_count(FallbackCount::Count(3))
666            .build()
667            .unwrap();
668
669        assert_eq!(client.config.fallback_count.count(4), 3);
670    }
671
672    #[test]
673    fn unit_rqllite_client_builder_fallback_persistence() {
674        let client = RqliteClientBuilder::new()
675            .known_host("http://localhost:4001")
676            .fallback_persistence(false)
677            .build()
678            .unwrap();
679
680        assert!(!client.config.fallback_persistence);
681    }
682}