db_rpc_client_rs/
lib.rs

1use percent_encoding::utf8_percent_encode;
2use percent_encoding::NON_ALPHANUMERIC;
3use reqwest::header::CONTENT_TYPE;
4use reqwest::Method;
5use rmpv::Value;
6use serde::de::DeserializeOwned;
7use serde::Deserialize;
8use serde::Serialize;
9use std::error::Error;
10use std::fmt::Display;
11
12#[derive(Debug)]
13pub enum DbRpcClientError {
14  Api {
15    status: u16,
16    error: String,
17    error_details: Option<String>,
18  },
19  Unauthorized,
20  Request(reqwest::Error),
21}
22
23impl Display for DbRpcClientError {
24  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25    match self {
26      DbRpcClientError::Api {
27        status,
28        error,
29        error_details,
30      } => write!(f, "API error ({status} - {error}): {error_details:?}"),
31      DbRpcClientError::Unauthorized => write!(f, "unauthorized"),
32      DbRpcClientError::Request(e) => write!(f, "request error: {e}"),
33    }
34  }
35}
36
37impl Error for DbRpcClientError {}
38
39pub type DbRpcClientResult<T> = Result<T, DbRpcClientError>;
40
41#[derive(Clone, Debug)]
42pub struct DbRpcClientCfg {
43  pub api_key: Option<String>,
44  pub endpoint: String,
45}
46
47#[derive(Clone, Debug)]
48pub struct DbRpcClient {
49  r: reqwest::Client,
50  cfg: DbRpcClientCfg,
51}
52
53impl DbRpcClient {
54  pub fn with_request_client(request_client: reqwest::Client, cfg: DbRpcClientCfg) -> Self {
55    Self {
56      r: request_client,
57      cfg,
58    }
59  }
60
61  pub fn new(cfg: DbRpcClientCfg) -> Self {
62    Self::with_request_client(reqwest::Client::new(), cfg)
63  }
64
65  async fn raw_request<I: Serialize, O: DeserializeOwned>(
66    &self,
67    method: Method,
68    path: impl AsRef<str>,
69    body: Option<&I>,
70  ) -> DbRpcClientResult<O> {
71    let mut req = self
72      .r
73      .request(method, format!("{}{}", self.cfg.endpoint, path.as_ref()))
74      .header("accept", "application/msgpack");
75    if let Some(k) = &self.cfg.api_key {
76      req = req.header("authorization", k);
77    };
78    if let Some(b) = body {
79      let raw = rmp_serde::to_vec_named(b).unwrap();
80      req = req.header("content-type", "application/msgpack").body(raw);
81    };
82    let res = req
83      .send()
84      .await
85      .map_err(|err| DbRpcClientError::Request(err))?;
86    let status = res.status().as_u16();
87    let res_type = res
88      .headers()
89      .get(CONTENT_TYPE)
90      .and_then(|v| v.to_str().ok().map(|v| v.to_string()))
91      .unwrap_or_default();
92    let res_body_raw = res
93      .bytes()
94      .await
95      .map_err(|err| DbRpcClientError::Request(err))?;
96    if status == 401 {
97      return Err(DbRpcClientError::Unauthorized);
98    };
99    #[derive(Deserialize)]
100    struct ApiError {
101      error: String,
102      error_details: Option<String>,
103    }
104    if status < 200 || status > 299 || !res_type.starts_with("application/msgpack") {
105      // The server may be behind some proxy, LB, etc., so we don't know what the body looks like for sure.
106      return Err(match rmp_serde::from_slice::<ApiError>(&res_body_raw) {
107        Ok(api_error) => DbRpcClientError::Api {
108          status,
109          error: api_error.error,
110          error_details: api_error.error_details,
111        },
112        Err(_) => DbRpcClientError::Api {
113          status,
114          // We don't know if the response contains valid UTF-8 text or not.
115          error: String::from_utf8_lossy(&res_body_raw).into_owned(),
116          error_details: None,
117        },
118      });
119    };
120    Ok(rmp_serde::from_slice(&res_body_raw).unwrap())
121  }
122
123  pub fn database(&self, db_name: &str) -> DbRpcDbClient {
124    DbRpcDbClient {
125      c: self.clone(),
126      dbpp: format!("/db/{}", utf8_percent_encode(&db_name, NON_ALPHANUMERIC)),
127    }
128  }
129}
130
131#[derive(Clone, Debug)]
132pub struct DbRpcDbClient {
133  c: DbRpcClient,
134  dbpp: String,
135}
136
137#[derive(Deserialize)]
138pub struct ExecResult {
139  pub affected_rows: u64,
140  pub last_insert_id: Option<u64>,
141}
142
143impl DbRpcDbClient {
144  pub async fn batch(
145    &self,
146    query: impl AsRef<str>,
147    params: Vec<Vec<Value>>,
148  ) -> DbRpcClientResult<Vec<ExecResult>> {
149    #[derive(Serialize)]
150    struct Input<'a> {
151      query: &'a str,
152      params: Vec<Vec<Value>>,
153    }
154    self
155      .c
156      .raw_request(
157        Method::POST,
158        format!("{}/batch", self.dbpp),
159        Some(&Input {
160          query: query.as_ref(),
161          params,
162        }),
163      )
164      .await
165  }
166
167  pub async fn exec(
168    &self,
169    query: impl AsRef<str>,
170    params: Vec<Value>,
171  ) -> DbRpcClientResult<ExecResult> {
172    #[derive(Serialize)]
173    struct Input<'a> {
174      query: &'a str,
175      params: Vec<Value>,
176    }
177    self
178      .c
179      .raw_request(
180        Method::POST,
181        format!("{}/exec", self.dbpp),
182        Some(&Input {
183          query: query.as_ref(),
184          params,
185        }),
186      )
187      .await
188  }
189
190  pub async fn query<R: DeserializeOwned>(
191    &self,
192    query: impl AsRef<str>,
193    params: Vec<Value>,
194  ) -> DbRpcClientResult<Vec<R>> {
195    #[derive(Serialize)]
196    struct Input<'a> {
197      query: &'a str,
198      params: Vec<Value>,
199    }
200    self
201      .c
202      .raw_request(
203        Method::POST,
204        format!("{}/query", self.dbpp),
205        Some(&Input {
206          query: query.as_ref(),
207          params,
208        }),
209      )
210      .await
211  }
212}