Skip to main content

bunnydb_http/
client.rs

1use std::{fmt, time::Duration};
2
3use reqwest::{header, StatusCode};
4use tokio::time::sleep;
5
6use crate::{
7    decode::{build_execute_statement, decode_exec_result, decode_query_result},
8    wire::{self, PipelineRequest, Request},
9    BunnyDbError, ClientOptions, ExecResult, Params, QueryResult, Result, Statement,
10    StatementOutcome,
11};
12
13#[derive(Clone)]
14/// HTTP client for Bunny.net Database SQL pipeline endpoint.
15pub struct BunnyDbClient {
16    http: reqwest::Client,
17    pipeline_url: String,
18    token: String,
19    options: ClientOptions,
20}
21
22impl fmt::Debug for BunnyDbClient {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("BunnyDbClient")
25            .field("pipeline_url", &self.pipeline_url)
26            .field("token", &"<redacted>")
27            .field("options", &self.options)
28            .finish()
29    }
30}
31
32impl BunnyDbClient {
33    /// Creates a client with a raw authorization header value.
34    ///
35    /// This is backward-compatible with previous versions where `token`
36    /// was passed directly as `Authorization: <value>`.
37    pub fn new(pipeline_url: impl Into<String>, token: impl Into<String>) -> Self {
38        Self::new_raw_auth(pipeline_url, token)
39    }
40
41    /// Creates a client with a full raw authorization value.
42    ///
43    /// Example: `"Bearer <token>"` or any custom scheme.
44    pub fn new_raw_auth(pipeline_url: impl Into<String>, authorization: impl Into<String>) -> Self {
45        Self {
46            http: reqwest::Client::new(),
47            pipeline_url: pipeline_url.into(),
48            token: authorization.into(),
49            options: ClientOptions::default(),
50        }
51    }
52
53    /// Creates a client from a bearer token.
54    ///
55    /// If the token is missing the `Bearer ` prefix, it is added automatically.
56    pub fn new_bearer(pipeline_url: impl Into<String>, token: impl AsRef<str>) -> Self {
57        let authorization = normalize_bearer_authorization(token.as_ref());
58        Self::new_raw_auth(pipeline_url, authorization)
59    }
60
61    /// Applies client options such as timeout and retry behavior.
62    pub fn with_options(mut self, opts: ClientOptions) -> Self {
63        self.options = opts;
64        self
65    }
66
67    /// Executes a query statement and returns rows.
68    pub async fn query<P: Into<Params>>(&self, sql: &str, params: P) -> Result<QueryResult> {
69        let result = self.run_single(sql, params.into(), true).await?;
70        decode_query_result(result)
71    }
72
73    /// Executes a statement and returns execution metadata.
74    pub async fn execute<P: Into<Params>>(&self, sql: &str, params: P) -> Result<ExecResult> {
75        let result = self.run_single(sql, params.into(), false).await?;
76        decode_exec_result(result)
77    }
78
79    /// Sends multiple statements in one pipeline request.
80    ///
81    /// SQL errors at statement level are returned as
82    /// [`StatementOutcome::SqlError`] instead of failing the entire batch.
83    pub async fn batch<I>(&self, statements: I) -> Result<Vec<StatementOutcome>>
84    where
85        I: IntoIterator<Item = Statement>,
86    {
87        let statements: Vec<Statement> = statements.into_iter().collect();
88        let mut requests = Vec::with_capacity(statements.len() + 1);
89        let mut wants_rows = Vec::with_capacity(statements.len());
90
91        for statement in statements {
92            let stmt =
93                build_execute_statement(&statement.sql, statement.params, statement.want_rows)?;
94            requests.push(Request::Execute { stmt });
95            wants_rows.push(statement.want_rows);
96        }
97
98        requests.push(Request::Close {});
99        let payload = PipelineRequest { requests };
100        let response = self.send_pipeline_with_retry(&payload).await?;
101
102        let expected = wants_rows.len() + 1;
103        if response.results.len() != expected {
104            return Err(BunnyDbError::Decode(format!(
105                "result count mismatch: expected {expected}, got {}",
106                response.results.len()
107            )));
108        }
109
110        let mut results = response.results.into_iter();
111        let mut outcomes = Vec::with_capacity(wants_rows.len());
112
113        for (index, want_rows) in wants_rows.into_iter().enumerate() {
114            let result = results.next().ok_or_else(|| {
115                BunnyDbError::Decode(format!("missing execute result at index {index}"))
116            })?;
117            outcomes.push(Self::decode_statement_outcome(result, index, want_rows)?);
118        }
119
120        let close_index = outcomes.len();
121        let close = results.next().ok_or_else(|| {
122            BunnyDbError::Decode(format!("missing close result at index {close_index}"))
123        })?;
124        Self::ensure_close_success(close, close_index)?;
125
126        Ok(outcomes)
127    }
128
129    async fn run_single(
130        &self,
131        sql: &str,
132        params: Params,
133        want_rows: bool,
134    ) -> Result<wire::ExecuteResult> {
135        let execute_stmt = build_execute_statement(sql, params, want_rows)?;
136        let payload = PipelineRequest {
137            requests: vec![Request::Execute { stmt: execute_stmt }, Request::Close {}],
138        };
139        let response = self.send_pipeline_with_retry(&payload).await?;
140
141        if response.results.len() != 2 {
142            return Err(BunnyDbError::Decode(format!(
143                "result count mismatch: expected 2, got {}",
144                response.results.len()
145            )));
146        }
147
148        let mut iter = response.results.into_iter();
149        let execute = iter
150            .next()
151            .ok_or_else(|| BunnyDbError::Decode("missing execute result".to_owned()))?;
152        let close = iter
153            .next()
154            .ok_or_else(|| BunnyDbError::Decode("missing close result".to_owned()))?;
155
156        let execute_result = Self::into_execute_result(execute, 0)?;
157        Self::ensure_close_success(close, 1)?;
158        Ok(execute_result)
159    }
160
161    async fn send_pipeline_with_retry(
162        &self,
163        payload: &PipelineRequest,
164    ) -> Result<wire::PipelineResponse> {
165        let mut attempt = 0usize;
166        loop {
167            let response = self
168                .http
169                .post(&self.pipeline_url)
170                .header(header::AUTHORIZATION, &self.token)
171                .header(header::CONTENT_TYPE, "application/json")
172                .timeout(Duration::from_millis(self.options.timeout_ms))
173                .json(payload)
174                .send()
175                .await;
176
177            match response {
178                Ok(response) => {
179                    let status = response.status();
180                    let body = response.text().await.map_err(BunnyDbError::Transport)?;
181
182                    if !status.is_success() {
183                        if self.should_retry_status(status) && attempt < self.options.max_retries {
184                            self.wait_before_retry(attempt).await;
185                            attempt += 1;
186                            continue;
187                        }
188
189                        return Err(BunnyDbError::Http {
190                            status: status.as_u16(),
191                            body,
192                        });
193                    }
194
195                    return serde_json::from_str::<wire::PipelineResponse>(&body).map_err(|err| {
196                        BunnyDbError::Decode(format!(
197                            "invalid pipeline response JSON: {err}; body: {body}"
198                        ))
199                    });
200                }
201                Err(err) => {
202                    if self.should_retry_transport(&err) && attempt < self.options.max_retries {
203                        self.wait_before_retry(attempt).await;
204                        attempt += 1;
205                        continue;
206                    }
207                    return Err(BunnyDbError::Transport(err));
208                }
209            }
210        }
211    }
212
213    fn decode_statement_outcome(
214        result: wire::PipelineResult,
215        request_index: usize,
216        want_rows: bool,
217    ) -> Result<StatementOutcome> {
218        match result.kind.as_str() {
219            "ok" => {
220                let execute_result = Self::into_execute_result(result, request_index)?;
221                if want_rows {
222                    Ok(StatementOutcome::Query(decode_query_result(
223                        execute_result,
224                    )?))
225                } else {
226                    Ok(StatementOutcome::Exec(decode_exec_result(execute_result)?))
227                }
228            }
229            "error" => {
230                let error = result.error.ok_or_else(|| {
231                    BunnyDbError::Decode(format!(
232                        "missing error payload for request {request_index}"
233                    ))
234                })?;
235                Ok(StatementOutcome::SqlError {
236                    request_index,
237                    message: error.message,
238                    code: error.code,
239                })
240            }
241            other => Err(BunnyDbError::Decode(format!(
242                "unknown pipeline result type '{other}' at request {request_index}"
243            ))),
244        }
245    }
246
247    fn into_execute_result(
248        result: wire::PipelineResult,
249        request_index: usize,
250    ) -> Result<wire::ExecuteResult> {
251        match result.kind.as_str() {
252            "ok" => {
253                let response = result.response.ok_or_else(|| {
254                    BunnyDbError::Decode(format!(
255                        "missing response payload for request {request_index}"
256                    ))
257                })?;
258                if response.kind != "execute" {
259                    return Err(BunnyDbError::Decode(format!(
260                        "expected execute response at request {request_index}, got '{}'",
261                        response.kind
262                    )));
263                }
264                response.result.ok_or_else(|| {
265                    BunnyDbError::Decode(format!(
266                        "missing execute result payload at request {request_index}"
267                    ))
268                })
269            }
270            "error" => {
271                let error = result.error.ok_or_else(|| {
272                    BunnyDbError::Decode(format!(
273                        "missing error payload for request {request_index}"
274                    ))
275                })?;
276                Err(BunnyDbError::Pipeline {
277                    request_index,
278                    message: error.message,
279                    code: error.code,
280                })
281            }
282            other => Err(BunnyDbError::Decode(format!(
283                "unknown pipeline result type '{other}' at request {request_index}"
284            ))),
285        }
286    }
287
288    fn ensure_close_success(result: wire::PipelineResult, request_index: usize) -> Result<()> {
289        match result.kind.as_str() {
290            "ok" => {
291                let response = result.response.ok_or_else(|| {
292                    BunnyDbError::Decode(format!(
293                        "missing close response payload for request {request_index}"
294                    ))
295                })?;
296                if response.kind != "close" {
297                    return Err(BunnyDbError::Decode(format!(
298                        "expected close response at request {request_index}, got '{}'",
299                        response.kind
300                    )));
301                }
302                Ok(())
303            }
304            "error" => {
305                let error = result.error.ok_or_else(|| {
306                    BunnyDbError::Decode(format!(
307                        "missing error payload for close request {request_index}"
308                    ))
309                })?;
310                Err(BunnyDbError::Pipeline {
311                    request_index,
312                    message: error.message,
313                    code: error.code,
314                })
315            }
316            other => Err(BunnyDbError::Decode(format!(
317                "unknown pipeline result type '{other}' at request {request_index}"
318            ))),
319        }
320    }
321
322    fn should_retry_status(&self, status: StatusCode) -> bool {
323        matches!(
324            status,
325            StatusCode::TOO_MANY_REQUESTS
326                | StatusCode::INTERNAL_SERVER_ERROR
327                | StatusCode::BAD_GATEWAY
328                | StatusCode::SERVICE_UNAVAILABLE
329                | StatusCode::GATEWAY_TIMEOUT
330        )
331    }
332
333    fn should_retry_transport(&self, err: &reqwest::Error) -> bool {
334        err.is_timeout() || err.is_connect() || err.is_request() || err.is_body()
335    }
336
337    async fn wait_before_retry(&self, attempt: usize) {
338        let exp = attempt.min(16) as u32;
339        let multiplier = 1u64 << exp;
340        let delay_ms = self.options.retry_backoff_ms.saturating_mul(multiplier);
341        #[cfg(feature = "tracing")]
342        tracing::debug!("retrying pipeline request after {} ms", delay_ms);
343        sleep(Duration::from_millis(delay_ms)).await;
344    }
345}
346
347fn normalize_bearer_authorization(token: &str) -> String {
348    let trimmed = token.trim();
349    let prefix = trimmed.get(..7);
350    if prefix.is_some_and(|value| value.eq_ignore_ascii_case("bearer ")) {
351        trimmed.to_owned()
352    } else {
353        format!("Bearer {trimmed}")
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::{normalize_bearer_authorization, BunnyDbClient};
360
361    #[test]
362    fn normalize_bearer_adds_prefix_when_missing() {
363        assert_eq!(
364            normalize_bearer_authorization("abc123"),
365            "Bearer abc123".to_owned()
366        );
367    }
368
369    #[test]
370    fn normalize_bearer_keeps_existing_prefix() {
371        assert_eq!(
372            normalize_bearer_authorization("bEaReR abc123"),
373            "bEaReR abc123".to_owned()
374        );
375    }
376
377    #[test]
378    fn debug_redacts_authorization_value() {
379        let client = BunnyDbClient::new_raw_auth("https://db/v2/pipeline", "secret-token");
380        let debug = format!("{client:?}");
381        assert!(debug.contains("<redacted>"));
382        assert!(!debug.contains("secret-token"));
383    }
384}