use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use tokio::sync::mpsc;
use uuid::Uuid;
pub struct PendingSubscriptionSink {
id: Uuid,
method: String,
sender: Option<mpsc::UnboundedSender<serde_json::Value>>,
}
impl PendingSubscriptionSink {
pub fn new(id: Uuid, method: String, sender: mpsc::UnboundedSender<serde_json::Value>) -> Self {
Self {
id,
method,
sender: Some(sender),
}
}
pub async fn accept(mut self) -> Result<SubscriptionSink> {
let sender = self
.sender
.take()
.ok_or_else(|| Error::runtime_msg("Subscription already accepted or rejected"))?;
log::trace!(
"Subscription {} accepted for method {}",
self.id,
self.method
);
Ok(SubscriptionSink::new(self.id, sender, self.method))
}
pub async fn reject(self, reason: String) -> Result<()> {
log::trace!(
"Subscription {} rejected for method {}: {}",
self.id,
self.method,
reason
);
drop(self.sender);
Err(Error::runtime_msg(format!(
"Subscription rejected: {reason}"
)))
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn method(&self) -> &str {
&self.method
}
}
pub struct SubscriptionSink {
id: Uuid,
sender: mpsc::UnboundedSender<serde_json::Value>,
method: String,
}
impl SubscriptionSink {
pub(crate) fn new(
id: Uuid,
sender: mpsc::UnboundedSender<serde_json::Value>,
method: String,
) -> Self {
Self { id, sender, method }
}
pub async fn send(&self, value: serde_json::Value) -> Result<()> {
self.sender
.send(value)
.map_err(|_| Error::runtime_msg("Subscription channel closed"))?;
Ok(())
}
pub async fn send_value<T: Serialize>(&self, value: T) -> Result<()> {
let json_value = serde_json::to_value(value)
.map_err(|e| Error::runtime_msg(format!("Failed to serialize value: {e}")))?;
self.send(json_value).await
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn method(&self) -> &str {
&self.method
}
}
pub struct RpcSubscription<T> {
id: Uuid,
receiver: mpsc::UnboundedReceiver<serde_json::Value>,
_phantom: PhantomData<T>,
}
impl<T> RpcSubscription<T>
where
T: for<'de> Deserialize<'de>,
{
pub fn new(id: Uuid, receiver: mpsc::UnboundedReceiver<serde_json::Value>) -> Self {
Self {
id,
receiver,
_phantom: PhantomData,
}
}
pub async fn next(&mut self) -> Option<Result<T>> {
match self.receiver.recv().await {
Some(json_value) => match serde_json::from_value(json_value) {
Ok(value) => Some(Ok(value)),
Err(e) => Some(Err(Error::runtime_msg(format!(
"Failed to deserialize subscription data: {e}"
)))),
},
None => None,
}
}
pub async fn cancel(self) -> Result<()> {
log::trace!("Subscription {} canceled", self.id);
drop(self.receiver);
Ok(())
}
pub fn id(&self) -> Uuid {
self.id
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SubscriptionMessage {
Request {
id: Uuid,
method: String,
params: serde_json::Value,
},
Accept { id: Uuid },
Reject { id: Uuid, reason: String },
Data { id: Uuid, data: serde_json::Value },
Cancel { id: Uuid },
}
impl SubscriptionMessage {
pub fn id(&self) -> Uuid {
match self {
SubscriptionMessage::Request { id, .. } => *id,
SubscriptionMessage::Accept { id } => *id,
SubscriptionMessage::Reject { id, .. } => *id,
SubscriptionMessage::Data { id, .. } => *id,
SubscriptionMessage::Cancel { id } => *id,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_pending_subscription_accept() {
let (tx, mut rx) = mpsc::unbounded_channel();
let pending = PendingSubscriptionSink::new(Uuid::new_v4(), "test_method".to_string(), tx);
let sink = pending.accept().await.unwrap();
assert_eq!(sink.method(), "test_method");
sink.send_value("test_data").await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received, json!("test_data"));
}
#[tokio::test]
async fn test_pending_subscription_reject() {
let (tx, _rx) = mpsc::unbounded_channel();
let pending = PendingSubscriptionSink::new(Uuid::new_v4(), "test_method".to_string(), tx);
let result = pending.reject("Invalid parameters".to_string()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_rpc_subscription() {
let (tx, rx) = mpsc::unbounded_channel();
let mut subscription: RpcSubscription<String> = RpcSubscription::new(Uuid::new_v4(), rx);
tx.send(json!("test_message")).unwrap();
let received = subscription.next().await.unwrap().unwrap();
assert_eq!(received, "test_message");
}
#[tokio::test]
async fn test_subscription_closed() {
let (tx, rx) = mpsc::unbounded_channel();
let mut subscription: RpcSubscription<String> = RpcSubscription::new(Uuid::new_v4(), rx);
drop(tx);
let result = subscription.next().await;
assert!(result.is_none());
}
}