use super::{
accumulator::{Accumulator, AccumulatorError},
composable_plaintext::ComposablePlaintext,
ComposableIndex,
};
use crate::{
encryption::{
text::{process_edge_ngram_raw, TokenFilter},
EncryptionError, Plaintext,
},
zerokms::IndexKey,
};
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
pub struct PrefixIndexer<'k> {
index_key: &'k IndexKey,
token_filters: Vec<TokenFilter>,
min_length: usize,
max_length: usize,
}
impl<'k> PrefixIndexer<'k> {
pub fn new(
index_key: &'k IndexKey,
token_filters: Vec<TokenFilter>,
min_length: usize,
max_length: usize,
) -> Self {
Self {
index_key,
token_filters,
min_length,
max_length,
}
}
pub fn create_hmac(&self) -> Result<HmacSha256, EncryptionError> {
Ok(HmacSha256::new_from_slice(self.index_key.key())?)
}
pub fn encrypt_with_salt<S>(
&self,
plaintext: &Plaintext,
salt: S,
) -> Result<Accumulator, EncryptionError>
where
S: AsRef<[u8]>,
{
match plaintext {
Plaintext::Utf8Str(Some(value)) => {
let tokens =
process_edge_ngram_raw(value.to_string(), self.min_length, self.max_length);
let terms = self
.token_filters
.iter()
.fold(tokens, |tokens, filter| filter.process(tokens));
let out = terms
.into_iter()
.map(|term| {
let mut mac = self.create_hmac()?;
mac.update(salt.as_ref());
mac.update(term.as_bytes());
Ok::<Vec<u8>, EncryptionError>(mac.finalize().into_bytes().to_vec())
})
.collect::<Result<Vec<_>, _>>()?;
Ok(Accumulator::Terms(out))
}
Plaintext::Utf8Str(None) => Ok(Accumulator::empty()),
_ => Err(EncryptionError::IndexingError(format!(
"{plaintext:?} is not supported by prefix indexes"
))),
}
}
pub fn query_with_salt<S>(
&self,
plaintext: &Plaintext,
salt: S,
) -> Result<Accumulator, EncryptionError>
where
S: AsRef<[u8]>,
{
match plaintext {
Plaintext::Utf8Str(Some(value)) => {
let tokens = vec![value.to_string()];
let terms = self
.token_filters
.iter()
.fold(tokens, |tokens, filter| filter.process(tokens));
let term = terms.first().ok_or_else(|| {
EncryptionError::IndexingError("Encryption terms were empty".to_string())
})?;
let mut mac = self.create_hmac()?;
mac.update(salt.as_ref());
mac.update(term.as_bytes());
Ok(Accumulator::Term(mac.finalize().into_bytes().to_vec()))
}
Plaintext::Utf8Str(None) => Ok(Accumulator::empty()),
_ => Err(EncryptionError::IndexingError(format!(
"{plaintext:?} is not supported by prefix indexes"
))),
}
}
}
#[derive(Debug)]
pub struct PrefixIndex {
token_filters: Vec<TokenFilter>,
min_length: usize,
max_length: usize,
}
impl PrefixIndex {
pub fn new(token_filters: Vec<TokenFilter>) -> Self {
Self::new_with_opts(token_filters, 3, 10)
}
pub fn new_with_opts(
token_filters: Vec<TokenFilter>,
min_length: usize,
max_length: usize,
) -> Self {
Self {
token_filters,
min_length,
max_length,
}
}
}
impl Default for PrefixIndex {
fn default() -> Self {
Self::new(vec![TokenFilter::Downcase])
}
}
impl ComposableIndex for PrefixIndex {
fn compose_index(
&self,
key: &IndexKey,
plaintext: ComposablePlaintext,
accumulator: Accumulator,
) -> Result<Accumulator, EncryptionError> {
let indexer = PrefixIndexer::new(
key,
self.token_filters.to_vec(),
self.min_length,
self.max_length,
);
let plaintext: Plaintext = plaintext.try_into()?;
match accumulator {
Accumulator::Term(term) => indexer.encrypt_with_salt(&plaintext, term),
Accumulator::Terms(terms) => {
terms
.into_iter()
.try_fold(Accumulator::empty(), |acc, term| {
indexer
.encrypt_with_salt(&plaintext, term)
.map(|out| acc.combine(out))
})
}
}
}
fn compose_query(
&self,
key: &IndexKey,
plaintext: ComposablePlaintext,
accumulator: Accumulator,
) -> Result<Accumulator, EncryptionError> {
let indexer = PrefixIndexer::new(
key,
self.token_filters.to_vec(),
self.min_length,
self.max_length,
);
let plaintext: Plaintext = plaintext.try_into()?;
match accumulator {
Accumulator::Term(term) => indexer.query_with_salt(&plaintext, term),
_ => Err(AccumulatorError::MultipleTermsFound)?,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{DateTime, NaiveDate};
use rust_decimal::Decimal;
fn prefixer(key: &IndexKey, min: usize, max: usize) -> PrefixIndexer<'_> {
PrefixIndexer::new(key, vec![TokenFilter::Downcase], min, max)
}
#[test]
fn test_encrypt_term() -> Result<(), Box<dyn std::error::Error>> {
let index_key = IndexKey::from([1; 32]);
let result =
prefixer(&index_key, 2, 4).encrypt_with_salt(&Plaintext::new("Hello World"), [])?;
assert_eq!(result.terms().len(), 3);
Ok(())
}
#[test]
fn test_encrypt_term_default() -> Result<(), Box<dyn std::error::Error>> {
let index_key = IndexKey::from([1; 32]);
let result_default = PrefixIndex::default().compose_index(
&index_key,
"Hello World".into(),
Accumulator::from_salt("salt"),
)?;
let result = PrefixIndex::new(vec![TokenFilter::Downcase]).compose_index(
&index_key,
"Hello World".into(),
Accumulator::from_salt("salt"),
)?;
let result_no_filters = PrefixIndex::new(vec![]).compose_index(
&index_key,
"Hello World".into(),
Accumulator::from_salt("salt"),
)?;
let terms = result_default.terms();
assert_eq!(terms.to_vec(), result.terms());
assert_ne!(terms.to_vec(), result_no_filters.terms());
Ok(())
}
#[test]
fn test_encrypt_with_salt() -> Result<(), Box<dyn std::error::Error>> {
let index_key = IndexKey::from([1; 32]);
let result_no_salt =
prefixer(&index_key, 2, 4).encrypt_with_salt(&Plaintext::new("Hello World"), [])?;
let result_salt = prefixer(&index_key, 2, 4)
.encrypt_with_salt(&Plaintext::new("Hello World"), "somesalt")?;
result_no_salt
.terms()
.into_iter()
.zip(result_salt.terms())
.for_each(|(no_salt, salt)| {
assert_ne!(no_salt, salt);
});
Ok(())
}
#[test]
fn test_query_single_word() -> Result<(), Box<dyn std::error::Error>> {
let index_key = IndexKey::from([1; 32]);
let result = prefixer(&index_key, 2, 4).query_with_salt(&Plaintext::new("Hello"), [])?;
assert_eq!(result.exactly_one()?.term()?.len(), 32);
Ok(())
}
#[test]
fn test_common_prefix_single_word() -> Result<(), Box<dyn std::error::Error>> {
let index_key = IndexKey::from([1; 32]);
let prefixer = prefixer(&index_key, 3, 6);
let data_1 = prefixer.encrypt_with_salt(&Plaintext::new("supervise"), [])?;
let data_2 = prefixer.encrypt_with_salt(&Plaintext::new("superhero"), [])?;
let query = prefixer.query_with_salt(&Plaintext::new("super"), [])?;
let q = query.exactly_one()?.term()?;
assert!(data_1.terms().contains(&q));
assert!(data_2.terms().contains(&q));
Ok(())
}
#[test]
fn test_encrypt_none() -> Result<(), Box<dyn std::error::Error>> {
let index_key = IndexKey::from([1; 32]);
let indexer = prefixer(&index_key, 2, 4);
assert!(indexer
.encrypt_with_salt(&Plaintext::Utf8Str(None), [])?
.terms()
.is_empty());
assert!(indexer
.encrypt_with_salt(&Plaintext::Utf8Str(None), "somesalt")?
.terms()
.is_empty());
assert!(indexer
.query_with_salt(&Plaintext::Utf8Str(None), [])?
.terms()
.is_empty());
assert!(indexer
.query_with_salt(&Plaintext::Utf8Str(None), "somesalt")?
.terms()
.is_empty());
Ok(())
}
#[test]
fn test_unsupported_plaintext() -> Result<(), Box<dyn std::error::Error>> {
fn test_unsupported(indexer: &PrefixIndexer, plaintext: Plaintext) {
let result = indexer.encrypt_with_salt(&plaintext, []);
assert!(
matches!(result, Err(EncryptionError::IndexingError(_))),
"Expected IndexingError for {plaintext:?}, got {result:?}",
);
assert!(
result.is_err_and(|e| e.to_string().contains("is not supported by prefix indexes"))
);
let result = indexer.query_with_salt(&plaintext, []);
assert!(
matches!(result, Err(EncryptionError::IndexingError(_))),
"Expected IndexingError for {plaintext:?}, got {result:?}",
);
assert!(
result.is_err_and(|e| e.to_string().contains("is not supported by prefix indexes"))
);
}
let index_key = IndexKey::from([1; 32]);
let indexer = prefixer(&index_key, 2, 4);
test_unsupported(&indexer, Plaintext::new(100i32));
test_unsupported(&indexer, Plaintext::new(100i16));
test_unsupported(&indexer, Plaintext::new(100u64));
test_unsupported(&indexer, Plaintext::new(100i64));
test_unsupported(&indexer, Plaintext::new(100.77f64));
test_unsupported(&indexer, Plaintext::new(Decimal::new(202, 2)));
test_unsupported(&indexer, Plaintext::new(true));
test_unsupported(
&indexer,
Plaintext::new(DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")?.to_utc()),
);
test_unsupported(
&indexer,
Plaintext::new(NaiveDate::from_ymd_opt(2002, 9, 6)),
);
test_unsupported(&indexer, Plaintext::new(serde_json::to_value("some json")?));
Ok(())
}
}