mongocrypt/
ctx.rs

1use std::{borrow::Borrow, ffi::CStr, marker::PhantomData, ptr};
2
3use crate::bson::{rawdoc, Document, RawDocument};
4use mongocrypt_sys as sys;
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    binary::{Binary, BinaryBuf, BinaryRef},
9    convert::{doc_binary, rawdoc_view, str_bytes_len},
10    error::{HasStatus, Result},
11    native::OwnedPtr,
12};
13
14pub struct CtxBuilder {
15    inner: OwnedPtr<sys::mongocrypt_ctx_t>,
16}
17
18impl HasStatus for CtxBuilder {
19    unsafe fn native_status(&self, status: *mut sys::mongocrypt_status_t) {
20        sys::mongocrypt_ctx_status(*self.inner.borrow(), status);
21    }
22}
23
24impl CtxBuilder {
25    /// Takes ownership of the given pointer, and will destroy it on drop.
26    pub(crate) fn steal(inner: *mut sys::mongocrypt_ctx_t) -> Self {
27        Self {
28            inner: OwnedPtr::steal(inner, sys::mongocrypt_ctx_destroy),
29        }
30    }
31
32    /// Set the key id to use for explicit encryption.
33    ///
34    /// It is an error to set both this and the key alt name.
35    ///
36    /// * `key_id` - The binary corresponding to the _id (a UUID) of the data
37    /// key to use from the key vault collection. Note, the UUID must be encoded with
38    /// RFC-4122 byte order.
39    pub fn key_id(self, key_id: &[u8]) -> Result<Self> {
40        let bin = BinaryRef::new(key_id);
41        unsafe {
42            if !sys::mongocrypt_ctx_setopt_key_id(*self.inner.borrow(), *bin.native()) {
43                return Err(self.status().as_error());
44            }
45        }
46        Ok(self)
47    }
48
49    /// Set the keyAltName to use for explicit encryption or
50    /// data key creation.
51    ///
52    /// For explicit encryption, it is an error to set both the keyAltName
53    /// and the key id.
54    ///
55    /// For creating data keys, call this function repeatedly to set
56    /// multiple keyAltNames.
57    pub fn key_alt_name(self, key_alt_name: &str) -> Result<Self> {
58        let mut bin: BinaryBuf = rawdoc! { "keyAltName": key_alt_name }.into();
59        unsafe {
60            if !sys::mongocrypt_ctx_setopt_key_alt_name(*self.inner.borrow(), *bin.native()) {
61                return Err(self.status().as_error());
62            }
63        }
64        Ok(self)
65    }
66
67    /// Set the keyMaterial to use for encrypting data.
68    ///
69    /// * `key_material` - The data encryption key to use.
70    pub fn key_material(self, key_material: &[u8]) -> Result<Self> {
71        let bson_bin = crate::bson::Binary {
72            subtype: crate::bson::spec::BinarySubtype::Generic,
73            bytes: key_material.to_vec(),
74        };
75        let mut bin: BinaryBuf = rawdoc! { "keyMaterial": bson_bin }.into();
76        unsafe {
77            if !sys::mongocrypt_ctx_setopt_key_material(*self.inner.borrow(), *bin.native()) {
78                return Err(self.status().as_error());
79            }
80        }
81        Ok(self)
82    }
83
84    /// Set the algorithm used for encryption to either
85    /// deterministic or random encryption. This value
86    /// should only be set when using explicit encryption.
87    pub fn algorithm(self, algorithm: Algorithm) -> Result<Self> {
88        unsafe {
89            if !sys::mongocrypt_ctx_setopt_algorithm(
90                *self.inner.borrow(),
91                algorithm.c_str().as_ptr(),
92                -1,
93            ) {
94                return Err(self.status().as_error());
95            }
96        }
97        Ok(self)
98    }
99
100    /// Identify the AWS KMS master key to use for creating a data key.
101    ///
102    /// This has been superseded by the more flexible `key_encryption_key`.
103    ///
104    /// * `region` - The AWS region.
105    /// * `cmk` - The Amazon Resource Name (ARN) of the customer master key (CMK).
106    #[cfg(test)]
107    pub(crate) fn masterkey_aws(self, region: &str, cmk: &str) -> Result<Self> {
108        let (region_bytes, region_len) = str_bytes_len(region)?;
109        let (cmk_bytes, cmk_len) = str_bytes_len(cmk)?;
110        unsafe {
111            if !sys::mongocrypt_ctx_setopt_masterkey_aws(
112                *self.inner.borrow(),
113                region_bytes,
114                region_len,
115                cmk_bytes,
116                cmk_len,
117            ) {
118                return Err(self.status().as_error());
119            }
120        }
121        Ok(self)
122    }
123
124    /// Identify a custom AWS endpoint when creating a data key.
125    /// This is used internally to construct the correct HTTP request
126    /// (with the Host header set to this endpoint). This endpoint
127    /// is persisted in the new data key, and will be returned via
128    /// `KmsCtx::endpoint`.
129    ///
130    /// This has been superseded by the more flexible `key_encryption_key`.
131    #[cfg(test)]
132    pub(crate) fn masterkey_aws_endpoint(self, endpoint: &str) -> Result<Self> {
133        let (bytes, len) = str_bytes_len(endpoint)?;
134        unsafe {
135            if !sys::mongocrypt_ctx_setopt_masterkey_aws_endpoint(*self.inner.borrow(), bytes, len)
136            {
137                return Err(self.status().as_error());
138            }
139        }
140        Ok(self)
141    }
142
143    /// Set key encryption key document for creating a data key or for rewrapping
144    /// datakeys.
145    ///
146    /// The following forms are accepted:
147    ///
148    /// AWS
149    /// {
150    ///    provider: "aws",
151    ///    region: <string>,
152    ///    key: <string>,
153    ///    endpoint: <optional string>
154    /// }
155    ///
156    /// Azure
157    /// {
158    ///    provider: "azure",
159    ///    keyVaultEndpoint: <string>,
160    ///    keyName: <string>,
161    ///    keyVersion: <optional string>
162    /// }
163    ///
164    /// GCP
165    /// {
166    ///    provider: "gcp",
167    ///    projectId: <string>,
168    ///    location: <string>,
169    ///    keyRing: <string>,
170    ///    keyName: <string>,
171    ///    keyVersion: <string>,
172    ///    endpoint: <optional string>
173    /// }
174    ///
175    /// Local
176    /// {
177    ///    provider: "local"
178    /// }
179    ///
180    /// KMIP
181    /// {
182    ///    provider: "kmip",
183    ///    keyId: <optional string>
184    ///    endpoint: <string>
185    /// }
186    pub fn key_encryption_key(self, key_encryption_key: &Document) -> Result<Self> {
187        let mut bin = doc_binary(key_encryption_key)?;
188        unsafe {
189            if !sys::mongocrypt_ctx_setopt_key_encryption_key(*self.inner.borrow(), *bin.native()) {
190                return Err(self.status().as_error());
191            }
192            Ok(self)
193        }
194    }
195
196    /// Set the contention factor used for explicit encryption.
197    /// The contention factor is only used for indexed Queryable Encryption.
198    pub fn contention_factor(self, contention_factor: i64) -> Result<Self> {
199        unsafe {
200            if !sys::mongocrypt_ctx_setopt_contention_factor(
201                *self.inner.borrow(),
202                contention_factor,
203            ) {
204                return Err(self.status().as_error());
205            }
206        }
207        Ok(self)
208    }
209
210    /// Set the index key id to use for explicit Queryable Encryption.
211    ///
212    /// If the index key id not set, the key id from `key_id` is used.
213    ///
214    /// * `key_id` - The _id (a UUID) of the data key to use from the key vault collection.
215    pub fn index_key_id(self, key_id: &crate::bson::Uuid) -> Result<Self> {
216        let bytes = key_id.bytes();
217        let bin = BinaryRef::new(&bytes);
218        unsafe {
219            if !sys::mongocrypt_ctx_setopt_index_key_id(*self.inner.borrow(), *bin.native()) {
220                return Err(self.status().as_error());
221            }
222        }
223        Ok(self)
224    }
225
226    /// Set the query type to use for explicit Queryable Encryption.
227    pub fn query_type(self, query_type: &str) -> Result<Self> {
228        let (s, len) = str_bytes_len(query_type)?;
229        unsafe {
230            if !sys::mongocrypt_ctx_setopt_query_type(*self.inner.borrow(), s, len) {
231                return Err(self.status().as_error());
232            }
233        }
234        Ok(self)
235    }
236
237    /// Set options for explicit encryption with [`Algorithm::Range`].
238    ///
239    /// `options` is a document of the form:
240    /// {
241    ///    "min": Optional<BSON value>,
242    ///    "max": Optional<BSON value>,
243    ///    "sparsity": Int64,
244    ///    "precision": Optional<Int32>,
245    ///    "trimFactor": Optional<Int32>
246    /// }
247    pub fn algorithm_range(self, options: Document) -> Result<Self> {
248        let mut bin = doc_binary(&options)?;
249        unsafe {
250            if !sys::mongocrypt_ctx_setopt_algorithm_range(*self.inner.borrow(), *bin.native()) {
251                return Err(self.status().as_error());
252            }
253        }
254        Ok(self)
255    }
256
257    /// Set options for explicit encryption with [`Algorithm::TextPreview`].
258    pub fn algorithm_text(self, options: Document) -> Result<Self> {
259        let mut bin = doc_binary(&options)?;
260        unsafe {
261            if !sys::mongocrypt_ctx_setopt_algorithm_text(*self.inner.borrow(), *bin.native()) {
262                return Err(self.status().as_error());
263            }
264        }
265        Ok(self)
266    }
267
268    fn into_ctx(self) -> Ctx {
269        Ctx { inner: self.inner }
270    }
271
272    /// Initialize a context to create a data key.
273    pub fn build_datakey(self) -> Result<Ctx> {
274        unsafe {
275            if !sys::mongocrypt_ctx_datakey_init(*self.inner.borrow()) {
276                return Err(self.status().as_error());
277            }
278        }
279        Ok(self.into_ctx())
280    }
281
282    /// Initialize a context for encryption.
283    ///
284    /// * `db` - The database name.
285    /// * `cmd` - The BSON command to be encrypted.
286    pub fn build_encrypt(self, db: &str, cmd: &RawDocument) -> Result<Ctx> {
287        let (db_bytes, db_len) = str_bytes_len(db)?;
288        let cmd_bin = BinaryRef::new(cmd.as_bytes());
289        unsafe {
290            if !sys::mongocrypt_ctx_encrypt_init(
291                *self.inner.borrow(),
292                db_bytes,
293                db_len,
294                *cmd_bin.native(),
295            ) {
296                return Err(self.status().as_error());
297            }
298        }
299        Ok(self.into_ctx())
300    }
301
302    /// Explicit helper method to encrypt a single BSON object. Contexts
303    /// created for explicit encryption will not go through mongocryptd.
304    ///
305    /// To specify a key_id, algorithm, or iv to use, please use the
306    /// corresponding methods before calling this.
307    ///
308    /// An error is returned if FLE 1 and Queryable Encryption incompatible options
309    /// are set.
310    ///
311    /// * `value` - the plaintext BSON value.
312    pub fn build_explicit_encrypt(self, value: crate::bson::RawBson) -> Result<Ctx> {
313        let mut bin: BinaryBuf = rawdoc! { "v": value }.into();
314        unsafe {
315            if !sys::mongocrypt_ctx_explicit_encrypt_init(*self.inner.borrow(), *bin.native()) {
316                return Err(self.status().as_error());
317            }
318        }
319        Ok(self.into_ctx())
320    }
321
322    /// Explicit helper method to encrypt a Match Expression or Aggregate Expression.
323    /// Contexts created for explicit encryption will not go through mongocryptd.
324    /// Requires query_type to be "range" or "rangePreview".
325    ///
326    /// NOTE: "rangePreview" is experimental only and is not intended for public use.
327    /// API for "rangePreview" may be removed in a future release.
328    ///
329    /// This method expects the passed-in BSON to be one of these forms:
330    ///
331    /// 1. A Match Expression of this form:
332    ///    {$and: [{<field>: {<op>: <value1>, {<field>: {<op>: <value2> }}]}
333    /// 2. An Aggregate Expression of this form:
334    ///    {$and: [{<op>: [<fieldpath>, <value1>]}, {<op>: [<fieldpath>, <value2>]}]
335    ///
336    /// <op> may be $lt, $lte, $gt, or $gte.
337    ///
338    /// The value of "v" is expected to be the BSON value passed to a driver
339    /// ClientEncryption.encryptExpression helper.
340    ///
341    /// Associated options for FLE 1:
342    /// - [CtxBuilder::key_id]
343    /// - [CtxBuilder::key_alt_name]
344    /// - [CtxBuilder::algorithm]
345    ///
346    /// Associated options for Queryable Encryption:
347    /// - [CtxBuilder::key_id]
348    /// - [CtxBuilder::index_key_id]
349    /// - [CtxBuilder::contention_factor]
350    /// - [CtxBuilder::query_type]
351    /// - [CtxBuilder::range_options]
352    ///
353    /// An error is returned if FLE 1 and Queryable Encryption incompatible options
354    /// are set.
355    pub fn build_explicit_encrypt_expression(
356        self,
357        value: crate::bson::RawDocumentBuf,
358    ) -> Result<Ctx> {
359        let mut bin: BinaryBuf = rawdoc! { "v": value }.into();
360        unsafe {
361            if !sys::mongocrypt_ctx_explicit_encrypt_expression_init(
362                *self.inner.borrow(),
363                *bin.native(),
364            ) {
365                return Err(self.status().as_error());
366            }
367        }
368        Ok(self.into_ctx())
369    }
370
371    /// Initialize a context for decryption.
372    ///
373    /// * `doc` - The document to be decrypted.
374    pub fn build_decrypt(self, doc: &RawDocument) -> Result<Ctx> {
375        let bin = BinaryRef::new(doc.as_bytes());
376        unsafe {
377            if !sys::mongocrypt_ctx_decrypt_init(*self.inner.borrow(), *bin.native()) {
378                return Err(self.status().as_error());
379            }
380        }
381        Ok(self.into_ctx())
382    }
383
384    /// Explicit helper method to decrypt a single BSON object.
385    ///
386    /// * `msg` - the encrypted BSON.
387    pub fn build_explicit_decrypt(self, msg: &[u8]) -> Result<Ctx> {
388        let bson_bin = crate::bson::Binary {
389            subtype: crate::bson::spec::BinarySubtype::Encrypted,
390            bytes: msg.into(),
391        };
392        let mut bin: BinaryBuf = rawdoc! { "v": bson_bin }.into();
393        unsafe {
394            if !sys::mongocrypt_ctx_explicit_decrypt_init(*self.inner.borrow(), *bin.native()) {
395                return Err(self.status().as_error());
396            }
397        }
398        Ok(self.into_ctx())
399    }
400
401    /// Initialize a context to rewrap datakeys.
402    ///
403    /// * `filter` - The filter to use for the find command on the key vault
404    /// collection to retrieve datakeys to rewrap.
405    pub fn build_rewrap_many_datakey(self, filter: &RawDocument) -> Result<Ctx> {
406        let bin = BinaryRef::new(filter.as_bytes());
407        unsafe {
408            if !sys::mongocrypt_ctx_rewrap_many_datakey_init(*self.inner.borrow(), *bin.native()) {
409                return Err(self.status().as_error());
410            }
411        }
412        Ok(self.into_ctx())
413    }
414}
415
416#[derive(Debug, PartialEq, Eq, Clone, Copy)]
417#[non_exhaustive]
418pub enum Algorithm {
419    Deterministic,
420    Random,
421    Indexed,
422    Unindexed,
423    #[deprecated]
424    RangePreview,
425    Range,
426    TextPreview,
427}
428
429impl Algorithm {
430    fn c_str(&self) -> &'static CStr {
431        let bytes: &[u8] = match self {
432            Self::Deterministic => b"AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic\0",
433            Self::Random => b"AEAD_AES_256_CBC_HMAC_SHA_512-Random\0",
434            Self::Indexed => b"Indexed\0",
435            Self::Unindexed => b"Unindexed\0",
436            #[allow(deprecated)]
437            Self::RangePreview => b"RangePreview\0",
438            Self::Range => b"Range\0",
439            Self::TextPreview => b"TextPreview\0",
440        };
441        unsafe { CStr::from_bytes_with_nul_unchecked(bytes) }
442    }
443}
444
445pub struct Ctx {
446    inner: OwnedPtr<sys::mongocrypt_ctx_t>,
447}
448
449// Functions on `mongocrypt_ctx_t` are not threadsafe but do not rely on any thread-local state, so `Ctx` is `Send` but not `Sync`.
450unsafe impl Send for Ctx {}
451
452impl HasStatus for Ctx {
453    unsafe fn native_status(&self, status: *mut sys::mongocrypt_status_t) {
454        sys::mongocrypt_ctx_status(*self.inner.borrow(), status);
455    }
456}
457
458/// Manages the state machine for encryption or decryption.
459impl Ctx {
460    /// Get the current state of a context.
461    pub fn state(&self) -> Result<State> {
462        let s = unsafe { sys::mongocrypt_ctx_state(*self.inner.borrow()) };
463        if s == sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_ERROR {
464            return Err(self.status().as_error());
465        }
466        Ok(State::from_native(s))
467    }
468
469    /// Get BSON necessary to run the mongo operation when in `State::NeedMongo*` states.
470    ///
471    /// The returned value:
472    /// * for `State::NeedMongoCollinfo[WithDb]`it is a listCollections filter.
473    /// * for `State::NeedMongoKeys` it is a find filter.
474    /// * for `State::NeedMongoMarkings` it is a command to send to mongocryptd.
475    pub fn mongo_op(&self) -> Result<&RawDocument> {
476        // Safety: `mongocrypt_ctx_mongo_op` updates the passed-in `Binary` to point to a chunk of
477        // BSON with the same lifetime as the underlying `Ctx`.  The `Binary` itself does not own
478        // the memory, and gets cleaned up at the end of the unsafe block.  Lifetime inference on
479        // the return type binds `op_bytes` to the same lifetime as `&self`, which is the correct
480        // one.
481        let op_bytes = unsafe {
482            let bin = Binary::new();
483            if !sys::mongocrypt_ctx_mongo_op(*self.inner.borrow(), *bin.native()) {
484                return Err(self.status().as_error());
485            }
486            bin.bytes()?
487        };
488        rawdoc_view(op_bytes)
489    }
490
491    /// Get the database to run the mongo operation.
492    ///
493    /// Only applies for [`State::NeedMongoCollinfoWithDb`].
494    pub fn mongo_db(&self) -> Result<&str> {
495        let cptr = unsafe { sys::mongocrypt_ctx_mongo_db(*self.inner.borrow()) };
496        if cptr.is_null() {
497            return Err(self.status().as_error());
498        }
499        // Lifetime safety: the returned cstr is valid for the lifetime of the underlying `Ctx`.
500        let cstr = unsafe { CStr::from_ptr(cptr) };
501        Ok(cstr.to_str()?)
502    }
503
504    /// Feed a BSON reply or result when this context is in
505    /// `State::NeedMongo*` states. This may be called multiple times
506    /// depending on the operation.
507    ///
508    /// `reply` is a BSON document result being fed back for this operation.
509    /// - For `State::NeedMongoCollinfo[WithDb]` it is a doc from a listCollections
510    /// cursor. (Note, if listCollections returned no result, do not call this
511    /// function.)
512    /// - For `State::NeedMongoKeys` it is a doc from a find cursor.
513    ///   (Note, if find returned no results, do not call this function.)
514    /// - For `State::NeedMongoMarkings` it is a reply from mongocryptd.
515    pub fn mongo_feed(&mut self, reply: &RawDocument) -> Result<()> {
516        let bin = BinaryRef::new(reply.as_bytes());
517        unsafe {
518            if !sys::mongocrypt_ctx_mongo_feed(*self.inner.borrow(), *bin.native()) {
519                return Err(self.status().as_error());
520            }
521        }
522        Ok(())
523    }
524
525    /// Call when done feeding the reply (or replies) back to the context.
526    pub fn mongo_done(&mut self) -> Result<()> {
527        unsafe {
528            if !sys::mongocrypt_ctx_mongo_done(*self.inner.borrow()) {
529                return Err(self.status().as_error());
530            }
531        }
532        Ok(())
533    }
534
535    /// Create a scope guard that provides handles to pending KMS requests.
536    pub fn kms_scope(&mut self) -> KmsScope<'_> {
537        KmsScope { ctx: self }
538    }
539
540    /// Call in response to the `State::NeedKmsCredentials` state
541    /// to set per-context KMS provider settings. These follow the same format
542    /// as `CryptBuilder::kms_providers`. If no keys are present in the
543    /// BSON input, the KMS provider settings configured for the `Crypt`
544    /// at initialization are used.
545    pub fn provide_kms_providers(&mut self, kms_providers_definition: &RawDocument) -> Result<()> {
546        let bin = BinaryRef::new(kms_providers_definition.as_bytes());
547        unsafe {
548            if !sys::mongocrypt_ctx_provide_kms_providers(*self.inner.borrow(), *bin.native()) {
549                return Err(self.status().as_error());
550            }
551        }
552        Ok(())
553    }
554
555    /// Perform the final encryption or decryption.
556    ///
557    /// If this context was initialized with `CtxBuilder::build_encrypt`, then
558    /// this BSON is the (possibly) encrypted command to send to the server.
559    ///
560    /// If this context was initialized with `CtxBuilder::build_decrypt`, then
561    /// this BSON is the decrypted result to return to the user.
562    ///
563    /// If this context was initialized with `CtxBuilder::build_explicit_encrypt`,
564    /// then this BSON has the form { "v": (BSON binary) } where the BSON binary
565    /// is the resulting encrypted value.
566    ///
567    /// If this context was initialized with `CtxBuilder::build_explicit_decrypt`,
568    /// then this BSON has the form { "v": (BSON value) } where the BSON value
569    /// is the resulting decrypted value.
570    ///
571    /// If this context was initialized with `CtxBuilder::build_datakey`, then
572    /// this BSON is the document containing the new data key to be inserted into
573    /// the key vault collection.
574    ///
575    /// If this context was initialized with `CtxBuilder::build_rewrap_many_datakey`,
576    /// then this BSON has the form { "v": [(BSON document), ...] } where each BSON
577    /// document in the array is a document containing a rewrapped datakey to be
578    /// bulk-updated into the key vault collection.
579    pub fn finalize(&mut self) -> Result<&RawDocument> {
580        let bytes = unsafe {
581            let bin = Binary::new();
582            if !sys::mongocrypt_ctx_finalize(*self.inner.borrow(), *bin.native()) {
583                return Err(self.status().as_error());
584            }
585            bin.bytes()?
586        };
587        rawdoc_view(bytes)
588    }
589}
590
591/// Indicates the state of the `Ctx`. Each state requires
592/// different handling. See [the integration
593/// guide](https://github.com/mongodb/libmongocrypt/blob/master/integrating.md#state-machine)
594/// for information on what to do for each state.
595#[derive(Debug, PartialEq, Eq, Clone, Copy)]
596#[non_exhaustive]
597pub enum State {
598    NeedMongoCollinfo,
599    NeedMongoCollinfoWithDb,
600    NeedMongoMarkings,
601    NeedMongoKeys,
602    NeedKms,
603    NeedKmsCredentials,
604    Ready,
605    Done,
606    Other(sys::mongocrypt_ctx_state_t),
607}
608
609impl State {
610    fn from_native(state: sys::mongocrypt_ctx_state_t) -> Self {
611        match state {
612            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_NEED_MONGO_COLLINFO => {
613                Self::NeedMongoCollinfo
614            }
615            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_NEED_MONGO_COLLINFO_WITH_DB => {
616                Self::NeedMongoCollinfoWithDb
617            }
618            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_NEED_MONGO_MARKINGS => {
619                Self::NeedMongoMarkings
620            }
621            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_NEED_MONGO_KEYS => Self::NeedMongoKeys,
622            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_NEED_KMS => Self::NeedKms,
623            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS => {
624                Self::NeedKmsCredentials
625            }
626            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_READY => Self::Ready,
627            sys::mongocrypt_ctx_state_t_MONGOCRYPT_CTX_DONE => Self::Done,
628            other => Self::Other(other),
629        }
630    }
631}
632
633/// A scope bounding the processing of (potentially multiple) KMS handles.
634pub struct KmsScope<'ctx> {
635    ctx: &'ctx Ctx,
636}
637
638// Handling multiple KMS requests is threadsafe, so `KmsScope` can be both `Send` and `Sync`.
639unsafe impl<'ctx> Send for KmsScope<'ctx> {}
640unsafe impl<'ctx> Sync for KmsScope<'ctx> {}
641
642// This is `Iterator`-like but does not impl that because it's encouraged for multiple `KmsCtx` to
643// be retrieved and processed in parallel, as reflected in the `&self` shared reference rather than
644// `Iterator`'s exclusive `next(&mut self)`.
645impl<'ctx> KmsScope<'ctx> {
646    /// Get the next KMS handle.
647    ///
648    /// Multiple KMS handles may be retrieved at once. Drivers may do this to fan
649    /// out multiple concurrent KMS HTTP requests. Feeding multiple KMS requests
650    /// is thread-safe.
651    ///
652    /// If KMS handles are being handled synchronously, the driver can reuse the same
653    /// TLS socket to send HTTP requests and receive responses.
654    pub fn next_kms_ctx(&self) -> Option<KmsCtx<'_>> {
655        let inner = unsafe { sys::mongocrypt_ctx_next_kms_ctx(*self.ctx.inner.borrow()) };
656        if inner.is_null() {
657            return None;
658        }
659        Some(KmsCtx {
660            inner,
661            _marker: PhantomData,
662        })
663    }
664}
665
666impl<'ctx> Drop for KmsScope<'ctx> {
667    fn drop(&mut self) {
668        unsafe {
669            // If this errors, it will show up in the next call to `ctx.status()` (or any other ctx call).
670            sys::mongocrypt_ctx_kms_done(*self.ctx.inner.borrow());
671        }
672    }
673}
674
675/// Manages a single KMS HTTP request/response.
676pub struct KmsCtx<'scope> {
677    inner: *mut sys::mongocrypt_kms_ctx_t,
678    _marker: PhantomData<&'scope mut ()>,
679}
680
681unsafe impl<'scope> Send for KmsCtx<'scope> {}
682unsafe impl<'scope> Sync for KmsCtx<'scope> {}
683
684impl<'scope> HasStatus for KmsCtx<'scope> {
685    unsafe fn native_status(&self, status: *mut sys::mongocrypt_status_t) {
686        sys::mongocrypt_kms_ctx_status(self.inner, status);
687    }
688}
689
690impl<'scope> KmsCtx<'scope> {
691    /// Get the HTTP request message for a KMS handle.
692    pub fn message(&self) -> Result<&'scope [u8]> {
693        // Safety: the message referenced has a lifetime that's valid until kms_done is called,
694        // which can't happen without ending 'scope.
695        unsafe {
696            let bin = Binary::new();
697            if !sys::mongocrypt_kms_ctx_message(self.inner, *bin.native()) {
698                return Err(self.status().as_error());
699            }
700            bin.bytes()
701        }
702    }
703
704    /// Get the hostname from which to connect over TLS.
705    ///
706    /// The endpoint consists of a hostname and port separated by a colon.
707    /// E.g. "example.com:123". A port is always present.
708    pub fn endpoint(&self) -> Result<&'scope str> {
709        let mut ptr: *const ::std::os::raw::c_char = ptr::null();
710        unsafe {
711            if !sys::mongocrypt_kms_ctx_endpoint(
712                self.inner,
713                &mut ptr as *mut *const ::std::os::raw::c_char,
714            ) {
715                return Err(self.status().as_error());
716            }
717            Ok(CStr::from_ptr(ptr).to_str()?)
718        }
719    }
720
721    /// How many microseconds to sleep before sending a request.
722    pub fn sleep_micros(&self) -> i64 {
723        unsafe { sys::mongocrypt_kms_ctx_usleep(self.inner) }
724    }
725
726    /// Whether a failed request can be retried.
727    pub fn retry_failure(&self) -> bool {
728        unsafe { sys::mongocrypt_kms_ctx_fail(self.inner) }
729    }
730
731    /// Indicates how many bytes to feed into `feed`.
732    pub fn bytes_needed(&self) -> u32 {
733        unsafe { sys::mongocrypt_kms_ctx_bytes_needed(self.inner) }
734    }
735
736    /// Feed bytes from the HTTP response.
737    ///
738    /// Feeding more bytes than what has been returned in `bytes_needed` is an error.
739    pub fn feed(&mut self, bytes: &[u8]) -> Result<()> {
740        let bin = BinaryRef::new(bytes);
741        unsafe {
742            if !sys::mongocrypt_kms_ctx_feed(self.inner, *bin.native()) {
743                return Err(self.status().as_error());
744            }
745        }
746        Ok(())
747    }
748
749    /// Get the KMS provider identifier associated with this KMS request.
750    ///
751    /// This is used to conditionally configure TLS connections based on the KMS
752    /// request. It is useful for KMIP, which authenticates with a client
753    /// certificate.
754    pub fn kms_provider(&self) -> Result<KmsProvider> {
755        let s = unsafe {
756            let ptr = sys::mongocrypt_kms_ctx_get_kms_provider(self.inner, ptr::null_mut());
757            CStr::from_ptr(ptr).to_str()?
758        };
759        Ok(KmsProvider::from_string(s))
760    }
761}
762
763/// A KMS provider. KMS providers can be constructed using the various constructors that correspond
764/// to each [`KmsProviderType`]. KMS providers also have an optional name that can be set using the
765/// [`with_name`](KmsProvider::with_name) method.
766#[derive(Debug, Clone, PartialEq, Eq, Hash)]
767pub struct KmsProvider {
768    provider_type: KmsProviderType,
769    name: Option<String>,
770}
771
772/// The supported KMS provider types.
773#[derive(Debug, Clone, PartialEq, Eq, Hash)]
774#[non_exhaustive]
775pub enum KmsProviderType {
776    Aws,
777    Azure,
778    Gcp,
779    Kmip,
780    Local,
781    Other(String),
782}
783
784impl KmsProvider {
785    /// Constructs an unnamed AWS KMS provider.
786    pub fn aws() -> Self {
787        Self {
788            provider_type: KmsProviderType::Aws,
789            name: None,
790        }
791    }
792
793    /// Constructs an unnamed Azure KMS provider.
794    pub fn azure() -> Self {
795        Self {
796            provider_type: KmsProviderType::Azure,
797            name: None,
798        }
799    }
800
801    /// Constructs an unnamed GCP KMS provider.
802    pub fn gcp() -> Self {
803        Self {
804            provider_type: KmsProviderType::Gcp,
805            name: None,
806        }
807    }
808
809    /// Constructs an unnamed local KMS provider.
810    pub fn local() -> Self {
811        Self {
812            provider_type: KmsProviderType::Local,
813            name: None,
814        }
815    }
816
817    /// Constructs an unnamed KMIP KMS provider.
818    pub fn kmip() -> Self {
819        Self {
820            provider_type: KmsProviderType::Kmip,
821            name: None,
822        }
823    }
824
825    /// Constructs an unnamed KMS provider with the given string.
826    pub fn other(other: impl Into<String>) -> Self {
827        Self {
828            provider_type: KmsProviderType::Other(other.into()),
829            name: None,
830        }
831    }
832
833    /// Sets the given name on this KMS provider. A name can be set to use multiple KMS providers
834    /// of the same type in one KMS provider list.
835    pub fn with_name(mut self, name: impl Into<String>) -> Self {
836        self.name = Some(name.into());
837        self
838    }
839
840    /// This KMS provider's type.
841    pub fn provider_type(&self) -> &KmsProviderType {
842        &self.provider_type
843    }
844
845    /// The name for this KMS provider.
846    pub fn name(&self) -> Option<&String> {
847        self.name.as_ref()
848    }
849
850    /// Returns the string representation of this KMS provider.
851    pub fn as_string(&self) -> String {
852        let mut full_name = match self.provider_type {
853            KmsProviderType::Aws => "aws",
854            KmsProviderType::Azure => "azure",
855            KmsProviderType::Gcp => "gcp",
856            KmsProviderType::Local => "local",
857            KmsProviderType::Kmip => "kmip",
858            KmsProviderType::Other(ref other) => other,
859        }
860        .to_string();
861        if let Some(ref name) = self.name {
862            full_name.push(':');
863            full_name.push_str(name);
864        }
865        full_name
866    }
867
868    /// Constructs a KMS provider from the given string. The string must begin with the provider
869    /// type followed by an optional ":" and name, e.g. "aws" or "aws:name".
870    pub fn from_string(name: &str) -> Self {
871        let (provider_type, name) = match name.split_once(':') {
872            Some((provider_type, name)) => (provider_type, Some(name.to_string())),
873            None => (name, None),
874        };
875        let provider_type = match provider_type {
876            "aws" => KmsProviderType::Aws,
877            "azure" => KmsProviderType::Azure,
878            "gcp" => KmsProviderType::Gcp,
879            "kmip" => KmsProviderType::Kmip,
880            "local" => KmsProviderType::Local,
881            other => KmsProviderType::Other(other.to_string()),
882        };
883        Self {
884            provider_type,
885            name,
886        }
887    }
888}
889
890impl Serialize for KmsProvider {
891    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
892    where
893        S: serde::Serializer,
894    {
895        serializer.serialize_str(&self.as_string())
896    }
897}
898
899impl<'de> Deserialize<'de> for KmsProvider {
900    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
901    where
902        D: serde::Deserializer<'de>,
903    {
904        struct V;
905        impl<'de> serde::de::Visitor<'de> for V {
906            type Value = KmsProvider;
907
908            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
909                write!(formatter, "a string containing a KMS provider name")
910            }
911
912            fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
913            where
914                E: serde::de::Error,
915            {
916                Ok(KmsProvider::from_string(v))
917            }
918        }
919        deserializer.deserialize_str(V)
920    }
921}