use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use mpp::server::VerificationError;
use tokio::sync::{Notify, RwLock};
use super::solana_types::SolanaChannelState;
pub type ChannelUpdater = Box<
dyn FnOnce(Option<SolanaChannelState>) -> Result<Option<SolanaChannelState>, VerificationError>
+ Send,
>;
pub trait SolanaChannelStore: Send + Sync {
fn get_channel(
&self,
channel_id: &str,
) -> Pin<
Box<dyn Future<Output = Result<Option<SolanaChannelState>, VerificationError>> + Send + '_>,
>;
fn update_channel(
&self,
channel_id: &str,
updater: ChannelUpdater,
) -> Pin<
Box<dyn Future<Output = Result<Option<SolanaChannelState>, VerificationError>> + Send + '_>,
>;
fn wait_for_update(&self, _channel_id: &str) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
Box::pin(async {})
}
}
pub struct InMemorySolanaChannelStore {
channels: RwLock<HashMap<String, SolanaChannelState>>,
notifiers: RwLock<HashMap<String, Arc<Notify>>>,
}
impl Default for InMemorySolanaChannelStore {
fn default() -> Self {
Self {
channels: RwLock::new(HashMap::new()),
notifiers: RwLock::new(HashMap::new()),
}
}
}
impl InMemorySolanaChannelStore {
pub fn new() -> Self {
Self::default()
}
}
impl SolanaChannelStore for InMemorySolanaChannelStore {
fn get_channel(
&self,
channel_id: &str,
) -> Pin<
Box<dyn Future<Output = Result<Option<SolanaChannelState>, VerificationError>> + Send + '_>,
> {
let channel_id = channel_id.to_string();
Box::pin(async move {
let channels = self.channels.read().await;
Ok(channels.get(&channel_id).cloned())
})
}
fn update_channel(
&self,
channel_id: &str,
updater: ChannelUpdater,
) -> Pin<
Box<dyn Future<Output = Result<Option<SolanaChannelState>, VerificationError>> + Send + '_>,
> {
let channel_id = channel_id.to_string();
Box::pin(async move {
let mut channels = self.channels.write().await;
let current = channels.get(&channel_id).cloned();
let next = updater(current)?;
match next {
Some(ref state) => {
channels.insert(channel_id.clone(), state.clone());
}
None => {
channels.remove(&channel_id);
}
}
drop(channels);
let notifiers = self.notifiers.read().await;
if let Some(notify) = notifiers.get(&channel_id) {
notify.notify_waiters();
}
Ok(next)
})
}
fn wait_for_update(&self, channel_id: &str) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
let channel_id = channel_id.to_string();
Box::pin(async move {
let notify = {
let mut notifiers = self.notifiers.write().await;
notifiers
.entry(channel_id)
.or_insert_with(|| Arc::new(Notify::new()))
.clone()
};
notify.notified().await;
})
}
}
pub async fn deduct_from_channel(
store: &dyn SolanaChannelStore,
channel_id: &str,
amount: u128,
) -> Result<SolanaChannelState, VerificationError> {
let result = store
.update_channel(
channel_id,
Box::new(move |current| {
let state = current
.ok_or_else(|| VerificationError::channel_not_found("channel not found"))?;
let authorized: u128 = state.last_authorized_amount.parse().map_err(|_| {
VerificationError::invalid_payload("invalid last_authorized_amount")
})?;
let settled: u128 = state
.settled_amount
.parse()
.map_err(|_| VerificationError::invalid_payload("invalid settled_amount"))?;
let available = authorized.saturating_sub(settled);
if available < amount {
return Err(VerificationError::insufficient_balance(format!(
"requested {amount}, available {available}",
)));
}
Ok(Some(SolanaChannelState {
settled_amount: (settled + amount).to_string(),
..state
}))
}),
)
.await?;
result.ok_or_else(|| VerificationError::channel_not_found("channel not found"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mpp::solana_types::{AuthorizationMode, ChannelStatus};
fn test_channel(id: &str) -> SolanaChannelState {
SolanaChannelState {
channel_id: id.into(),
payer: "alice".into(),
recipient: "bob".into(),
server_nonce: "nonce".into(),
channel_program: "prog".into(),
chain_id: "solana:mainnet-beta".into(),
authorization_mode: AuthorizationMode::SwigSession,
authority_wallet: "alice".into(),
delegated_session_key: None,
escrowed_amount: "1000".into(),
last_authorized_amount: "500".into(),
last_sequence: 1,
settled_amount: "100".into(),
status: ChannelStatus::Open,
expires_at_unix: None,
created_at: "2026-01-01T00:00:00Z".into(),
}
}
#[tokio::test]
async fn get_missing_channel_returns_none() {
let store = InMemorySolanaChannelStore::new();
let result = store.get_channel("missing").await.expect("no error");
assert!(result.is_none());
}
#[tokio::test]
async fn update_creates_new_channel() {
let store = InMemorySolanaChannelStore::new();
let channel = test_channel("ch1");
let ch = channel.clone();
let result = store
.update_channel("ch1", Box::new(move |_| Ok(Some(ch))))
.await
.expect("no error");
assert!(result.is_some());
let fetched = store.get_channel("ch1").await.expect("no error");
assert_eq!(fetched.expect("exists").channel_id, channel.channel_id);
}
#[tokio::test]
async fn deduct_succeeds_with_available_balance() {
let store = InMemorySolanaChannelStore::new();
let channel = test_channel("ch1");
let ch = channel.clone();
store
.update_channel("ch1", Box::new(move |_| Ok(Some(ch))))
.await
.expect("no error");
let updated = deduct_from_channel(&store, "ch1", 200).await.expect("ok");
assert_eq!(updated.settled_amount, "300");
}
#[tokio::test]
async fn deduct_fails_with_insufficient_balance() {
let store = InMemorySolanaChannelStore::new();
let channel = test_channel("ch1");
let ch = channel.clone();
store
.update_channel("ch1", Box::new(move |_| Ok(Some(ch))))
.await
.expect("no error");
let result = deduct_from_channel(&store, "ch1", 500).await;
assert!(result.is_err());
}
}