use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use time::format_description::well_known::Iso8601;
use time::OffsetDateTime;
use super::ws::{WsMessage, WsResponse};
use crate::protocol::core::parse_authorization;
use crate::protocol::methods::tempo::session_method::deduct_from_channel;
use crate::protocol::methods::tempo::session_receipt::SessionReceipt;
use crate::protocol::traits::{ChargeMethod, SessionMethod};
pub struct WsSessionOptions<G> {
pub store: Arc<dyn crate::protocol::methods::tempo::session_method::ChannelStore>,
pub channel_id: String,
pub challenge_id: String,
pub tick_cost: u128,
pub generate: G,
pub poll_interval_ms: u64,
}
pub async fn ws_session<G, S>(sender: &mut S, options: WsSessionOptions<G>)
where
G: futures_core::Stream<Item = String> + Send + Unpin + 'static,
S: futures_util::Sink<String, Error = Box<dyn std::error::Error + Send + Sync>> + Send + Unpin,
{
let WsSessionOptions {
store,
channel_id,
challenge_id,
tick_cost,
generate,
poll_interval_ms,
} = options;
let mut stream = std::pin::pin!(generate);
while let Some(value) = stream.next().await {
loop {
match deduct_from_channel(&*store, &channel_id, tick_cost).await {
Ok(_state) => break,
Err(_) => {
if let Ok(Some(ch)) = store.get_channel(&channel_id).await {
let msg = WsResponse::NeedVoucher {
channel_id: channel_id.clone(),
required_cumulative: (ch.spent + tick_cost).to_string(),
accepted_cumulative: ch.highest_voucher_amount.to_string(),
deposit: ch.deposit.to_string(),
};
if sender.send(msg.to_text()).await.is_err() {
return; }
}
tokio::select! {
_ = store.wait_for_update(&channel_id) => {},
_ = tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval_ms)) => {},
}
}
}
}
let msg = WsResponse::Data { data: value };
if sender.send(msg.to_text()).await.is_err() {
break;
}
}
if let Ok(Some(ch)) = store.get_channel(&channel_id).await {
let timestamp = OffsetDateTime::now_utc()
.format(&Iso8601::DEFAULT)
.expect("ISO 8601 formatting cannot fail");
let mut receipt = SessionReceipt::new(
timestamp,
&challenge_id,
&channel_id,
ch.highest_voucher_amount.to_string(),
ch.spent.to_string(),
);
receipt.units = Some(ch.units);
let msg = WsResponse::Receipt {
receipt: serde_json::to_value(&receipt)
.unwrap_or_else(|_| serde_json::json!({"error": "serialization failed"})),
};
let _ = sender.send(msg.to_text()).await;
}
}
pub async fn process_incoming_vouchers<M, S, R>(receiver: &mut R, mpp: &crate::server::Mpp<M, S>)
where
M: ChargeMethod,
S: SessionMethod,
R: futures_util::Stream<Item = Result<String, Box<dyn std::error::Error + Send + Sync>>>
+ Send
+ Unpin,
{
while let Some(Ok(text)) = receiver.next().await {
let Ok(WsMessage::Credential { credential }) = serde_json::from_str(&text) else {
continue;
};
let Ok(parsed) = parse_authorization(&credential) else {
continue;
};
let _ = mpp.verify_session(&parsed).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::methods::tempo::session_method::InMemoryChannelStore;
#[test]
fn test_ws_session_options_fields() {
let store = Arc::new(InMemoryChannelStore::new());
let _opts = WsSessionOptions {
store,
channel_id: "0xabc".to_string(),
challenge_id: "ch-1".to_string(),
tick_cost: 1000,
generate: futures_util::stream::empty::<String>(),
poll_interval_ms: 100,
};
}
}