use super::TransformContextBuilder;
use super::{DownChainProtocol, UpChainProtocol};
use crate::frame::MessageType;
use crate::frame::{
CassandraFrame, CassandraOperation, CassandraResult, Frame, value::GenericValue,
};
use crate::message::{Message, MessageIdMap, Messages};
use crate::transforms::protect::key_management::KeyManager;
pub use crate::transforms::protect::key_management::KeyManagerConfig;
use crate::transforms::{ChainState, Transform, TransformBuilder};
use anyhow::Result;
use async_trait::async_trait;
use cql3_parser::cassandra_statement::CassandraStatement;
use cql3_parser::common::Identifier;
use cql3_parser::insert::InsertValues;
use cql3_parser::select::SelectElement;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
mod aws_kms;
mod crypto;
mod key_management;
mod local_kek;
mod pkcs_11;
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct ProtectConfig {
pub keyspace_table_columns: HashMap<String, HashMap<String, Vec<String>>>,
pub key_manager: KeyManagerConfig,
}
const NAME: &str = "Protect";
#[typetag::serde(name = "Protect")]
#[async_trait(?Send)]
impl crate::transforms::TransformConfig for ProtectConfig {
async fn get_builder(
&self,
_transform_context: crate::transforms::TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(Protect {
keyspace_table_columns: self
.keyspace_table_columns
.iter()
.map(|(k, v)| {
(
Identifier::Quoted(k.clone()),
v.iter()
.map(|(k, v)| {
(
Identifier::Quoted(k.clone()),
v.iter().map(|x| Identifier::Quoted(x.clone())).collect(),
)
})
.collect(),
)
})
.collect(),
key_source: self.key_manager.build().await?,
key_id: "XXXXXXX".to_string(),
requests: MessageIdMap::default(),
}))
}
fn up_chain_protocol(&self) -> UpChainProtocol {
UpChainProtocol::MustBeOneOf(vec![MessageType::Cassandra])
}
fn down_chain_protocol(&self) -> DownChainProtocol {
DownChainProtocol::SameAsUpChain
}
}
#[derive(Clone)]
struct Protect {
keyspace_table_columns: HashMap<Identifier, HashMap<Identifier, Vec<Identifier>>>,
key_source: KeyManager,
key_id: String,
requests: MessageIdMap<Message>,
}
impl TransformBuilder for Protect {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(self.clone())
}
fn get_name(&self) -> &'static str {
NAME
}
}
impl Protect {
fn get_protected_columns(&self, statement: &CassandraStatement) -> &[Identifier] {
if let Some(table_name) = statement.get_table_name() {
if let Some(tables) = self.keyspace_table_columns.get(
table_name
.keyspace
.as_ref()
.unwrap_or(&Identifier::default()),
) {
if let Some(columns) = tables.get(&table_name.name) {
return columns;
}
}
}
&[]
}
async fn encrypt_columns(&self, statement: &mut CassandraStatement) -> Result<bool> {
let mut invalidate_cache = false;
let columns_to_encrypt = self.get_protected_columns(statement);
match statement {
CassandraStatement::Insert(insert) => {
for (i, col_name) in insert.columns.iter().enumerate() {
if columns_to_encrypt.contains(col_name) {
match &mut insert.values {
InsertValues::Values(value_operands) => {
if let Some(value) = value_operands.get_mut(i) {
*value = crypto::encrypt(value, &self.key_source, &self.key_id)
.await?;
invalidate_cache = true
}
}
InsertValues::Json(_) => todo!("parse json and encrypt."),
}
}
}
}
CassandraStatement::Update(update) => {
for assignment in &mut update.assignments {
if columns_to_encrypt.contains(&assignment.name.column) {
assignment.value =
crypto::encrypt(&assignment.value, &self.key_source, &self.key_id)
.await?;
invalidate_cache = true;
}
}
}
_ => {
}
}
Ok(invalidate_cache)
}
async fn decrypt_results(
&self,
statement: &CassandraStatement,
rows: &mut Vec<Vec<GenericValue>>,
) -> Result<bool> {
let mut invalidate_cache = false;
if let CassandraStatement::Select(select) = &statement {
let columns_to_decrypt = self.get_protected_columns(statement);
for (i, col) in select.columns.iter().enumerate() {
if let SelectElement::Column(col) = col {
if columns_to_decrypt.contains(&col.name) {
for row in &mut *rows {
if let Some(message_value) = row.get_mut(i) {
*message_value =
crypto::decrypt(message_value, &self.key_source, &self.key_id)
.await?;
invalidate_cache = true;
}
}
}
}
}
}
Ok(invalidate_cache)
}
}
#[async_trait]
impl Transform for Protect {
fn get_name(&self) -> &'static str {
NAME
}
async fn transform<'shorter, 'longer: 'shorter>(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
) -> Result<Messages> {
for message in chain_state.requests.iter_mut() {
let mut invalidate_cache = false;
if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = message.frame() {
for statement in operation.queries() {
invalidate_cache |= self.encrypt_columns(statement).await.unwrap();
}
}
if invalidate_cache {
message.invalidate_cache();
}
}
chain_state.clone_requests_into_hashmap(&mut self.requests);
let mut responses = chain_state.call_next_transform().await?;
for response in &mut responses {
if let Some(request_id) = response.request_id() {
let mut request = self.requests.remove(&request_id).unwrap();
let mut invalidate_cache = false;
if let Some(Frame::Cassandra(CassandraFrame { operation, .. })) = request.frame() {
if let Some(Frame::Cassandra(CassandraFrame {
operation: CassandraOperation::Result(CassandraResult::Rows { rows, .. }),
..
})) = response.frame()
{
for statement in operation.queries() {
invalidate_cache |= self.decrypt_results(statement, rows).await?
}
}
}
if invalidate_cache {
response.invalidate_cache();
}
}
}
Ok(responses)
}
}