use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use cdk::mint::QuoteId;
use cdk::nuts::nut17::NotificationPayload;
use cdk::subscription::SubId;
use cdk::ws::{
notification_to_ws_message, NotificationInner, WsErrorBody, WsMessageOrResponse,
WsMethodRequest, WsRequest,
};
use futures::StreamExt;
use tokio::sync::mpsc;
use crate::MintState;
mod error;
mod subscribe;
mod unsubscribe;
pub(crate) const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100;
pub(crate) const MAX_FILTERS_PER_SUBSCRIPTION: usize = 1000;
async fn process(
context: &mut WsContext,
body: WsRequest,
) -> Result<serde_json::Value, serde_json::Error> {
let response = match body.method {
WsMethodRequest::Subscribe(sub) => subscribe::handle(context, sub).await,
WsMethodRequest::Unsubscribe(unsub) => unsubscribe::handle(context, unsub).await,
}
.map_err(WsErrorBody::from);
let response: WsMessageOrResponse = (body.id, response).into();
serde_json::to_value(response)
}
pub use error::WsError;
pub struct WsContext {
state: MintState,
subscriptions: HashMap<Arc<SubId>, tokio::task::JoinHandle<()>>,
publisher: mpsc::Sender<(Arc<SubId>, NotificationPayload<QuoteId>)>,
}
impl Drop for WsContext {
fn drop(&mut self) {
for (_, handle) in self.subscriptions.drain() {
handle.abort();
}
}
}
pub async fn main_websocket(mut socket: WebSocket, state: MintState) {
let (publisher, mut subscriber) = mpsc::channel(100);
let mut context = WsContext {
state,
subscriptions: HashMap::new(),
publisher,
};
loop {
tokio::select! {
Some((sub_id, payload)) = subscriber.recv() => {
if !context.subscriptions.contains_key(&sub_id) {
continue;
}
let notification = notification_to_ws_message(NotificationInner {
sub_id,
payload,
});
let message = match serde_json::to_string(¬ification) {
Ok(message) => message,
Err(err) => {
tracing::error!("Could not serialize notification: {}", err);
continue;
}
};
if let Err(err)= socket.send(Message::Text(message.into())).await {
tracing::error!("Could not send websocket message: {}", err);
break;
}
}
Some(from_ws) = socket.next() => {
let text = match from_ws {
Ok(Message::Text(text)) => text.to_string(),
Ok(Message::Binary(bin)) => String::from_utf8_lossy(&bin).to_string(),
Ok(Message::Ping(payload)) => {
if let Err(e) = socket.send(Message::Pong(payload)).await {
tracing::error!("failed to send pong: {e}");
break;
}
continue;
},
Ok(Message::Pong(_payload)) => {
tracing::error!("Unexpected pong");
continue;
},
Ok(Message::Close(frame)) => {
if let Some(CloseFrame { code, reason }) = frame {
tracing::info!("ws-close: code={code:?} reason='{reason}'");
} else {
tracing::info!("ws-close: no frame");
}
let _ = socket.send(Message::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::NORMAL,
reason: "bye!".into(),
}))).await;
break;
}
Err(err) => {
tracing::error!("ws-error: {err}");
break;
}
};
let request = match serde_json::from_str::<WsRequest>(&text) {
Ok(request) => request,
Err(err) => {
tracing::error!("Could not parse request: {}", err);
continue;
}
};
match process(&mut context, request).await {
Ok(result) => {
if let Err(err) = socket
.send(Message::Text(result.to_string().into()))
.await
{
tracing::error!("Could not send request: {}", err);
break;
}
}
Err(err) => {
tracing::error!("Error serializing response: {}", err);
break;
}
}
}
else => {
tracing::warn!("Unexpected event, closing ws");
break;
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use cdk::mint::{Mint, QuoteId};
use cdk::nuts::nut02::KeySetVersion;
use cdk::nuts::{CurrencyUnit, MintInfo};
use cdk::subscription::{Params, SubId};
use cdk::ws::WsUnsubscribeRequest;
use cdk_signatory::db_signatory::DbSignatory;
use cdk_signatory::signatory::{RotateKeyArguments, Signatory};
use cdk_sqlite::mint::memory;
use super::*;
use crate::cache::HttpCache;
async fn create_test_mint_with_limits(max_inputs: usize, max_outputs: usize) -> Arc<Mint> {
let localstore = Arc::new(memory::empty().await.expect("in-memory db"));
let seed = [0u8; 32];
let mut supported_units = HashMap::new();
let amounts: Vec<u64> = (0..8).map(|i| 2u64.pow(i)).collect();
supported_units.insert(CurrencyUnit::Sat, (0u64, amounts));
let signatory = Arc::new(
DbSignatory::new(
localstore.clone(),
&seed,
supported_units.clone(),
HashMap::new(),
)
.await
.expect("signatory"),
);
for (unit, (fee, amounts)) in &supported_units {
signatory
.rotate_keyset(RotateKeyArguments {
unit: unit.clone(),
amounts: amounts.clone(),
input_fee_ppk: *fee,
keyset_id_type: KeySetVersion::Version00,
final_expiry: None,
})
.await
.expect("rotate keyset");
}
Arc::new(
Mint::new(
MintInfo::default(),
signatory,
localstore,
HashMap::new(),
max_inputs,
max_outputs,
)
.await
.expect("mint"),
)
}
async fn create_test_mint() -> Arc<Mint> {
create_test_mint_with_limits(1000, 1000).await
}
fn make_params(sub_id: &str) -> Params {
Params {
kind: cdk::nuts::nut17::Kind::Bolt11MintQuote,
filters: vec![QuoteId::new().to_string()],
id: Arc::new(SubId::from(sub_id)),
}
}
fn make_context(mint: Arc<Mint>) -> WsContext {
let state = MintState {
mint,
cache: Arc::new(HttpCache::default()),
};
let (publisher, _receiver) = tokio::sync::mpsc::channel(100);
WsContext {
state,
subscriptions: HashMap::new(),
publisher,
}
}
#[tokio::test]
async fn test_unsubscribe_cleans_up_active_subscription() {
let mint = create_test_mint().await;
let pubsub = mint.pubsub_manager();
let mut context = make_context(mint);
subscribe::handle(&mut context, make_params("sub-1"))
.await
.expect("subscribe");
tokio::task::yield_now().await;
assert_eq!(
pubsub.active_subscribers(),
1,
"should have 1 active subscriber after subscribe"
);
unsubscribe::handle(
&mut context,
WsUnsubscribeRequest {
sub_id: Arc::new(SubId::from("sub-1")),
},
)
.await
.expect("unsubscribe");
tokio::task::yield_now().await;
assert_eq!(
pubsub.active_subscribers(),
0,
"active_subscribers should be 0 after explicit unsubscribe"
);
}
#[tokio::test]
async fn test_context_drop_cleans_up_active_subscriptions() {
let mint = create_test_mint().await;
let pubsub = mint.pubsub_manager();
let mut context = make_context(mint);
subscribe::handle(&mut context, make_params("sub-A"))
.await
.expect("subscribe A");
subscribe::handle(&mut context, make_params("sub-B"))
.await
.expect("subscribe B");
tokio::task::yield_now().await;
assert_eq!(
pubsub.active_subscribers(),
2,
"should have 2 active subscribers"
);
drop(context);
tokio::task::yield_now().await;
assert_eq!(
pubsub.active_subscribers(),
0,
"active_subscribers should be 0 after context drop (disconnect)"
);
}
#[tokio::test]
async fn test_per_connection_subscription_count_limit() {
let mint = create_test_mint().await;
let pubsub = mint.pubsub_manager();
let mut context = make_context(mint);
for i in 0..MAX_SUBSCRIPTIONS_PER_CONNECTION {
subscribe::handle(&mut context, make_params(&format!("sub-cap-{i}")))
.await
.expect("subscribe before cap should succeed");
}
tokio::task::yield_now().await;
assert_eq!(
pubsub.active_subscribers(),
MAX_SUBSCRIPTIONS_PER_CONNECTION,
"should have subscribers up to the per-connection cap"
);
let over_cap = subscribe::handle(
&mut context,
make_params(&format!("sub-cap-{MAX_SUBSCRIPTIONS_PER_CONNECTION}")),
)
.await;
assert!(
over_cap.is_err(),
"subscription over the per-connection cap should be rejected"
);
assert_eq!(
pubsub.active_subscribers(),
MAX_SUBSCRIPTIONS_PER_CONNECTION,
"rejected subscription should not allocate a pub/sub subscriber"
);
}
#[tokio::test]
async fn test_subscription_filter_count_not_tied_to_max_inputs() {
let mint = create_test_mint_with_limits(2, 2).await;
let mut context = make_context(mint);
let params = Params {
kind: cdk::nuts::nut17::Kind::Bolt11MintQuote,
filters: (0..5).map(|_| QuoteId::new().to_string()).collect(),
id: Arc::new(SubId::from("sub-many-filters")),
};
let result = subscribe::handle(&mut context, params).await;
assert!(
result.is_ok(),
"subscription filter count must not be capped by mint max_inputs; got {:?}",
result.as_ref().err()
);
}
}