Skip to main content

duroxide_cdb/
client.rs

1use crate::errors;
2use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
3use chrono::Utc;
4use duroxide::providers::ProviderError;
5use hmac::{Hmac, Mac};
6use reqwest::header::{HeaderMap, HeaderValue};
7use sha2::Sha256;
8use std::sync::Arc;
9
10type HmacSha256 = Hmac<Sha256>;
11
12/// Low-level CosmosDB REST API client.
13/// Wraps reqwest and handles auth headers, serialization, and error mapping.
14#[derive(Clone)]
15pub struct CosmosDBClient {
16    inner: Arc<CosmosDBClientInner>,
17}
18
19struct CosmosDBClientInner {
20    http: reqwest::Client,
21    endpoint: String,
22    key_bytes: Vec<u8>,
23    database: String,
24    container: String,
25}
26
27/// Response from a CosmosDB REST call.
28#[derive(Debug)]
29pub struct CosmosDBResponse {
30    pub status: u16,
31    pub etag: Option<String>,
32    pub body: String,
33}
34
35impl CosmosDBResponse {
36    pub fn is_success(&self) -> bool {
37        self.status >= 200 && self.status < 300
38    }
39}
40
41impl CosmosDBClient {
42    pub fn new(endpoint: &str, key: &str, database: &str, container: &str) -> Result<Self, String> {
43        let key_bytes = BASE64
44            .decode(key)
45            .map_err(|e| format!("Invalid CosmosDB key: {e}"))?;
46        let http = reqwest::Client::builder()
47            .danger_accept_invalid_certs(true) // For local emulator
48            .pool_max_idle_per_host(20)
49            .build()
50            .map_err(|e| format!("Failed to create HTTP client: {e}"))?;
51
52        Ok(Self {
53            inner: Arc::new(CosmosDBClientInner {
54                http,
55                endpoint: endpoint.trim_end_matches('/').to_string(),
56                key_bytes,
57                database: database.to_string(),
58                container: container.to_string(),
59            }),
60        })
61    }
62
63    pub fn endpoint(&self) -> &str {
64        &self.inner.endpoint
65    }
66
67    pub fn database(&self) -> &str {
68        &self.inner.database
69    }
70
71    pub fn container(&self) -> &str {
72        &self.inner.container
73    }
74
75    fn collection_url(&self) -> String {
76        format!(
77            "{}/dbs/{}/colls/{}",
78            self.inner.endpoint, self.inner.database, self.inner.container
79        )
80    }
81
82    fn doc_url(&self, doc_id: &str) -> String {
83        format!(
84            "{}/dbs/{}/colls/{}/docs/{}",
85            self.inner.endpoint,
86            self.inner.database,
87            self.inner.container,
88            urlencoding::encode(doc_id)
89        )
90    }
91
92    /// Generate the CosmosDB authorization header.
93    fn auth_header(
94        &self,
95        verb: &str,
96        resource_type: &str,
97        resource_link: &str,
98        date: &str,
99    ) -> String {
100        let payload = format!(
101            "{}\n{}\n{}\n{}\n\n",
102            verb.to_lowercase(),
103            resource_type.to_lowercase(),
104            resource_link,
105            date.to_lowercase()
106        );
107
108        let mut mac =
109            HmacSha256::new_from_slice(&self.inner.key_bytes).expect("HMAC key creation failed");
110        mac.update(payload.as_bytes());
111        let signature = BASE64.encode(mac.finalize().into_bytes());
112
113        let auth = format!("type=master&ver=1.0&sig={signature}");
114        urlencoding::encode(&auth).to_string()
115    }
116
117    fn resource_link_for_collection(&self) -> String {
118        format!("dbs/{}/colls/{}", self.inner.database, self.inner.container)
119    }
120
121    fn resource_link_for_doc(&self, doc_id: &str) -> String {
122        format!(
123            "dbs/{}/colls/{}/docs/{}",
124            self.inner.database, self.inner.container, doc_id
125        )
126    }
127
128    fn common_headers(
129        &self,
130        verb: &str,
131        resource_type: &str,
132        resource_link: &str,
133        partition_key: Option<&str>,
134    ) -> HeaderMap {
135        let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
136        let auth = self.auth_header(verb, resource_type, resource_link, &date);
137
138        let mut headers = HeaderMap::new();
139        headers.insert("x-ms-date", HeaderValue::from_str(&date).unwrap());
140        headers.insert("x-ms-version", HeaderValue::from_static("2020-07-15"));
141        headers.insert("Authorization", HeaderValue::from_str(&auth).unwrap());
142        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
143
144        if let Some(pk) = partition_key {
145            let pk_header = format!("[\"{pk}\"]");
146            headers.insert(
147                "x-ms-documentdb-partitionkey",
148                HeaderValue::from_str(&pk_header).unwrap(),
149            );
150        }
151
152        headers
153    }
154
155    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
156    // Database / Container management
157    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
158
159    /// Create database if it doesn't exist.
160    pub async fn ensure_database(&self) -> Result<(), ProviderError> {
161        let url = format!("{}/dbs", self.inner.endpoint);
162        let resource_link = "";
163        let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
164        let auth = self.auth_header("post", "dbs", resource_link, &date);
165
166        let body = serde_json::json!({ "id": self.inner.database });
167
168        let resp = self
169            .inner
170            .http
171            .post(&url)
172            .header("x-ms-date", &date)
173            .header("x-ms-version", "2020-07-15")
174            .header("Authorization", &auth)
175            .header("Content-Type", "application/json")
176            .json(&body)
177            .send()
178            .await
179            .map_err(|e| ProviderError::retryable("ensure_database", e.to_string()))?;
180
181        let status = resp.status().as_u16();
182        if status == 201 || status == 409 {
183            // 201 = created, 409 = already exists
184            Ok(())
185        } else {
186            let text = resp.text().await.unwrap_or_default();
187            Err(errors::map_cosmosdb_error("ensure_database", status, &text))
188        }
189    }
190
191    /// Create container if it doesn't exist.
192    pub async fn ensure_container(
193        &self,
194        indexing_policy: Option<serde_json::Value>,
195    ) -> Result<(), ProviderError> {
196        let url = format!("{}/dbs/{}/colls", self.inner.endpoint, self.inner.database);
197        let resource_link = format!("dbs/{}", self.inner.database);
198        let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
199        let auth = self.auth_header("post", "colls", &resource_link, &date);
200
201        let mut body = serde_json::json!({
202            "id": self.inner.container,
203            "partitionKey": {
204                "paths": ["/instanceId"],
205                "kind": "Hash",
206                "version": 2
207            }
208        });
209
210        if let Some(policy) = indexing_policy {
211            body["indexingPolicy"] = policy;
212        }
213
214        let resp = self
215            .inner
216            .http
217            .post(&url)
218            .header("x-ms-date", &date)
219            .header("x-ms-version", "2020-07-15")
220            .header("Authorization", &auth)
221            .header("Content-Type", "application/json")
222            .json(&body)
223            .send()
224            .await
225            .map_err(|e| ProviderError::retryable("ensure_container", e.to_string()))?;
226
227        let status = resp.status().as_u16();
228        if status == 201 || status == 409 {
229            Ok(())
230        } else {
231            let text = resp.text().await.unwrap_or_default();
232            Err(errors::map_cosmosdb_error(
233                "ensure_container",
234                status,
235                &text,
236            ))
237        }
238    }
239
240    /// Delete container.
241    pub async fn delete_container(&self) -> Result<(), ProviderError> {
242        let url = self.collection_url();
243        let resource_link = self.resource_link_for_collection();
244        let headers = self.common_headers("delete", "colls", &resource_link, None);
245
246        let resp = self
247            .inner
248            .http
249            .delete(&url)
250            .headers(headers)
251            .send()
252            .await
253            .map_err(|e| ProviderError::retryable("delete_container", e.to_string()))?;
254
255        let status = resp.status().as_u16();
256        if status == 204 || status == 404 {
257            Ok(())
258        } else {
259            let text = resp.text().await.unwrap_or_default();
260            Err(errors::map_cosmosdb_error(
261                "delete_container",
262                status,
263                &text,
264            ))
265        }
266    }
267
268    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
269    // Document CRUD
270    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
271
272    /// Create a document. Returns the response with the created document.
273    pub async fn create_document(
274        &self,
275        partition_key: &str,
276        document: &serde_json::Value,
277    ) -> Result<CosmosDBResponse, ProviderError> {
278        let url = format!("{}/docs", self.collection_url());
279        let resource_link = self.resource_link_for_collection();
280        let headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
281
282        let resp = self
283            .inner
284            .http
285            .post(&url)
286            .headers(headers)
287            .json(document)
288            .send()
289            .await
290            .map_err(|e| ProviderError::retryable("create_document", e.to_string()))?;
291
292        let status = resp.status().as_u16();
293        let etag = resp
294            .headers()
295            .get("etag")
296            .and_then(|v| v.to_str().ok())
297            .map(|s| s.to_string());
298        let body = resp.text().await.unwrap_or_default();
299
300        Ok(CosmosDBResponse { status, etag, body })
301    }
302
303    /// Upsert a document (create or replace).
304    pub async fn upsert_document(
305        &self,
306        partition_key: &str,
307        document: &serde_json::Value,
308    ) -> Result<CosmosDBResponse, ProviderError> {
309        let url = format!("{}/docs", self.collection_url());
310        let resource_link = self.resource_link_for_collection();
311        let mut headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
312        headers.insert(
313            "x-ms-documentdb-is-upsert",
314            HeaderValue::from_static("true"),
315        );
316
317        let resp = self
318            .inner
319            .http
320            .post(&url)
321            .headers(headers)
322            .json(document)
323            .send()
324            .await
325            .map_err(|e| ProviderError::retryable("upsert_document", e.to_string()))?;
326
327        let status = resp.status().as_u16();
328        let etag = resp
329            .headers()
330            .get("etag")
331            .and_then(|v| v.to_str().ok())
332            .map(|s| s.to_string());
333        let body = resp.text().await.unwrap_or_default();
334
335        Ok(CosmosDBResponse { status, etag, body })
336    }
337
338    /// Read a document by ID.
339    pub async fn read_document(
340        &self,
341        doc_id: &str,
342        partition_key: &str,
343    ) -> Result<CosmosDBResponse, ProviderError> {
344        let url = self.doc_url(doc_id);
345        let resource_link = self.resource_link_for_doc(doc_id);
346        let headers = self.common_headers("get", "docs", &resource_link, Some(partition_key));
347
348        let resp = self
349            .inner
350            .http
351            .get(&url)
352            .headers(headers)
353            .send()
354            .await
355            .map_err(|e| ProviderError::retryable("read_document", e.to_string()))?;
356
357        let status = resp.status().as_u16();
358        let etag = resp
359            .headers()
360            .get("etag")
361            .and_then(|v| v.to_str().ok())
362            .map(|s| s.to_string());
363        let body = resp.text().await.unwrap_or_default();
364
365        Ok(CosmosDBResponse { status, etag, body })
366    }
367
368    /// Replace a document by ID with optional ETag condition.
369    pub async fn replace_document(
370        &self,
371        doc_id: &str,
372        partition_key: &str,
373        document: &serde_json::Value,
374        if_match: Option<&str>,
375    ) -> Result<CosmosDBResponse, ProviderError> {
376        let url = self.doc_url(doc_id);
377        let resource_link = self.resource_link_for_doc(doc_id);
378        let mut headers = self.common_headers("put", "docs", &resource_link, Some(partition_key));
379
380        if let Some(etag) = if_match {
381            headers.insert("If-Match", HeaderValue::from_str(etag).unwrap());
382        }
383
384        let resp = self
385            .inner
386            .http
387            .put(&url)
388            .headers(headers)
389            .json(document)
390            .send()
391            .await
392            .map_err(|e| ProviderError::retryable("replace_document", e.to_string()))?;
393
394        let status = resp.status().as_u16();
395        let etag = resp
396            .headers()
397            .get("etag")
398            .and_then(|v| v.to_str().ok())
399            .map(|s| s.to_string());
400        let body = resp.text().await.unwrap_or_default();
401
402        Ok(CosmosDBResponse { status, etag, body })
403    }
404
405    /// Delete a document by ID.
406    pub async fn delete_document(
407        &self,
408        doc_id: &str,
409        partition_key: &str,
410    ) -> Result<CosmosDBResponse, ProviderError> {
411        let url = self.doc_url(doc_id);
412        let resource_link = self.resource_link_for_doc(doc_id);
413        let headers = self.common_headers("delete", "docs", &resource_link, Some(partition_key));
414
415        let resp = self
416            .inner
417            .http
418            .delete(&url)
419            .headers(headers)
420            .send()
421            .await
422            .map_err(|e| ProviderError::retryable("delete_document", e.to_string()))?;
423
424        let status = resp.status().as_u16();
425        let body = resp.text().await.unwrap_or_default();
426
427        Ok(CosmosDBResponse {
428            status,
429            etag: None,
430            body,
431        })
432    }
433
434    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
435    // Query
436    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
437
438    /// Execute a SQL query. If partition_key is None, this is a cross-partition query.
439    pub async fn query(
440        &self,
441        sql: &str,
442        parameters: Vec<QueryParameter>,
443        partition_key: Option<&str>,
444    ) -> Result<Vec<serde_json::Value>, ProviderError> {
445        let url = format!("{}/docs", self.collection_url());
446        let resource_link = self.resource_link_for_collection();
447        let mut headers = self.common_headers("post", "docs", &resource_link, partition_key);
448        headers.insert("x-ms-documentdb-isquery", HeaderValue::from_static("true"));
449        headers.insert(
450            "Content-Type",
451            HeaderValue::from_static("application/query+json"),
452        );
453        if partition_key.is_none() {
454            headers.insert(
455                "x-ms-documentdb-query-enablecrosspartition",
456                HeaderValue::from_static("true"),
457            );
458        }
459
460        let query_body = serde_json::json!({
461            "query": sql,
462            "parameters": parameters.iter().map(|p| {
463                serde_json::json!({
464                    "name": p.name,
465                    "value": p.value
466                })
467            }).collect::<Vec<_>>()
468        });
469
470        let mut all_documents = Vec::new();
471        let mut continuation: Option<String> = None;
472
473        loop {
474            let mut req_headers = headers.clone();
475            if let Some(ref token) = continuation {
476                req_headers.insert("x-ms-continuation", HeaderValue::from_str(token).unwrap());
477            }
478
479            let resp = self
480                .inner
481                .http
482                .post(&url)
483                .headers(req_headers)
484                .json(&query_body)
485                .send()
486                .await
487                .map_err(|e| ProviderError::retryable("query", e.to_string()))?;
488
489            let status = resp.status().as_u16();
490            let next_continuation = resp
491                .headers()
492                .get("x-ms-continuation")
493                .and_then(|v| v.to_str().ok())
494                .map(|s| s.to_string());
495
496            let body = resp.text().await.unwrap_or_default();
497
498            if !((200..300).contains(&status)) {
499                return Err(errors::map_cosmosdb_error("query", status, &body));
500            }
501
502            let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
503                ProviderError::permanent("query", format!("Failed to parse query response: {e}"))
504            })?;
505
506            if let Some(docs) = parsed.get("Documents").and_then(|d| d.as_array()) {
507                all_documents.extend(docs.iter().cloned());
508            }
509
510            match next_continuation {
511                Some(token) if !token.is_empty() => {
512                    continuation = Some(token);
513                }
514                _ => break,
515            }
516        }
517
518        Ok(all_documents)
519    }
520
521    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
522    // Transactional Batch
523    // ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
524
525    /// Execute a transactional batch of operations within a single partition.
526    pub async fn transactional_batch(
527        &self,
528        partition_key: &str,
529        operations: Vec<BatchOperation>,
530    ) -> Result<Vec<BatchOperationResult>, ProviderError> {
531        let url = format!("{}/docs", self.collection_url());
532        let resource_link = self.resource_link_for_collection();
533        let mut headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
534        headers.insert(
535            "x-ms-cosmos-is-batch-request",
536            HeaderValue::from_static("true"),
537        );
538        // Batch atomicity: if any operation fails, all roll back
539        headers.insert(
540            "x-ms-cosmos-batch-continue-on-error",
541            HeaderValue::from_static("false"),
542        );
543
544        let batch_body: Vec<serde_json::Value> = operations.iter().map(|op| op.to_json()).collect();
545
546        let resp = self
547            .inner
548            .http
549            .post(&url)
550            .headers(headers)
551            .json(&batch_body)
552            .send()
553            .await
554            .map_err(|e| ProviderError::retryable("transactional_batch", e.to_string()))?;
555
556        let status = resp.status().as_u16();
557        let body = resp.text().await.unwrap_or_default();
558
559        if status == 200 || status == 207 {
560            // Parse batch results
561            let results: Vec<serde_json::Value> = serde_json::from_str(&body).map_err(|e| {
562                ProviderError::permanent(
563                    "transactional_batch",
564                    format!("Failed to parse batch response: {e}"),
565                )
566            })?;
567
568            let batch_results: Vec<BatchOperationResult> = results
569                .into_iter()
570                .map(|r| {
571                    let op_status =
572                        r.get("statusCode").and_then(|s| s.as_u64()).unwrap_or(0) as u16;
573                    let etag = r
574                        .get("eTag")
575                        .and_then(|e| e.as_str())
576                        .map(|s| s.to_string());
577                    let resource_body = r.get("resourceBody").map(|b| b.to_string());
578                    BatchOperationResult {
579                        status_code: op_status,
580                        etag,
581                        resource_body,
582                    }
583                })
584                .collect();
585
586            // Check if any operation in the batch failed
587            if let Some(failed) = batch_results.iter().find(|r| r.status_code >= 400) {
588                let msg = format!("Batch operation failed with status {}", failed.status_code);
589                if failed.status_code == 409 || failed.status_code == 412 {
590                    return Err(ProviderError::retryable("transactional_batch", msg));
591                }
592                return Err(ProviderError::permanent("transactional_batch", msg));
593            }
594
595            Ok(batch_results)
596        } else {
597            Err(errors::map_cosmosdb_error(
598                "transactional_batch",
599                status,
600                &body,
601            ))
602        }
603    }
604}
605
606// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
607// Query parameter
608// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
609
610#[derive(Debug, Clone)]
611pub struct QueryParameter {
612    pub name: String,
613    pub value: serde_json::Value,
614}
615
616impl QueryParameter {
617    pub fn new(name: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
618        Self {
619            name: name.into(),
620            value: value.into(),
621        }
622    }
623}
624
625// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
626// Batch operation types
627// ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
628
629#[derive(Debug, Clone)]
630pub enum BatchOperation {
631    Create {
632        body: serde_json::Value,
633    },
634    Upsert {
635        body: serde_json::Value,
636    },
637    Replace {
638        id: String,
639        body: serde_json::Value,
640        if_match: Option<String>,
641    },
642    Delete {
643        id: String,
644    },
645    Read {
646        id: String,
647    },
648}
649
650impl BatchOperation {
651    pub fn to_json(&self) -> serde_json::Value {
652        match self {
653            BatchOperation::Create { body } => serde_json::json!({
654                "operationType": "Create",
655                "resourceBody": body
656            }),
657            BatchOperation::Upsert { body } => serde_json::json!({
658                "operationType": "Upsert",
659                "resourceBody": body
660            }),
661            BatchOperation::Replace { id, body, if_match } => {
662                let mut op = serde_json::json!({
663                    "operationType": "Replace",
664                    "id": id,
665                    "resourceBody": body
666                });
667                if let Some(etag) = if_match {
668                    op["ifMatch"] = serde_json::json!(etag);
669                }
670                op
671            }
672            BatchOperation::Delete { id } => serde_json::json!({
673                "operationType": "Delete",
674                "id": id
675            }),
676            BatchOperation::Read { id } => serde_json::json!({
677                "operationType": "Read",
678                "id": id
679            }),
680        }
681    }
682}
683
684#[derive(Debug, Clone)]
685pub struct BatchOperationResult {
686    pub status_code: u16,
687    pub etag: Option<String>,
688    pub resource_body: Option<String>,
689}