1use crate::{Error, Result};
7use serde::{Deserialize, Serialize};
8use std::marker::PhantomData;
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12pub struct PendingSubscriptionSink {
17 id: Uuid,
18 method: String,
19 sender: Option<mpsc::UnboundedSender<serde_json::Value>>,
20}
21
22impl PendingSubscriptionSink {
23 pub fn new(id: Uuid, method: String, sender: mpsc::UnboundedSender<serde_json::Value>) -> Self {
25 Self {
26 id,
27 method,
28 sender: Some(sender),
29 }
30 }
31
32 pub async fn accept(mut self) -> Result<SubscriptionSink> {
37 let sender = self
38 .sender
39 .take()
40 .ok_or_else(|| Error::runtime_msg("Subscription already accepted or rejected"))?;
41
42 log::trace!(
44 "Subscription {} accepted for method {}",
45 self.id,
46 self.method
47 );
48
49 Ok(SubscriptionSink::new(self.id, sender, self.method))
50 }
51
52 pub async fn reject(self, reason: String) -> Result<()> {
57 log::trace!(
59 "Subscription {} rejected for method {}: {}",
60 self.id,
61 self.method,
62 reason
63 );
64
65 drop(self.sender);
67
68 Err(Error::runtime_msg(format!(
69 "Subscription rejected: {reason}"
70 )))
71 }
72
73 pub fn id(&self) -> Uuid {
75 self.id
76 }
77
78 pub fn method(&self) -> &str {
80 &self.method
81 }
82}
83
84pub struct SubscriptionSink {
89 id: Uuid,
90 sender: mpsc::UnboundedSender<serde_json::Value>,
91 method: String,
92}
93
94impl SubscriptionSink {
95 pub(crate) fn new(
97 id: Uuid,
98 sender: mpsc::UnboundedSender<serde_json::Value>,
99 method: String,
100 ) -> Self {
101 Self { id, sender, method }
102 }
103
104 pub async fn send(&self, value: serde_json::Value) -> Result<()> {
108 self.sender
109 .send(value)
110 .map_err(|_| Error::runtime_msg("Subscription channel closed"))?;
111 Ok(())
112 }
113
114 pub async fn send_value<T: Serialize>(&self, value: T) -> Result<()> {
119 let json_value = serde_json::to_value(value)
120 .map_err(|e| Error::runtime_msg(format!("Failed to serialize value: {e}")))?;
121 self.send(json_value).await
122 }
123
124 pub fn is_closed(&self) -> bool {
128 self.sender.is_closed()
129 }
130
131 pub fn id(&self) -> Uuid {
133 self.id
134 }
135
136 pub fn method(&self) -> &str {
138 &self.method
139 }
140}
141
142pub struct RpcSubscription<T> {
147 id: Uuid,
148 receiver: mpsc::UnboundedReceiver<serde_json::Value>,
149 _phantom: PhantomData<T>,
150}
151
152impl<T> RpcSubscription<T>
153where
154 T: for<'de> Deserialize<'de>,
155{
156 pub fn new(id: Uuid, receiver: mpsc::UnboundedReceiver<serde_json::Value>) -> Self {
158 Self {
159 id,
160 receiver,
161 _phantom: PhantomData,
162 }
163 }
164
165 pub async fn next(&mut self) -> Option<Result<T>> {
170 match self.receiver.recv().await {
171 Some(json_value) => match serde_json::from_value(json_value) {
172 Ok(value) => Some(Ok(value)),
173 Err(e) => Some(Err(Error::runtime_msg(format!(
174 "Failed to deserialize subscription data: {e}"
175 )))),
176 },
177 None => None,
178 }
179 }
180
181 pub async fn cancel(self) -> Result<()> {
186 log::trace!("Subscription {} canceled", self.id);
188
189 drop(self.receiver);
191 Ok(())
192 }
193
194 pub fn id(&self) -> Uuid {
196 self.id
197 }
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub enum SubscriptionMessage {
203 Request {
205 id: Uuid,
206 method: String,
207 params: serde_json::Value,
208 },
209 Accept { id: Uuid },
211 Reject { id: Uuid, reason: String },
213 Data { id: Uuid, data: serde_json::Value },
215 Cancel { id: Uuid },
217}
218
219impl SubscriptionMessage {
220 pub fn id(&self) -> Uuid {
222 match self {
223 SubscriptionMessage::Request { id, .. } => *id,
224 SubscriptionMessage::Accept { id } => *id,
225 SubscriptionMessage::Reject { id, .. } => *id,
226 SubscriptionMessage::Data { id, .. } => *id,
227 SubscriptionMessage::Cancel { id } => *id,
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use serde_json::json;
236
237 #[tokio::test]
238 async fn test_pending_subscription_accept() {
239 let (tx, mut rx) = mpsc::unbounded_channel();
240 let pending = PendingSubscriptionSink::new(Uuid::new_v4(), "test_method".to_string(), tx);
241
242 let sink = pending.accept().await.unwrap();
243 assert_eq!(sink.method(), "test_method");
244
245 sink.send_value("test_data").await.unwrap();
247 let received = rx.recv().await.unwrap();
248 assert_eq!(received, json!("test_data"));
249 }
250
251 #[tokio::test]
252 async fn test_pending_subscription_reject() {
253 let (tx, _rx) = mpsc::unbounded_channel();
254 let pending = PendingSubscriptionSink::new(Uuid::new_v4(), "test_method".to_string(), tx);
255
256 let result = pending.reject("Invalid parameters".to_string()).await;
257 assert!(result.is_err());
258 }
259
260 #[tokio::test]
261 async fn test_rpc_subscription() {
262 let (tx, rx) = mpsc::unbounded_channel();
263 let mut subscription: RpcSubscription<String> = RpcSubscription::new(Uuid::new_v4(), rx);
264
265 tx.send(json!("test_message")).unwrap();
267
268 let received = subscription.next().await.unwrap().unwrap();
270 assert_eq!(received, "test_message");
271 }
272
273 #[tokio::test]
274 async fn test_subscription_closed() {
275 let (tx, rx) = mpsc::unbounded_channel();
276 let mut subscription: RpcSubscription<String> = RpcSubscription::new(Uuid::new_v4(), rx);
277
278 drop(tx);
280
281 let result = subscription.next().await;
283 assert!(result.is_none());
284 }
285}