use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use cdk_common::nut17::ws::{
WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest,
};
use cdk_common::nut17::{Kind, NotificationId};
use cdk_common::parking_lot::RwLock;
use cdk_common::pub_sub::remote_consumer::{
Consumer, InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport,
};
use cdk_common::pub_sub::{Error as PubsubError, Spec, Subscriber};
use cdk_common::subscription::WalletParams;
use cdk_common::ws_client::{connect as ws_connect, WsError};
use cdk_common::{CheckStateRequest, Method, PaymentMethod, RoutePath};
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::event::MintEvent;
use crate::mint_url::MintUrl;
use crate::wallet::MintConnector;
pub type NotificationPayload = crate::nuts::NotificationPayload<String>;
pub type ActiveSubscription = RemoteActiveConsumer<SubscriptionClient>;
#[derive(Clone)]
pub struct SubscriptionManager {
all_connections: Arc<RwLock<HashMap<MintUrl, Arc<Consumer<SubscriptionClient>>>>>,
http_client: Arc<dyn MintConnector + Send + Sync>,
prefer_http: bool,
}
impl Debug for SubscriptionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Subscription Manager connected to {:?}",
self.all_connections
.write()
.keys()
.cloned()
.collect::<Vec<_>>()
)
}
}
impl SubscriptionManager {
pub fn new(http_client: Arc<dyn MintConnector + Send + Sync>, prefer_http: bool) -> Self {
Self {
all_connections: Arc::new(RwLock::new(HashMap::new())),
http_client,
prefer_http,
}
}
pub fn subscribe(
&self,
mint_url: MintUrl,
filter: WalletParams,
) -> Result<RemoteActiveConsumer<SubscriptionClient>, PubsubError> {
self.all_connections
.write()
.entry(mint_url.clone())
.or_insert_with(|| {
Consumer::new(
SubscriptionClient {
mint_url,
http_client: self.http_client.clone(),
req_id: 0.into(),
},
self.prefer_http,
(),
)
})
.subscribe(filter)
}
}
#[derive(Clone, Default, Debug)]
pub struct MintSubTopics {}
#[async_trait::async_trait]
impl Spec for MintSubTopics {
type SubscriptionId = String;
type Event = MintEvent<String>;
type Topic = NotificationId<String>;
type Context = ();
fn new_instance(_context: Self::Context) -> Arc<Self>
where
Self: Sized,
{
Arc::new(Self {})
}
async fn fetch_events(self: &Arc<Self>, _topics: Vec<Self::Topic>, _reply_to: Subscriber<Self>)
where
Self: Sized,
{
}
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct SubscriptionClient {
http_client: Arc<dyn MintConnector + Send + Sync>,
mint_url: MintUrl,
req_id: AtomicUsize,
}
#[allow(dead_code)]
impl SubscriptionClient {
fn get_sub_request(
&self,
id: String,
params: NotificationId<String>,
) -> Option<(usize, String)> {
let (kind, filter) = match params {
NotificationId::ProofState(x) => (Kind::ProofState, x.to_string()),
NotificationId::MeltQuoteBolt11(q) => (Kind::Bolt11MeltQuote, q),
NotificationId::MeltQuoteBolt12(q) => (Kind::Bolt12MeltQuote, q),
NotificationId::MintQuoteBolt11(q) => (Kind::Bolt11MintQuote, q),
NotificationId::MintQuoteBolt12(q) => (Kind::Bolt12MintQuote, q),
NotificationId::MintQuoteCustom(method, q) => {
(Kind::Custom(format!("{}_mint_quote", method)), q)
}
NotificationId::MeltQuoteCustom(method, q) => {
(Kind::Custom(format!("{}_melt_quote", method)), q)
}
};
let request: WsRequest<_> = (
WsMethodRequest::Subscribe(WalletParams {
kind,
filters: vec![filter],
id: id.into(),
}),
self.req_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
)
.into();
serde_json::to_string(&request)
.inspect_err(|err| {
tracing::error!("Could not serialize subscribe message: {:?}", err);
})
.map(|json| (request.id, json))
.ok()
}
fn get_unsub_request(&self, sub_id: String) -> Option<String> {
let request: WsRequest<_> = (
WsMethodRequest::Unsubscribe(WsUnsubscribeRequest { sub_id }),
self.req_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
)
.into();
match serde_json::to_string(&request) {
Ok(json) => Some(json),
Err(err) => {
tracing::error!("Could not serialize unsubscribe message: {:?}", err);
None
}
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl Transport for SubscriptionClient {
type Spec = MintSubTopics;
fn new_name(&self) -> <Self::Spec as Spec>::SubscriptionId {
Uuid::new_v4().to_string()
}
async fn stream(
&self,
ctrls: mpsc::Receiver<StreamCtrl<Self::Spec>>,
topics: Vec<SubscribeMessage<Self::Spec>>,
reply_to: InternalRelay<Self::Spec>,
) -> Result<(), PubsubError> {
stream_client(self, ctrls, topics, reply_to).await
}
async fn poll(
&self,
topics: Vec<SubscribeMessage<Self::Spec>>,
reply_to: InternalRelay<Self::Spec>,
) -> Result<(), PubsubError> {
let proofs = topics
.iter()
.filter_map(|(_, x)| match &x {
NotificationId::ProofState(p) => Some(*p),
_ => None,
})
.collect::<Vec<_>>();
if !proofs.is_empty() {
for state in self
.http_client
.post_check_state(CheckStateRequest { ys: proofs })
.await
.map_err(|e| PubsubError::Internal(Box::new(e)))?
.states
{
reply_to.send(MintEvent::new(NotificationPayload::ProofState(state)));
}
}
for topic in topics
.into_iter()
.map(|(_, x)| x)
.filter(|x| !matches!(x, NotificationId::ProofState(_)))
{
match topic {
NotificationId::MintQuoteBolt11(id) => {
let response = match self
.http_client
.get_mint_quote_status(PaymentMethod::BOLT11, &id)
.await
{
Ok(success) => match success {
cdk_common::MintQuoteResponse::Bolt11(r) => r,
_ => {
tracing::error!("Unexpected response type for MintBolt11 {}", id);
continue;
}
},
Err(err) => {
tracing::error!("Error with MintBolt11 {} with {:?}", id, err);
continue;
}
};
reply_to.send(MintEvent::new(
NotificationPayload::MintQuoteBolt11Response(response),
));
}
NotificationId::MeltQuoteBolt11(id) => {
let response = match self
.http_client
.get_melt_quote_status(PaymentMethod::BOLT11, &id)
.await
{
Ok(success) => match success {
cdk_common::MeltQuoteResponse::Bolt11(r) => r,
_ => {
tracing::error!("Unexpected response type for MeltBolt11 {}", id);
continue;
}
},
Err(err) => {
tracing::error!("Error with MeltBolt11 {} with {:?}", id, err);
continue;
}
};
reply_to.send(MintEvent::new(
NotificationPayload::MeltQuoteBolt11Response(response),
));
}
NotificationId::MintQuoteBolt12(id) => {
let response = match self
.http_client
.get_mint_quote_status(PaymentMethod::BOLT12, &id)
.await
{
Ok(success) => match success {
cdk_common::MintQuoteResponse::Bolt12(r) => r,
_ => {
tracing::error!("Unexpected response type for MintBolt12 {}", id);
continue;
}
},
Err(err) => {
tracing::error!("Error with MintBolt12 {} with {:?}", id, err);
continue;
}
};
reply_to.send(MintEvent::new(
NotificationPayload::MintQuoteBolt12Response(response),
));
}
NotificationId::MeltQuoteBolt12(id) => {
let response = match self
.http_client
.get_melt_quote_status(PaymentMethod::BOLT12, &id)
.await
{
Ok(success) => match success {
cdk_common::MeltQuoteResponse::Bolt12(r) => r,
_ => {
tracing::error!("Unexpected response type for MeltBolt12 {}", id);
continue;
}
},
Err(err) => {
tracing::error!("Error with MeltBolt12 {} with {:?}", id, err);
continue;
}
};
reply_to.send(MintEvent::new(
NotificationPayload::MeltQuoteBolt12Response(response),
));
}
NotificationId::MintQuoteCustom(method, id) => {
let (_, response) = match self
.http_client
.get_mint_quote_status(PaymentMethod::Custom(method.clone()), &id)
.await
{
Ok(success) => match success {
cdk_common::MintQuoteResponse::Custom(r) => r,
_ => {
tracing::error!(
"Unexpected response type for Custom Mint Quote {}",
id
);
continue;
}
},
Err(err) => {
tracing::error!("Error with Custom Mint Quote {} with {:?}", id, err);
continue;
}
};
reply_to.send(MintEvent::new(
NotificationPayload::CustomMintQuoteResponse(method, response),
));
}
NotificationId::MeltQuoteCustom(method, id) => {
let response = match self
.http_client
.get_melt_quote_status(PaymentMethod::Custom(method.clone()), &id)
.await
{
Ok(success) => match success {
cdk_common::MeltQuoteResponse::Custom((_, r)) => r,
_ => {
tracing::error!(
"Unexpected response type for Custom Melt Quote {}",
id
);
continue;
}
},
Err(err) => {
tracing::error!("Error with Custom Melt Quote {} with {:?}", id, err);
continue;
}
};
reply_to.send(MintEvent::new(
NotificationPayload::CustomMeltQuoteResponse(method, response),
));
}
_ => {}
}
}
Ok(())
}
}
async fn stream_client(
client: &SubscriptionClient,
mut ctrl: mpsc::Receiver<StreamCtrl<MintSubTopics>>,
topics: Vec<SubscribeMessage<MintSubTopics>>,
reply_to: InternalRelay<MintSubTopics>,
) -> Result<(), PubsubError> {
let mut url = client
.mint_url
.join_paths(&["v1", "ws"])
.expect("Could not join paths");
if url.scheme() == "https" {
url.set_scheme("wss").expect("Could not set scheme");
} else {
url.set_scheme("ws").expect("Could not set scheme");
}
let mut headers: Vec<(&str, String)> = Vec::new();
{
let auth_wallet = client.http_client.get_auth_wallet().await;
let token = match auth_wallet.as_ref() {
Some(auth_wallet) => {
let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws);
match auth_wallet.get_auth_for_request(&endpoint).await {
Ok(token) => token,
Err(err) => {
tracing::warn!("Failed to get auth token: {:?}", err);
None
}
}
}
None => None,
};
if let Some(auth_token) = token {
let header_key = match &auth_token {
cdk_common::AuthToken::ClearAuth(_) => "Clear-auth",
cdk_common::AuthToken::BlindAuth(_) => "Blind-auth",
};
let header_value = auth_token.to_string();
headers.push((header_key, header_value));
}
}
let url_str = url.to_string();
let header_refs: Vec<(&str, &str)> = headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
tracing::debug!("Connecting to {}", url);
let (mut sender, mut receiver) = ws_connect(&url_str, &header_refs).await.map_err(|err| {
tracing::error!("Error connecting: {err:?}");
map_ws_error(err)
})?;
tracing::debug!("Connected to {}", url);
for (name, index) in topics {
let (_, req) = if let Some(req) = client.get_sub_request(name, index) {
req
} else {
continue;
};
let _ = sender.send(req).await;
}
loop {
tokio::select! {
Some(msg) = ctrl.recv() => {
match msg {
StreamCtrl::Subscribe(msg) => {
let (_, req) = if let Some(req) = client.get_sub_request(msg.0, msg.1) {
req
} else {
continue;
};
let _ = sender.send(req).await;
}
StreamCtrl::Unsubscribe(msg) => {
let req = if let Some(req) = client.get_unsub_request(msg) {
req
} else {
continue;
};
let _ = sender.send(req).await;
}
StreamCtrl::Stop => {
if let Err(err) = sender.close().await {
tracing::error!("Closing error {err:?}");
}
break;
}
};
}
msg = receiver.recv() => {
let msg = match msg {
Some(Ok(msg)) => msg,
Some(Err(_)) => {
if let Err(err) = sender.close().await {
tracing::error!("Closing error {err:?}");
}
break;
}
None => break,
};
let msg = match serde_json::from_str::<WsMessageOrResponse<String>>(&msg) {
Ok(msg) => msg,
Err(_) => continue,
};
match msg {
WsMessageOrResponse::Notification(ref payload) => {
reply_to.send(payload.params.payload.clone());
}
WsMessageOrResponse::Response(response) => {
tracing::debug!("Received response from server: {:?}", response);
}
WsMessageOrResponse::ErrorResponse(error) => {
tracing::debug!("Received an error from server: {:?}", error);
return Err(PubsubError::InternalStr(error.error.message));
}
}
}
}
}
Ok(())
}
fn map_ws_error(err: WsError) -> PubsubError {
match err {
WsError::Connection(_) => PubsubError::NotSupported,
other => PubsubError::InternalStr(other.to_string()),
}
}