use std::collections::BTreeMap;
use crate::crypto::envelope::RewrapOutcome;
use crate::crypto::EnvelopeCipher;
use crate::db::queries::wallet_rotate as queries;
use crate::db::DbPool;
use crate::error::AppResult;
const DEFAULT_BATCH_SIZE: i64 = 100;
const DEFAULT_MAX_BATCHES: i64 = 1000;
#[derive(Debug, Clone, Default, serde::Serialize)]
pub struct RotateSummary {
pub processed: u64,
pub rewrapped: u64,
pub skipped: u64,
pub failed: u64,
pub last_id: i64,
}
pub type KeyStatusSummary = BTreeMap<String, i64>;
#[derive(Clone)]
pub struct WalletRotateService {
pool: DbPool,
cipher: EnvelopeCipher,
}
impl WalletRotateService {
pub fn new(pool: DbPool, cipher: EnvelopeCipher) -> Self {
Self { pool, cipher }
}
pub async fn rotate_table(
&self,
table: WalletTable,
batch_size: Option<i64>,
max_batches: Option<i64>,
) -> AppResult<RotateSummary> {
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE).max(1);
let max_batches = max_batches.unwrap_or(DEFAULT_MAX_BATCHES).max(1);
let mut summary = RotateSummary::default();
let mut after_id: i64 = 0;
for _ in 0..max_batches {
let rows = match table {
WalletTable::Credential => {
queries::iter_credential_rows(&self.pool, after_id, batch_size).await?
}
WalletTable::Keychain => {
queries::iter_keychain_rows(&self.pool, after_id, batch_size).await?
}
};
if rows.is_empty() {
break;
}
for row in &rows {
summary.processed += 1;
summary.last_id = row.id;
match self.cipher.rewrap_storage_string(&row.data_encrypted).await {
Ok(RewrapOutcome::Skipped { .. }) => {
summary.skipped += 1;
crate::metrics::record_wallet_rotate(table.as_label(), "skipped");
}
Ok(RewrapOutcome::Rewrapped {
new_storage_string, ..
}) => {
let update_result = match table {
WalletTable::Credential => {
queries::update_credential_data(
&self.pool,
row.id,
&new_storage_string,
)
.await
}
WalletTable::Keychain => {
queries::update_keychain_data(
&self.pool,
row.id,
&new_storage_string,
)
.await
}
};
match update_result {
Ok(()) => {
summary.rewrapped += 1;
crate::metrics::record_wallet_rotate(
table.as_label(),
"rewrapped",
);
}
Err(e) => {
summary.failed += 1;
crate::metrics::record_wallet_rotate(
table.as_label(),
"failed_write",
);
tracing::warn!(
table = %table.as_label(),
id = row.id,
error = %e,
"wallet_rotate.update failed"
);
}
}
}
Err(e) => {
summary.failed += 1;
let status = classify_failure(&e);
crate::metrics::record_wallet_rotate(table.as_label(), status);
tracing::warn!(
table = %table.as_label(),
id = row.id,
status = %status,
error = %e,
"wallet_rotate.rewrap failed"
);
}
}
}
after_id = summary.last_id;
if (rows.len() as i64) < batch_size {
break;
}
}
Ok(summary)
}
pub async fn key_status_table(&self, table: WalletTable) -> AppResult<KeyStatusSummary> {
let pairs = match table {
WalletTable::Credential => queries::key_status_credential(&self.pool).await?,
WalletTable::Keychain => queries::key_status_keychain(&self.pool).await?,
};
Ok(pairs.into_iter().collect())
}
}
#[derive(Debug, Clone, Copy)]
pub enum WalletTable {
Credential,
Keychain,
}
impl WalletTable {
pub fn as_label(self) -> &'static str {
match self {
WalletTable::Credential => "credential",
WalletTable::Keychain => "keychain",
}
}
}
fn classify_failure(err: &crate::error::AppError) -> &'static str {
let s = err.to_string();
if s.contains("not a wallet envelope") || s.contains("envelope") && s.contains("base64") {
"parse_error"
} else if s.contains("cannot unwrap") || s.contains("unwrap") {
"failed_unwrap"
} else if s.contains("wrap") {
"failed_wrap"
} else {
"failed"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wallet_table_labels() {
assert_eq!(WalletTable::Credential.as_label(), "credential");
assert_eq!(WalletTable::Keychain.as_label(), "keychain");
}
#[test]
fn classify_failure_buckets() {
use crate::error::AppError;
assert_eq!(
classify_failure(&AppError::Encryption(
"not a wallet envelope record: bad input".to_string()
)),
"parse_error"
);
assert_eq!(
classify_failure(&AppError::Encryption(
"LocalDevKms cannot unwrap a DEK from provider 'gcp-kms'".to_string()
)),
"failed_unwrap"
);
assert_eq!(
classify_failure(&AppError::Encryption("kms wrap call failed".to_string())),
"failed_wrap"
);
assert_eq!(
classify_failure(&AppError::Internal("unknown".to_string())),
"failed"
);
}
#[test]
fn rotate_summary_defaults_zero() {
let s = RotateSummary::default();
assert_eq!(s.processed, 0);
assert_eq!(s.rewrapped, 0);
assert_eq!(s.skipped, 0);
assert_eq!(s.failed, 0);
assert_eq!(s.last_id, 0);
}
}