#![cfg(feature = "tokio")]
mod formats;
use formats::{format_sem_term_binary, format_sem_term_ore, format_sem_term_ore_array};
use std::borrow::Cow;
use std::sync::Arc;
use crate::{
encryption::{self, Encrypted, EncryptedEntry, EncryptedSteVecTerm, IndexTerm, QueryOp},
zerokms::{self, RecordDecryptError},
};
use crate::{
encryption::StorageBuilder,
zerokms::{GenerateKeyPayload, IndexKey},
};
use super::zerokms::EncryptedRecord;
use cipherstash_config::{column::IndexType, ColumnConfig};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
use crate::encryption::{PlaintextTarget, Queryable, ScopedCipher};
use cipherstash_config::column::Index;
use zerokms_protocol::{Context, UnverifiedContext};
pub async fn encrypt_eql<'a, C>(
cipher: Arc<ScopedCipher<C>>,
plaintexts: Vec<PreparedPlaintext<'a>>,
opts: &EqlEncryptOpts<'a>,
) -> Result<Vec<EqlCiphertext>, EqlError>
where
C: Send + Sync + 'static,
for<'b> &'b C: stack_auth::AuthStrategy,
{
use std::collections::VecDeque;
let effective_keyset_id = opts.keyset_id.unwrap_or(cipher.keyset_id());
let targets: Vec<EncryptionTarget> =
to_encryption_targets(cipher.index_key(), plaintexts, effective_keyset_id)?;
let mut data_keys = VecDeque::from(
cipher
.generate_data_keys(
generate_data_key_payloads(opts, &targets),
opts.service_token.clone(),
opts.unverified_context.clone(),
)
.await?,
);
let encrypted = targets
.into_iter()
.map(|target| -> Result<EqlCiphertext, EqlError> {
match target {
EncryptionTarget::ForStorage(identifier, builder) => {
let encrypted = builder.build_for_encryption().encrypt(
data_keys
.remove(0)
.expect("insufficient data keys to encrypt all plaintexts"),
)?;
Ok(to_eql_ciphertext(encrypted, &identifier)?)
}
EncryptionTarget::ForQuery(identifier, plaintext, index_type, query_op) => {
let index = Index::new(index_type.clone());
let index_term =
(index, plaintext).build_queryable(cipher.clone(), query_op)?;
Ok(to_eql_ciphertext_from_sem_term(index_term, &identifier)?)
}
}
})
.collect::<Result<Vec<_>, _>>()?;
Ok(encrypted)
}
pub async fn decrypt_eql<'a, C>(
cipher: Arc<ScopedCipher<C>>,
ciphertexts: impl IntoIterator<Item = EqlCiphertext>,
opts: &EqlDecryptOpts<'a>,
) -> Result<Vec<encryption::Plaintext>, EqlError>
where
C: Send + Sync + 'static,
for<'b> &'b C: stack_auth::AuthStrategy,
{
use crate::{encryption::DecryptOptions, zerokms::WithContext};
let decrypt_opts = DecryptOptions {
keyset_id: opts.keyset_id,
service_token: opts.service_token.clone(),
unverified_context: opts.unverified_context.clone(),
};
let ciphertexts = ciphertexts
.into_iter()
.map(|eql| {
let ciphertext = eql
.body
.ciphertext
.ok_or_else(|| EqlError::MissingCiphertext(eql.identifier.clone()))?;
Ok(WithContext {
record: ciphertext,
context: opts.lock_context.clone(),
})
})
.collect::<Result<Vec<_>, EqlError>>()?;
Ok(cipher
.decrypt(ciphertexts, &decrypt_opts)
.await
.map_err(|err| convert_zerokms_error(err, cipher.keyset_id(), opts.keyset_id))?
.into_iter()
.map(|decrypted| encryption::Plaintext::from_slice(&decrypted))
.collect::<Result<Vec<_>, _>>()?)
}
pub async fn decrypt_eql_fallible<'a, C>(
cipher: Arc<ScopedCipher<C>>,
ciphertexts: impl IntoIterator<Item = EqlCiphertext>,
opts: &EqlDecryptOpts<'a>,
) -> Result<Vec<Result<encryption::Plaintext, EqlError>>, EqlError>
where
C: Send + Sync + 'static,
for<'b> &'b C: stack_auth::AuthStrategy,
{
use crate::{encryption::DecryptOptions, zerokms::WithContext};
let decrypt_opts = DecryptOptions {
keyset_id: opts.keyset_id,
service_token: opts.service_token.clone(),
unverified_context: opts.unverified_context.clone(),
};
let inputs: Vec<_> = ciphertexts.into_iter().collect();
let input_count = inputs.len();
let mut results: Vec<Option<Result<encryption::Plaintext, EqlError>>> =
(0..input_count).map(|_| None).collect();
let mut valid_payloads: Vec<(usize, WithContext<EncryptedRecord>)> =
Vec::with_capacity(input_count);
for (index, eql) in inputs.into_iter().enumerate() {
match eql.body.ciphertext {
Some(ciphertext) => {
valid_payloads.push((
index,
WithContext {
record: ciphertext,
context: opts.lock_context.clone(),
},
));
}
None => {
results[index] = Some(Err(EqlError::MissingCiphertext(eql.identifier)));
}
}
}
let (indices, payloads): (Vec<usize>, Vec<_>) = valid_payloads.into_iter().unzip();
let decrypt_results = cipher
.decrypt_fallible(payloads, &decrypt_opts)
.await
.map_err(|err| convert_zerokms_error(err, cipher.keyset_id(), opts.keyset_id))?;
for (index, decrypt_result) in indices.into_iter().zip(decrypt_results) {
results[index] = Some(match decrypt_result {
Ok(bytes) => encryption::Plaintext::from_slice(&bytes).map_err(Into::into),
Err(err) => Err(EqlError::from(err)),
});
}
Ok(results
.into_iter()
.map(|r| r.expect("all result slots filled"))
.collect())
}
pub const EQL_SCHEMA_VERSION: u16 = 2;
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct Identifier {
#[serde(rename = "t")]
pub table: String,
#[serde(rename = "c")]
pub column: String,
}
impl Identifier {
pub fn new(table: impl Into<String>, column: impl Into<String>) -> Self {
let table = table.into();
let column = column.into();
Self { table, column }
}
pub fn table(&self) -> &str {
&self.table
}
pub fn column(&self) -> &str {
&self.column
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EqlCiphertext {
#[serde(rename = "i")]
pub identifier: Identifier,
#[serde(rename = "v")]
pub version: u16,
#[serde(flatten)]
pub body: EqlCiphertextBody,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EqlCiphertextBody {
#[serde(
rename = "c",
default,
with = "formats::mp_base85",
skip_serializing_if = "Option::is_none"
)]
pub ciphertext: Option<EncryptedRecord>,
#[serde(flatten)]
pub sem: EqlSEM,
#[serde(rename = "a", skip_serializing_if = "Option::is_none")]
pub is_array_item: Option<bool>,
}
#[derive(Clone, Debug, Deserialize, Serialize, Default)]
pub struct EqlSEM {
#[serde(rename = "ob", skip_serializing_if = "Option::is_none")]
pub ore_block_u64_8_256: Option<Vec<String>>,
#[serde(rename = "bf", skip_serializing_if = "Option::is_none")]
pub bloom_filter: Option<Vec<u16>>,
#[serde(rename = "hm", skip_serializing_if = "Option::is_none")]
pub hmac_256: Option<String>,
#[serde(rename = "s", skip_serializing_if = "Option::is_none")]
pub selector: Option<String>,
#[serde(rename = "b3", skip_serializing_if = "Option::is_none")]
pub blake3: Option<String>,
#[serde(rename = "ocf", skip_serializing_if = "Option::is_none")]
pub ore_cllw_u64_8: Option<String>,
#[serde(rename = "ocv", skip_serializing_if = "Option::is_none")]
pub ore_cllw_var_8: Option<String>,
#[serde(rename = "sv", skip_serializing_if = "Option::is_none")]
pub ste_vec: Option<Vec<EqlCiphertextBody>>,
}
impl EqlSEM {
fn update(&mut self, term: IndexTerm) {
match term {
IndexTerm::Binary(bytes) => {
self.hmac_256 = Some(format_sem_term_binary(&bytes));
}
IndexTerm::BinaryVec(_) => unimplemented!("this is not used by EQL"),
IndexTerm::BitMap(bloom_filter) => {
self.bloom_filter = Some(bloom_filter);
}
IndexTerm::OreFull(bytes) => {
self.ore_block_u64_8_256 = Some(format_sem_term_ore(&bytes));
}
IndexTerm::OreArray(bytes) => {
self.ore_block_u64_8_256 = Some(format_sem_term_ore_array(&bytes));
}
IndexTerm::OreLeft(bytes) => {
self.ore_block_u64_8_256 = Some(format_sem_term_ore(&bytes));
}
IndexTerm::SteVecSelector(selector) => {
self.selector = Some(hex::encode(selector.as_bytes()));
}
IndexTerm::SteVecTerm(ste_vec_term) => match ste_vec_term {
EncryptedSteVecTerm::Mac(bytes) => self.blake3 = Some(hex::encode(bytes)),
EncryptedSteVecTerm::OreFixed(ore) => self.ore_cllw_u64_8 = Some(hex::encode(ore)),
EncryptedSteVecTerm::OreVariable(ore) => {
self.ore_cllw_var_8 = Some(hex::encode(&ore))
}
},
IndexTerm::SteQueryVec(_) => {}
IndexTerm::Null => {}
};
}
}
fn to_eql_ciphertext_from_sem_term(
index_term: IndexTerm,
identifier: &Identifier,
) -> Result<EqlCiphertext, EqlError> {
let mut sem = EqlSEM::default();
sem.update(index_term);
Ok(EqlCiphertext {
identifier: identifier.to_owned(),
version: EQL_SCHEMA_VERSION,
body: EqlCiphertextBody {
ciphertext: None,
sem,
is_array_item: None,
},
})
}
fn to_eql_ciphertext(
encrypted: Encrypted,
identifier: &Identifier,
) -> Result<EqlCiphertext, EqlError> {
let mut sem = EqlSEM::default();
match encrypted {
Encrypted::Record(ciphertext, terms) => {
for term in terms {
sem.update(term);
}
Ok(EqlCiphertext {
identifier: identifier.to_owned(),
version: EQL_SCHEMA_VERSION,
body: EqlCiphertextBody {
ciphertext: Some(ciphertext),
sem,
is_array_item: None,
},
})
}
Encrypted::SteVec(ste_vec) => {
let ciphertext = ste_vec.root_ciphertext()?.clone();
let ste_vec_sem: Vec<EqlCiphertextBody> = ste_vec
.into_iter()
.map(
|EncryptedEntry {
tokenized_selector,
term,
record,
parent_is_array,
}| {
let sem = match term {
EncryptedSteVecTerm::Mac(bytes) => EqlSEM {
selector: Some(hex::encode(tokenized_selector.as_bytes())),
blake3: Some(hex::encode(bytes)),
..Default::default()
},
EncryptedSteVecTerm::OreFixed(ore) => EqlSEM {
selector: Some(hex::encode(tokenized_selector.as_bytes())),
ore_cllw_u64_8: Some(hex::encode(ore)),
..Default::default()
},
EncryptedSteVecTerm::OreVariable(ore) => EqlSEM {
selector: Some(hex::encode(tokenized_selector.as_bytes())),
ore_cllw_var_8: Some(hex::encode(&ore)),
..Default::default()
},
};
EqlCiphertextBody {
ciphertext: Some(record),
sem,
is_array_item: Some(parent_is_array),
}
},
)
.collect();
sem.ste_vec = Some(ste_vec_sem);
Ok(EqlCiphertext {
identifier: identifier.to_owned(),
version: EQL_SCHEMA_VERSION,
body: EqlCiphertextBody {
ciphertext: Some(ciphertext),
sem,
is_array_item: None,
},
})
}
}
}
#[derive(Error, Debug)]
pub enum EqlError {
#[error(transparent)]
CiphertextCouldNotBeSerialised(#[from] serde_json::Error),
#[error("Encrypted column could not be parsed")]
ColumnCouldNotBeParsed,
#[error("Encrypted column is null")]
ColumnIsNull,
#[error("Column '{column}' in table '{table}' could not be deserialised")]
ColumnCouldNotBeDeserialised { table: String, column: String },
#[error("Column '{column}' in table '{table}' could not be encrypted")]
ColumnCouldNotBeEncrypted { table: String, column: String },
#[error("Column configuration for column '{column}' in table '{table}' does not match the encrypted column")]
ColumnConfigurationMismatch { table: String, column: String },
#[error("Could not decrypt data using keyset '{keyset_id}'")]
CouldNotDecryptDataForKeyset { keyset_id: String },
#[error("InvalidIndexTerm")]
InvalidIndexTerm,
#[error("EQL payload for column '{}' in table '{}' is missing ciphertext", _0.column(), _0.table())]
MissingCiphertext(Identifier),
#[error("KeysetId `{id}` could not be parsed using `SET CIPHERSTASH.KEYSET_ID`. KeysetId should be a valid UUID")]
KeysetIdCouldNotBeParsed { id: String },
#[error("Keyset Id could not be set using `SET CIPHERSTASH.KEYSET_ID`")]
KeysetIdCouldNotBeSet,
#[error("Keyset Name could not be set using `SET CIPHERSTASH.KEYSET_NAME`")]
KeysetNameCouldNotBeSet,
#[error("Missing encrypt configuration for column type `{plaintext_type}`")]
MissingEncryptConfiguration { plaintext_type: &'static str },
#[error("Decrypted column could not be encoded as the expected type")]
PlaintextCouldNotBeEncoded,
#[error(transparent)]
Pipeline(#[from] encryption::EncryptionError),
#[error(transparent)]
PlaintextCouldNotBeDecoded(#[from] encryption::TypeParseError),
#[error("Missing keyset identifer")]
MissingKeysetIdentifier,
#[error("Cannot SET CIPHERSTASH.KEYSET if a default keyset has been configured")]
UnexpectedSetKeyset,
#[error("Column '{column}' in table '{table}' has no Encrypt configuration")]
UnknownColumn { table: String, column: String },
#[error("Unknown keyset name or id '{keyset}'. Check the configured credentials")]
UnknownKeysetIdentifier { keyset: String },
#[error("Table '{table}' has no Encrypt configuration")]
UnknownTable { table: String },
#[error("Unknown Index Term for column '{}' in table '{}'", _0.column(), _0.table())]
UnknownIndexTerm(Identifier),
#[error("ZeroKMS error '{}'", _0)]
ZeroKMS(#[from] zerokms::Error),
#[error("Record decryption error '{}'", _0)]
RecordDecrypt(#[from] RecordDecryptError),
}
#[derive(Debug)]
pub enum EqlOperation<'a> {
Store,
Query(&'a IndexType, QueryOp),
}
pub struct PreparedPlaintext<'a> {
identifier: Identifier,
plaintext: encryption::Plaintext,
eql_op: EqlOperation<'a>,
column_config: Cow<'a, ColumnConfig>,
}
impl<'a> PreparedPlaintext<'a> {
pub fn new(
column_config: Cow<'a, ColumnConfig>,
identifier: Identifier,
plaintext: encryption::Plaintext,
eql_op: EqlOperation<'a>,
) -> Self {
Self {
identifier,
plaintext,
eql_op,
column_config,
}
}
}
enum EncryptionTarget<'a> {
ForStorage(Identifier, StorageBuilder<'a, encryption::Plaintext>),
ForQuery(Identifier, encryption::Plaintext, &'a IndexType, QueryOp),
}
fn generate_data_key_payloads<'a>(
opts: &EqlEncryptOpts<'a>,
targets: &'a Vec<EncryptionTarget<'a>>,
) -> Vec<GenerateKeyPayload<'a>> {
targets
.iter()
.filter_map(|target| match target {
EncryptionTarget::ForStorage(_, builder) => Some(GenerateKeyPayload::new(
builder.descriptor(),
opts.lock_context.clone(),
)),
EncryptionTarget::ForQuery(_, _, _, _) => None,
})
.collect()
}
fn to_encryption_targets<'a>(
index_key: &'a IndexKey,
plaintexts: Vec<PreparedPlaintext<'a>>,
effective_keyset_id: Uuid,
) -> Result<Vec<EncryptionTarget<'a>>, encryption::EncryptionError> {
plaintexts
.into_iter()
.map(
move |prepared_plaintext| -> Result<EncryptionTarget, encryption::EncryptionError> {
use crate::encryption::Encryptable;
let PreparedPlaintext {
identifier,
plaintext,
eql_op,
column_config,
} = prepared_plaintext;
match eql_op {
EqlOperation::Store => Ok(EncryptionTarget::ForStorage(
identifier,
PlaintextTarget::new(plaintext, (*column_config).clone())
.build_encryptable(index_key, effective_keyset_id)?,
)),
EqlOperation::Query(index_type, query_op) => Ok(EncryptionTarget::ForQuery(
identifier, plaintext, index_type, query_op,
)),
}
},
)
.collect::<Result<Vec<_>, _>>()
}
#[derive(Debug, Default)]
pub struct EqlDecryptOpts<'a> {
pub keyset_id: Option<Uuid>,
pub lock_context: Cow<'a, [Context]>,
pub service_token: Option<Cow<'a, crate::credentials::ServiceToken>>,
pub unverified_context: Option<Cow<'a, UnverifiedContext>>,
}
#[derive(Debug, Default)]
pub struct EqlEncryptOpts<'a> {
pub keyset_id: Option<Uuid>,
pub lock_context: Cow<'a, [Context]>,
pub service_token: Option<Cow<'a, crate::credentials::ServiceToken>>,
pub unverified_context: Option<Cow<'a, UnverifiedContext>>,
pub index_types: Option<Cow<'a, [IndexType]>>,
}
fn convert_zerokms_error(
err: zerokms::Error,
cipher_keyset_id: Uuid,
keyset_id_override: Option<Uuid>,
) -> EqlError {
match err {
zerokms::Error::Decrypt(_) => {
let error_msg = err.to_string();
if error_msg.contains("Failed to retrieve key") {
EqlError::CouldNotDecryptDataForKeyset {
keyset_id: keyset_id_override
.map(|id| id.to_string())
.unwrap_or(cipher_keyset_id.to_string()),
}
} else {
EqlError::ZeroKMS(err)
}
}
_ => EqlError::ZeroKMS(err),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eql_ciphertext_missing_ciphertext_field() {
let identifier = Identifier::new("test_table", "test_column");
let eql = EqlCiphertext {
identifier: identifier.clone(),
version: EQL_SCHEMA_VERSION,
body: EqlCiphertextBody {
ciphertext: None, sem: EqlSEM::default(),
is_array_item: None,
},
};
let result = eql
.body
.ciphertext
.ok_or_else(|| EqlError::MissingCiphertext(identifier));
assert!(matches!(result, Err(EqlError::MissingCiphertext(_))));
}
#[test]
fn test_mp_base85_deserialize_invalid_input_returns_error() {
use serde::de::value::{Error as ValueError, StrDeserializer};
use serde::de::IntoDeserializer;
let invalid: StrDeserializer<ValueError> = "not-valid-base85!!!".into_deserializer();
let result: Result<Option<EncryptedRecord>, ValueError> =
formats::mp_base85::deserialize(invalid);
assert!(
result.is_err(),
"Invalid base85 input should return error, not Ok(None)"
);
}
#[test]
fn test_fallible_contract_missing_ciphertext_is_per_item_error() {
let identifier = Identifier::new("test_table", "test_column");
let eql = EqlCiphertext {
identifier: identifier.clone(),
version: EQL_SCHEMA_VERSION,
body: EqlCiphertextBody {
ciphertext: None, sem: EqlSEM::default(),
is_array_item: None,
},
};
let per_item_error: Result<encryption::Plaintext, EqlError> =
Err(EqlError::MissingCiphertext(eql.identifier.clone()));
assert!(matches!(
per_item_error,
Err(EqlError::MissingCiphertext(_))
));
}
}