1use std::{fmt, time::Duration};
2
3use reqwest::{header, StatusCode};
4use tokio::time::sleep;
5
6use crate::{
7 decode::{build_execute_statement, decode_exec_result, decode_query_result},
8 wire::{self, PipelineRequest, Request},
9 BunnyDbError, ClientOptions, ExecResult, Params, QueryResult, Result, Statement,
10 StatementOutcome,
11};
12
13#[derive(Clone)]
14pub struct BunnyDbClient {
16 http: reqwest::Client,
17 pipeline_url: String,
18 token: String,
19 options: ClientOptions,
20}
21
22impl fmt::Debug for BunnyDbClient {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_struct("BunnyDbClient")
25 .field("pipeline_url", &self.pipeline_url)
26 .field("token", &"<redacted>")
27 .field("options", &self.options)
28 .finish()
29 }
30}
31
32impl BunnyDbClient {
33 pub fn new(pipeline_url: impl Into<String>, token: impl Into<String>) -> Self {
38 Self::new_raw_auth(pipeline_url, token)
39 }
40
41 pub fn new_raw_auth(pipeline_url: impl Into<String>, authorization: impl Into<String>) -> Self {
45 Self {
46 http: reqwest::Client::new(),
47 pipeline_url: pipeline_url.into(),
48 token: authorization.into(),
49 options: ClientOptions::default(),
50 }
51 }
52
53 pub fn new_bearer(pipeline_url: impl Into<String>, token: impl AsRef<str>) -> Self {
57 let authorization = normalize_bearer_authorization(token.as_ref());
58 Self::new_raw_auth(pipeline_url, authorization)
59 }
60
61 pub fn with_options(mut self, opts: ClientOptions) -> Self {
63 self.options = opts;
64 self
65 }
66
67 pub async fn query<P: Into<Params>>(&self, sql: &str, params: P) -> Result<QueryResult> {
69 let result = self.run_single(sql, params.into(), true).await?;
70 decode_query_result(result)
71 }
72
73 pub async fn execute<P: Into<Params>>(&self, sql: &str, params: P) -> Result<ExecResult> {
75 let result = self.run_single(sql, params.into(), false).await?;
76 decode_exec_result(result)
77 }
78
79 pub async fn batch<I>(&self, statements: I) -> Result<Vec<StatementOutcome>>
84 where
85 I: IntoIterator<Item = Statement>,
86 {
87 let statements: Vec<Statement> = statements.into_iter().collect();
88 let mut requests = Vec::with_capacity(statements.len() + 1);
89 let mut wants_rows = Vec::with_capacity(statements.len());
90
91 for statement in statements {
92 let stmt =
93 build_execute_statement(&statement.sql, statement.params, statement.want_rows)?;
94 requests.push(Request::Execute { stmt });
95 wants_rows.push(statement.want_rows);
96 }
97
98 requests.push(Request::Close {});
99 let payload = PipelineRequest { requests };
100 let response = self.send_pipeline_with_retry(&payload).await?;
101
102 let expected = wants_rows.len() + 1;
103 if response.results.len() != expected {
104 return Err(BunnyDbError::Decode(format!(
105 "result count mismatch: expected {expected}, got {}",
106 response.results.len()
107 )));
108 }
109
110 let mut results = response.results.into_iter();
111 let mut outcomes = Vec::with_capacity(wants_rows.len());
112
113 for (index, want_rows) in wants_rows.into_iter().enumerate() {
114 let result = results.next().ok_or_else(|| {
115 BunnyDbError::Decode(format!("missing execute result at index {index}"))
116 })?;
117 outcomes.push(Self::decode_statement_outcome(result, index, want_rows)?);
118 }
119
120 let close_index = outcomes.len();
121 let close = results.next().ok_or_else(|| {
122 BunnyDbError::Decode(format!("missing close result at index {close_index}"))
123 })?;
124 Self::ensure_close_success(close, close_index)?;
125
126 Ok(outcomes)
127 }
128
129 async fn run_single(
130 &self,
131 sql: &str,
132 params: Params,
133 want_rows: bool,
134 ) -> Result<wire::ExecuteResult> {
135 let execute_stmt = build_execute_statement(sql, params, want_rows)?;
136 let payload = PipelineRequest {
137 requests: vec![Request::Execute { stmt: execute_stmt }, Request::Close {}],
138 };
139 let response = self.send_pipeline_with_retry(&payload).await?;
140
141 if response.results.len() != 2 {
142 return Err(BunnyDbError::Decode(format!(
143 "result count mismatch: expected 2, got {}",
144 response.results.len()
145 )));
146 }
147
148 let mut iter = response.results.into_iter();
149 let execute = iter
150 .next()
151 .ok_or_else(|| BunnyDbError::Decode("missing execute result".to_owned()))?;
152 let close = iter
153 .next()
154 .ok_or_else(|| BunnyDbError::Decode("missing close result".to_owned()))?;
155
156 let execute_result = Self::into_execute_result(execute, 0)?;
157 Self::ensure_close_success(close, 1)?;
158 Ok(execute_result)
159 }
160
161 async fn send_pipeline_with_retry(
162 &self,
163 payload: &PipelineRequest,
164 ) -> Result<wire::PipelineResponse> {
165 let mut attempt = 0usize;
166 loop {
167 let response = self
168 .http
169 .post(&self.pipeline_url)
170 .header(header::AUTHORIZATION, &self.token)
171 .header(header::CONTENT_TYPE, "application/json")
172 .timeout(Duration::from_millis(self.options.timeout_ms))
173 .json(payload)
174 .send()
175 .await;
176
177 match response {
178 Ok(response) => {
179 let status = response.status();
180 let body = response.text().await.map_err(BunnyDbError::Transport)?;
181
182 if !status.is_success() {
183 if self.should_retry_status(status) && attempt < self.options.max_retries {
184 self.wait_before_retry(attempt).await;
185 attempt += 1;
186 continue;
187 }
188
189 return Err(BunnyDbError::Http {
190 status: status.as_u16(),
191 body,
192 });
193 }
194
195 return serde_json::from_str::<wire::PipelineResponse>(&body).map_err(|err| {
196 BunnyDbError::Decode(format!(
197 "invalid pipeline response JSON: {err}; body: {body}"
198 ))
199 });
200 }
201 Err(err) => {
202 if self.should_retry_transport(&err) && attempt < self.options.max_retries {
203 self.wait_before_retry(attempt).await;
204 attempt += 1;
205 continue;
206 }
207 return Err(BunnyDbError::Transport(err));
208 }
209 }
210 }
211 }
212
213 fn decode_statement_outcome(
214 result: wire::PipelineResult,
215 request_index: usize,
216 want_rows: bool,
217 ) -> Result<StatementOutcome> {
218 match result.kind.as_str() {
219 "ok" => {
220 let execute_result = Self::into_execute_result(result, request_index)?;
221 if want_rows {
222 Ok(StatementOutcome::Query(decode_query_result(
223 execute_result,
224 )?))
225 } else {
226 Ok(StatementOutcome::Exec(decode_exec_result(execute_result)?))
227 }
228 }
229 "error" => {
230 let error = result.error.ok_or_else(|| {
231 BunnyDbError::Decode(format!(
232 "missing error payload for request {request_index}"
233 ))
234 })?;
235 Ok(StatementOutcome::SqlError {
236 request_index,
237 message: error.message,
238 code: error.code,
239 })
240 }
241 other => Err(BunnyDbError::Decode(format!(
242 "unknown pipeline result type '{other}' at request {request_index}"
243 ))),
244 }
245 }
246
247 fn into_execute_result(
248 result: wire::PipelineResult,
249 request_index: usize,
250 ) -> Result<wire::ExecuteResult> {
251 match result.kind.as_str() {
252 "ok" => {
253 let response = result.response.ok_or_else(|| {
254 BunnyDbError::Decode(format!(
255 "missing response payload for request {request_index}"
256 ))
257 })?;
258 if response.kind != "execute" {
259 return Err(BunnyDbError::Decode(format!(
260 "expected execute response at request {request_index}, got '{}'",
261 response.kind
262 )));
263 }
264 response.result.ok_or_else(|| {
265 BunnyDbError::Decode(format!(
266 "missing execute result payload at request {request_index}"
267 ))
268 })
269 }
270 "error" => {
271 let error = result.error.ok_or_else(|| {
272 BunnyDbError::Decode(format!(
273 "missing error payload for request {request_index}"
274 ))
275 })?;
276 Err(BunnyDbError::Pipeline {
277 request_index,
278 message: error.message,
279 code: error.code,
280 })
281 }
282 other => Err(BunnyDbError::Decode(format!(
283 "unknown pipeline result type '{other}' at request {request_index}"
284 ))),
285 }
286 }
287
288 fn ensure_close_success(result: wire::PipelineResult, request_index: usize) -> Result<()> {
289 match result.kind.as_str() {
290 "ok" => {
291 let response = result.response.ok_or_else(|| {
292 BunnyDbError::Decode(format!(
293 "missing close response payload for request {request_index}"
294 ))
295 })?;
296 if response.kind != "close" {
297 return Err(BunnyDbError::Decode(format!(
298 "expected close response at request {request_index}, got '{}'",
299 response.kind
300 )));
301 }
302 Ok(())
303 }
304 "error" => {
305 let error = result.error.ok_or_else(|| {
306 BunnyDbError::Decode(format!(
307 "missing error payload for close request {request_index}"
308 ))
309 })?;
310 Err(BunnyDbError::Pipeline {
311 request_index,
312 message: error.message,
313 code: error.code,
314 })
315 }
316 other => Err(BunnyDbError::Decode(format!(
317 "unknown pipeline result type '{other}' at request {request_index}"
318 ))),
319 }
320 }
321
322 fn should_retry_status(&self, status: StatusCode) -> bool {
323 matches!(
324 status,
325 StatusCode::TOO_MANY_REQUESTS
326 | StatusCode::INTERNAL_SERVER_ERROR
327 | StatusCode::BAD_GATEWAY
328 | StatusCode::SERVICE_UNAVAILABLE
329 | StatusCode::GATEWAY_TIMEOUT
330 )
331 }
332
333 fn should_retry_transport(&self, err: &reqwest::Error) -> bool {
334 err.is_timeout() || err.is_connect() || err.is_request() || err.is_body()
335 }
336
337 async fn wait_before_retry(&self, attempt: usize) {
338 let exp = attempt.min(16) as u32;
339 let multiplier = 1u64 << exp;
340 let delay_ms = self.options.retry_backoff_ms.saturating_mul(multiplier);
341 #[cfg(feature = "tracing")]
342 tracing::debug!("retrying pipeline request after {} ms", delay_ms);
343 sleep(Duration::from_millis(delay_ms)).await;
344 }
345}
346
347fn normalize_bearer_authorization(token: &str) -> String {
348 let trimmed = token.trim();
349 let prefix = trimmed.get(..7);
350 if prefix.is_some_and(|value| value.eq_ignore_ascii_case("bearer ")) {
351 trimmed.to_owned()
352 } else {
353 format!("Bearer {trimmed}")
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::{normalize_bearer_authorization, BunnyDbClient};
360
361 #[test]
362 fn normalize_bearer_adds_prefix_when_missing() {
363 assert_eq!(
364 normalize_bearer_authorization("abc123"),
365 "Bearer abc123".to_owned()
366 );
367 }
368
369 #[test]
370 fn normalize_bearer_keeps_existing_prefix() {
371 assert_eq!(
372 normalize_bearer_authorization("bEaReR abc123"),
373 "bEaReR abc123".to_owned()
374 );
375 }
376
377 #[test]
378 fn debug_redacts_authorization_value() {
379 let client = BunnyDbClient::new_raw_auth("https://db/v2/pipeline", "secret-token");
380 let debug = format!("{client:?}");
381 assert!(debug.contains("<redacted>"));
382 assert!(!debug.contains("secret-token"));
383 }
384}