use std::sync::Arc;
use async_trait::async_trait;
use azure_data_cosmos::prelude::{
AuthorizationToken, CollectionClient, CosmosClient as AzCosmosClient, CosmosClientBuilder,
GetDocumentResponse, Param, Query, QueryCrossPartition,
};
use azure_identity::DefaultAzureCredentialBuilder;
use futures::StreamExt as _;
use serde_json::Value;
use crate::cosmos::{
CosmosBackend, CosmosError, QueryStream, document::CosmosDocument,
query_stream::QueryStreamInner,
};
#[derive(Clone)]
pub struct CosmosClient {
inner: AzCosmosClient,
database_name: String,
}
impl CosmosClient {
pub async fn new(account_endpoint: &str, database_name: &str) -> Result<Self, CosmosError> {
let credential = DefaultAzureCredentialBuilder::default()
.build()
.map_err(|e| {
CosmosError::Backend(format!("failed to build DefaultAzureCredential: {e}"))
})?;
let auth_token = AuthorizationToken::from_token_credential(Arc::new(credential));
let inner =
CosmosClientBuilder::with_location(azure_data_cosmos::prelude::CloudLocation::Custom {
uri: account_endpoint.trim_end_matches('/').to_string(),
auth_token,
})
.build();
Ok(Self {
inner,
database_name: database_name.to_string(),
})
}
fn collection(&self, container: &str) -> CollectionClient {
self.inner
.database_client(self.database_name.clone())
.collection_client(container.to_string())
}
}
#[async_trait]
impl CosmosBackend for CosmosClient {
async fn upsert(&self, container: &str, doc: Value) -> Result<(), CosmosError> {
let id = CosmosDocument::extract_id(&doc)?.to_string();
let pk = CosmosDocument::extract_partition_key(&doc)?.to_string();
let collection = self.collection(container);
collection
.create_document(doc)
.is_upsert(true)
.partition_key(&pk)
.map_err(sdk_err)?
.into_future()
.await
.map_err(sdk_err)?;
let _ = id; Ok(())
}
async fn get(
&self,
container: &str,
id: &str,
partition_key: &str,
) -> Result<Option<Value>, CosmosError> {
let collection = self.collection(container);
let document_client = collection
.document_client(id, &partition_key)
.map_err(sdk_err)?;
let response: GetDocumentResponse<Value> = document_client
.get_document()
.into_future()
.await
.map_err(sdk_err)?;
match response {
GetDocumentResponse::Found(found) => Ok(Some(found.document.document)),
GetDocumentResponse::NotFound(_) => Ok(None),
}
}
async fn query(
&self,
container: &str,
sql: &str,
params: Vec<(String, Value)>,
) -> Result<QueryStream, CosmosError> {
let sdk_params: Vec<Param> = params
.into_iter()
.map(|(name, value)| Param::new(name, value))
.collect();
let query = Query::with_params(sql.to_string(), sdk_params);
let collection = self.collection(container);
let pageable = collection
.query_documents(query)
.query_cross_partition(QueryCrossPartition::Yes)
.into_stream::<Value>();
Ok(QueryStream::new(Box::new(SdkQueryStream {
pageable,
continuation: None,
})))
}
}
type SdkPageable = azure_data_cosmos::prelude::QueryDocuments<Value>;
struct SdkQueryStream {
pageable: SdkPageable,
continuation: Option<String>,
}
unsafe impl Send for SdkQueryStream {}
#[async_trait]
impl QueryStreamInner for SdkQueryStream {
async fn next_page(&mut self) -> Result<Option<Vec<Value>>, CosmosError> {
match self.pageable.next().await {
None => {
self.continuation = None;
Ok(None)
}
Some(Err(e)) => Err(sdk_err(e)),
Some(Ok(response)) => {
use azure_core::headers::Header;
self.continuation = response
.continuation_token
.as_ref()
.map(|c| c.value().as_str().to_owned());
let docs: Vec<Value> = response
.results
.into_iter()
.map(|(doc, _attrs)| doc)
.collect();
Ok(Some(docs))
}
}
}
fn continuation_token(&self) -> Option<&str> {
self.continuation.as_deref()
}
}
fn sdk_err(e: impl std::fmt::Display) -> CosmosError {
CosmosError::Backend(e.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn extracts_id_and_partition_key() {
let doc = json!({
"id": "prod::my-jira::DO-1",
"_partition_key": "prod",
"title": "Fix the bug"
});
let id = CosmosDocument::extract_id(&doc).unwrap();
let pk = CosmosDocument::extract_partition_key(&doc).unwrap();
assert_eq!(id, "prod::my-jira::DO-1");
assert_eq!(pk, "prod");
}
#[test]
fn extract_id_missing_returns_validation_error() {
let doc = json!({ "_partition_key": "prod" });
let result = CosmosDocument::extract_id(&doc);
assert!(matches!(result, Err(CosmosError::Validation(_))));
}
#[test]
fn extract_partition_key_missing_returns_validation_error() {
let doc = json!({ "id": "some-id" });
let result = CosmosDocument::extract_partition_key(&doc);
assert!(matches!(result, Err(CosmosError::Validation(_))));
}
#[test]
fn sdk_err_wraps_message_in_backend_variant() {
let err = sdk_err("connection refused");
assert!(matches!(err, CosmosError::Backend(_)));
let CosmosError::Backend(msg) = err else {
panic!("wrong variant");
};
assert!(msg.contains("connection refused"));
}
#[test]
fn sdk_err_wraps_any_display_type() {
let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timed out connecting");
let cosmos_err = sdk_err(io_err);
assert!(matches!(cosmos_err, CosmosError::Backend(_)));
let CosmosError::Backend(msg) = cosmos_err else {
panic!("wrong variant");
};
assert!(msg.contains("timed out connecting"));
}
#[tokio::test]
#[ignore = "requires Azure Cosmos; set QUELCH_COSMOS_E2E_ENDPOINT to enable"]
async fn e2e_upsert_get_query_round_trip() {
let endpoint = match std::env::var("QUELCH_COSMOS_E2E_ENDPOINT") {
Ok(e) => e,
Err(_) => return, };
let client = CosmosClient::new(&endpoint, "quelch-test")
.await
.expect("CosmosClient::new should succeed with valid credentials");
let container = "e2e-test";
let doc = json!({
"id": "e2e-test-doc-1",
"_partition_key": "e2e",
"title": "E2E round-trip document",
"value": 42
});
client
.upsert(container, doc.clone())
.await
.expect("upsert should succeed");
let fetched = client
.get(container, "e2e-test-doc-1", "e2e")
.await
.expect("get should succeed");
assert!(fetched.is_some(), "document should exist after upsert");
let fetched = fetched.unwrap();
assert_eq!(fetched["title"], "E2E round-trip document");
assert_eq!(fetched["value"], 42);
let mut stream = client
.query(
container,
"SELECT * FROM c WHERE c.id = @id",
vec![("@id".into(), json!("e2e-test-doc-1"))],
)
.await
.expect("query should succeed");
let page = stream
.next_page()
.await
.expect("next_page should succeed")
.expect("should have at least one page");
assert_eq!(page.len(), 1);
assert_eq!(page[0]["id"], "e2e-test-doc-1");
assert_eq!(page[0]["value"], 42);
let missing = client
.get(container, "definitely-does-not-exist", "e2e")
.await
.expect("get of missing doc should return Ok(None)");
assert!(missing.is_none());
let updated = json!({
"id": "e2e-test-doc-1",
"_partition_key": "e2e",
"title": "Updated document",
"value": 99
});
client
.upsert(container, updated)
.await
.expect("upsert overwrite should succeed");
let refetched = client
.get(container, "e2e-test-doc-1", "e2e")
.await
.expect("get after overwrite should succeed")
.expect("document should still exist");
assert_eq!(refetched["value"], 99);
}
}