Skip to main content

cosq_client/
cosmos.rs

1//! Cosmos DB data plane client
2//!
3//! Executes SQL queries against Cosmos DB containers using the REST API
4//! with AAD token authentication. Handles cross-partition queries by
5//! fetching partition key ranges and fanning out the query.
6
7use serde::Deserialize;
8use serde_json::Value;
9use tracing::debug;
10
11use crate::auth::{AzCliAuth, COSMOS_RESOURCE};
12use crate::error::ClientError;
13
14const API_VERSION: &str = "2018-12-31";
15
16/// Result of a Cosmos DB SQL query
17#[derive(Debug)]
18pub struct QueryResult {
19    pub documents: Vec<Value>,
20    pub request_charge: f64,
21}
22
23/// Cosmos DB REST API response for queries
24#[derive(Debug, Deserialize)]
25struct QueryResponse {
26    #[serde(rename = "Documents")]
27    documents: Vec<Value>,
28}
29
30/// Cosmos DB REST API response for listing databases
31#[derive(Debug, Deserialize)]
32struct DatabaseListResponse {
33    #[serde(rename = "Databases")]
34    databases: Vec<DatabaseEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct DatabaseEntry {
39    id: String,
40}
41
42/// Cosmos DB REST API response for listing collections
43#[derive(Debug, Deserialize)]
44struct CollectionListResponse {
45    #[serde(rename = "DocumentCollections")]
46    document_collections: Vec<CollectionEntry>,
47}
48
49#[derive(Debug, Deserialize)]
50struct CollectionEntry {
51    id: String,
52}
53
54/// Partition key range info from the pkranges endpoint
55#[derive(Debug, Deserialize)]
56struct PartitionKeyRangesResponse {
57    #[serde(rename = "PartitionKeyRanges")]
58    partition_key_ranges: Vec<PartitionKeyRange>,
59}
60
61#[derive(Debug, Deserialize)]
62struct PartitionKeyRange {
63    id: String,
64}
65
66/// Client for the Cosmos DB data plane REST API.
67#[derive(Clone)]
68pub struct CosmosClient {
69    http: reqwest::Client,
70    endpoint: String,
71    token: String,
72}
73
74impl CosmosClient {
75    /// Create a new Cosmos client, acquiring a Cosmos DB token via the Azure CLI.
76    pub async fn new(endpoint: &str) -> Result<Self, ClientError> {
77        let token = AzCliAuth::get_token(COSMOS_RESOURCE).await?;
78        let endpoint = endpoint.trim_end_matches('/').to_string();
79        Ok(Self {
80            http: reqwest::Client::new(),
81            endpoint,
82            token,
83        })
84    }
85
86    /// Build the Authorization header value for AAD token auth.
87    fn auth_header(&self) -> String {
88        let sig = urlencoding::encode(&self.token);
89        format!("type%3Daad%26ver%3D1.0%26sig%3D{sig}")
90    }
91
92    /// Build the x-ms-date header value in RFC 1123 format.
93    fn date_header() -> String {
94        chrono::Utc::now()
95            .format("%a, %d %b %Y %H:%M:%S GMT")
96            .to_string()
97    }
98
99    /// List all databases in the Cosmos DB account.
100    pub async fn list_databases(&self) -> Result<Vec<String>, ClientError> {
101        debug!("listing databases");
102        let url = format!("{}/dbs", self.endpoint);
103        let date = Self::date_header();
104
105        let resp = self
106            .http
107            .get(&url)
108            .header("Authorization", self.auth_header())
109            .header("x-ms-date", &date)
110            .header("x-ms-version", API_VERSION)
111            .send()
112            .await?;
113
114        let status = resp.status();
115        if !status.is_success() {
116            let body = resp.text().await.unwrap_or_default();
117            if status.as_u16() == 403 {
118                return Err(ClientError::forbidden(
119                    body,
120                    "You may not have data plane access. Check your Cosmos DB RBAC roles.",
121                ));
122            }
123            return Err(ClientError::api(status.as_u16(), body));
124        }
125
126        let list: DatabaseListResponse = resp.json().await?;
127        let names: Vec<String> = list.databases.into_iter().map(|d| d.id).collect();
128        debug!(count = names.len(), "found databases");
129        Ok(names)
130    }
131
132    /// List all containers in a database.
133    pub async fn list_containers(&self, database: &str) -> Result<Vec<String>, ClientError> {
134        debug!(database, "listing containers");
135        let url = format!("{}/dbs/{}/colls", self.endpoint, database);
136        let date = Self::date_header();
137
138        let resp = self
139            .http
140            .get(&url)
141            .header("Authorization", self.auth_header())
142            .header("x-ms-date", &date)
143            .header("x-ms-version", API_VERSION)
144            .send()
145            .await?;
146
147        let status = resp.status();
148        if !status.is_success() {
149            let body = resp.text().await.unwrap_or_default();
150            return Err(ClientError::api(status.as_u16(), body));
151        }
152
153        let list: CollectionListResponse = resp.json().await?;
154        let names: Vec<String> = list
155            .document_collections
156            .into_iter()
157            .map(|c| c.id)
158            .collect();
159        debug!(count = names.len(), "found containers");
160        Ok(names)
161    }
162
163    /// Get partition key ranges for a container.
164    async fn get_partition_key_ranges(
165        &self,
166        database: &str,
167        container: &str,
168    ) -> Result<Vec<String>, ClientError> {
169        let url = format!(
170            "{}/dbs/{}/colls/{}/pkranges",
171            self.endpoint, database, container
172        );
173        let date = Self::date_header();
174
175        let resp = self
176            .http
177            .get(&url)
178            .header("Authorization", self.auth_header())
179            .header("x-ms-date", &date)
180            .header("x-ms-version", API_VERSION)
181            .send()
182            .await?;
183
184        let status = resp.status();
185        if !status.is_success() {
186            let body = resp.text().await.unwrap_or_default();
187            return Err(ClientError::api(status.as_u16(), body));
188        }
189
190        let ranges: PartitionKeyRangesResponse = resp.json().await?;
191        let ids: Vec<String> = ranges
192            .partition_key_ranges
193            .into_iter()
194            .map(|r| r.id)
195            .collect();
196        debug!(count = ids.len(), "found partition key ranges");
197        Ok(ids)
198    }
199
200    /// Execute a SQL query against a single partition key range, handling pagination.
201    async fn query_partition(
202        &self,
203        url: &str,
204        body: &Value,
205        partition_key_range_id: &str,
206    ) -> Result<(Vec<Value>, f64), ClientError> {
207        let mut documents = Vec::new();
208        let mut total_charge = 0.0_f64;
209        let mut continuation: Option<String> = None;
210
211        loop {
212            let date = Self::date_header();
213            let mut request = self
214                .http
215                .post(url)
216                .header("Authorization", self.auth_header())
217                .header("x-ms-date", &date)
218                .header("x-ms-version", API_VERSION)
219                .header("x-ms-documentdb-isquery", "True")
220                .header("x-ms-documentdb-query-enablecrosspartition", "True")
221                .header(
222                    "x-ms-documentdb-partitionkeyrangeid",
223                    partition_key_range_id,
224                )
225                .header("Content-Type", "application/query+json")
226                .json(body);
227
228            if let Some(ref token) = continuation {
229                request = request.header("x-ms-continuation", token);
230            }
231
232            let resp = request.send().await?;
233            let status = resp.status();
234
235            if !status.is_success() {
236                let body_text = resp.text().await.unwrap_or_default();
237                if status.as_u16() == 403 {
238                    return Err(ClientError::forbidden(
239                        body_text,
240                        "You may not have data plane access. Check your Cosmos DB RBAC roles.",
241                    ));
242                }
243                return Err(ClientError::api(status.as_u16(), body_text));
244            }
245
246            let next_continuation = resp
247                .headers()
248                .get("x-ms-continuation")
249                .and_then(|v| v.to_str().ok())
250                .map(|s| s.to_string());
251
252            let charge: f64 = resp
253                .headers()
254                .get("x-ms-request-charge")
255                .and_then(|v| v.to_str().ok())
256                .and_then(|v| v.parse().ok())
257                .unwrap_or(0.0);
258            total_charge += charge;
259
260            let query_resp: QueryResponse = resp.json().await?;
261            documents.extend(query_resp.documents);
262
263            match next_continuation {
264                Some(token) if !token.is_empty() => {
265                    debug!("continuing with pagination token");
266                    continuation = Some(token);
267                }
268                _ => break,
269            }
270        }
271
272        Ok((documents, total_charge))
273    }
274
275    /// Execute a SQL query against a container, handling cross-partition fanout and pagination.
276    pub async fn query(
277        &self,
278        database: &str,
279        container: &str,
280        sql: &str,
281    ) -> Result<QueryResult, ClientError> {
282        self.query_with_params(database, container, sql, Vec::new())
283            .await
284    }
285
286    /// Execute a parameterized SQL query against a container.
287    ///
288    /// Parameters should be in Cosmos DB format:
289    /// `[{"name": "@param", "value": ...}, ...]`
290    pub async fn query_with_params(
291        &self,
292        database: &str,
293        container: &str,
294        sql: &str,
295        parameters: Vec<Value>,
296    ) -> Result<QueryResult, ClientError> {
297        debug!(database, container, sql, params = ?parameters, "executing query");
298
299        let url = format!(
300            "{}/dbs/{}/colls/{}/docs",
301            self.endpoint, database, container
302        );
303        let body = serde_json::json!({
304            "query": sql,
305            "parameters": parameters
306        });
307
308        // Get partition key ranges and fan out the query
309        let ranges = self.get_partition_key_ranges(database, container).await?;
310        debug!(count = ranges.len(), "querying across partition key ranges");
311
312        let mut all_documents = Vec::new();
313        let mut total_charge = 0.0_f64;
314
315        for range_id in &ranges {
316            let (docs, charge) = self.query_partition(&url, &body, range_id).await?;
317            debug!(
318                range_id,
319                docs = docs.len(),
320                charge,
321                "partition query complete"
322            );
323            all_documents.extend(docs);
324            total_charge += charge;
325        }
326
327        debug!(
328            count = all_documents.len(),
329            request_charge = total_charge,
330            "query complete"
331        );
332
333        Ok(QueryResult {
334            documents: all_documents,
335            request_charge: total_charge,
336        })
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_auth_header_format() {
346        let client = CosmosClient {
347            http: reqwest::Client::new(),
348            endpoint: "https://test.documents.azure.com".into(),
349            token: "eyJ0eXAi.test.token".into(),
350        };
351        let header = client.auth_header();
352        assert!(header.starts_with("type%3Daad%26ver%3D1.0%26sig%3D"));
353        assert!(header.contains("eyJ0eXAi"));
354    }
355
356    #[test]
357    fn test_date_header_format() {
358        let date = CosmosClient::date_header();
359        // Should match RFC 1123 format: "Wed, 09 Nov 2023 12:34:56 GMT"
360        assert!(date.ends_with("GMT"));
361        assert!(date.len() > 20);
362    }
363
364    #[test]
365    fn test_query_response_deserialization() {
366        let json = r#"{"Documents": [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}], "_count": 2}"#;
367        let resp: QueryResponse = serde_json::from_str(json).unwrap();
368        assert_eq!(resp.documents.len(), 2);
369        assert_eq!(resp.documents[0]["id"], "1");
370        assert_eq!(resp.documents[1]["name"], "Bob");
371    }
372
373    #[test]
374    fn test_query_response_empty() {
375        let json = r#"{"Documents": [], "_count": 0}"#;
376        let resp: QueryResponse = serde_json::from_str(json).unwrap();
377        assert!(resp.documents.is_empty());
378    }
379
380    #[test]
381    fn test_database_list_deserialization() {
382        let json = r#"{"Databases": [{"id": "db1", "_rid": "r1"}, {"id": "db2", "_rid": "r2"}]}"#;
383        let resp: DatabaseListResponse = serde_json::from_str(json).unwrap();
384        assert_eq!(resp.databases.len(), 2);
385        assert_eq!(resp.databases[0].id, "db1");
386        assert_eq!(resp.databases[1].id, "db2");
387    }
388
389    #[test]
390    fn test_collection_list_deserialization() {
391        let json = r#"{"DocumentCollections": [{"id": "coll1", "_rid": "r1"}, {"id": "coll2", "_rid": "r2"}]}"#;
392        let resp: CollectionListResponse = serde_json::from_str(json).unwrap();
393        assert_eq!(resp.document_collections.len(), 2);
394        assert_eq!(resp.document_collections[0].id, "coll1");
395        assert_eq!(resp.document_collections[1].id, "coll2");
396    }
397
398    #[test]
399    fn test_partition_key_ranges_deserialization() {
400        let json =
401            r#"{"PartitionKeyRanges": [{"id": "0", "minInclusive": "", "maxExclusive": "FF"}]}"#;
402        let resp: PartitionKeyRangesResponse = serde_json::from_str(json).unwrap();
403        assert_eq!(resp.partition_key_ranges.len(), 1);
404        assert_eq!(resp.partition_key_ranges[0].id, "0");
405    }
406}