use cdk_common::wallet::{OperationData, SwapOperationData, SwapSagaState, WalletSaga};
use tracing::instrument;
use crate::dhke::hash_to_curve;
use crate::nuts::{PreMintSecrets, State};
use crate::wallet::recovery::{RecoveryAction, RecoveryHelpers};
use crate::wallet::saga::{CompensatingAction, RevertProofReservation};
use crate::{Error, Wallet};
impl Wallet {
#[instrument(skip(self, saga))]
pub(crate) async fn resume_swap_saga(
&self,
saga: &WalletSaga,
) -> Result<RecoveryAction, Error> {
let state = match &saga.state {
cdk_common::wallet::WalletSagaState::Swap(s) => s,
_ => {
return Err(Error::Custom(format!(
"Invalid saga state type for swap saga {}",
saga.id
)))
}
};
let data = match &saga.data {
OperationData::Swap(d) => d,
_ => {
return Err(Error::Custom(format!(
"Invalid operation data type for swap saga {}",
saga.id
)))
}
};
match state {
SwapSagaState::ProofsReserved => {
tracing::info!(
"Swap saga {} in ProofsReserved state - compensating",
saga.id
);
self.compensate_swap(&saga.id).await?;
Ok(RecoveryAction::Compensated)
}
SwapSagaState::SwapRequested => {
tracing::info!(
"Swap saga {} in SwapRequested state - checking mint for proof states",
saga.id
);
self.recover_or_compensate_swap(&saga.id, data).await
}
}
}
async fn recover_or_compensate_swap(
&self,
saga_id: &uuid::Uuid,
data: &SwapOperationData,
) -> Result<RecoveryAction, Error> {
if self.check_db_for_swap_success(saga_id, data).await? {
return Ok(RecoveryAction::Recovered);
}
let reserved_proofs = self.localstore.get_reserved_proofs(saga_id).await?;
if !reserved_proofs.is_empty() {
if let Some(new_proofs) = self
.try_replay_swap_request(
saga_id,
"Swap",
data.blinded_messages.as_deref(),
data.counter_start,
data.counter_end,
&reserved_proofs,
)
.await?
{
let input_ys: Vec<_> = reserved_proofs.iter().map(|p| p.y).collect();
self.localstore.update_proofs(new_proofs, input_ys).await?;
self.localstore.delete_saga(saga_id).await?;
return Ok(RecoveryAction::Recovered);
}
}
let should_restore = if reserved_proofs.is_empty() {
tracing::warn!(
"Reserved proofs missing for swap saga {}, assuming spent and attempting restore.",
saga_id
);
true
} else {
match self.are_proofs_spent(&reserved_proofs).await {
Ok(true) => {
tracing::info!(
"Swap saga {} - input proofs spent, recovering outputs via /restore",
saga_id
);
true
}
Ok(false) => {
tracing::info!(
"Swap saga {} - input proofs not spent, compensating",
saga_id
);
false
}
Err(e) => {
tracing::warn!(
"Swap saga {} - can't check proof states ({}), skipping",
saga_id,
e
);
return Ok(RecoveryAction::Skipped);
}
}
};
if should_restore {
match self
.complete_swap_from_restore(saga_id, data, &reserved_proofs)
.await
{
Ok(_) => Ok(RecoveryAction::Recovered),
Err(e) => {
if reserved_proofs.is_empty() {
tracing::warn!(
"Restore failed for orphaned saga {} ({}). Cleaning up.",
saga_id,
e
);
self.localstore.delete_saga(saga_id).await?;
Ok(RecoveryAction::Recovered)
} else {
Err(e)
}
}
}
} else {
self.compensate_swap(saga_id).await?;
Ok(RecoveryAction::Compensated)
}
}
async fn check_db_for_swap_success(
&self,
saga_id: &uuid::Uuid,
data: &SwapOperationData,
) -> Result<bool, Error> {
if let (Some(blinded_messages), Some(start), Some(end)) = (
data.blinded_messages.as_deref(),
data.counter_start,
data.counter_end,
) {
if !blinded_messages.is_empty() {
let keyset_id = blinded_messages[0].keyset_id;
if let Ok(premint_secrets) =
PreMintSecrets::restore_batch(keyset_id, &self.seed, start, end)
{
let ys_result: Result<Vec<crate::nuts::PublicKey>, _> = premint_secrets
.secrets
.iter()
.map(|p| hash_to_curve(&p.secret.to_bytes()))
.collect();
if let Ok(ys) = ys_result {
if let Ok(existing_proofs) = self.localstore.get_proofs_by_ys(ys).await {
if !existing_proofs.is_empty() {
tracing::info!(
"Swap saga {} - new proofs found in DB, cleaning up",
saga_id
);
self.localstore.delete_saga(saga_id).await?;
return Ok(true);
}
}
}
}
}
}
Ok(false)
}
async fn complete_swap_from_restore(
&self,
saga_id: &uuid::Uuid,
data: &SwapOperationData,
reserved_proofs: &[cdk_common::wallet::ProofInfo],
) -> Result<(), Error> {
let new_proofs = self
.restore_outputs(
saga_id,
"Swap",
data.blinded_messages.as_deref(),
data.counter_start,
data.counter_end,
)
.await?;
let input_ys: Vec<_> = reserved_proofs.iter().map(|p| p.y).collect();
match new_proofs {
Some(proofs) => {
self.localstore.update_proofs(proofs, input_ys).await?;
}
None => {
tracing::warn!(
"Swap saga {} - couldn't restore outputs, marking inputs as spent. \
Run wallet.restore() to recover any missing proofs.",
saga_id
);
self.localstore
.update_proofs_state(input_ys, State::Spent)
.await?;
}
}
self.localstore.delete_saga(saga_id).await?;
Ok(())
}
async fn compensate_swap(&self, saga_id: &uuid::Uuid) -> Result<(), Error> {
let reserved_proofs = self.localstore.get_reserved_proofs(saga_id).await?;
let proof_ys = reserved_proofs.iter().map(|p| p.y).collect();
RevertProofReservation {
localstore: self.localstore.clone(),
proof_ys,
saga_id: *saga_id,
}
.execute()
.await
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use cdk_common::nuts::{CheckStateResponse, CurrencyUnit, ProofState, State};
use cdk_common::wallet::{
OperationData, SwapOperationData, SwapSagaState, WalletSaga, WalletSagaState,
};
use cdk_common::Amount;
use crate::wallet::test_utils::*;
#[tokio::test]
async fn test_recover_swap_proofs_reserved() {
let db = create_test_db().await;
let mint_url = test_mint_url();
let keyset_id = test_keyset_id();
let saga_id = uuid::Uuid::new_v4();
let proof_info = test_proof_info(keyset_id, 100, mint_url.clone());
let proof_y = proof_info.y;
db.update_proofs(vec![proof_info], vec![]).await.unwrap();
db.reserve_proofs(vec![proof_y], &saga_id).await.unwrap();
let saga = WalletSaga::new(
saga_id,
WalletSagaState::Swap(SwapSagaState::ProofsReserved),
Amount::from(100),
mint_url.clone(),
CurrencyUnit::Sat,
OperationData::Swap(SwapOperationData {
input_amount: Amount::from(100),
output_amount: Amount::from(90),
counter_start: Some(0),
counter_end: Some(10),
blinded_messages: None,
}),
);
db.add_saga(saga).await.unwrap();
let wallet = create_test_wallet(db.clone()).await;
let report = wallet.recover_incomplete_sagas().await.unwrap();
assert_eq!(report.compensated, 1);
assert_eq!(report.recovered, 0);
assert_eq!(report.failed, 0);
let proofs = db
.get_proofs(None, None, Some(vec![State::Unspent]), None)
.await
.unwrap();
assert_eq!(proofs.len(), 1);
assert_eq!(proofs[0].y, proof_y);
assert!(db.get_saga(&saga_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_recover_swap_requested_proofs_not_spent() {
let db = create_test_db().await;
let mint_url = test_mint_url();
let keyset_id = test_keyset_id();
let saga_id = uuid::Uuid::new_v4();
let proof_info = test_proof_info(keyset_id, 100, mint_url.clone());
let proof_y = proof_info.y;
db.update_proofs(vec![proof_info], vec![]).await.unwrap();
db.reserve_proofs(vec![proof_y], &saga_id).await.unwrap();
let saga = WalletSaga::new(
saga_id,
WalletSagaState::Swap(SwapSagaState::SwapRequested),
Amount::from(100),
mint_url.clone(),
CurrencyUnit::Sat,
OperationData::Swap(SwapOperationData {
input_amount: Amount::from(100),
output_amount: Amount::from(90),
counter_start: Some(0),
counter_end: Some(10),
blinded_messages: None,
}),
);
db.add_saga(saga).await.unwrap();
let mock_client = Arc::new(MockMintConnector::new());
mock_client.set_check_state_response(Ok(CheckStateResponse {
states: vec![ProofState {
y: proof_y,
state: State::Unspent, witness: None,
}],
}));
let wallet = create_test_wallet_with_mock(db.clone(), mock_client).await;
let report = wallet.recover_incomplete_sagas().await.unwrap();
assert_eq!(report.compensated, 1);
assert_eq!(report.recovered, 0);
let proofs = db
.get_proofs(None, None, Some(vec![State::Unspent]), None)
.await
.unwrap();
assert_eq!(proofs.len(), 1);
assert!(db.get_saga(&saga_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_recover_swap_proofs_reserved_without_operation_link_leaves_reserved_proof() {
let db = create_test_db().await;
let mint_url = test_mint_url();
let keyset_id = test_keyset_id();
let saga_id = uuid::Uuid::new_v4();
let proof_info = test_proof_info(keyset_id, 100, mint_url.clone());
let proof_y = proof_info.y;
db.update_proofs(vec![proof_info], vec![]).await.unwrap();
db.update_proofs_state(vec![proof_y], State::Reserved)
.await
.unwrap();
let saga = WalletSaga::new(
saga_id,
WalletSagaState::Swap(SwapSagaState::ProofsReserved),
Amount::from(100),
mint_url,
CurrencyUnit::Sat,
OperationData::Swap(SwapOperationData {
input_amount: Amount::from(100),
output_amount: Amount::from(90),
counter_start: Some(0),
counter_end: Some(10),
blinded_messages: None,
}),
);
db.add_saga(saga).await.unwrap();
let wallet = create_test_wallet(db.clone()).await;
let report = wallet.recover_incomplete_sagas().await.unwrap();
assert_eq!(report.compensated, 1);
assert_eq!(report.recovered, 0);
let reserved = db.get_proofs_by_ys(vec![proof_y]).await.unwrap();
assert_eq!(reserved.len(), 1);
assert_eq!(reserved[0].state, State::Reserved);
assert_eq!(reserved[0].used_by_operation, None);
assert!(db.get_saga(&saga_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_recover_swap_requested_mint_unreachable() {
let db = create_test_db().await;
let mint_url = test_mint_url();
let keyset_id = test_keyset_id();
let saga_id = uuid::Uuid::new_v4();
let proof_info = test_proof_info(keyset_id, 100, mint_url.clone());
let proof_y = proof_info.y;
db.update_proofs(vec![proof_info], vec![]).await.unwrap();
db.reserve_proofs(vec![proof_y], &saga_id).await.unwrap();
let saga = WalletSaga::new(
saga_id,
WalletSagaState::Swap(SwapSagaState::SwapRequested),
Amount::from(100),
mint_url.clone(),
CurrencyUnit::Sat,
OperationData::Swap(SwapOperationData {
input_amount: Amount::from(100),
output_amount: Amount::from(90),
counter_start: Some(0),
counter_end: Some(10),
blinded_messages: None,
}),
);
db.add_saga(saga).await.unwrap();
let mock_client = Arc::new(MockMintConnector::new());
mock_client
.set_check_state_response(Err(crate::Error::Custom("Connection refused".to_string())));
let wallet = create_test_wallet_with_mock(db.clone(), mock_client).await;
let report = wallet.recover_incomplete_sagas().await.unwrap();
assert_eq!(report.skipped, 1);
assert_eq!(report.compensated, 0);
assert_eq!(report.recovered, 0);
let reserved = db.get_reserved_proofs(&saga_id).await.unwrap();
assert_eq!(reserved.len(), 1);
assert!(db.get_saga(&saga_id).await.unwrap().is_some());
}
}