1use percent_encoding::utf8_percent_encode;
2use percent_encoding::NON_ALPHANUMERIC;
3use reqwest::header::CONTENT_TYPE;
4use reqwest::Method;
5use rmpv::Value;
6use serde::de::DeserializeOwned;
7use serde::Deserialize;
8use serde::Serialize;
9use std::error::Error;
10use std::fmt::Display;
11
12#[derive(Debug)]
13pub enum DbRpcClientError {
14 Api {
15 status: u16,
16 error: String,
17 error_details: Option<String>,
18 },
19 Unauthorized,
20 Request(reqwest::Error),
21}
22
23impl Display for DbRpcClientError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 DbRpcClientError::Api {
27 status,
28 error,
29 error_details,
30 } => write!(f, "API error ({status} - {error}): {error_details:?}"),
31 DbRpcClientError::Unauthorized => write!(f, "unauthorized"),
32 DbRpcClientError::Request(e) => write!(f, "request error: {e}"),
33 }
34 }
35}
36
37impl Error for DbRpcClientError {}
38
39pub type DbRpcClientResult<T> = Result<T, DbRpcClientError>;
40
41#[derive(Clone, Debug)]
42pub struct DbRpcClientCfg {
43 pub api_key: Option<String>,
44 pub endpoint: String,
45}
46
47#[derive(Clone, Debug)]
48pub struct DbRpcClient {
49 r: reqwest::Client,
50 cfg: DbRpcClientCfg,
51}
52
53impl DbRpcClient {
54 pub fn with_request_client(request_client: reqwest::Client, cfg: DbRpcClientCfg) -> Self {
55 Self {
56 r: request_client,
57 cfg,
58 }
59 }
60
61 pub fn new(cfg: DbRpcClientCfg) -> Self {
62 Self::with_request_client(reqwest::Client::new(), cfg)
63 }
64
65 async fn raw_request<I: Serialize, O: DeserializeOwned>(
66 &self,
67 method: Method,
68 path: impl AsRef<str>,
69 body: Option<&I>,
70 ) -> DbRpcClientResult<O> {
71 let mut req = self
72 .r
73 .request(method, format!("{}{}", self.cfg.endpoint, path.as_ref()))
74 .header("accept", "application/msgpack");
75 if let Some(k) = &self.cfg.api_key {
76 req = req.header("authorization", k);
77 };
78 if let Some(b) = body {
79 let raw = rmp_serde::to_vec_named(b).unwrap();
80 req = req.header("content-type", "application/msgpack").body(raw);
81 };
82 let res = req
83 .send()
84 .await
85 .map_err(|err| DbRpcClientError::Request(err))?;
86 let status = res.status().as_u16();
87 let res_type = res
88 .headers()
89 .get(CONTENT_TYPE)
90 .and_then(|v| v.to_str().ok().map(|v| v.to_string()))
91 .unwrap_or_default();
92 let res_body_raw = res
93 .bytes()
94 .await
95 .map_err(|err| DbRpcClientError::Request(err))?;
96 if status == 401 {
97 return Err(DbRpcClientError::Unauthorized);
98 };
99 #[derive(Deserialize)]
100 struct ApiError {
101 error: String,
102 error_details: Option<String>,
103 }
104 if status < 200 || status > 299 || !res_type.starts_with("application/msgpack") {
105 return Err(match rmp_serde::from_slice::<ApiError>(&res_body_raw) {
107 Ok(api_error) => DbRpcClientError::Api {
108 status,
109 error: api_error.error,
110 error_details: api_error.error_details,
111 },
112 Err(_) => DbRpcClientError::Api {
113 status,
114 error: String::from_utf8_lossy(&res_body_raw).into_owned(),
116 error_details: None,
117 },
118 });
119 };
120 Ok(rmp_serde::from_slice(&res_body_raw).unwrap())
121 }
122
123 pub fn database(&self, db_name: &str) -> DbRpcDbClient {
124 DbRpcDbClient {
125 c: self.clone(),
126 dbpp: format!("/db/{}", utf8_percent_encode(&db_name, NON_ALPHANUMERIC)),
127 }
128 }
129}
130
131#[derive(Clone, Debug)]
132pub struct DbRpcDbClient {
133 c: DbRpcClient,
134 dbpp: String,
135}
136
137#[derive(Deserialize)]
138pub struct ExecResult {
139 pub affected_rows: u64,
140 pub last_insert_id: Option<u64>,
141}
142
143impl DbRpcDbClient {
144 pub async fn batch(
145 &self,
146 query: impl AsRef<str>,
147 params: Vec<Vec<Value>>,
148 ) -> DbRpcClientResult<Vec<ExecResult>> {
149 #[derive(Serialize)]
150 struct Input<'a> {
151 query: &'a str,
152 params: Vec<Vec<Value>>,
153 }
154 self
155 .c
156 .raw_request(
157 Method::POST,
158 format!("{}/batch", self.dbpp),
159 Some(&Input {
160 query: query.as_ref(),
161 params,
162 }),
163 )
164 .await
165 }
166
167 pub async fn exec(
168 &self,
169 query: impl AsRef<str>,
170 params: Vec<Value>,
171 ) -> DbRpcClientResult<ExecResult> {
172 #[derive(Serialize)]
173 struct Input<'a> {
174 query: &'a str,
175 params: Vec<Value>,
176 }
177 self
178 .c
179 .raw_request(
180 Method::POST,
181 format!("{}/exec", self.dbpp),
182 Some(&Input {
183 query: query.as_ref(),
184 params,
185 }),
186 )
187 .await
188 }
189
190 pub async fn query<R: DeserializeOwned>(
191 &self,
192 query: impl AsRef<str>,
193 params: Vec<Value>,
194 ) -> DbRpcClientResult<Vec<R>> {
195 #[derive(Serialize)]
196 struct Input<'a> {
197 query: &'a str,
198 params: Vec<Value>,
199 }
200 self
201 .c
202 .raw_request(
203 Method::POST,
204 format!("{}/query", self.dbpp),
205 Some(&Input {
206 query: query.as_ref(),
207 params,
208 }),
209 )
210 .await
211 }
212}