use crate::errors;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use chrono::Utc;
use duroxide::providers::ProviderError;
use hmac::{Hmac, Mac};
use reqwest::header::{HeaderMap, HeaderValue};
use sha2::Sha256;
use std::sync::Arc;
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone)]
pub struct CosmosDBClient {
inner: Arc<CosmosDBClientInner>,
}
struct CosmosDBClientInner {
http: reqwest::Client,
endpoint: String,
key_bytes: Vec<u8>,
database: String,
container: String,
}
#[derive(Debug)]
pub struct CosmosDBResponse {
pub status: u16,
pub etag: Option<String>,
pub body: String,
}
impl CosmosDBResponse {
pub fn is_success(&self) -> bool {
self.status >= 200 && self.status < 300
}
}
impl CosmosDBClient {
pub fn new(endpoint: &str, key: &str, database: &str, container: &str) -> Result<Self, String> {
let key_bytes = BASE64
.decode(key)
.map_err(|e| format!("Invalid CosmosDB key: {e}"))?;
let http = reqwest::Client::builder()
.danger_accept_invalid_certs(true) .pool_max_idle_per_host(20)
.build()
.map_err(|e| format!("Failed to create HTTP client: {e}"))?;
Ok(Self {
inner: Arc::new(CosmosDBClientInner {
http,
endpoint: endpoint.trim_end_matches('/').to_string(),
key_bytes,
database: database.to_string(),
container: container.to_string(),
}),
})
}
pub fn endpoint(&self) -> &str {
&self.inner.endpoint
}
pub fn database(&self) -> &str {
&self.inner.database
}
pub fn container(&self) -> &str {
&self.inner.container
}
fn collection_url(&self) -> String {
format!(
"{}/dbs/{}/colls/{}",
self.inner.endpoint, self.inner.database, self.inner.container
)
}
fn doc_url(&self, doc_id: &str) -> String {
format!(
"{}/dbs/{}/colls/{}/docs/{}",
self.inner.endpoint,
self.inner.database,
self.inner.container,
urlencoding::encode(doc_id)
)
}
fn auth_header(
&self,
verb: &str,
resource_type: &str,
resource_link: &str,
date: &str,
) -> String {
let payload = format!(
"{}\n{}\n{}\n{}\n\n",
verb.to_lowercase(),
resource_type.to_lowercase(),
resource_link,
date.to_lowercase()
);
let mut mac =
HmacSha256::new_from_slice(&self.inner.key_bytes).expect("HMAC key creation failed");
mac.update(payload.as_bytes());
let signature = BASE64.encode(mac.finalize().into_bytes());
let auth = format!("type=master&ver=1.0&sig={signature}");
urlencoding::encode(&auth).to_string()
}
fn resource_link_for_collection(&self) -> String {
format!("dbs/{}/colls/{}", self.inner.database, self.inner.container)
}
fn resource_link_for_doc(&self, doc_id: &str) -> String {
format!(
"dbs/{}/colls/{}/docs/{}",
self.inner.database, self.inner.container, doc_id
)
}
fn common_headers(
&self,
verb: &str,
resource_type: &str,
resource_link: &str,
partition_key: Option<&str>,
) -> HeaderMap {
let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
let auth = self.auth_header(verb, resource_type, resource_link, &date);
let mut headers = HeaderMap::new();
headers.insert("x-ms-date", HeaderValue::from_str(&date).unwrap());
headers.insert("x-ms-version", HeaderValue::from_static("2020-07-15"));
headers.insert("Authorization", HeaderValue::from_str(&auth).unwrap());
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
if let Some(pk) = partition_key {
let pk_header = format!("[\"{pk}\"]");
headers.insert(
"x-ms-documentdb-partitionkey",
HeaderValue::from_str(&pk_header).unwrap(),
);
}
headers
}
pub async fn ensure_database(&self) -> Result<(), ProviderError> {
let url = format!("{}/dbs", self.inner.endpoint);
let resource_link = "";
let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
let auth = self.auth_header("post", "dbs", resource_link, &date);
let body = serde_json::json!({ "id": self.inner.database });
let resp = self
.inner
.http
.post(&url)
.header("x-ms-date", &date)
.header("x-ms-version", "2020-07-15")
.header("Authorization", &auth)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::retryable("ensure_database", e.to_string()))?;
let status = resp.status().as_u16();
if status == 201 || status == 409 {
Ok(())
} else {
let text = resp.text().await.unwrap_or_default();
Err(errors::map_cosmosdb_error("ensure_database", status, &text))
}
}
pub async fn ensure_container(
&self,
indexing_policy: Option<serde_json::Value>,
) -> Result<(), ProviderError> {
let url = format!("{}/dbs/{}/colls", self.inner.endpoint, self.inner.database);
let resource_link = format!("dbs/{}", self.inner.database);
let date = Utc::now().format("%a, %d %b %Y %H:%M:%S GMT").to_string();
let auth = self.auth_header("post", "colls", &resource_link, &date);
let mut body = serde_json::json!({
"id": self.inner.container,
"partitionKey": {
"paths": ["/instanceId"],
"kind": "Hash",
"version": 2
}
});
if let Some(policy) = indexing_policy {
body["indexingPolicy"] = policy;
}
let resp = self
.inner
.http
.post(&url)
.header("x-ms-date", &date)
.header("x-ms-version", "2020-07-15")
.header("Authorization", &auth)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| ProviderError::retryable("ensure_container", e.to_string()))?;
let status = resp.status().as_u16();
if status == 201 || status == 409 {
Ok(())
} else {
let text = resp.text().await.unwrap_or_default();
Err(errors::map_cosmosdb_error(
"ensure_container",
status,
&text,
))
}
}
pub async fn delete_container(&self) -> Result<(), ProviderError> {
let url = self.collection_url();
let resource_link = self.resource_link_for_collection();
let headers = self.common_headers("delete", "colls", &resource_link, None);
let resp = self
.inner
.http
.delete(&url)
.headers(headers)
.send()
.await
.map_err(|e| ProviderError::retryable("delete_container", e.to_string()))?;
let status = resp.status().as_u16();
if status == 204 || status == 404 {
Ok(())
} else {
let text = resp.text().await.unwrap_or_default();
Err(errors::map_cosmosdb_error(
"delete_container",
status,
&text,
))
}
}
pub async fn create_document(
&self,
partition_key: &str,
document: &serde_json::Value,
) -> Result<CosmosDBResponse, ProviderError> {
let url = format!("{}/docs", self.collection_url());
let resource_link = self.resource_link_for_collection();
let headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
let resp = self
.inner
.http
.post(&url)
.headers(headers)
.json(document)
.send()
.await
.map_err(|e| ProviderError::retryable("create_document", e.to_string()))?;
let status = resp.status().as_u16();
let etag = resp
.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = resp.text().await.unwrap_or_default();
Ok(CosmosDBResponse { status, etag, body })
}
pub async fn upsert_document(
&self,
partition_key: &str,
document: &serde_json::Value,
) -> Result<CosmosDBResponse, ProviderError> {
let url = format!("{}/docs", self.collection_url());
let resource_link = self.resource_link_for_collection();
let mut headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
headers.insert(
"x-ms-documentdb-is-upsert",
HeaderValue::from_static("true"),
);
let resp = self
.inner
.http
.post(&url)
.headers(headers)
.json(document)
.send()
.await
.map_err(|e| ProviderError::retryable("upsert_document", e.to_string()))?;
let status = resp.status().as_u16();
let etag = resp
.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = resp.text().await.unwrap_or_default();
Ok(CosmosDBResponse { status, etag, body })
}
pub async fn read_document(
&self,
doc_id: &str,
partition_key: &str,
) -> Result<CosmosDBResponse, ProviderError> {
let url = self.doc_url(doc_id);
let resource_link = self.resource_link_for_doc(doc_id);
let headers = self.common_headers("get", "docs", &resource_link, Some(partition_key));
let resp = self
.inner
.http
.get(&url)
.headers(headers)
.send()
.await
.map_err(|e| ProviderError::retryable("read_document", e.to_string()))?;
let status = resp.status().as_u16();
let etag = resp
.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = resp.text().await.unwrap_or_default();
Ok(CosmosDBResponse { status, etag, body })
}
pub async fn replace_document(
&self,
doc_id: &str,
partition_key: &str,
document: &serde_json::Value,
if_match: Option<&str>,
) -> Result<CosmosDBResponse, ProviderError> {
let url = self.doc_url(doc_id);
let resource_link = self.resource_link_for_doc(doc_id);
let mut headers = self.common_headers("put", "docs", &resource_link, Some(partition_key));
if let Some(etag) = if_match {
headers.insert("If-Match", HeaderValue::from_str(etag).unwrap());
}
let resp = self
.inner
.http
.put(&url)
.headers(headers)
.json(document)
.send()
.await
.map_err(|e| ProviderError::retryable("replace_document", e.to_string()))?;
let status = resp.status().as_u16();
let etag = resp
.headers()
.get("etag")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = resp.text().await.unwrap_or_default();
Ok(CosmosDBResponse { status, etag, body })
}
pub async fn delete_document(
&self,
doc_id: &str,
partition_key: &str,
) -> Result<CosmosDBResponse, ProviderError> {
let url = self.doc_url(doc_id);
let resource_link = self.resource_link_for_doc(doc_id);
let headers = self.common_headers("delete", "docs", &resource_link, Some(partition_key));
let resp = self
.inner
.http
.delete(&url)
.headers(headers)
.send()
.await
.map_err(|e| ProviderError::retryable("delete_document", e.to_string()))?;
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
Ok(CosmosDBResponse {
status,
etag: None,
body,
})
}
pub async fn query(
&self,
sql: &str,
parameters: Vec<QueryParameter>,
partition_key: Option<&str>,
) -> Result<Vec<serde_json::Value>, ProviderError> {
let url = format!("{}/docs", self.collection_url());
let resource_link = self.resource_link_for_collection();
let mut headers = self.common_headers("post", "docs", &resource_link, partition_key);
headers.insert("x-ms-documentdb-isquery", HeaderValue::from_static("true"));
headers.insert(
"Content-Type",
HeaderValue::from_static("application/query+json"),
);
if partition_key.is_none() {
headers.insert(
"x-ms-documentdb-query-enablecrosspartition",
HeaderValue::from_static("true"),
);
}
let query_body = serde_json::json!({
"query": sql,
"parameters": parameters.iter().map(|p| {
serde_json::json!({
"name": p.name,
"value": p.value
})
}).collect::<Vec<_>>()
});
let mut all_documents = Vec::new();
let mut continuation: Option<String> = None;
loop {
let mut req_headers = headers.clone();
if let Some(ref token) = continuation {
req_headers.insert("x-ms-continuation", HeaderValue::from_str(token).unwrap());
}
let resp = self
.inner
.http
.post(&url)
.headers(req_headers)
.json(&query_body)
.send()
.await
.map_err(|e| ProviderError::retryable("query", e.to_string()))?;
let status = resp.status().as_u16();
let next_continuation = resp
.headers()
.get("x-ms-continuation")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body = resp.text().await.unwrap_or_default();
if !((200..300).contains(&status)) {
return Err(errors::map_cosmosdb_error("query", status, &body));
}
let parsed: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
ProviderError::permanent("query", format!("Failed to parse query response: {e}"))
})?;
if let Some(docs) = parsed.get("Documents").and_then(|d| d.as_array()) {
all_documents.extend(docs.iter().cloned());
}
match next_continuation {
Some(token) if !token.is_empty() => {
continuation = Some(token);
}
_ => break,
}
}
Ok(all_documents)
}
pub async fn transactional_batch(
&self,
partition_key: &str,
operations: Vec<BatchOperation>,
) -> Result<Vec<BatchOperationResult>, ProviderError> {
let url = format!("{}/docs", self.collection_url());
let resource_link = self.resource_link_for_collection();
let mut headers = self.common_headers("post", "docs", &resource_link, Some(partition_key));
headers.insert(
"x-ms-cosmos-is-batch-request",
HeaderValue::from_static("true"),
);
headers.insert(
"x-ms-cosmos-batch-continue-on-error",
HeaderValue::from_static("false"),
);
let batch_body: Vec<serde_json::Value> = operations.iter().map(|op| op.to_json()).collect();
let resp = self
.inner
.http
.post(&url)
.headers(headers)
.json(&batch_body)
.send()
.await
.map_err(|e| ProviderError::retryable("transactional_batch", e.to_string()))?;
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
if status == 200 || status == 207 {
let results: Vec<serde_json::Value> = serde_json::from_str(&body).map_err(|e| {
ProviderError::permanent(
"transactional_batch",
format!("Failed to parse batch response: {e}"),
)
})?;
let batch_results: Vec<BatchOperationResult> = results
.into_iter()
.map(|r| {
let op_status =
r.get("statusCode").and_then(|s| s.as_u64()).unwrap_or(0) as u16;
let etag = r
.get("eTag")
.and_then(|e| e.as_str())
.map(|s| s.to_string());
let resource_body = r.get("resourceBody").map(|b| b.to_string());
BatchOperationResult {
status_code: op_status,
etag,
resource_body,
}
})
.collect();
if let Some(failed) = batch_results.iter().find(|r| r.status_code >= 400) {
let msg = format!("Batch operation failed with status {}", failed.status_code);
if failed.status_code == 409 || failed.status_code == 412 {
return Err(ProviderError::retryable("transactional_batch", msg));
}
return Err(ProviderError::permanent("transactional_batch", msg));
}
Ok(batch_results)
} else {
Err(errors::map_cosmosdb_error(
"transactional_batch",
status,
&body,
))
}
}
}
#[derive(Debug, Clone)]
pub struct QueryParameter {
pub name: String,
pub value: serde_json::Value,
}
impl QueryParameter {
pub fn new(name: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
Self {
name: name.into(),
value: value.into(),
}
}
}
#[derive(Debug, Clone)]
pub enum BatchOperation {
Create {
body: serde_json::Value,
},
Upsert {
body: serde_json::Value,
},
Replace {
id: String,
body: serde_json::Value,
if_match: Option<String>,
},
Delete {
id: String,
},
Read {
id: String,
},
}
impl BatchOperation {
pub fn to_json(&self) -> serde_json::Value {
match self {
BatchOperation::Create { body } => serde_json::json!({
"operationType": "Create",
"resourceBody": body
}),
BatchOperation::Upsert { body } => serde_json::json!({
"operationType": "Upsert",
"resourceBody": body
}),
BatchOperation::Replace { id, body, if_match } => {
let mut op = serde_json::json!({
"operationType": "Replace",
"id": id,
"resourceBody": body
});
if let Some(etag) = if_match {
op["ifMatch"] = serde_json::json!(etag);
}
op
}
BatchOperation::Delete { id } => serde_json::json!({
"operationType": "Delete",
"id": id
}),
BatchOperation::Read { id } => serde_json::json!({
"operationType": "Read",
"id": id
}),
}
}
}
#[derive(Debug, Clone)]
pub struct BatchOperationResult {
pub status_code: u16,
pub etag: Option<String>,
pub resource_body: Option<String>,
}