1use crate::errors;
2use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
3use chrono::Utc;
4use duroxide::providers::ProviderError;
5use hmac::{Hmac, Mac};
6use reqwest::header::{HeaderMap, HeaderValue};
7use sha2::Sha256;
8use std::sync::Arc;
9
10type HmacSha256 = Hmac<Sha256>;
11
12#[derive(Clone)]
15pub struct CosmosDBClient {
16 inner: Arc<CosmosDBClientInner>,
17}
18
19struct CosmosDBClientInner {
20 http: reqwest::Client,
21 endpoint: String,
22 key_bytes: Vec<u8>,
23 database: String,
24 container: String,
25}
26
27#[derive(Debug)]
29pub struct CosmosDBResponse {
30 pub status: u16,
31 pub etag: Option<String>,
32 pub body: String,
33}
34
35impl CosmosDBResponse {
36 pub fn is_success(&self) -> bool {
37 self.status >= 200 && self.status < 300
38 }
39}
40
41impl CosmosDBClient {
42 pub fn new(endpoint: &str, key: &str, database: &str, container: &str) -> Result<Self, String> {
43 let key_bytes = BASE64
44 .decode(key)
45 .map_err(|e| format!("Invalid CosmosDB key: {e}"))?;
46 let http = reqwest::Client::builder()
47 .danger_accept_invalid_certs(true) .pool_max_idle_per_host(20)
49 .build()
50 .map_err(|e| format!("Failed to create HTTP client: {e}"))?;
51
52 Ok(Self {
53 inner: Arc::new(CosmosDBClientInner {
54 http,
55 endpoint: endpoint.trim_end_matches('/').to_string(),
56 key_bytes,
57 database: database.to_string(),
58 container: container.to_string(),
59 }),
60 })
61 }
62
63 pub fn endpoint(&self) -> &str {
64 &self.inner.endpoint
65 }
66
67 pub fn database(&self) -> &str {
68 &self.inner.database
69 }
70
71 pub fn container(&self) -> &str {
72 &self.inner.container
73 }
74
75 fn collection_url(&self) -> String {
76 format!(
77 "{}/dbs/{}/colls/{}",
78 self.inner.endpoint, self.inner.database, self.inner.container
79 )
80 }
81
82 fn doc_url(&self, doc_id: &str) -> String {
83 format!(
84 "{}/dbs/{}/colls/{}/docs/{}",
85 self.inner.endpoint,
86 self.inner.database,
87 self.inner.container,
88 urlencoding::encode(doc_id)
89 )
90 }
91
92 fn auth_header(
94 &self,
95 verb: &str,
96 resource_type: &str,
97 resource_link: &str,
98 date: &str,
99 ) -> String {
100 let payload = format!(
101 "{}\n{}\n{}\n{}\n\n",
102 verb.to_lowercase(),
103 resource_type.to_lowercase(),
104 resource_link,
105 date.to_lowercase()
106 );
107
108 let mut mac =
109 HmacSha256::new_from_slice(&self.inner.key_bytes).expect("HMAC key creation failed");
110 mac.update(payload.as_bytes());
111 let signature = BASE64.encode(mac.finalize().into_bytes());
112
113 let auth = format!("type=master&ver=1.0&sig={signature}");
114 urlencoding::encode(&auth).to_string()
115 }
116
117 fn resource_link_for_collection(&self) -> String {
118 format!("dbs/{}/colls/{}", self.inner.database, self.inner.container)
119 }
120
121 fn resource_link_for_doc(&self, doc_id: &str) -> String {
122 format!(
123 "dbs/{}/colls/{}/docs/{}",
124 self.inner.database, self.inner.container, doc_id
125 )
126 }
127
128 fn common_headers(
129 &self,
130 verb: &str,
131 resource_type: &str,
132 resource_link: &str,
133 partition_key: Option<&str>,
134 ) -> HeaderMap {
135 let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
136 let auth = self.auth_header(verb, resource_type, resource_link, &date);
137
138 let mut headers = HeaderMap::new();
139 headers.insert("x-ms-date", HeaderValue::from_str(&date).unwrap());
140 headers.insert("x-ms-version", HeaderValue::from_static("2020-07-15"));
141 headers.insert("Authorization", HeaderValue::from_str(&auth).unwrap());
142 headers.insert("Content-Type", HeaderValue::from_static("application/json"));
143
144 if let Some(pk) = partition_key {
145 let pk_header = format!("[\"{pk}\"]");
146 headers.insert(
147 "x-ms-documentdb-partitionkey",
148 HeaderValue::from_str(&pk_header).unwrap(),
149 );
150 }
151
152 headers
153 }
154
155 pub async fn ensure_database(&self) -> Result<(), ProviderError> {
161 let url = format!("{}/dbs", self.inner.endpoint);
162 let resource_link = "";
163 let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
164 let auth = self.auth_header("post", "dbs", resource_link, &date);
165
166 let body = serde_json::json!({ "id": self.inner.database });
167
168 let resp = self
169 .inner
170 .http
171 .post(&url)
172 .header("x-ms-date", &date)
173 .header("x-ms-version", "2020-07-15")
174 .header("Authorization", &auth)
175 .header("Content-Type", "application/json")
176 .json(&body)
177 .send()
178 .await
179 .map_err(|e| ProviderError::retryable("ensure_database", e.to_string()))?;
180
181 let status = resp.status().as_u16();
182 if status == 201 || status == 409 {
183 Ok(())
185 } else {
186 let text = resp.text().await.unwrap_or_default();
187 Err(errors::map_cosmosdb_error("ensure_database", status, &text))
188 }
189 }
190
191 pub async fn ensure_container(
193 &self,
194 indexing_policy: Option<serde_json::Value>,
195 ) -> Result<(), ProviderError> {
196 let url = format!("{}/dbs/{}/colls", self.inner.endpoint, self.inner.database);
197 let resource_link = format!("dbs/{}", self.inner.database);
198 let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
199 let auth = self.auth_header("post", "colls", &resource_link, &date);
200
201 let mut body = serde_json::json!({
202 "id": self.inner.container,
203 "partitionKey": {
204 "paths": ["/instanceId"],
205 "kind": "Hash",
206 "version": 2
207 }
208 });
209
210 if let Some(policy) = indexing_policy {
211 body["indexingPolicy"] = policy;
212 }
213
214 let resp = self
215 .inner
216 .http
217 .post(&url)
218 .header("x-ms-date", &date)
219 .header("x-ms-version", "2020-07-15")
220 .header("Authorization", &auth)
221 .header("Content-Type", "application/json")
222 .json(&body)
223 .send()
224 .await
225 .map_err(|e| ProviderError::retryable("ensure_container", e.to_string()))?;
226
227 let status = resp.status().as_u16();
228 if status == 201 || status == 409 {
229 Ok(())
230 } else {
231 let text = resp.text().await.unwrap_or_default();
232 Err(errors::map_cosmosdb_error(
233 "ensure_container",
234 status,
235 &text,
236 ))
237 }
238 }
239
240 pub async fn delete_container(&self) -> Result<(), ProviderError> {
242 let url = self.collection_url();
243 let resource_link = self.resource_link_for_collection();
244 let headers = self.common_headers("delete", "colls", &resource_link, None);
245
246 let resp = self
247 .inner
248 .http
249 .delete(&url)
250 .headers(headers)
251 .send()
252 .await
253 .map_err(|e| ProviderError::retryable("delete_container", e.to_string()))?;
254
255 let status = resp.status().as_u16();
256 if status == 204 || status == 404 {
257 Ok(())
258 } else {
259 let text = resp.text().await.unwrap_or_default();
260 Err(errors::map_cosmosdb_error(
261 "delete_container",
262 status,
263 &text,
264 ))
265 }
266 }
267
268 pub async fn create_document(
274 &self,
275 partition_key: &str,
276 document: &serde_json::Value,
277 ) -> Result<CosmosDBResponse, ProviderError> {
278 let url = format!("{}/docs", self.collection_url());
279 let resource_link = self.resource_link_for_collection();
280 let headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
281
282 let resp = self
283 .inner
284 .http
285 .post(&url)
286 .headers(headers)
287 .json(document)
288 .send()
289 .await
290 .map_err(|e| ProviderError::retryable("create_document", e.to_string()))?;
291
292 let status = resp.status().as_u16();
293 let etag = resp
294 .headers()
295 .get("etag")
296 .and_then(|v| v.to_str().ok())
297 .map(|s| s.to_string());
298 let body = resp.text().await.unwrap_or_default();
299
300 Ok(CosmosDBResponse { status, etag, body })
301 }
302
303 pub async fn upsert_document(
305 &self,
306 partition_key: &str,
307 document: &serde_json::Value,
308 ) -> Result<CosmosDBResponse, ProviderError> {
309 let url = format!("{}/docs", self.collection_url());
310 let resource_link = self.resource_link_for_collection();
311 let mut headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
312 headers.insert(
313 "x-ms-documentdb-is-upsert",
314 HeaderValue::from_static("true"),
315 );
316
317 let resp = self
318 .inner
319 .http
320 .post(&url)
321 .headers(headers)
322 .json(document)
323 .send()
324 .await
325 .map_err(|e| ProviderError::retryable("upsert_document", e.to_string()))?;
326
327 let status = resp.status().as_u16();
328 let etag = resp
329 .headers()
330 .get("etag")
331 .and_then(|v| v.to_str().ok())
332 .map(|s| s.to_string());
333 let body = resp.text().await.unwrap_or_default();
334
335 Ok(CosmosDBResponse { status, etag, body })
336 }
337
338 pub async fn read_document(
340 &self,
341 doc_id: &str,
342 partition_key: &str,
343 ) -> Result<CosmosDBResponse, ProviderError> {
344 let url = self.doc_url(doc_id);
345 let resource_link = self.resource_link_for_doc(doc_id);
346 let headers = self.common_headers("get", "docs", &resource_link, Some(partition_key));
347
348 let resp = self
349 .inner
350 .http
351 .get(&url)
352 .headers(headers)
353 .send()
354 .await
355 .map_err(|e| ProviderError::retryable("read_document", e.to_string()))?;
356
357 let status = resp.status().as_u16();
358 let etag = resp
359 .headers()
360 .get("etag")
361 .and_then(|v| v.to_str().ok())
362 .map(|s| s.to_string());
363 let body = resp.text().await.unwrap_or_default();
364
365 Ok(CosmosDBResponse { status, etag, body })
366 }
367
368 pub async fn replace_document(
370 &self,
371 doc_id: &str,
372 partition_key: &str,
373 document: &serde_json::Value,
374 if_match: Option<&str>,
375 ) -> Result<CosmosDBResponse, ProviderError> {
376 let url = self.doc_url(doc_id);
377 let resource_link = self.resource_link_for_doc(doc_id);
378 let mut headers = self.common_headers("put", "docs", &resource_link, Some(partition_key));
379
380 if let Some(etag) = if_match {
381 headers.insert("If-Match", HeaderValue::from_str(etag).unwrap());
382 }
383
384 let resp = self
385 .inner
386 .http
387 .put(&url)
388 .headers(headers)
389 .json(document)
390 .send()
391 .await
392 .map_err(|e| ProviderError::retryable("replace_document", e.to_string()))?;
393
394 let status = resp.status().as_u16();
395 let etag = resp
396 .headers()
397 .get("etag")
398 .and_then(|v| v.to_str().ok())
399 .map(|s| s.to_string());
400 let body = resp.text().await.unwrap_or_default();
401
402 Ok(CosmosDBResponse { status, etag, body })
403 }
404
405 pub async fn delete_document(
407 &self,
408 doc_id: &str,
409 partition_key: &str,
410 ) -> Result<CosmosDBResponse, ProviderError> {
411 let url = self.doc_url(doc_id);
412 let resource_link = self.resource_link_for_doc(doc_id);
413 let headers = self.common_headers("delete", "docs", &resource_link, Some(partition_key));
414
415 let resp = self
416 .inner
417 .http
418 .delete(&url)
419 .headers(headers)
420 .send()
421 .await
422 .map_err(|e| ProviderError::retryable("delete_document", e.to_string()))?;
423
424 let status = resp.status().as_u16();
425 let body = resp.text().await.unwrap_or_default();
426
427 Ok(CosmosDBResponse {
428 status,
429 etag: None,
430 body,
431 })
432 }
433
434 pub async fn query(
440 &self,
441 sql: &str,
442 parameters: Vec<QueryParameter>,
443 partition_key: Option<&str>,
444 ) -> Result<Vec<serde_json::Value>, ProviderError> {
445 let url = format!("{}/docs", self.collection_url());
446 let resource_link = self.resource_link_for_collection();
447 let mut headers = self.common_headers("post", "docs", &resource_link, partition_key);
448 headers.insert("x-ms-documentdb-isquery", HeaderValue::from_static("true"));
449 headers.insert(
450 "Content-Type",
451 HeaderValue::from_static("application/query+json"),
452 );
453 if partition_key.is_none() {
454 headers.insert(
455 "x-ms-documentdb-query-enablecrosspartition",
456 HeaderValue::from_static("true"),
457 );
458 }
459
460 let query_body = serde_json::json!({
461 "query": sql,
462 "parameters": parameters.iter().map(|p| {
463 serde_json::json!({
464 "name": p.name,
465 "value": p.value
466 })
467 }).collect::<Vec<_>>()
468 });
469
470 let mut all_documents = Vec::new();
471 let mut continuation: Option<String> = None;
472
473 loop {
474 let mut req_headers = headers.clone();
475 if let Some(ref token) = continuation {
476 req_headers.insert("x-ms-continuation", HeaderValue::from_str(token).unwrap());
477 }
478
479 let resp = self
480 .inner
481 .http
482 .post(&url)
483 .headers(req_headers)
484 .json(&query_body)
485 .send()
486 .await
487 .map_err(|e| ProviderError::retryable("query", e.to_string()))?;
488
489 let status = resp.status().as_u16();
490 let next_continuation = resp
491 .headers()
492 .get("x-ms-continuation")
493 .and_then(|v| v.to_str().ok())
494 .map(|s| s.to_string());
495
496 let body = resp.text().await.unwrap_or_default();
497
498 if !((200..300).contains(&status)) {
499 return Err(errors::map_cosmosdb_error("query", status, &body));
500 }
501
502 let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
503 ProviderError::permanent("query", format!("Failed to parse query response: {e}"))
504 })?;
505
506 if let Some(docs) = parsed.get("Documents").and_then(|d| d.as_array()) {
507 all_documents.extend(docs.iter().cloned());
508 }
509
510 match next_continuation {
511 Some(token) if !token.is_empty() => {
512 continuation = Some(token);
513 }
514 _ => break,
515 }
516 }
517
518 Ok(all_documents)
519 }
520
521 pub async fn transactional_batch(
527 &self,
528 partition_key: &str,
529 operations: Vec<BatchOperation>,
530 ) -> Result<Vec<BatchOperationResult>, ProviderError> {
531 let url = format!("{}/docs", self.collection_url());
532 let resource_link = self.resource_link_for_collection();
533 let mut headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
534 headers.insert(
535 "x-ms-cosmos-is-batch-request",
536 HeaderValue::from_static("true"),
537 );
538 headers.insert(
540 "x-ms-cosmos-batch-continue-on-error",
541 HeaderValue::from_static("false"),
542 );
543
544 let batch_body: Vec<serde_json::Value> = operations.iter().map(|op| op.to_json()).collect();
545
546 let resp = self
547 .inner
548 .http
549 .post(&url)
550 .headers(headers)
551 .json(&batch_body)
552 .send()
553 .await
554 .map_err(|e| ProviderError::retryable("transactional_batch", e.to_string()))?;
555
556 let status = resp.status().as_u16();
557 let body = resp.text().await.unwrap_or_default();
558
559 if status == 200 || status == 207 {
560 let results: Vec<serde_json::Value> = serde_json::from_str(&body).map_err(|e| {
562 ProviderError::permanent(
563 "transactional_batch",
564 format!("Failed to parse batch response: {e}"),
565 )
566 })?;
567
568 let batch_results: Vec<BatchOperationResult> = results
569 .into_iter()
570 .map(|r| {
571 let op_status =
572 r.get("statusCode").and_then(|s| s.as_u64()).unwrap_or(0) as u16;
573 let etag = r
574 .get("eTag")
575 .and_then(|e| e.as_str())
576 .map(|s| s.to_string());
577 let resource_body = r.get("resourceBody").map(|b| b.to_string());
578 BatchOperationResult {
579 status_code: op_status,
580 etag,
581 resource_body,
582 }
583 })
584 .collect();
585
586 if let Some(failed) = batch_results.iter().find(|r| r.status_code >= 400) {
588 let msg = format!("Batch operation failed with status {}", failed.status_code);
589 if failed.status_code == 409 || failed.status_code == 412 {
590 return Err(ProviderError::retryable("transactional_batch", msg));
591 }
592 return Err(ProviderError::permanent("transactional_batch", msg));
593 }
594
595 Ok(batch_results)
596 } else {
597 Err(errors::map_cosmosdb_error(
598 "transactional_batch",
599 status,
600 &body,
601 ))
602 }
603 }
604}
605
606#[derive(Debug, Clone)]
611pub struct QueryParameter {
612 pub name: String,
613 pub value: serde_json::Value,
614}
615
616impl QueryParameter {
617 pub fn new(name: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
618 Self {
619 name: name.into(),
620 value: value.into(),
621 }
622 }
623}
624
625#[derive(Debug, Clone)]
630pub enum BatchOperation {
631 Create {
632 body: serde_json::Value,
633 },
634 Upsert {
635 body: serde_json::Value,
636 },
637 Replace {
638 id: String,
639 body: serde_json::Value,
640 if_match: Option<String>,
641 },
642 Delete {
643 id: String,
644 },
645 Read {
646 id: String,
647 },
648}
649
650impl BatchOperation {
651 pub fn to_json(&self) -> serde_json::Value {
652 match self {
653 BatchOperation::Create { body } => serde_json::json!({
654 "operationType": "Create",
655 "resourceBody": body
656 }),
657 BatchOperation::Upsert { body } => serde_json::json!({
658 "operationType": "Upsert",
659 "resourceBody": body
660 }),
661 BatchOperation::Replace { id, body, if_match } => {
662 let mut op = serde_json::json!({
663 "operationType": "Replace",
664 "id": id,
665 "resourceBody": body
666 });
667 if let Some(etag) = if_match {
668 op["ifMatch"] = serde_json::json!(etag);
669 }
670 op
671 }
672 BatchOperation::Delete { id } => serde_json::json!({
673 "operationType": "Delete",
674 "id": id
675 }),
676 BatchOperation::Read { id } => serde_json::json!({
677 "operationType": "Read",
678 "id": id
679 }),
680 }
681 }
682}
683
684#[derive(Debug, Clone)]
685pub struct BatchOperationResult {
686 pub status_code: u16,
687 pub etag: Option<String>,
688 pub resource_body: Option<String>,
689}