use std::sync::Arc;
use std::task::Poll;
use cdk_common::{Amount, Error, MeltQuoteState, MintQuoteState, NotificationPayload};
use futures::future::join_all;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use super::RecvFuture;
use crate::event::MintEvent;
use crate::wallet::subscription::ActiveSubscription;
use crate::{Wallet, WalletSubscription};
type SubscribeReceived = (Option<MintEvent<String>>, Vec<ActiveSubscription>);
type PaymentValue = (String, Option<Amount>);
#[allow(missing_debug_implementations)]
pub struct PaymentStream<'a> {
wallet: &'a Wallet,
filters: Option<Vec<WalletSubscription>>,
is_finalized: bool,
active_subscription: Option<Vec<ActiveSubscription>>,
cancel_token: CancellationToken,
subscriber_future: Option<RecvFuture<'a, Vec<ActiveSubscription>>>,
subscription_receiver_future: Option<RecvFuture<'static, SubscribeReceived>>,
cancellation_future: Option<RecvFuture<'a, ()>>,
}
impl<'a> PaymentStream<'a> {
pub fn new(wallet: &'a Wallet, filters: Vec<WalletSubscription>) -> Self {
Self {
wallet,
filters: Some(filters),
is_finalized: false,
active_subscription: None,
cancel_token: Default::default(),
subscriber_future: None,
subscription_receiver_future: None,
cancellation_future: None,
}
}
pub fn get_cancel_token(&self) -> CancellationToken {
self.cancel_token.clone()
}
fn poll_init_subscription(&mut self, cx: &mut std::task::Context<'_>) -> Option<()> {
if let Some(filters) = self.filters.take() {
let wallet = self.wallet;
self.subscriber_future = Some(Box::pin(async move {
let results = join_all(filters.into_iter().map(|w| wallet.subscribe(w))).await;
results
.into_iter()
.filter_map(|r| match r {
Ok(sub) => Some(sub),
Err(e) => {
tracing::warn!("Failed to create subscription: {}", e);
None
}
})
.collect::<Vec<_>>()
}));
}
let mut subscriber_future = self.subscriber_future.take()?;
match subscriber_future.poll_unpin(cx) {
Poll::Pending => {
self.subscriber_future = Some(subscriber_future);
Some(())
}
Poll::Ready(active_subscription) => {
self.active_subscription = Some(active_subscription);
None
}
}
}
fn poll_cancel(&mut self, cx: &mut std::task::Context<'_>) -> bool {
let mut cancellation_future = self.cancellation_future.take().unwrap_or_else(|| {
let cancel_token = self.cancel_token.clone();
Box::pin(async move { cancel_token.cancelled().await })
});
if cancellation_future.poll_unpin(cx).is_ready() {
self.subscription_receiver_future = None;
true
} else {
self.cancellation_future = Some(cancellation_future);
false
}
}
fn poll_event(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<PaymentValue, Error>>> {
let (subscription_receiver_future, active_subscription) = (
self.subscription_receiver_future.take(),
self.active_subscription.take(),
);
if subscription_receiver_future.is_none() && active_subscription.is_none() {
return Poll::Ready(Some(Err(Error::Internal)));
}
let localstore = Arc::clone(&self.wallet.localstore);
let mut receiver = subscription_receiver_future.unwrap_or_else(|| {
let mut subscription_receiver =
active_subscription.expect("active subscription object");
Box::pin(async move {
let mut futures: FuturesUnordered<_> = subscription_receiver
.iter_mut()
.map(|sub| sub.recv())
.collect();
if let Some(res) = futures.next().await {
drop(futures);
if let Some(event) = &res {
match event.inner() {
NotificationPayload::MintQuoteBolt11Response(info) => {
let quote_id = info.quote.clone();
if let Ok(Some(mut quote)) =
localstore.get_mint_quote("e_id).await
{
quote.state = info.state;
quote.amount_paid = info.amount.unwrap_or(Amount::ZERO);
if let Err(e) = localstore.add_mint_quote(quote).await {
tracing::warn!("Failed to update quote state: {}", e);
}
}
}
NotificationPayload::MintQuoteBolt12Response(info) => {
let quote_id = info.quote.clone();
if let Ok(Some(mut quote)) =
localstore.get_mint_quote("e_id).await
{
quote.amount_paid = info.amount_paid;
quote.amount_issued = info.amount_issued;
if let Err(e) = localstore.add_mint_quote(quote).await {
tracing::warn!("Failed to update quote state: {}", e);
}
}
}
NotificationPayload::MintQuoteOnchainResponse(info) => {
let quote_id = info.quote.clone();
if let Ok(Some(mut quote)) =
localstore.get_mint_quote("e_id).await
{
quote.amount_paid = info.amount_paid;
quote.amount_issued = info.amount_issued;
if let Err(e) = localstore.add_mint_quote(quote).await {
tracing::warn!("Failed to update quote state: {}", e);
}
}
}
NotificationPayload::CustomMintQuoteResponse(_, info) => {
let quote_id = info.quote.clone();
if let Ok(Some(mut quote)) =
localstore.get_mint_quote("e_id).await
{
quote.amount_paid = info.amount_paid;
quote.amount_issued = info.amount_issued;
if let Err(e) = localstore.add_mint_quote(quote).await {
tracing::warn!("Failed to update quote state: {}", e);
}
}
}
_ => (),
}
}
return (res, subscription_receiver);
}
drop(futures);
(None, subscription_receiver)
})
});
match receiver.poll_unpin(cx) {
Poll::Pending => {
self.subscription_receiver_future = Some(receiver);
Poll::Pending
}
Poll::Ready((notification, subscription)) => {
tracing::debug!("Receive payment notification {:?}", notification);
self.active_subscription = Some(subscription);
self.cancellation_future = None; match notification {
None => {
self.is_finalized = true;
Poll::Ready(None)
}
Some(info) => {
match info.into_inner() {
NotificationPayload::MintQuoteBolt11Response(info)
if info.state == MintQuoteState::Paid =>
{
self.is_finalized = true;
return Poll::Ready(Some(Ok((info.quote, None))));
}
NotificationPayload::MintQuoteBolt12Response(info) => {
let to_be_issued =
info.amount_paid.saturating_sub(info.amount_issued);
if to_be_issued > Amount::ZERO {
return Poll::Ready(Some(Ok((info.quote, Some(to_be_issued)))));
}
}
NotificationPayload::MintQuoteOnchainResponse(info) => {
let to_be_issued =
info.amount_paid.saturating_sub(info.amount_issued);
if to_be_issued > Amount::ZERO {
return Poll::Ready(Some(Ok((info.quote, Some(to_be_issued)))));
}
}
NotificationPayload::CustomMintQuoteResponse(_, info) => {
let to_be_issued =
info.amount_paid.saturating_sub(info.amount_issued);
if to_be_issued > Amount::ZERO {
return Poll::Ready(Some(Ok((info.quote, Some(to_be_issued)))));
}
}
NotificationPayload::MeltQuoteBolt11Response(info)
if info.state == MeltQuoteState::Paid =>
{
self.is_finalized = true;
return Poll::Ready(Some(Ok((info.quote, None))));
}
NotificationPayload::MeltQuoteOnchainResponse(info)
if info.state == MeltQuoteState::Paid =>
{
self.is_finalized = true;
return Poll::Ready(Some(Ok((info.quote, None))));
}
_ => {}
}
self.poll_event(cx)
}
}
}
}
}
}
impl Stream for PaymentStream<'_> {
type Item = Result<PaymentValue, Error>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.is_finalized {
return Poll::Ready(None);
}
if this.poll_cancel(cx) {
return Poll::Ready(None);
}
if this.poll_init_subscription(cx).is_some() {
return Poll::Pending;
}
this.poll_event(cx)
}
}
#[cfg(test)]
mod tests {
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use cdk_common::{
Amount, CurrencyUnit, MintQuoteBolt12Response, MintQuoteCustomResponse,
MintQuoteOnchainResponse, NotificationPayload,
};
use futures::Stream;
use super::PaymentStream;
use crate::event::MintEvent;
use crate::nuts::SecretKey;
use crate::wallet::subscription::ActiveSubscription;
use crate::wallet::test_utils::{create_test_db, create_test_wallet};
#[tokio::test]
async fn mint_quote_notification_underflow_does_not_panic() {
let db = create_test_db().await;
let wallet = create_test_wallet(db).await;
let pubkey = SecretKey::generate().public_key();
let events = vec![
MintEvent::new(NotificationPayload::MintQuoteBolt12Response(
MintQuoteBolt12Response::<String> {
quote: "bolt12_quote".to_string(),
request: "test_request".to_string(),
amount: None,
unit: CurrencyUnit::Sat,
expiry: None,
pubkey,
amount_paid: Amount::from(50u64),
amount_issued: Amount::from(100u64),
},
)),
MintEvent::new(NotificationPayload::MintQuoteOnchainResponse(
MintQuoteOnchainResponse::<String> {
quote: "onchain_quote".to_string(),
request: "test_request".to_string(),
unit: CurrencyUnit::Sat,
expiry: None,
pubkey,
amount_paid: Amount::from(50u64),
amount_issued: Amount::from(100u64),
},
)),
MintEvent::new(NotificationPayload::CustomMintQuoteResponse(
"custom".to_string(),
MintQuoteCustomResponse::<String> {
quote: "custom_quote".to_string(),
request: "test_request".to_string(),
amount: None,
amount_paid: Amount::from(50u64),
amount_issued: Amount::from(100u64),
unit: Some(CurrencyUnit::Sat),
expiry: None,
pubkey: Some(pubkey),
extra: serde_json::Value::Null,
},
)),
];
for event in events {
let mut stream = PaymentStream::new(&wallet, Vec::new());
stream.filters = None;
stream.subscription_receiver_future = Some(Box::pin(async move {
let subscriptions: Vec<ActiveSubscription> = Vec::new();
(Some(event), subscriptions)
}));
let mut cx = Context::from_waker(Waker::noop());
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Pin::new(&mut stream).poll_next(&mut cx)
}));
assert!(result.is_ok());
assert!(matches!(
result.expect("poll should not panic"),
Poll::Ready(None)
));
}
}
}