cosmos_utils/
lib.rs

1use azure_core::{error::Error as AzureError, prelude::IfMatchCondition, HttpClient, StatusCode};
2use azure_cosmos::prelude::*;
3use azure_storage::prelude::*;
4use azure_storage_blobs::prelude::*;
5use futures::StreamExt;
6use lazy_static::lazy_static;
7use once_cell::sync::OnceCell;
8use serde::{de::DeserializeOwned, Serialize};
9use std::sync::Arc;
10use tokio::time::{sleep, Duration};
11use uuid::Uuid;
12use warp::{filters::multipart::FormData, Buf as OtherBuf};
13
14/// Re-export `CosmosEntity` which must be implemented by every model.
15pub use azure_cosmos::CosmosEntity;
16
17mod cosmos_saga;
18pub use cosmos_saga::{
19    CosmosSaga, ErasedCosmosEntity, ErasedCosmosEntityClone, UpcastErasedSerdeSerialize,
20};
21
22/// This struct keeps the static state of cosmos DB, it must be initialized before any cosmos
23/// functions are called.
24#[derive(Debug)]
25pub struct CosmosState {
26    pub cosmos_account: String,
27    pub cosmos_database: String,
28    pub cosmos_master_key: String,
29    pub storage_account: Option<String>,
30    pub storage_master_key: Option<String>,
31    pub image_storage_container: Option<String>,
32}
33
34static COSMOS_STATE: OnceCell<CosmosState> = OnceCell::new();
35
36lazy_static! {
37    static ref HTTP_CLIENT: Arc<Box<dyn HttpClient>> = Arc::new(Box::new(reqwest::Client::new()));
38}
39
40pub fn set_state(state: CosmosState) {
41    COSMOS_STATE.set(state).unwrap();
42}
43
44type CosmosError = CosmosErrorStruct;
45
46#[derive(Debug)]
47pub struct CosmosErrorStruct {
48    pub err: String,
49    pub kind: CosmosErrorKind,
50}
51
52impl std::fmt::Display for CosmosErrorStruct {
53    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
54        fmt.write_fmt(format_args!("kind: {}, err: {}", self.kind, self.err))?;
55        Ok(())
56    }
57}
58
59#[derive(Debug)]
60pub enum CosmosErrorKind {
61    PreconditionFailed,
62    NotFound,
63    BadRequest,
64    InternalError,
65    Conflict,
66    BlobError,
67    ModificationError(warp::Rejection),
68}
69
70impl std::fmt::Display for CosmosErrorKind {
71    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
72        match self {
73            CosmosErrorKind::PreconditionFailed => {
74                fmt.write_str("PreconditionFailed")?;
75            }
76            CosmosErrorKind::NotFound => {
77                fmt.write_str("NotFound")?;
78            }
79            CosmosErrorKind::BadRequest => {
80                fmt.write_str("BadRequest")?;
81            }
82            CosmosErrorKind::InternalError => {
83                fmt.write_str("InternalError")?;
84            }
85            CosmosErrorKind::Conflict => {
86                fmt.write_str("Conflict")?;
87            }
88            CosmosErrorKind::BlobError => {
89                fmt.write_str("BlobError")?;
90            }
91            CosmosErrorKind::ModificationError(rej) => {
92                fmt.write_fmt(format_args!("ModificationError({:?})", rej))?;
93            }
94        }
95        Ok(())
96    }
97}
98
99/// Marker trait for all types that implement `CosmosEntity` and `Serialize` and `Clone`.
100/// Necessary in order to fullfille the requirements for Cosmos DB objects.
101/// This trait should be implemented automatically for any type that needs it.
102pub trait CosmosObject: CosmosEntity + Serialize + Clone + Send + Sync + 'static {}
103impl<T> CosmosObject for T where T: CosmosEntity + Serialize + Clone + Send + Sync + 'static {}
104
105impl warp::reject::Reject for CosmosErrorStruct {}
106
107/// Utility function that returns a closure that converts whatever error into a Reject error.
108/// Usage is: function_call_which_returns_non_reject_result().map_err(into_cosmos_error("custom error message"))?;
109fn into_cosmos_error<M: ToString>(msg: M) -> impl FnOnce(AzureError) -> CosmosError {
110    move |e: AzureError| new_cosmos_error(msg.to_string(), e)
111}
112
113/// Creates a new cosmos error from a given error
114fn new_cosmos_error<M: ToString>(msg: M, err: AzureError) -> CosmosError {
115    let kind = match err.as_http_error() {
116        Some(err) => match err.status() {
117            StatusCode::PreconditionFailed => CosmosErrorKind::PreconditionFailed,
118            StatusCode::NotFound => CosmosErrorKind::NotFound,
119            StatusCode::BadRequest => CosmosErrorKind::BadRequest,
120            StatusCode::Conflict => CosmosErrorKind::Conflict,
121            _ => CosmosErrorKind::InternalError,
122        },
123        None => CosmosErrorKind::InternalError,
124    };
125    // We get more informative error messages if we use the debug printing rather than display
126    // printing.
127    let msg = format!("{} : {:?}", msg.to_string(), err);
128    CosmosErrorStruct { kind, err: msg }
129}
130
131fn new_cosmos_error_internal<M: ToString>(msg: M) -> CosmosError {
132    CosmosErrorStruct {
133        kind: CosmosErrorKind::InternalError,
134        err: msg.to_string(),
135    }
136}
137
138/// Creates a new cosmos error from a given error and the error kind
139fn new_cosmos_error_kind<E: ToString>(err: E, kind: CosmosErrorKind) -> CosmosError {
140    let err = err.to_string();
141    CosmosErrorStruct { kind, err }
142}
143
144/// Intermediate type for making a seemless conversation between a [&str; 1] and a `PartitionKey`
145/// This is needed because the previous signature of a lot of functions required a slice of
146/// PartitionKeys however the interface for CosmosDB was changed to require a single PartitionKey
147/// breaking all of the previous signatures.
148/// Without this intermediate type the codebase would require a lot of rewriting the signatures.
149pub struct PartitionKeyIntermediate(String);
150
151impl From<[&str; 1]> for PartitionKeyIntermediate {
152    fn from(from: [&str; 1]) -> Self {
153        Self(from.first().unwrap().to_string())
154    }
155}
156
157impl From<[&String; 1]> for PartitionKeyIntermediate {
158    fn from(from: [&String; 1]) -> Self {
159        Self(from.first().unwrap().to_string())
160    }
161}
162
163impl From<[&&String; 1]> for PartitionKeyIntermediate {
164    fn from(from: [&&String; 1]) -> Self {
165        Self(from.first().unwrap().to_string())
166    }
167}
168
169impl From<[&(); 1]> for PartitionKeyIntermediate {
170    fn from(_: [&(); 1]) -> Self {
171        Self(String::default())
172    }
173}
174
175impl From<&[(); 1]> for PartitionKeyIntermediate {
176    fn from(_: &[(); 1]) -> Self {
177        Self(String::default())
178    }
179}
180
181impl From<[(); 1]> for PartitionKeyIntermediate {
182    fn from(_: [(); 1]) -> Self {
183        Self(String::default())
184    }
185}
186
187impl From<String> for PartitionKeyIntermediate {
188    fn from(from: String) -> Self {
189        Self(from)
190    }
191}
192
193impl From<PartitionKeyIntermediate> for String {
194    fn from(from: PartitionKeyIntermediate) -> String {
195        from.0
196    }
197}
198
199async fn insert_internal<D: CosmosObject, P: Into<PartitionKeyIntermediate>, C: ToString>(
200    collection_name: C,
201    pk: P,
202    document: &D,
203    etag: Option<&String>,
204    upsert: bool,
205) -> Result<String, CosmosError> {
206    let state = COSMOS_STATE
207        .get()
208        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
209    let collection_name = collection_name.to_string();
210    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
211        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
212
213    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
214    let database_client = client.database_client(&state.cosmos_database);
215    let collection_client = database_client.collection_client(collection_name);
216
217    let pk: String = pk.into().into();
218
219    let c = collection_client
220        .create_document(document.clone())
221        .is_upsert(upsert)
222        .partition_key(&pk)
223        .map_err(|e| new_cosmos_error_internal(format!("Could not create document: {}", e)))?;
224    let c = match etag {
225        Some(etag) => c.if_match_condition(IfMatchCondition::Match(etag.to_string())),
226        None => c,
227    };
228
229    let resp = retry_loop(MAX_RETRY_LOOPS, || async {
230        match c.clone().into_future().await {
231            Ok(t) => Ok(t),
232            Err(err) => Err(RetryLoopError::Permanent(new_cosmos_error(
233                "Cosmos db error",
234                err,
235            ))),
236        }
237    })
238    .await?;
239    let etag = resp.etag;
240    Ok(etag)
241}
242
243/// Insert a document into the cosmos database and returning an etag from the response if
244/// successful
245pub async fn insert<D: CosmosObject, P: Into<PartitionKeyIntermediate>, C: ToString>(
246    collection_name: C,
247    pk: P,
248    document: &D,
249    etag: Option<&String>,
250) -> Result<String, CosmosError> {
251    insert_internal(collection_name, pk, document, etag, false).await
252}
253
254/// Upsert a document into the cosmos database and returning an etag from the response if
255/// successful
256pub async fn upsert<
257    D: CosmosEntity + Serialize + Clone + Send + Sync + 'static,
258    P: Into<PartitionKeyIntermediate>,
259    C: ToString,
260>(
261    collection_name: C,
262    pk: P,
263    document: &D,
264    etag: Option<&String>,
265) -> Result<String, CosmosError> {
266    insert_internal(collection_name, pk, document, etag, true).await
267}
268
269/// Returns a specific document from the cosmos DB together with a corresponding etag
270pub async fn get<
271    D: DeserializeOwned + Send,
272    P: Into<PartitionKeyIntermediate>,
273    C: ToString,
274    S: ToString,
275>(
276    collection_name: C,
277    pk: P,
278    document_id: S,
279) -> Result<(D, String), CosmosError> {
280    let state = COSMOS_STATE
281        .get()
282        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
283    let collection_name = collection_name.to_string();
284    let document_id = document_id.to_string();
285    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
286        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
287
288    let pk: String = pk.into().into();
289    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
290    let database_client = client.database_client(&state.cosmos_database);
291    let collection_client = database_client.collection_client(collection_name);
292    let document_client = collection_client
293        .clone()
294        .document_client(document_id, &pk)
295        .map_err(|_| new_cosmos_error_internal("Could not create document client"))?;
296    let resp = match document_client
297        .get_document()
298        .into_future()
299        .await
300        .map_err(into_cosmos_error("Could not get document"))?
301    {
302        GetDocumentResponse::Found(resp) => resp,
303        GetDocumentResponse::NotFound(resp) => {
304            return Err(new_cosmos_error_kind(
305                format!("Document not found: {:?}", resp),
306                CosmosErrorKind::NotFound,
307            ));
308        }
309    };
310    let doc: D = resp.document.document;
311    let etag = resp.etag;
312    Ok((doc, etag))
313}
314
315/// Modifies a document in cosmos by applying `transform` async closure on the existing document and then
316/// inserting the returned document and returning both the transformed document, the old
317/// document and the etag if successful.
318/// If the transform closure fails then no insertion is performed and the error it fails with
319/// is returned
320pub async fn modify_async_get_old<
321    D: CosmosObject + DeserializeOwned + Clone,
322    P: Into<PartitionKeyIntermediate>,
323    F: Fn(D) -> Fut,
324    C: ToString,
325    S: ToString,
326    Fut: futures::Future<Output = Result<D, warp::Rejection>>,
327>(
328    collection_name: C,
329    pk: P,
330    document_id: S,
331    transform: F,
332) -> Result<(D, D, String), CosmosError> {
333    let state = COSMOS_STATE
334        .get()
335        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
336    let collection_name = collection_name.to_string();
337    let document_id = document_id.to_string();
338    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
339        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
340
341    let pk: String = pk.into().into();
342    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
343    let database_client = client.database_client(&state.cosmos_database);
344    let collection_client = database_client.collection_client(collection_name);
345    let document_client = collection_client
346        .clone()
347        .document_client(document_id, &pk)
348        .map_err(into_cosmos_error("Could not get document client"))?;
349    let (doc, old_doc, etag) = retry_loop(MAX_RETRY_LOOPS, || async {
350        let resp = match document_client
351            .clone()
352            .get_document()
353            .into_future()
354            .await
355            .map_err(|e| RetryLoopError::Permanent(new_cosmos_error("Could not get document", e)))?
356        {
357            GetDocumentResponse::Found(resp) => resp,
358            GetDocumentResponse::NotFound(resp) => {
359                return Err(RetryLoopError::Permanent(new_cosmos_error_kind(
360                    format!("Document not found: {:?}", resp),
361                    CosmosErrorKind::NotFound,
362                )));
363            }
364        };
365        let doc: D = resp.document.document;
366        let old_doc = doc.clone();
367
368        // Perform changes to the document
369        let doc = transform(doc).await.map_err(|e| {
370            RetryLoopError::Permanent(new_cosmos_error_kind(
371                format!("Modification not possible: {:?}", e),
372                CosmosErrorKind::InternalError,
373            ))
374        })?;
375        let c = collection_client
376            .create_document(doc.clone())
377            .is_upsert(true)
378            .if_match_condition(IfMatchCondition::Match(resp.etag))
379            .partition_key(&pk)
380            .map_err(|e| {
381                RetryLoopError::Permanent(new_cosmos_error(
382                    format!("Could not create document after modification: {}", e),
383                    e,
384                ))
385            })?;
386
387        match c.into_future().await {
388            Ok(resp) => Result::Ok::<_, RetryLoopError<CosmosError>>((doc, old_doc, resp.etag)),
389            Err(err) => {
390                let err = new_cosmos_error("Cosmos db error", err);
391                //NOTE: 412 - PreconditionFailed means the document has been edited between read and write so it
392                //means we need to retry the entire read/write block
393                match err.kind {
394                    CosmosErrorKind::PreconditionFailed => Err(RetryLoopError::Transient(err)),
395                    _ => Err(RetryLoopError::Permanent(err)),
396                }
397            }
398        }
399    })
400    .await?;
401    Ok((doc, old_doc, etag))
402}
403
404/// Modifies a document in cosmos by applying `transform` async closure on the existing document and then
405/// inserting the returned document and returning the transformed document if successful
406/// If the transform closure fails then no insertion is performed and the error it fails with
407/// is returned
408pub async fn modify_async<
409    D: CosmosObject + DeserializeOwned,
410    P: Into<PartitionKeyIntermediate>,
411    F: Fn(D) -> Fut,
412    C: ToString,
413    S: ToString,
414    Fut: futures::Future<Output = Result<D, warp::Rejection>>,
415>(
416    collection_name: C,
417    pk: P,
418    document_id: S,
419    transform: F,
420) -> Result<D, CosmosError> {
421    let state = COSMOS_STATE
422        .get()
423        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
424    let collection_name = collection_name.to_string();
425    let document_id = document_id.to_string();
426    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
427        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
428
429    let pk: String = pk.into().into();
430    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
431    let database_client = client.database_client(&state.cosmos_database);
432    let collection_client = database_client.collection_client(collection_name);
433    let document_client = collection_client
434        .clone()
435        .document_client(document_id, &pk)
436        .map_err(|e| {
437            new_cosmos_error_internal(format!("Could not create document client: {}", e))
438        })?;
439    let doc = retry_loop(MAX_RETRY_LOOPS, || async {
440        let resp = match document_client
441            .get_document()
442            .into_future()
443            .await
444            .map_err(|e| RetryLoopError::Permanent(new_cosmos_error("Could not get document", e)))?
445        {
446            GetDocumentResponse::Found(resp) => resp,
447            GetDocumentResponse::NotFound(resp) => {
448                return Err(RetryLoopError::Permanent(new_cosmos_error_kind(
449                    format!("Document not found: {:?}", resp),
450                    CosmosErrorKind::NotFound,
451                )));
452            }
453        };
454        let doc: D = resp.document.document;
455
456        // Perform changes to the document
457        let doc = transform(doc).await.map_err(|e| {
458            RetryLoopError::Permanent(new_cosmos_error_kind(
459                format!("Modification not possible: {:?}", e),
460                CosmosErrorKind::InternalError,
461            ))
462        })?;
463        let c = collection_client
464            .create_document(doc.clone())
465            .is_upsert(true)
466            .if_match_condition(IfMatchCondition::Match(resp.etag))
467            .partition_key(&pk)
468            .map_err(|e| {
469                RetryLoopError::Permanent(new_cosmos_error(
470                    format!("Could not create document after modification: {}", e),
471                    e,
472                ))
473            })?;
474
475        match c.into_future().await {
476            Ok(_) => Result::Ok::<_, RetryLoopError<CosmosError>>(doc),
477            Err(err) => {
478                let err = new_cosmos_error("Cosmos db error", err);
479                //NOTE: 412 means the document has been edited between read and write so it
480                //means we need to retry the entire read/write block
481                match err.kind {
482                    CosmosErrorKind::PreconditionFailed => Err(RetryLoopError::Transient(err)),
483                    _ => Err(RetryLoopError::Permanent(err)),
484                }
485            }
486        }
487    })
488    .await?;
489    Ok(doc)
490}
491
492#[derive(Debug, Clone)]
493pub enum ModifyReturn<D> {
494    Replace(D),
495    DontReplace(D),
496}
497
498impl<D> ModifyReturn<D> {
499    /// No risk of panicing, gets the inner value
500    pub fn unwrap(self) -> D {
501        match self {
502            ModifyReturn::Replace(d) => d,
503            ModifyReturn::DontReplace(d) => d,
504        }
505    }
506}
507
508/// Modifies a document in cosmos by applying `transform` closure on the existing document and
509/// either inserting the returned value or just returning it to the caller
510pub async fn maybe_modify<
511    D: CosmosObject + DeserializeOwned + std::fmt::Debug,
512    P: Into<PartitionKeyIntermediate>,
513    F: Fn(D) -> Result<ModifyReturn<D>, warp::Rejection>,
514    C: ToString,
515    S: ToString,
516>(
517    collection_name: C,
518    pk: P,
519    document_id: S,
520    transform: F,
521) -> Result<ModifyReturn<D>, CosmosError> {
522    let state = COSMOS_STATE
523        .get()
524        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
525    let collection_name = collection_name.to_string();
526    let document_id = document_id.to_string();
527    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
528        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
529
530    let pk: String = pk.into().into();
531    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
532
533    let database_client = client.database_client(&state.cosmos_database);
534    let collection_client = database_client.collection_client(collection_name);
535    let document_client = collection_client
536        .clone()
537        .document_client(document_id, &pk)
538        .map_err(|e| {
539            new_cosmos_error_internal(format!("Could not create document client: {}", e))
540        })?;
541
542    let doc = retry_loop(MAX_RETRY_LOOPS, || async {
543        let resp = match document_client
544            .get_document()
545            .into_future()
546            .await
547            .map_err(|e| RetryLoopError::Permanent(new_cosmos_error("Could not get document", e)))?
548        {
549            GetDocumentResponse::Found(resp) => resp,
550            GetDocumentResponse::NotFound(resp) => {
551                return Err(RetryLoopError::Permanent(new_cosmos_error_kind(
552                    format!("Document not found: {:?}", resp),
553                    CosmosErrorKind::NotFound,
554                )));
555            }
556        };
557        let doc: D = resp.document.document;
558
559        // Perform changes to the document
560        let doc = transform(doc).map_err(|e| {
561            RetryLoopError::Permanent(new_cosmos_error_kind(
562                format!("Modification error: {:?}", e),
563                CosmosErrorKind::InternalError,
564            ))
565        })?;
566
567        let doc = match doc {
568            ModifyReturn::Replace(doc) => doc,
569            ModifyReturn::DontReplace(doc) => return Ok(ModifyReturn::DontReplace(doc)),
570        };
571
572        let c = collection_client
573            .create_document(doc.clone())
574            .is_upsert(true)
575            .if_match_condition(IfMatchCondition::Match(resp.etag))
576            .partition_key(&pk)
577            .map_err(|e| {
578                RetryLoopError::Permanent(new_cosmos_error(
579                    format!("Could not create document after modification: {}", e),
580                    e,
581                ))
582            })?;
583
584        match c.into_future().await {
585            Ok(_) => Result::Ok::<_, RetryLoopError<CosmosError>>(ModifyReturn::Replace(doc)),
586            Err(err) => {
587                let err = new_cosmos_error("Cosmos db error", err);
588                //NOTE: 412 means the document has been edited between read and write so it
589                //means we need to retry the entire read/write block
590                match err.kind {
591                    CosmosErrorKind::PreconditionFailed => Err(RetryLoopError::Transient(err)),
592                    _ => Err(RetryLoopError::Permanent(err)),
593                }
594            }
595        }
596    })
597    .await?;
598    Ok(doc)
599}
600
601/// Modify in cosmos but does not retry on failure which is useful when one needs to pass in a
602/// `FnOnce` closure
603pub async fn modify_no_retry<
604    D: CosmosObject + DeserializeOwned + std::fmt::Debug,
605    P: Into<PartitionKeyIntermediate>,
606    F: FnOnce(D) -> Result<D, warp::Rejection>,
607    C: ToString,
608    S: ToString,
609>(
610    collection_name: C,
611    pk: P,
612    document_id: S,
613    transform: F,
614) -> Result<D, CosmosError> {
615    let state = COSMOS_STATE
616        .get()
617        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
618    let collection_name = collection_name.to_string();
619    let document_id = document_id.to_string();
620    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
621        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
622
623    let pk: String = pk.into().into();
624    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
625    let database_client = client.database_client(&state.cosmos_database);
626    let collection_client = database_client.collection_client(collection_name);
627    let document_client = collection_client
628        .clone()
629        .document_client(document_id, &pk)
630        .map_err(|e| {
631            new_cosmos_error_internal(format!("Could not create document client: {}", e))
632        })?;
633    let resp = match document_client
634        .get_document()
635        .into_future()
636        .await
637        .map_err(|e| new_cosmos_error("Could not get document", e))?
638    {
639        GetDocumentResponse::Found(resp) => resp,
640        GetDocumentResponse::NotFound(resp) => {
641            return Err(new_cosmos_error_kind(
642                format!("Document not found: {:?}", resp),
643                CosmosErrorKind::NotFound,
644            ));
645        }
646    };
647    let doc: D = resp.document.document;
648
649    // Perform changes to the document
650    let doc = transform(doc).map_err(|e| {
651        new_cosmos_error_kind(
652            format!("Modification error: {:?}", e),
653            CosmosErrorKind::InternalError,
654        )
655    })?;
656    let c = collection_client
657        .create_document(doc.clone())
658        .is_upsert(true)
659        .if_match_condition(IfMatchCondition::Match(resp.etag))
660        .partition_key(&pk)
661        .map_err(|e| {
662            new_cosmos_error(
663                format!("Could not create document after modification: {}", e),
664                e,
665            )
666        })?;
667
668    let doc = match c.into_future().await {
669        Ok(_) => doc,
670        Err(err) => {
671            return Err(new_cosmos_error("Cosmos db error", err));
672        }
673    };
674    Ok(doc)
675}
676
677/// Modifies a document in cosmos by applying `transform` closure on the existing document and then
678/// inserting the returned document and returning the transformed document if successful
679/// If the transform closure fails then no insertion is performed and the error it fails with
680/// is returned
681pub async fn modify<
682    D: CosmosObject + DeserializeOwned + std::fmt::Debug,
683    P: Into<PartitionKeyIntermediate>,
684    F: Fn(D) -> Result<D, warp::Rejection>,
685    C: ToString,
686    S: ToString,
687>(
688    collection_name: C,
689    pk: P,
690    document_id: S,
691    transform: F,
692) -> Result<D, CosmosError> {
693    modify_async(collection_name, pk, document_id, |d| async { transform(d) }).await
694}
695
696pub async fn delete<C: ToString, S: ToString, P: Into<PartitionKeyIntermediate>>(
697    collection_name: C,
698    pk: P,
699    document_id: S,
700    etag: Option<String>,
701) -> Result<(), CosmosError> {
702    let state = COSMOS_STATE
703        .get()
704        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
705    let collection_name = collection_name.to_string();
706    let document_id = document_id.to_string();
707    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
708        .map_err(|_| new_cosmos_error_internal("Could not get authorization token"))?;
709
710    let pk: String = pk.into().into();
711
712    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
713    let database_client = client.database_client(&state.cosmos_database);
714    let collection_client = database_client.collection_client(collection_name);
715    let document_client = collection_client
716        .clone()
717        .document_client(document_id, &pk)
718        .map_err(|e| {
719            new_cosmos_error_internal(format!("Could not create document client: {}", e))
720        })?;
721
722    let del_doc = document_client.delete_document();
723    if let Some(etag) = etag {
724        del_doc
725            .if_match_condition(IfMatchCondition::Match(etag.to_string()))
726            .into_future()
727            .await
728            .map_err(into_cosmos_error("Could not delete document"))?;
729    } else {
730        del_doc
731            .into_future()
732            .await
733            .map_err(into_cosmos_error("Could not delete document"))?;
734    }
735    Ok(())
736}
737
738pub async fn query_crosspartition_etag<
739    D: DeserializeOwned + Send + Sync,
740    P: Into<PartitionKeyIntermediate>,
741    C: ToString,
742>(
743    collection_name: C,
744    pk: P,
745    query: String,
746    max_count: i32,
747    cross_partition: bool,
748) -> Result<Vec<(D, String)>, CosmosError> {
749    let state = COSMOS_STATE
750        .get()
751        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
752    let collection_name = collection_name.to_string();
753    let authorization_token = AuthorizationToken::primary_from_base64(&state.cosmos_master_key)
754        .map_err(|_| new_cosmos_error_internal("Could not get authorization token."))?;
755    let pk: String = pk.into().into();
756
757    let client = CosmosClient::new(state.cosmos_account.to_string(), authorization_token);
758    let database_client = client.database_client(&state.cosmos_database);
759    let collection_client = database_client.collection_client(collection_name.clone());
760    let mut documents: Vec<(D, String)> = vec![];
761
762    let q = Query::new(query.clone());
763    let mut query_documents_builder = collection_client
764        .query_documents(q)
765        .max_item_count(max_count);
766
767    if cross_partition {
768        query_documents_builder = query_documents_builder.query_cross_partition(true);
769    } else {
770        query_documents_builder = query_documents_builder.partition_key(&pk).map_err(|e| {
771            new_cosmos_error_internal(format!(
772                "Could not set partition key for query document builder: {}",
773                e
774            ))
775        })?;
776    }
777
778    let mut query_documents_stream = query_documents_builder.into_stream();
779    while let Some(query_documents_response) = query_documents_stream.next().await {
780        let query_documents_response =
781            query_documents_response.map_err(into_cosmos_error(format!(
782                "Could not get query documents response: {} {} {:?}.",
783                collection_name, query, pk
784            )))?;
785        let mut fetched_documents: Vec<(D, String)> = query_documents_response
786            .results
787            .into_iter()
788            .map(|(document, attributes): (D, _)| {
789                Ok((
790                    document,
791                    attributes
792                        .ok_or_else(|| new_cosmos_error_internal("No etag returned by query"))?
793                        .etag()
794                        .to_string(),
795                ))
796            })
797            .collect::<Result<Vec<_>, _>>()?;
798        documents.append(&mut fetched_documents);
799    }
800
801    Ok(documents)
802}
803
804pub async fn query_crosspartition<
805    D: DeserializeOwned + Send + Sync,
806    P: Into<PartitionKeyIntermediate>,
807    C: ToString,
808>(
809    collection_name: C,
810    pk: P,
811    query: String,
812    max_count: i32,
813    cross_partition: bool,
814) -> Result<Vec<D>, CosmosError> {
815    let v =
816        query_crosspartition_etag(collection_name, pk, query, max_count, cross_partition).await?;
817    Ok(v.into_iter().map(|(d, _)| d).collect())
818}
819
820pub async fn query<
821    D: DeserializeOwned + Send + Sync,
822    P: Into<PartitionKeyIntermediate>,
823    C: ToString,
824>(
825    collection_name: C,
826    pk: P,
827    query: String,
828    max_count: i32,
829) -> Result<Vec<D>, CosmosError> {
830    query_crosspartition(collection_name, pk, query, max_count, false).await
831}
832
833/// Gets the data from a blob storage
834pub async fn get_blob<S: AsRef<str>>(
835    id: S,
836    storage_container: &str,
837) -> Result<Vec<u8>, CosmosError> {
838    let state = COSMOS_STATE
839        .get()
840        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
841    if state.storage_account.is_none() || state.storage_master_key.is_none() {
842        return Err(new_cosmos_error_internal(
843            "Cosmos storage values are not initialized",
844        ));
845    }
846
847    let blob_client = StorageClient::new_access_key(
848        state.storage_account.as_ref().unwrap(),
849        state.storage_master_key.as_ref().unwrap(),
850    )
851    .container_client(storage_container)
852    .blob_client(id.as_ref());
853
854    let mut data = Vec::new();
855    let mut blob_stream = blob_client.get().into_stream();
856
857    while let Some(blob_response) = blob_stream.next().await {
858        let blob_response =
859            blob_response.map_err(into_cosmos_error("Could not get blob response"))?;
860        match blob_response.data.collect().await {
861            Ok(r) => data.extend_from_slice(&r[..]),
862            Err(err) => {
863                return Err(new_cosmos_error_kind(
864                    format!("Could not get blob from storage account {:?}", err),
865                    CosmosErrorKind::BlobError,
866                ));
867            }
868        }
869    }
870    Ok(data)
871}
872
873/// Uploads form data to the blob storage and returns the blob id
874pub async fn upload_blob(
875    mut f: FormData,
876    blob_type: &str,
877    expected_content_type: &str,
878    storage_container: &str,
879) -> Result<String, CosmosError> {
880    let state = COSMOS_STATE
881        .get()
882        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
883    if state.storage_account.is_none() || state.storage_master_key.is_none() {
884        return Err(new_cosmos_error_internal(
885            "Cosmos storage values are not initialized",
886        ));
887    }
888    while let Some(r) = f.next().await {
889        match r {
890            Ok(part) => {
891                if part.name() == blob_type {
892                    if let Some(g) = part.content_type() {
893                        let content_type = String::from(g);
894                        if content_type.starts_with(&expected_content_type) {
895                            match part.filename() {
896                                Some(n) => {
897                                    let pos = n.find('.');
898                                    if let Some(pos) = pos {
899                                        let ext = String::from(n);
900                                        let ext = &ext[pos..];
901
902                                        //FIXME(Jonathan): Here we are using a Vec<u8> when we probably want to
903                                        //use a bytes::Bytes. Additionally we are copying the buffer
904                                        //when moving or referencing is prerferrable.
905                                        //The reason we're doing it this way is because there are
906                                        //currently some versioning issues with azure
907                                        let mut buf: Vec<u8> = Vec::with_capacity(64);
908                                        let mut s = part.stream();
909                                        while let Some(r) = s.next().await {
910                                            match r {
911                                                Ok(mut b) => {
912                                                    buf.extend(b.copy_to_bytes(b.remaining()));
913                                                }
914                                                Err(err) => {
915                                                    return Err(new_cosmos_error_kind(
916                                                        format!(
917                                                            "Error getting {} data: {:?}",
918                                                            blob_type, err
919                                                        ),
920                                                        CosmosErrorKind::BlobError,
921                                                    ));
922                                                }
923                                            }
924                                        }
925
926                                        let mut blob_id = Uuid::new_v4().to_string();
927
928                                        // Add extension to blob id.
929                                        blob_id.push_str(ext);
930
931                                        let blob_client = StorageClient::new_access_key(
932                                            state.storage_account.as_ref().unwrap(),
933                                            state.storage_master_key.as_ref().unwrap(),
934                                        )
935                                        .container_client(storage_container)
936                                        .blob_client(&blob_id);
937
938                                        // Helps preventing spurious data to be uploaded.
939                                        let digest: Hash = md5::compute(&buf[..]).into();
940                                        {
941                                            match blob_client
942                                                .put_block_blob(buf)
943                                                .content_type(content_type)
944                                                .hash(digest)
945                                                .into_future()
946                                                .await
947                                            {
948                                                Ok(r) => r,
949                                                Err(err) => {
950                                                    return Err(new_cosmos_error_kind(
951                                                                format!("Could not add blob to storage account {:?}", err),
952                                                                CosmosErrorKind::BlobError));
953                                                }
954                                            };
955                                        }
956
957                                        // Set blob id.
958                                        return Ok(blob_id);
959                                    } else {
960                                        return Err(new_cosmos_error_kind(
961                                            "Could not get filename for blob.",
962                                            CosmosErrorKind::BlobError,
963                                        ));
964                                    }
965                                }
966                                None => {
967                                    return Err(new_cosmos_error_kind(
968                                        "Could not get filename for blob.",
969                                        CosmosErrorKind::BlobError,
970                                    ));
971                                }
972                            }
973                        } else {
974                            return Err(new_cosmos_error_kind(
975                                format!(
976                                    "Blob does not have a {} content-type.",
977                                    expected_content_type
978                                ),
979                                CosmosErrorKind::BlobError,
980                            ));
981                        }
982                    }
983                }
984            }
985            Err(err) => {
986                return Err(new_cosmos_error_kind(
987                    format!("Error getting multipart data {:?}", err),
988                    CosmosErrorKind::BlobError,
989                ));
990            }
991        }
992    }
993    Err(new_cosmos_error_kind(
994        "No blob provided",
995        CosmosErrorKind::BlobError,
996    ))
997}
998
999/// Uploads a new form data image to the blob storage and returns the image_id
1000pub async fn upload_image(f: FormData) -> Result<String, CosmosError> {
1001    let state = COSMOS_STATE
1002        .get()
1003        .ok_or_else(|| new_cosmos_error_internal("Cosmos state not initialized"))?;
1004    if state.image_storage_container.is_none() {
1005        return Err(new_cosmos_error_internal(
1006            "Cosmos image storage container is not initialized",
1007        ));
1008    }
1009    upload_blob(
1010        f,
1011        "image",
1012        "image",
1013        state.image_storage_container.as_ref().unwrap(),
1014    )
1015    .await
1016}
1017
1018/// The default value for amount of retry loops we do
1019const MAX_RETRY_LOOPS: usize = 4;
1020/// `retry_loop` is utilized in order to combat transient errors. A closure which generates a future
1021/// is used to generate the same future and run this untill completion, if the future returns success
1022/// with r then Ok(r) is returned. If the future returns an error then exponential backoff is tried
1023/// until eventually a max amount of tries is reached at which point the function returns the
1024/// latest error. The function runs at least once.
1025/// See here for microsoft documentation https://docs.microsoft.com/en-us/azure/azure-sql/database/troubleshoot-common-connectivity-issues#principles-for-retry
1026const RETRY_MAX_RANDOM: u64 = 5000;
1027const RETRY_START_WAIT: u64 = 5000;
1028pub enum RetryLoopError<E> {
1029    Permanent(E),
1030    Transient(E),
1031}
1032pub async fn retry_loop<F, A, R, E>(tries: usize, mut f: F) -> Result<R, E>
1033where
1034    F: FnMut() -> A,
1035    A: std::future::Future<Output = Result<R, RetryLoopError<E>>>,
1036{
1037    let mut counter = 0;
1038    let mut wait = RETRY_START_WAIT;
1039    loop {
1040        match f().await {
1041            Ok(r) => {
1042                return Ok(r);
1043            }
1044            Err(e) => {
1045                match e {
1046                    RetryLoopError::Permanent(e) => return Err(e),
1047                    RetryLoopError::Transient(e) => {
1048                        counter += 1;
1049                        let random_wait: u64 = rand::random::<u64>() % RETRY_MAX_RANDOM;
1050                        // Wait for an exponential backoff time + some random wait
1051                        sleep(Duration::from_millis(wait) + Duration::from_millis(random_wait))
1052                            .await;
1053                        if counter >= tries {
1054                            return Err(e);
1055                        }
1056                        //Exponential backoff
1057                        wait *= 2;
1058                    }
1059                }
1060            }
1061        }
1062    }
1063}
1064
1065#[cfg(test)]
1066mod util_tests {
1067    use super::*;
1068    use std::time::Instant;
1069
1070    #[tokio::test]
1071    async fn retry_loop_test() {
1072        let calls = std::cell::RefCell::new(vec![]);
1073        retry_loop(8, || async {
1074            calls.borrow_mut().push(Instant::now());
1075            if calls.borrow().len() >= 4 {
1076                Ok(())
1077            } else {
1078                Err(RetryLoopError::Transient(()))
1079            }
1080        })
1081        .await
1082        .unwrap();
1083        let mut calls = calls.borrow_mut();
1084        assert_eq!(calls.len(), 4);
1085        let t = calls.pop().unwrap().elapsed().as_millis();
1086        assert!(t == 0);
1087        let t = calls.pop().unwrap().elapsed().as_millis();
1088        assert!((200..=1200).contains(&t));
1089        let t = calls.pop().unwrap().elapsed().as_millis();
1090        assert!((300..=2200).contains(&t));
1091        let t = calls.pop().unwrap().elapsed().as_millis();
1092        assert!((350..=3200).contains(&t));
1093    }
1094
1095    #[tokio::test]
1096    #[ignore]
1097    // Ignored since it requires quite a bit of time to retry several times
1098    async fn rety_loop_failure() {
1099        let calls = std::cell::RefCell::new(vec![]);
1100        let result = retry_loop(8, || async {
1101            calls.borrow_mut().push(Instant::now());
1102            if 1 == 2 {
1103                return Ok(());
1104            }
1105            Err(RetryLoopError::Transient(()))
1106        })
1107        .await;
1108        if result.is_ok() {
1109            panic!();
1110        }
1111        let calls = calls.borrow();
1112        assert_eq!(calls.len(), 8);
1113    }
1114}