1use serde::Deserialize;
8use serde_json::Value;
9use tracing::debug;
10
11use crate::auth::{AzCliAuth, COSMOS_RESOURCE};
12use crate::error::ClientError;
13
14const API_VERSION: &str = "2018-12-31";
15
16#[derive(Debug)]
18pub struct QueryResult {
19 pub documents: Vec<Value>,
20 pub request_charge: f64,
21}
22
23#[derive(Debug, Deserialize)]
25struct QueryResponse {
26 #[serde(rename = "Documents")]
27 documents: Vec<Value>,
28}
29
30#[derive(Debug, Deserialize)]
32struct DatabaseListResponse {
33 #[serde(rename = "Databases")]
34 databases: Vec<DatabaseEntry>,
35}
36
37#[derive(Debug, Deserialize)]
38struct DatabaseEntry {
39 id: String,
40}
41
42#[derive(Debug, Deserialize)]
44struct CollectionListResponse {
45 #[serde(rename = "DocumentCollections")]
46 document_collections: Vec<CollectionEntry>,
47}
48
49#[derive(Debug, Deserialize)]
50struct CollectionEntry {
51 id: String,
52}
53
54#[derive(Debug, Deserialize)]
56struct PartitionKeyRangesResponse {
57 #[serde(rename = "PartitionKeyRanges")]
58 partition_key_ranges: Vec<PartitionKeyRange>,
59}
60
61#[derive(Debug, Deserialize)]
62struct PartitionKeyRange {
63 id: String,
64}
65
66#[derive(Clone)]
68pub struct CosmosClient {
69 http: reqwest::Client,
70 endpoint: String,
71 token: String,
72}
73
74impl CosmosClient {
75 pub async fn new(endpoint: &str) -> Result<Self, ClientError> {
77 let token = AzCliAuth::get_token(COSMOS_RESOURCE).await?;
78 let endpoint = endpoint.trim_end_matches('/').to_string();
79 Ok(Self {
80 http: reqwest::Client::new(),
81 endpoint,
82 token,
83 })
84 }
85
86 fn auth_header(&self) -> String {
88 let sig = urlencoding::encode(&self.token);
89 format!("type%3Daad%26ver%3D1.0%26sig%3D{sig}")
90 }
91
92 fn date_header() -> String {
94 chrono::Utc::now()
95 .format("%a, %d %b %Y %H:%M:%S GMT")
96 .to_string()
97 }
98
99 pub async fn list_databases(&self) -> Result<Vec<String>, ClientError> {
101 debug!("listing databases");
102 let url = format!("{}/dbs", self.endpoint);
103 let date = Self::date_header();
104
105 let resp = self
106 .http
107 .get(&url)
108 .header("Authorization", self.auth_header())
109 .header("x-ms-date", &date)
110 .header("x-ms-version", API_VERSION)
111 .send()
112 .await?;
113
114 let status = resp.status();
115 if !status.is_success() {
116 let body = resp.text().await.unwrap_or_default();
117 if status.as_u16() == 403 {
118 return Err(ClientError::forbidden(
119 body,
120 "You may not have data plane access. Check your Cosmos DB RBAC roles.",
121 ));
122 }
123 return Err(ClientError::api(status.as_u16(), body));
124 }
125
126 let list: DatabaseListResponse = resp.json().await?;
127 let names: Vec<String> = list.databases.into_iter().map(|d| d.id).collect();
128 debug!(count = names.len(), "found databases");
129 Ok(names)
130 }
131
132 pub async fn list_containers(&self, database: &str) -> Result<Vec<String>, ClientError> {
134 debug!(database, "listing containers");
135 let url = format!("{}/dbs/{}/colls", self.endpoint, database);
136 let date = Self::date_header();
137
138 let resp = self
139 .http
140 .get(&url)
141 .header("Authorization", self.auth_header())
142 .header("x-ms-date", &date)
143 .header("x-ms-version", API_VERSION)
144 .send()
145 .await?;
146
147 let status = resp.status();
148 if !status.is_success() {
149 let body = resp.text().await.unwrap_or_default();
150 return Err(ClientError::api(status.as_u16(), body));
151 }
152
153 let list: CollectionListResponse = resp.json().await?;
154 let names: Vec<String> = list
155 .document_collections
156 .into_iter()
157 .map(|c| c.id)
158 .collect();
159 debug!(count = names.len(), "found containers");
160 Ok(names)
161 }
162
163 async fn get_partition_key_ranges(
165 &self,
166 database: &str,
167 container: &str,
168 ) -> Result<Vec<String>, ClientError> {
169 let url = format!(
170 "{}/dbs/{}/colls/{}/pkranges",
171 self.endpoint, database, container
172 );
173 let date = Self::date_header();
174
175 let resp = self
176 .http
177 .get(&url)
178 .header("Authorization", self.auth_header())
179 .header("x-ms-date", &date)
180 .header("x-ms-version", API_VERSION)
181 .send()
182 .await?;
183
184 let status = resp.status();
185 if !status.is_success() {
186 let body = resp.text().await.unwrap_or_default();
187 return Err(ClientError::api(status.as_u16(), body));
188 }
189
190 let ranges: PartitionKeyRangesResponse = resp.json().await?;
191 let ids: Vec<String> = ranges
192 .partition_key_ranges
193 .into_iter()
194 .map(|r| r.id)
195 .collect();
196 debug!(count = ids.len(), "found partition key ranges");
197 Ok(ids)
198 }
199
200 async fn query_partition(
202 &self,
203 url: &str,
204 body: &Value,
205 partition_key_range_id: &str,
206 ) -> Result<(Vec<Value>, f64), ClientError> {
207 let mut documents = Vec::new();
208 let mut total_charge = 0.0_f64;
209 let mut continuation: Option<String> = None;
210
211 loop {
212 let date = Self::date_header();
213 let mut request = self
214 .http
215 .post(url)
216 .header("Authorization", self.auth_header())
217 .header("x-ms-date", &date)
218 .header("x-ms-version", API_VERSION)
219 .header("x-ms-documentdb-isquery", "True")
220 .header("x-ms-documentdb-query-enablecrosspartition", "True")
221 .header(
222 "x-ms-documentdb-partitionkeyrangeid",
223 partition_key_range_id,
224 )
225 .header("Content-Type", "application/query+json")
226 .json(body);
227
228 if let Some(ref token) = continuation {
229 request = request.header("x-ms-continuation", token);
230 }
231
232 let resp = request.send().await?;
233 let status = resp.status();
234
235 if !status.is_success() {
236 let body_text = resp.text().await.unwrap_or_default();
237 if status.as_u16() == 403 {
238 return Err(ClientError::forbidden(
239 body_text,
240 "You may not have data plane access. Check your Cosmos DB RBAC roles.",
241 ));
242 }
243 return Err(ClientError::api(status.as_u16(), body_text));
244 }
245
246 let next_continuation = resp
247 .headers()
248 .get("x-ms-continuation")
249 .and_then(|v| v.to_str().ok())
250 .map(|s| s.to_string());
251
252 let charge: f64 = resp
253 .headers()
254 .get("x-ms-request-charge")
255 .and_then(|v| v.to_str().ok())
256 .and_then(|v| v.parse().ok())
257 .unwrap_or(0.0);
258 total_charge += charge;
259
260 let query_resp: QueryResponse = resp.json().await?;
261 documents.extend(query_resp.documents);
262
263 match next_continuation {
264 Some(token) if !token.is_empty() => {
265 debug!("continuing with pagination token");
266 continuation = Some(token);
267 }
268 _ => break,
269 }
270 }
271
272 Ok((documents, total_charge))
273 }
274
275 pub async fn query(
277 &self,
278 database: &str,
279 container: &str,
280 sql: &str,
281 ) -> Result<QueryResult, ClientError> {
282 self.query_with_params(database, container, sql, Vec::new())
283 .await
284 }
285
286 pub async fn query_with_params(
291 &self,
292 database: &str,
293 container: &str,
294 sql: &str,
295 parameters: Vec<Value>,
296 ) -> Result<QueryResult, ClientError> {
297 debug!(database, container, sql, params = ?parameters, "executing query");
298
299 let url = format!(
300 "{}/dbs/{}/colls/{}/docs",
301 self.endpoint, database, container
302 );
303 let body = serde_json::json!({
304 "query": sql,
305 "parameters": parameters
306 });
307
308 let ranges = self.get_partition_key_ranges(database, container).await?;
310 debug!(count = ranges.len(), "querying across partition key ranges");
311
312 let mut all_documents = Vec::new();
313 let mut total_charge = 0.0_f64;
314
315 for range_id in &ranges {
316 let (docs, charge) = self.query_partition(&url, &body, range_id).await?;
317 debug!(
318 range_id,
319 docs = docs.len(),
320 charge,
321 "partition query complete"
322 );
323 all_documents.extend(docs);
324 total_charge += charge;
325 }
326
327 debug!(
328 count = all_documents.len(),
329 request_charge = total_charge,
330 "query complete"
331 );
332
333 Ok(QueryResult {
334 documents: all_documents,
335 request_charge: total_charge,
336 })
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_auth_header_format() {
346 let client = CosmosClient {
347 http: reqwest::Client::new(),
348 endpoint: "https://test.documents.azure.com".into(),
349 token: "eyJ0eXAi.test.token".into(),
350 };
351 let header = client.auth_header();
352 assert!(header.starts_with("type%3Daad%26ver%3D1.0%26sig%3D"));
353 assert!(header.contains("eyJ0eXAi"));
354 }
355
356 #[test]
357 fn test_date_header_format() {
358 let date = CosmosClient::date_header();
359 assert!(date.ends_with("GMT"));
361 assert!(date.len() > 20);
362 }
363
364 #[test]
365 fn test_query_response_deserialization() {
366 let json = r#"{"Documents": [{"id": "1", "name": "Alice"}, {"id": "2", "name": "Bob"}], "_count": 2}"#;
367 let resp: QueryResponse = serde_json::from_str(json).unwrap();
368 assert_eq!(resp.documents.len(), 2);
369 assert_eq!(resp.documents[0]["id"], "1");
370 assert_eq!(resp.documents[1]["name"], "Bob");
371 }
372
373 #[test]
374 fn test_query_response_empty() {
375 let json = r#"{"Documents": [], "_count": 0}"#;
376 let resp: QueryResponse = serde_json::from_str(json).unwrap();
377 assert!(resp.documents.is_empty());
378 }
379
380 #[test]
381 fn test_database_list_deserialization() {
382 let json = r#"{"Databases": [{"id": "db1", "_rid": "r1"}, {"id": "db2", "_rid": "r2"}]}"#;
383 let resp: DatabaseListResponse = serde_json::from_str(json).unwrap();
384 assert_eq!(resp.databases.len(), 2);
385 assert_eq!(resp.databases[0].id, "db1");
386 assert_eq!(resp.databases[1].id, "db2");
387 }
388
389 #[test]
390 fn test_collection_list_deserialization() {
391 let json = r#"{"DocumentCollections": [{"id": "coll1", "_rid": "r1"}, {"id": "coll2", "_rid": "r2"}]}"#;
392 let resp: CollectionListResponse = serde_json::from_str(json).unwrap();
393 assert_eq!(resp.document_collections.len(), 2);
394 assert_eq!(resp.document_collections[0].id, "coll1");
395 assert_eq!(resp.document_collections[1].id, "coll2");
396 }
397
398 #[test]
399 fn test_partition_key_ranges_deserialization() {
400 let json =
401 r#"{"PartitionKeyRanges": [{"id": "0", "minInclusive": "", "maxExclusive": "FF"}]}"#;
402 let resp: PartitionKeyRangesResponse = serde_json::from_str(json).unwrap();
403 assert_eq!(resp.partition_key_ranges.len(), 1);
404 assert_eq!(resp.partition_key_ranges[0].id, "0");
405 }
406}