use std::sync::Arc;
use mssql_auth::AeadEncryptor;
use tds_protocol::crypto::CryptoMetadata;
use tds_protocol::token::{ColMetaData, ColumnData};
use crate::encryption::EncryptionContext;
use crate::error::{Error, Result};
pub(crate) struct ColumnDecryptor {
columns: Vec<Option<ColumnDecryptionInfo>>,
}
struct ColumnDecryptionInfo {
encryptor: Arc<AeadEncryptor>,
base_column: ColumnData,
}
impl ColumnDecryptor {
pub(crate) async fn from_metadata(meta: &ColMetaData, ctx: &EncryptionContext) -> Result<Self> {
let cek_table = meta.cek_table.as_ref().ok_or_else(|| {
Error::Encryption("encrypted result set has no CEK table".to_string())
})?;
let mut columns = Vec::with_capacity(meta.columns.len());
for col in &meta.columns {
if let Some(ref crypto) = col.crypto_metadata {
let cek_entry = cek_table.get(crypto.cek_table_ordinal).ok_or_else(|| {
Error::Encryption(format!(
"CEK table ordinal {} out of range (table has {} entries)",
crypto.cek_table_ordinal,
cek_table.len()
))
})?;
let encryptor = ctx.get_encryptor(cek_entry).await?;
let base_column = build_base_column(col, crypto);
columns.push(Some(ColumnDecryptionInfo {
encryptor,
base_column,
}));
} else {
columns.push(None);
}
}
Ok(Self { columns })
}
#[inline]
pub(crate) fn is_encrypted(&self, ordinal: usize) -> bool {
self.columns.get(ordinal).is_some_and(|c| c.is_some())
}
pub(crate) fn decrypt_column_value(
&self,
ordinal: usize,
ciphertext: &[u8],
) -> Result<(Vec<u8>, &ColumnData)> {
let info = self
.columns
.get(ordinal)
.and_then(|c| c.as_ref())
.ok_or_else(|| {
Error::Encryption(format!("column {ordinal} is not encrypted or out of range"))
})?;
let plaintext = info.encryptor.decrypt(ciphertext).map_err(|e| {
Error::Encryption(format!("column {ordinal} decryption failed: {e}"))
})?;
Ok((plaintext, &info.base_column))
}
}
fn build_base_column(col: &ColumnData, crypto: &CryptoMetadata) -> ColumnData {
let base_type_id = crypto.base_type_id();
ColumnData {
name: col.name.clone(),
type_id: base_type_id,
col_type: crypto.base_col_type,
flags: col.flags,
user_type: crypto.base_user_type,
type_info: crypto.base_type_info.clone(),
crypto_metadata: None, }
}
impl std::fmt::Debug for ColumnDecryptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ColumnDecryptor")
.field("column_count", &self.columns.len())
.field(
"encrypted_count",
&self.columns.iter().filter(|c| c.is_some()).count(),
)
.finish()
}
}