1use std::pin::Pin;
2use std::task::Poll;
3
4use futures::stream::Stream;
5use futures::TryStreamExt;
6use tokio::sync::broadcast;
7use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
8use tokio_stream::wrappers::BroadcastStream;
9use tokio_stream::StreamExt;
10
11use ark::lightning::AsPaymentHash;
12
13use crate::Wallet;
14use crate::movement::{Movement, PaymentMethod};
15use crate::subsystem::{LightningMovement, Subsystem};
16
17#[derive(Debug, Clone)]
19pub enum WalletNotification {
20 MovementCreated {
22 movement: Movement,
23 },
24 MovementUpdated {
26 movement: Movement,
27 },
28 ChannelLagging,
33}
34
35pub struct NotificationStream {
42 rx: BroadcastStream<WalletNotification>,
43}
44
45impl NotificationStream {
46 pub(crate) fn new(rx: broadcast::Receiver<WalletNotification>) -> Self {
47 Self {
48 rx: BroadcastStream::new(rx),
49 }
50 }
51
52 pub fn movements(self) -> impl Stream<Item = Movement> + Unpin + Send {
54 self.filter_map(|n| match n {
55 WalletNotification::MovementCreated { movement } => Some(movement),
56 WalletNotification::MovementUpdated { movement } => Some(movement),
57 WalletNotification::ChannelLagging => None,
58 })
59 }
60
61 pub fn filter_arkoor_address_movements(
63 self,
64 address: ark::Address,
65 ) -> impl Stream<Item = Movement> + Unpin + Send {
66 self.movements().filter(move |m| {
67 if !m.subsystem.is_subsystem(Subsystem::ARKOOR) {
68 return false;
69 }
70
71 m.received_on.iter().any(|d| match d.destination {
72 PaymentMethod::Ark(ref a) if *a == address => true,
73 _ => false,
74 })
75 })
76 }
77
78 pub fn filter_lightning_payment_movements(
82 self,
83 payment: impl AsPaymentHash,
84 ) -> impl Stream<Item = Movement> + Unpin + Send {
85 let payment_hash = payment.as_payment_hash();
86 self.movements().filter(move |m| {
87 if !m.subsystem.is_subsystem(Subsystem::LIGHTNING_RECEIVE)
88 && !m.subsystem.is_subsystem(Subsystem::LIGHTNING_SEND)
89 {
90 return false;
91 }
92
93 if LightningMovement::get_payment_hash(&m.metadata) == Some(payment_hash) {
94 return true;
95 }
96
97 for d in &m.received_on {
98 match d.destination {
99 PaymentMethod::Invoice(ref i) if i.payment_hash() == payment_hash => {
100 return true;
101 },
102 _ => {},
103 }
104 }
105
106 false
107 })
108 }
109
110 pub fn into_raw_stream(self) -> BroadcastStream<WalletNotification> {
115 self.rx
116 }
117}
118
119impl Stream for NotificationStream {
120 type Item = WalletNotification;
121 fn poll_next(
122 mut self: Pin<&mut Self>,
123 cx: &mut std::task::Context<'_>,
124 ) -> Poll<Option<Self::Item>> {
125 match self.rx.try_poll_next_unpin(cx) {
126 Poll::Pending => Poll::Pending,
127 Poll::Ready(None) => Poll::Ready(None),
128 Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(m)),
129 Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(_)))) => {
130 Poll::Ready(Some(WalletNotification::ChannelLagging))
131 },
132 }
133 }
134}
135
136#[derive(Clone)]
137pub(crate) struct NotificationDispatch {
138 tx: broadcast::Sender<WalletNotification>,
139}
140
141impl NotificationDispatch {
142 pub fn new() -> Self {
143 let (tx, _rx) = broadcast::channel(64);
144 Self { tx }
145 }
146
147 pub fn subscribe(&self) -> NotificationStream {
148 NotificationStream::new(self.tx.subscribe())
149 }
150
151 fn dispatch(&self, n: WalletNotification) {
152 let _ = self.tx.send(n);
153 }
154
155 pub fn dispatch_movement_created(&self, movement: Movement) {
156 self.dispatch(WalletNotification::MovementCreated { movement });
157 }
158
159 pub fn dispatch_movement_updated(&self, movement: Movement) {
160 self.dispatch(WalletNotification::MovementUpdated { movement });
161 }
162}
163
164impl Wallet {
165 pub fn subscribe_notifications(&self) -> NotificationStream {
184 self.notifications.subscribe()
185 }
186}