mod create_data_key;
mod encrypt;
use std::time::Duration;
use mongocrypt::{ctx::KmsProvider, Crypt};
use serde::{Deserialize, Serialize};
use typed_builder::TypedBuilder;
#[cfg(feature = "bson-3")]
use crate::bson_compat::RawBsonRefExt as _;
use crate::{
bson::{doc, spec::BinarySubtype, Binary, RawBinaryRef, RawDocumentBuf},
client::options::TlsOptions,
coll::options::CollectionOptions,
error::{Error, Result},
options::{ReadConcern, WriteConcern},
results::DeleteResult,
Client,
Collection,
Cursor,
Namespace,
};
use super::{options::KmsProviders, state_machine::CryptExecutor};
pub use super::client_builder::EncryptedClientBuilder;
pub use crate::action::csfle::encrypt::{EncryptKey, RangeOptions};
#[cfg(feature = "text-indexes-unstable")]
pub use crate::action::csfle::encrypt::{
PrefixOptions,
SubstringOptions,
SuffixOptions,
TextOptions,
};
pub struct ClientEncryption {
crypt: Crypt,
exec: CryptExecutor,
key_vault: Collection<RawDocumentBuf>,
}
impl ClientEncryption {
pub fn new(
key_vault_client: Client,
key_vault_namespace: Namespace,
kms_providers: impl IntoIterator<
Item = (KmsProvider, crate::bson::Document, Option<TlsOptions>),
>,
) -> Result<Self> {
Self::builder(key_vault_client, key_vault_namespace, kms_providers).build()
}
pub fn builder(
key_vault_client: Client,
key_vault_namespace: Namespace,
kms_providers: impl IntoIterator<
Item = (KmsProvider, crate::bson::Document, Option<TlsOptions>),
>,
) -> ClientEncryptionBuilder {
ClientEncryptionBuilder {
key_vault_client,
key_vault_namespace,
kms_providers: kms_providers.into_iter().collect(),
key_cache_expiration: None,
}
}
pub async fn delete_key(&self, id: &Binary) -> Result<DeleteResult> {
self.key_vault.delete_one(doc! { "_id": id }).await
}
pub async fn get_key(&self, id: &Binary) -> Result<Option<RawDocumentBuf>> {
self.key_vault.find_one(doc! { "_id": id }).await
}
pub async fn get_keys(&self) -> Result<Cursor<RawDocumentBuf>> {
self.key_vault.find(doc! {}).await
}
pub async fn add_key_alt_name(
&self,
id: &Binary,
key_alt_name: &str,
) -> Result<Option<RawDocumentBuf>> {
self.key_vault
.find_one_and_update(
doc! { "_id": id },
doc! { "$addToSet": { "keyAltNames": key_alt_name } },
)
.await
}
pub async fn remove_key_alt_name(
&self,
id: &Binary,
key_alt_name: &str,
) -> Result<Option<RawDocumentBuf>> {
let update = doc! {
"$set": {
"keyAltNames": {
"$cond": [
{ "$eq": ["$keyAltNames", [key_alt_name]] },
"$$REMOVE",
{
"$filter": {
"input": "$keyAltNames",
"cond": { "$ne": ["$$this", key_alt_name] },
}
}
]
}
}
};
self.key_vault
.find_one_and_update(doc! { "_id": id }, vec![update])
.await
}
pub async fn get_key_by_alt_name(
&self,
key_alt_name: impl AsRef<str>,
) -> Result<Option<RawDocumentBuf>> {
self.key_vault
.find_one(doc! { "keyAltNames": key_alt_name.as_ref() })
.await
}
pub async fn decrypt(&self, value: RawBinaryRef<'_>) -> Result<crate::bson::RawBson> {
if value.subtype != BinarySubtype::Encrypted {
return Err(Error::invalid_argument(format!(
"Invalid binary subtype for decrypt: expected {:?}, got {:?}",
BinarySubtype::Encrypted,
value.subtype
)));
}
let ctx = self
.crypt
.ctx_builder()
.build_explicit_decrypt(value.bytes)?;
let result = self.exec.run_ctx(ctx, None).await?;
Ok(result
.get("v")?
.ok_or_else(|| Error::internal("invalid decryption result"))?
.to_raw_bson())
}
}
pub struct ClientEncryptionBuilder {
key_vault_client: Client,
key_vault_namespace: Namespace,
kms_providers: Vec<(KmsProvider, crate::bson::Document, Option<TlsOptions>)>,
key_cache_expiration: Option<Duration>,
}
impl ClientEncryptionBuilder {
pub fn key_cache_expiration(mut self, expiration: impl Into<Option<Duration>>) -> Self {
self.key_cache_expiration = expiration.into();
self
}
pub fn build(self) -> Result<ClientEncryption> {
let kms_providers = KmsProviders::new(self.kms_providers)?;
let mut crypt_builder = Crypt::builder()
.kms_providers(&kms_providers.credentials_doc()?)?
.use_need_kms_credentials_state()
.use_range_v2()?
.retry_kms(true)?;
if let Some(key_cache_expiration) = self.key_cache_expiration {
let expiration_ms: u64 = key_cache_expiration.as_millis().try_into().map_err(|_| {
Error::invalid_argument(format!(
"key_cache_expiration must not exceed {} milliseconds, got {:?}",
u64::MAX,
key_cache_expiration
))
})?;
crypt_builder = crypt_builder.key_cache_expiration(expiration_ms)?;
}
let crypt = crypt_builder.build()?;
let exec = CryptExecutor::new_explicit(
self.key_vault_client.weak(),
self.key_vault_namespace.clone(),
kms_providers,
)?;
let key_vault = self
.key_vault_client
.database(&self.key_vault_namespace.db)
.collection_with_options(
&self.key_vault_namespace.coll,
CollectionOptions::builder()
.write_concern(WriteConcern::majority())
.read_concern(ReadConcern::majority())
.build(),
);
Ok(ClientEncryption {
crypt,
exec,
key_vault,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
#[non_exhaustive]
#[allow(missing_docs)]
pub enum MasterKey {
Aws(AwsMasterKey),
Azure(AzureMasterKey),
Gcp(GcpMasterKey),
Kmip(KmipMasterKey),
Local(LocalMasterKey),
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(default, setter(into)))]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct AwsMasterKey {
#[serde(skip)]
pub name: Option<String>,
pub region: String,
pub key: String,
pub endpoint: Option<String>,
}
impl From<AwsMasterKey> for MasterKey {
fn from(aws_master_key: AwsMasterKey) -> Self {
Self::Aws(aws_master_key)
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(default, setter(into)))]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct AzureMasterKey {
#[serde(skip)]
pub name: Option<String>,
pub key_vault_endpoint: String,
pub key_name: String,
pub key_version: Option<String>,
}
impl From<AzureMasterKey> for MasterKey {
fn from(azure_master_key: AzureMasterKey) -> Self {
Self::Azure(azure_master_key)
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(default, setter(into)))]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct GcpMasterKey {
#[serde(skip)]
pub name: Option<String>,
pub project_id: String,
pub location: String,
pub key_ring: String,
pub key_name: String,
pub key_version: Option<String>,
pub endpoint: Option<String>,
}
impl From<GcpMasterKey> for MasterKey {
fn from(gcp_master_key: GcpMasterKey) -> Self {
Self::Gcp(gcp_master_key)
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(default, setter(into)))]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct LocalMasterKey {
#[serde(skip)]
pub name: Option<String>,
}
impl From<LocalMasterKey> for MasterKey {
fn from(local_master_key: LocalMasterKey) -> Self {
Self::Local(local_master_key)
}
}
#[serde_with::skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(default, setter(into)))]
#[serde(rename_all = "camelCase")]
#[non_exhaustive]
pub struct KmipMasterKey {
#[serde(skip)]
pub name: Option<String>,
pub key_id: Option<String>,
pub delegated: Option<bool>,
pub endpoint: Option<String>,
}
impl From<KmipMasterKey> for MasterKey {
fn from(kmip_master_key: KmipMasterKey) -> Self {
Self::Kmip(kmip_master_key)
}
}
impl MasterKey {
pub fn provider(&self) -> KmsProvider {
let (provider, name) = match self {
MasterKey::Aws(AwsMasterKey { name, .. }) => (KmsProvider::aws(), name.clone()),
MasterKey::Azure(AzureMasterKey { name, .. }) => (KmsProvider::azure(), name.clone()),
MasterKey::Gcp(GcpMasterKey { name, .. }) => (KmsProvider::gcp(), name.clone()),
MasterKey::Kmip(KmipMasterKey { name, .. }) => (KmsProvider::kmip(), name.clone()),
MasterKey::Local(LocalMasterKey { name, .. }) => (KmsProvider::local(), name.clone()),
};
if let Some(name) = name {
provider.with_name(name)
} else {
provider
}
}
}