pub mod query;
mod table_entry;
pub use self::{
    query::QueryBuilder,
    table_entry::{TableAttribute, TableEntry, TryFromTableAttr},
};
use crate::{
    crypto::*,
    errors::*,
    traits::{Decryptable, PrimaryKey, PrimaryKeyError, PrimaryKeyParts, Searchable},
    Identifiable, IndexType,
};
use aws_sdk_dynamodb::types::{AttributeValue, Delete, Put, TransactWriteItem};
use cipherstash_client::{
    config::{console_config::ConsoleConfig, zero_kms_config::ZeroKMSConfig},
    credentials::{
        auto_refresh::AutoRefresh,
        service_credentials::{ServiceCredentials, ServiceToken},
        Credentials,
    },
    encryption::{Encryption, Plaintext},
    zero_kms::ZeroKMS,
};
use log::info;
use std::{
    borrow::Cow,
    collections::{HashMap, HashSet},
    ops::Deref,
};
pub struct Headless;
pub struct Dynamo {
    pub(crate) db: aws_sdk_dynamodb::Client,
    pub(crate) table_name: String,
}
impl Deref for Dynamo {
    type Target = aws_sdk_dynamodb::Client;
    fn deref(&self) -> &Self::Target {
        &self.db
    }
}
pub struct EncryptedTable<D = Dynamo> {
    db: D,
    cipher: Box<Encryption<AutoRefresh<ServiceCredentials>>>,
}
impl<D> EncryptedTable<D> {
    pub fn cipher(&self) -> &Encryption<impl Credentials<Token = ServiceToken>> {
        self.cipher.as_ref()
    }
}
impl EncryptedTable<Headless> {
    pub async fn init_headless() -> Result<Self, InitError> {
        info!("Initializing...");
        let console_config = ConsoleConfig::builder().with_env().build()?;
        let zero_kms_config = ZeroKMSConfig::builder()
            .decryption_log(true)
            .with_env()
            .console_config(&console_config)
            .build_with_client_key()?;
        let zero_kms_client = ZeroKMS::new_with_client_key(
            &zero_kms_config.base_url(),
            AutoRefresh::new(zero_kms_config.credentials()),
            zero_kms_config.decryption_log_path().as_deref(),
            zero_kms_config.client_key(),
        );
        info!("Fetching dataset config...");
        let dataset_config = zero_kms_client.load_dataset_config().await?;
        let cipher = Box::new(Encryption::new(
            dataset_config.index_root_key,
            zero_kms_client,
        ));
        info!("Ready!");
        Ok(Self {
            db: Headless,
            cipher,
        })
    }
}
pub struct DynamoRecordPatch {
    pub put_records: Vec<HashMap<String, AttributeValue>>,
    pub delete_records: Vec<PrimaryKeyParts>,
}
pub struct PreparedRecord {
    protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
    protected_attributes: Cow<'static, [Cow<'static, str>]>,
    sealer: Sealer,
}
pub struct PreparedDelete {
    primary_key: PreparedPrimaryKey,
    protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
}
impl PreparedDelete {
    pub fn new<S: Searchable>(k: impl Into<S::PrimaryKey>) -> Self {
        Self::new_from_parts::<S>(
            k.into()
                .into_parts(&S::type_name(), S::sort_key_prefix().as_deref()),
        )
    }
    pub fn new_from_parts<S: Searchable>(k: PrimaryKeyParts) -> Self {
        let primary_key = PreparedPrimaryKey::new_from_parts::<S>(k);
        let protected_indexes = S::protected_indexes();
        Self {
            primary_key,
            protected_indexes,
        }
    }
    pub fn prepared_primary_key(&self) -> PreparedPrimaryKey {
        self.primary_key.clone()
    }
}
impl PreparedRecord {
    pub(crate) fn new(
        protected_indexes: Cow<'static, [(Cow<'static, str>, IndexType)]>,
        protected_attributes: Cow<'static, [Cow<'static, str>]>,
        sealer: Sealer,
    ) -> Self {
        Self {
            protected_indexes,
            protected_attributes,
            sealer,
        }
    }
    pub fn protected(&self) -> impl Iterator<Item = (&str, &Plaintext)> {
        self.sealer
            .unsealed
            .protected()
            .iter()
            .map(|(key, (plaintext, _descriptor))| (key.as_str(), plaintext))
    }
    pub fn unprotected(&self) -> impl Iterator<Item = (&str, &TableAttribute)> {
        self.sealer
            .unsealed
            .unprotected()
            .iter()
            .map(|(key, attr)| (key.as_str(), attr))
    }
    pub fn prepare_record<R>(record: R) -> Result<Self, SealError>
    where
        R: Searchable + Identifiable,
    {
        let type_name = R::type_name();
        let PrimaryKeyParts { pk, sk } = record
            .get_primary_key()
            .into_parts(&type_name, R::sort_key_prefix().as_deref());
        let protected_indexes = R::protected_indexes();
        let protected_attributes = R::protected_attributes();
        let unsealed_indexes = protected_indexes
            .iter()
            .map(|(index_name, index_type)| {
                record
                    .attribute_for_index(index_name, *index_type)
                    .and_then(|attr| {
                        R::index_by_name(index_name, *index_type)
                            .map(|index| (attr, index, index_name.clone(), *index_type))
                    })
                    .ok_or(SealError::MissingAttribute(index_name.to_string()))
            })
            .collect::<Result<Vec<_>, _>>()?;
        let unsealed = record.into_unsealed();
        let sealer = Sealer {
            pk,
            sk,
            is_sk_encrypted: R::is_sk_encrypted(),
            is_pk_encrypted: R::is_pk_encrypted(),
            type_name,
            unsealed_indexes,
            unsealed,
        };
        Ok(PreparedRecord::new(
            protected_indexes,
            protected_attributes,
            sealer,
        ))
    }
    pub fn primary_key_parts(&self) -> PrimaryKeyParts {
        PrimaryKeyParts {
            pk: self.sealer.pk.clone(),
            sk: self.sealer.sk.clone(),
        }
    }
    pub fn type_name(&self) -> &str {
        &self.sealer.type_name
    }
}
impl DynamoRecordPatch {
    pub fn into_transact_write_items(
        self,
        table_name: &str,
    ) -> Result<Vec<TransactWriteItem>, BuildError> {
        let mut items = Vec::with_capacity(self.put_records.len() + self.delete_records.len());
        for insert in self.put_records.into_iter() {
            items.push(
                TransactWriteItem::builder()
                    .put(
                        Put::builder()
                            .table_name(table_name)
                            .set_item(Some(insert))
                            .build()?,
                    )
                    .build(),
            );
        }
        for PrimaryKeyParts { pk, sk } in self.delete_records.into_iter() {
            items.push(
                TransactWriteItem::builder()
                    .delete(
                        Delete::builder()
                            .table_name(table_name)
                            .key("pk", AttributeValue::S(pk))
                            .key("sk", AttributeValue::S(sk))
                            .build()?,
                    )
                    .build(),
            );
        }
        Ok(items)
    }
}
impl<D> EncryptedTable<D> {
    pub fn query<S>(&self) -> QueryBuilder<S, &Self>
    where
        S: Searchable,
    {
        QueryBuilder::with_backend(self)
    }
    pub async fn unseal_all(
        &self,
        items: impl IntoIterator<Item = HashMap<String, AttributeValue>>,
        spec: UnsealSpec<'_>,
    ) -> Result<Vec<Unsealed>, DecryptError> {
        let table_entries = SealedTableEntry::vec_from(items)?;
        let results = SealedTableEntry::unseal_all(table_entries, spec, &self.cipher).await?;
        Ok(results)
    }
    pub async fn unseal(
        &self,
        item: HashMap<String, AttributeValue>,
        spec: UnsealSpec<'_>,
    ) -> Result<Unsealed, DecryptError> {
        let table_entry = SealedTableEntry::try_from(item)?;
        let result = table_entry.unseal(spec, &self.cipher).await?;
        Ok(result)
    }
    pub async fn decrypt_all<T: Decryptable>(
        &self,
        items: impl IntoIterator<Item = HashMap<String, AttributeValue>>,
    ) -> Result<Vec<T>, DecryptError> {
        let items = self
            .unseal_all(items, UnsealSpec::new_for_decryptable::<T>())
            .await?;
        Ok(items
            .into_iter()
            .map(|x| x.into_value::<T>())
            .collect::<Result<Vec<_>, _>>()?)
    }
    pub async fn decrypt<T: Decryptable>(
        &self,
        item: HashMap<String, AttributeValue>,
    ) -> Result<T, DecryptError> {
        let item = self
            .unseal(item, UnsealSpec::new_for_decryptable::<T>())
            .await?;
        Ok(item.into_value()?)
    }
    pub async fn create_delete_patch(
        &self,
        delete: PreparedDelete,
    ) -> Result<DynamoRecordPatch, DeleteError> {
        let PrimaryKeyParts { pk, sk } = self.encrypt_primary_key_parts(delete.primary_key)?;
        let delete_records = all_index_keys(&sk, delete.protected_indexes)
            .into_iter()
            .map(|x| Ok::<_, DeleteError>(b64_encode(hmac(&x, Some(pk.as_str()), &self.cipher)?)))
            .chain([Ok(sk)])
            .map(|sk| {
                let sk = sk?;
                Ok::<_, DeleteError>(PrimaryKeyParts { pk: pk.clone(), sk })
            })
            .collect::<Result<Vec<_>, _>>()?;
        Ok(DynamoRecordPatch {
            put_records: vec![],
            delete_records,
        })
    }
    pub async fn create_put_patch(
        &self,
        record: PreparedRecord,
        index_predicate: impl FnMut(&str, &TableAttribute) -> bool,
    ) -> Result<DynamoRecordPatch, PutError> {
        let mut seen_sk = HashSet::new();
        let PreparedRecord {
            protected_attributes,
            protected_indexes,
            sealer,
        } = record;
        let sealed = sealer.seal(protected_attributes, &self.cipher, 12).await?;
        let mut put_records = Vec::with_capacity(sealed.len());
        let mut delete_records = vec![];
        let PrimaryKeyParts { pk, sk } = sealed.primary_key();
        let (root, index_entries) = sealed.into_table_entries(index_predicate);
        seen_sk.insert(root.inner().sk.clone());
        put_records.push(root.try_into()?);
        for entry in index_entries.into_iter() {
            seen_sk.insert(entry.inner().sk.clone());
            put_records.push(entry.try_into()?);
        }
        for index_sk in all_index_keys(&sk, protected_indexes) {
            let index_sk = b64_encode(hmac(&index_sk, Some(pk.as_str()), &self.cipher)?);
            if seen_sk.contains(&index_sk) {
                continue;
            }
            delete_records.push(PrimaryKeyParts {
                pk: pk.clone(),
                sk: index_sk,
            });
        }
        Ok(DynamoRecordPatch {
            put_records,
            delete_records,
        })
    }
    pub fn encrypt_primary_key_parts(
        &self,
        prepared_primary_key: PreparedPrimaryKey,
    ) -> Result<PrimaryKeyParts, PrimaryKeyError> {
        let PrimaryKeyParts { mut pk, mut sk } = prepared_primary_key.primary_key_parts;
        if prepared_primary_key.is_pk_encrypted {
            pk = b64_encode(hmac(&pk, None, &self.cipher)?);
        }
        if prepared_primary_key.is_sk_encrypted {
            sk = b64_encode(hmac(&sk, Some(pk.as_str()), &self.cipher)?);
        }
        Ok(PrimaryKeyParts { pk, sk })
    }
}
impl EncryptedTable<Dynamo> {
    pub async fn init(
        db: aws_sdk_dynamodb::Client,
        table_name: impl Into<String>,
    ) -> Result<Self, InitError> {
        let table = EncryptedTable::init_headless().await?;
        Ok(Self {
            db: Dynamo {
                table_name: table_name.into(),
                db,
            },
            cipher: table.cipher,
        })
    }
    pub async fn get<T>(&self, k: impl Into<T::PrimaryKey>) -> Result<Option<T>, GetError>
    where
        T: Decryptable + Identifiable,
    {
        let PrimaryKeyParts { pk, sk } =
            self.encrypt_primary_key_parts(PreparedPrimaryKey::new::<T>(k))?;
        let result = self
            .db
            .get_item()
            .table_name(&self.db.table_name)
            .key("pk", AttributeValue::S(pk))
            .key("sk", AttributeValue::S(sk))
            .send()
            .await
            .map_err(|e| GetError::Aws(format!("{e:?}")))?;
        if let Some(item) = result.item {
            Ok(Some(self.decrypt(item).await?))
        } else {
            Ok(None)
        }
    }
    pub async fn delete<E: Searchable + Identifiable>(
        &self,
        k: impl Into<E::PrimaryKey>,
    ) -> Result<(), DeleteError> {
        let transact_items = self
            .create_delete_patch(PreparedDelete::new::<E>(k))
            .await?
            .into_transact_write_items(&self.db.table_name)?;
        for items in transact_items.chunks(100) {
            self.db
                .transact_write_items()
                .set_transact_items(Some(items.to_vec()))
                .send()
                .await
                .map_err(|e| DeleteError::Aws(format!("{e:?}")))?;
        }
        Ok(())
    }
    pub async fn put<T>(&self, record: T) -> Result<(), PutError>
    where
        T: Searchable + Identifiable,
    {
        let record = PreparedRecord::prepare_record(record)?;
        let transact_items = self
            .create_put_patch(
                record,
                |_, _| true,
            )
            .await?
            .into_transact_write_items(&self.db.table_name)?;
        for items in transact_items.chunks(100) {
            self.db
                .transact_write_items()
                .set_transact_items(Some(items.to_vec()))
                .send()
                .await
                .map_err(|e| PutError::Aws(format!("{e:?}")))?;
        }
        Ok(())
    }
}