1use std::pin::Pin;
2use std::task::Poll;
3
4use futures::stream::Stream;
5use futures::TryStreamExt;
6use tokio::sync::broadcast;
7use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
8use tokio_stream::wrappers::BroadcastStream;
9use tokio_stream::StreamExt;
10
11use ark::lightning::{AsPaymentHash, Invoice, Offer, PaymentHash};
12
13use crate::Wallet;
14use crate::movement::{Movement, PaymentMethod};
15use crate::subsystem::Subsystem;
16
17#[derive(Debug, Clone)]
19pub enum WalletNotification {
20 MovementCreated {
22 movement: Movement,
23 },
24 MovementUpdated {
26 movement: Movement,
27 },
28 ChannelLagging,
33}
34
35impl WalletNotification {
36 pub fn movement(&self) -> Option<&Movement> {
38 match self {
39 Self::MovementCreated { movement } => Some(movement),
40 Self::MovementUpdated { movement } => Some(movement),
41 Self::ChannelLagging => None,
42 }
43 }
44
45 pub fn lightning_invoice(&self) -> Option<&Invoice> {
49 self.movement().and_then(|m| m.lightning_invoice())
50 }
51
52 pub fn lightning_offer(&self) -> Option<&Offer> {
56 self.movement().and_then(|m| m.lightning_offer())
57 }
58
59 pub fn lightning_payment_hash(&self) -> Option<PaymentHash> {
63 self.movement().and_then(|m| m.lightning_payment_hash())
64 }
65}
66
67pub struct NotificationStream {
74 rx: BroadcastStream<WalletNotification>,
75}
76
77impl NotificationStream {
78 pub(crate) fn new(rx: broadcast::Receiver<WalletNotification>) -> Self {
79 Self {
80 rx: BroadcastStream::new(rx),
81 }
82 }
83
84 pub fn movements(self) -> impl Stream<Item = Movement> + Unpin + Send {
86 self.filter_map(|n| match n {
87 WalletNotification::MovementCreated { movement } => Some(movement),
88 WalletNotification::MovementUpdated { movement } => Some(movement),
89 WalletNotification::ChannelLagging => None,
90 })
91 }
92
93 pub fn filter_arkoor_address_movements(
95 self,
96 address: ark::Address,
97 ) -> impl Stream<Item = Movement> + Unpin + Send {
98 self.movements().filter(move |m| {
99 if !m.subsystem.is_subsystem(Subsystem::ARKOOR) {
100 return false;
101 }
102
103 m.received_on.iter().any(|d| match d.destination {
104 PaymentMethod::Ark(ref a) if *a == address => true,
105 _ => false,
106 })
107 })
108 }
109
110 pub fn filter_lightning_payment_movements(
114 self,
115 payment: impl AsPaymentHash,
116 ) -> impl Stream<Item = Movement> + Unpin + Send {
117 let payment_hash = payment.as_payment_hash();
118 self.movements().filter(move |m| m.lightning_payment_hash() == Some(payment_hash))
119 }
120
121 pub fn into_raw_stream(self) -> BroadcastStream<WalletNotification> {
126 self.rx
127 }
128}
129
130impl Stream for NotificationStream {
131 type Item = WalletNotification;
132 fn poll_next(
133 mut self: Pin<&mut Self>,
134 cx: &mut std::task::Context<'_>,
135 ) -> Poll<Option<Self::Item>> {
136 match self.rx.try_poll_next_unpin(cx) {
137 Poll::Pending => Poll::Pending,
138 Poll::Ready(None) => Poll::Ready(None),
139 Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(m)),
140 Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
141 Poll::Ready(Some(WalletNotification::ChannelLagging))
142 },
143 }
144 }
145}
146
147#[derive(Clone)]
148pub(crate) struct NotificationDispatch {
149 tx: broadcast::Sender<WalletNotification>,
150}
151
152impl NotificationDispatch {
153 pub fn new() -> Self {
154 let (tx, _rx) = broadcast::channel(64);
155 Self { tx }
156 }
157
158 pub fn subscribe(&self) -> NotificationStream {
159 NotificationStream::new(self.tx.subscribe())
160 }
161
162 fn dispatch(&self, n: WalletNotification) {
163 let _ = self.tx.send(n);
164 }
165
166 pub fn dispatch_movement_created(&self, movement: Movement) {
167 self.dispatch(WalletNotification::MovementCreated { movement });
168 }
169
170 pub fn dispatch_movement_updated(&self, movement: Movement) {
171 self.dispatch(WalletNotification::MovementUpdated { movement });
172 }
173}
174
175impl Wallet {
176 pub fn subscribe_notifications(&self) -> NotificationStream {
195 self.notifications.subscribe()
196 }
197}