use std::collections::VecDeque;
use std::sync::Arc;
use async_trait::async_trait;
use cdk_common::database::{self, WalletDatabase};
use tracing::instrument;
use crate::nuts::{PublicKey, State};
use crate::Error;
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub(crate) trait CompensatingAction: Send + Sync {
async fn execute(&self) -> Result<(), Error>;
fn name(&self) -> &'static str;
}
pub(crate) type Compensations = VecDeque<Box<dyn CompensatingAction>>;
pub(crate) fn new_compensations() -> Compensations {
VecDeque::new()
}
pub(crate) async fn execute_compensations(compensations: &mut Compensations) -> Result<(), Error> {
if compensations.is_empty() {
return Ok(());
}
tracing::warn!("Running {} compensating actions", compensations.len());
while let Some(compensation) = compensations.pop_front() {
tracing::debug!("Running compensation: {}", compensation.name());
if let Err(e) = compensation.execute().await {
tracing::error!(
"Compensation {} failed: {}. Continuing...",
compensation.name(),
e
);
}
}
Ok(())
}
pub(crate) async fn clear_compensations(compensations: &mut Compensations) {
compensations.clear();
}
pub(crate) async fn add_compensation(
compensations: &mut Compensations,
action: Box<dyn CompensatingAction>,
) {
compensations.push_front(action);
}
pub(crate) struct RevertProofReservation {
pub localstore: Arc<dyn WalletDatabase<database::Error> + Send + Sync>,
pub proof_ys: Vec<PublicKey>,
pub saga_id: uuid::Uuid,
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl CompensatingAction for RevertProofReservation {
#[instrument(skip_all)]
async fn execute(&self) -> Result<(), Error> {
tracing::info!(
"Compensation: Reverting {} proofs from Reserved to Unspent",
self.proof_ys.len()
);
let current_proofs = self
.localstore
.get_proofs_by_ys(self.proof_ys.clone())
.await
.map_err(Error::Database)?;
let ys_to_revert: Vec<_> = current_proofs
.into_iter()
.filter(|p| p.state == State::Reserved || p.state == State::Pending)
.map(|p| p.y)
.collect();
if !ys_to_revert.is_empty() {
self.localstore
.update_proofs_state(ys_to_revert, State::Unspent)
.await
.map_err(Error::Database)?;
}
if let Err(e) = self.localstore.delete_saga(&self.saga_id).await {
tracing::warn!(
"Compensation: Failed to delete saga {}: {}. Will be cleaned up on recovery.",
self.saga_id,
e
);
}
Ok(())
}
fn name(&self) -> &'static str {
"RevertProofReservation"
}
}
#[cfg(test)]
pub mod test_utils {
use std::str::FromStr;
use std::sync::Arc;
use cdk_common::database::WalletDatabase;
use cdk_common::nuts::{CurrencyUnit, Id, Proof, State};
use cdk_common::secret::Secret;
use cdk_common::wallet::ProofInfo;
use cdk_common::{Amount, SecretKey};
pub async fn create_test_db(
) -> Arc<dyn WalletDatabase<cdk_common::database::Error> + Send + Sync> {
Arc::new(cdk_sqlite::wallet::memory::empty().await.unwrap())
}
pub fn test_keyset_id() -> Id {
Id::from_str("00916bbf7ef91a36").unwrap()
}
pub fn test_mint_url() -> cdk_common::mint_url::MintUrl {
cdk_common::mint_url::MintUrl::from_str("https://test-mint.example.com").unwrap()
}
pub fn test_proof(keyset_id: Id, amount: u64) -> Proof {
Proof {
amount: Amount::from(amount),
keyset_id,
secret: Secret::generate(),
c: SecretKey::generate().public_key(),
witness: None,
dleq: None,
p2pk_e: None,
}
}
pub fn test_proof_info(
keyset_id: Id,
amount: u64,
mint_url: cdk_common::mint_url::MintUrl,
state: State,
) -> ProofInfo {
let proof = test_proof(keyset_id, amount);
ProofInfo::new(proof, mint_url, state, CurrencyUnit::Sat).unwrap()
}
pub fn test_simple_saga(
mint_url: cdk_common::mint_url::MintUrl,
) -> cdk_common::wallet::WalletSaga {
use cdk_common::wallet::{
OperationData, SwapOperationData, SwapSagaState, WalletSaga, WalletSagaState,
};
use cdk_common::Amount;
WalletSaga::new(
uuid::Uuid::new_v4(),
WalletSagaState::Swap(SwapSagaState::ProofsReserved),
Amount::from(1000),
mint_url,
CurrencyUnit::Sat,
OperationData::Swap(SwapOperationData {
input_amount: Amount::from(1000),
output_amount: Amount::from(990),
counter_start: Some(0),
counter_end: Some(10),
blinded_messages: None,
}),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_revert_proof_reservation_is_idempotent() {
let db = test_utils::create_test_db().await;
let mint_url = test_utils::test_mint_url();
let keyset_id = test_utils::test_keyset_id();
let proof_info =
test_utils::test_proof_info(keyset_id, 100, mint_url.clone(), State::Reserved);
let proof_y = proof_info.y;
db.update_proofs(vec![proof_info], vec![]).await.unwrap();
let saga = test_utils::test_simple_saga(mint_url);
let saga_id = saga.id;
db.add_saga(saga).await.unwrap();
let compensation = RevertProofReservation {
localstore: db.clone(),
proof_ys: vec![proof_y],
saga_id,
};
compensation.execute().await.unwrap();
compensation.execute().await.unwrap();
let proofs = db
.get_proofs(None, None, Some(vec![State::Unspent]), None)
.await
.unwrap();
assert_eq!(proofs.len(), 1);
}
#[tokio::test]
async fn test_revert_proof_reservation_handles_missing_saga() {
let db = test_utils::create_test_db().await;
let mint_url = test_utils::test_mint_url();
let keyset_id = test_utils::test_keyset_id();
let proof_info =
test_utils::test_proof_info(keyset_id, 100, mint_url.clone(), State::Reserved);
let proof_y = proof_info.y;
db.update_proofs(vec![proof_info], vec![]).await.unwrap();
let saga_id = uuid::Uuid::new_v4();
let compensation = RevertProofReservation {
localstore: db.clone(),
proof_ys: vec![proof_y],
saga_id,
};
compensation.execute().await.unwrap();
let proofs = db
.get_proofs(None, None, Some(vec![State::Unspent]), None)
.await
.unwrap();
assert_eq!(proofs.len(), 1);
}
#[tokio::test]
async fn test_revert_proof_reservation_only_affects_specified_proofs() {
let db = test_utils::create_test_db().await;
let mint_url = test_utils::test_mint_url();
let keyset_id = test_utils::test_keyset_id();
let proof_info_1 =
test_utils::test_proof_info(keyset_id, 100, mint_url.clone(), State::Reserved);
let proof_info_2 =
test_utils::test_proof_info(keyset_id, 200, mint_url.clone(), State::Reserved);
let proof_y_1 = proof_info_1.y;
let proof_y_2 = proof_info_2.y;
db.update_proofs(vec![proof_info_1, proof_info_2], vec![])
.await
.unwrap();
let saga = test_utils::test_simple_saga(mint_url);
let saga_id = saga.id;
db.add_saga(saga).await.unwrap();
let compensation = RevertProofReservation {
localstore: db.clone(),
proof_ys: vec![proof_y_1],
saga_id,
};
compensation.execute().await.unwrap();
let unspent = db
.get_proofs(None, None, Some(vec![State::Unspent]), None)
.await
.unwrap();
assert_eq!(unspent.len(), 1);
assert_eq!(unspent[0].y, proof_y_1);
let reserved = db
.get_proofs(None, None, Some(vec![State::Reserved]), None)
.await
.unwrap();
assert_eq!(reserved.len(), 1);
assert_eq!(reserved[0].y, proof_y_2);
}
struct MockCompensation {
name: &'static str,
execution_order: Arc<std::sync::Mutex<Vec<&'static str>>>,
should_fail: bool,
}
impl MockCompensation {
fn new(
name: &'static str,
execution_order: Arc<std::sync::Mutex<Vec<&'static str>>>,
) -> Self {
Self {
name,
execution_order,
should_fail: false,
}
}
fn failing(
name: &'static str,
execution_order: Arc<std::sync::Mutex<Vec<&'static str>>>,
) -> Self {
Self {
name,
execution_order,
should_fail: true,
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl CompensatingAction for MockCompensation {
async fn execute(&self) -> Result<(), Error> {
self.execution_order.lock().unwrap().push(self.name);
if self.should_fail {
Err(Error::Custom("Intentional test failure".to_string()))
} else {
Ok(())
}
}
fn name(&self) -> &'static str {
self.name
}
}
#[tokio::test]
async fn test_compensations_lifo_order() {
let mut compensations = new_compensations();
let execution_order = Arc::new(std::sync::Mutex::new(Vec::new()));
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("first", execution_order.clone())),
)
.await;
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("second", execution_order.clone())),
)
.await;
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("third", execution_order.clone())),
)
.await;
execute_compensations(&mut compensations).await.unwrap();
let order = execution_order.lock().unwrap();
assert_eq!(order.as_slice(), &["third", "second", "first"]);
}
#[tokio::test]
async fn test_compensations_continues_on_error() {
let mut compensations = new_compensations();
let execution_order = Arc::new(std::sync::Mutex::new(Vec::new()));
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("first", execution_order.clone())),
)
.await;
add_compensation(
&mut compensations,
Box::new(MockCompensation::failing(
"second_fails",
execution_order.clone(),
)),
)
.await;
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("third", execution_order.clone())),
)
.await;
let result = execute_compensations(&mut compensations).await;
assert!(result.is_ok());
let order = execution_order.lock().unwrap();
assert_eq!(order.as_slice(), &["third", "second_fails", "first"]);
}
#[tokio::test]
async fn test_compensations_empty_queue() {
let mut compensations = new_compensations();
let result = execute_compensations(&mut compensations).await;
assert!(result.is_ok());
clear_compensations(&mut compensations).await;
assert!(compensations.is_empty());
}
#[tokio::test]
async fn test_clear_compensations() {
let mut compensations = new_compensations();
let execution_order = Arc::new(std::sync::Mutex::new(Vec::new()));
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("first", execution_order.clone())),
)
.await;
add_compensation(
&mut compensations,
Box::new(MockCompensation::new("second", execution_order.clone())),
)
.await;
assert!(!compensations.is_empty());
clear_compensations(&mut compensations).await;
assert!(compensations.is_empty());
execute_compensations(&mut compensations).await.unwrap();
assert!(execution_order.lock().unwrap().is_empty());
}
#[tokio::test]
async fn test_new_compensations_creates_empty_queue() {
let compensations = new_compensations();
assert!(compensations.is_empty());
}
}