1use super::helpers::MethodSink;
30use super::{MethodResponse, MethodsError, ResponsePayload};
31use crate::server::LOG_TARGET;
32use crate::server::error::{DisconnectError, PendingSubscriptionAcceptError, SendTimeoutError, TrySendError};
33use crate::server::rpc_module::ConnectionId;
34use crate::{error::SubscriptionError, traits::IdProvider};
35use jsonrpsee_types::SubscriptionPayload;
36use jsonrpsee_types::response::SubscriptionPayloadError;
37use jsonrpsee_types::{ErrorObjectOwned, Id, SubscriptionId, SubscriptionResponse};
38use parking_lot::Mutex;
39use rustc_hash::FxHashMap;
40use serde::{Serialize, de::DeserializeOwned};
41use serde_json::value::RawValue;
42use std::{sync::Arc, time::Duration};
43use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
44
45pub type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, mpsc::Receiver<()>)>>>;
47pub type SubscriptionPermit = OwnedSemaphorePermit;
49
50pub trait IntoSubscriptionCloseResponse {
53 fn into_response(self) -> SubscriptionCloseResponse;
55}
56
57#[derive(Debug)]
59pub enum SubscriptionCloseResponse {
60 None,
62 Notif(SubscriptionMessage),
78 NotifErr(SubscriptionError),
94}
95
96impl IntoSubscriptionCloseResponse for Result<(), SubscriptionError> {
97 fn into_response(self) -> SubscriptionCloseResponse {
98 match self {
99 Ok(()) => SubscriptionCloseResponse::None,
100 Err(e) => SubscriptionCloseResponse::NotifErr(e),
101 }
102 }
103}
104
105impl IntoSubscriptionCloseResponse for () {
106 fn into_response(self) -> SubscriptionCloseResponse {
107 SubscriptionCloseResponse::None
108 }
109}
110
111impl IntoSubscriptionCloseResponse for SubscriptionCloseResponse {
112 fn into_response(self) -> Self {
113 self
114 }
115}
116
117#[derive(Debug, Clone)]
119pub enum SubscriptionMessageInner {
120 Complete(Box<RawValue>),
122 NeedsData(Box<RawValue>),
124}
125
126#[derive(Debug, Clone)]
128pub struct SubscriptionMessage(pub(crate) SubscriptionMessageInner);
129
130impl From<Box<RawValue>> for SubscriptionMessage {
131 fn from(json: Box<RawValue>) -> Self {
132 Self(SubscriptionMessageInner::NeedsData(json))
133 }
134}
135
136impl SubscriptionMessage {
137 pub fn new(method: &str, subscription: SubscriptionId, result: &impl Serialize) -> Result<Self, serde_json::Error> {
142 let json = serde_json::value::to_raw_value(&SubscriptionResponse::new(
143 method.into(),
144 SubscriptionPayload { subscription, result },
145 ))?;
146 Ok(Self::from_complete_message(json))
147 }
148
149 pub(crate) fn from_complete_message(msg: Box<RawValue>) -> Self {
150 SubscriptionMessage(SubscriptionMessageInner::Complete(msg))
151 }
152}
153
154#[derive(Clone, Debug, PartialEq, Eq, Hash)]
156pub struct SubscriptionKey {
157 pub(crate) conn_id: ConnectionId,
158 pub(crate) sub_id: SubscriptionId<'static>,
159}
160
161#[derive(Debug, Clone)]
166pub struct IsUnsubscribed(mpsc::Sender<()>);
167
168impl IsUnsubscribed {
169 pub fn is_unsubscribed(&self) -> bool {
174 self.0.is_closed()
175 }
176
177 pub async fn unsubscribed(&self) {
184 self.0.closed().await;
185 }
186}
187
188#[derive(Debug)]
195#[must_use = "PendingSubscriptionSink does nothing unless `accept` or `reject` is called"]
196pub struct PendingSubscriptionSink {
197 pub(crate) inner: MethodSink,
199 pub(crate) method: &'static str,
201 pub(crate) subscribers: Subscribers,
203 pub(crate) uniq_sub: SubscriptionKey,
205 pub(crate) id: Id<'static>,
208 pub(crate) subscribe: oneshot::Sender<MethodResponse>,
210 pub(crate) permit: OwnedSemaphorePermit,
212}
213
214impl PendingSubscriptionSink {
215 pub async fn reject(self, err: impl Into<ErrorObjectOwned>) {
224 let err = MethodResponse::subscription_error(self.id, err.into());
225 _ = self.inner.send(err.to_json()).await;
226 _ = self.subscribe.send(err);
227 }
228
229 pub async fn accept(self) -> Result<SubscriptionSink, PendingSubscriptionAcceptError> {
235 let response = MethodResponse::subscription_response(
236 self.id,
237 ResponsePayload::success_borrowed(&self.uniq_sub.sub_id),
238 self.inner.max_response_size() as usize,
239 );
240 let success = response.is_success();
241
242 self.inner.send(response.to_json()).await.map_err(|_| PendingSubscriptionAcceptError)?;
249 self.subscribe.send(response).map_err(|_| PendingSubscriptionAcceptError)?;
250
251 if success {
252 let (tx, rx) = mpsc::channel(1);
253 self.subscribers.lock().insert(self.uniq_sub.clone(), (self.inner.clone(), rx));
254 Ok(SubscriptionSink {
255 inner: self.inner,
256 method: self.method,
257 subscribers: self.subscribers,
258 uniq_sub: self.uniq_sub,
259 unsubscribe: IsUnsubscribed(tx),
260 _permit: Arc::new(self.permit),
261 })
262 } else {
263 panic!(
264 "The subscription response was too big; adjust the `max_response_size` or change Subscription ID generation"
265 );
266 }
267 }
268
269 pub fn connection_id(&self) -> ConnectionId {
271 self.uniq_sub.conn_id
272 }
273
274 pub fn capacity(&self) -> usize {
276 self.inner.capacity()
277 }
278
279 pub fn max_capacity(&self) -> usize {
281 self.inner.max_capacity()
282 }
283
284 pub fn method_name(&self) -> &str {
286 self.method
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct SubscriptionSink {
293 inner: MethodSink,
295 method: &'static str,
297 subscribers: Subscribers,
299 uniq_sub: SubscriptionKey,
301 unsubscribe: IsUnsubscribed,
303 _permit: Arc<SubscriptionPermit>,
305}
306
307impl SubscriptionSink {
308 pub fn subscription_id(&self) -> SubscriptionId<'static> {
310 self.uniq_sub.sub_id.clone()
311 }
312
313 pub fn method_name(&self) -> &str {
315 self.method
316 }
317
318 pub fn connection_id(&self) -> ConnectionId {
320 self.uniq_sub.conn_id
321 }
322
323 pub async fn send(&self, msg: impl Into<SubscriptionMessage>) -> Result<(), DisconnectError> {
334 let msg = msg.into();
335
336 if self.is_closed() {
338 return Err(DisconnectError(msg));
339 }
340
341 let json = sub_message_to_json(msg, &self.uniq_sub.sub_id, self.method);
342 self.inner.send(json).await
343 }
344
345 pub async fn send_timeout(
347 &self,
348 msg: impl Into<SubscriptionMessage>,
349 timeout: Duration,
350 ) -> Result<(), SendTimeoutError> {
351 let msg = msg.into();
352
353 if self.is_closed() {
355 return Err(SendTimeoutError::Closed(msg));
356 }
357
358 let json = sub_message_to_json(msg, &self.uniq_sub.sub_id, self.method);
359 self.inner.send_timeout(json, timeout).await
360 }
361
362 pub fn try_send(&mut self, msg: impl Into<SubscriptionMessage>) -> Result<(), TrySendError> {
369 let msg = msg.into();
370
371 if self.is_closed() {
373 return Err(TrySendError::Closed(msg));
374 }
375
376 let json = sub_message_to_json(msg, &self.uniq_sub.sub_id, self.method);
377 self.inner.try_send(json)
378 }
379
380 pub fn is_closed(&self) -> bool {
382 self.inner.is_closed() || !self.is_active_subscription()
383 }
384
385 pub async fn closed(&self) {
387 tokio::select! {
389 _ = self.inner.closed() => (),
390 _ = self.unsubscribe.unsubscribed() => (),
391 }
392 }
393
394 pub fn capacity(&self) -> usize {
396 self.inner.capacity()
397 }
398
399 pub fn max_capacity(&self) -> usize {
401 self.inner.max_capacity()
402 }
403
404 fn is_active_subscription(&self) -> bool {
405 !self.unsubscribe.is_unsubscribed()
406 }
407}
408
409impl Drop for SubscriptionSink {
410 fn drop(&mut self) {
411 if self.is_active_subscription() {
412 self.subscribers.lock().remove(&self.uniq_sub);
413 }
414 }
415}
416
417#[derive(Debug)]
419pub struct Subscription {
420 pub(crate) rx: mpsc::Receiver<Box<RawValue>>,
421 pub(crate) sub_id: SubscriptionId<'static>,
422}
423
424impl Subscription {
425 pub fn close(&mut self) {
427 tracing::trace!(target: LOG_TARGET, "[Subscription::close] Notifying");
428 self.rx.close();
429 }
430
431 pub fn subscription_id(&self) -> &SubscriptionId {
433 &self.sub_id
434 }
435
436 pub async fn next<T: DeserializeOwned>(&mut self) -> Option<Result<(T, SubscriptionId<'static>), MethodsError>> {
438 let raw = self.rx.recv().await?;
439
440 tracing::debug!(target: LOG_TARGET, "[Subscription::next]: rx {}", raw);
441
442 #[allow(clippy::let_and_return)]
444 let res = match serde_json::from_str::<SubscriptionResponse<T>>(raw.get()) {
445 Ok(r) => Some(Ok((r.params.result, r.params.subscription.into_owned()))),
446 Err(e) => {
447 match serde_json::from_str::<jsonrpsee_types::response::SubscriptionError<&RawValue>>(raw.get()) {
448 Ok(_) => None,
449 Err(_) => Some(Err(e.into())),
450 }
451 }
452 };
453 res
454 }
455}
456
457impl Drop for Subscription {
458 fn drop(&mut self) {
459 self.close();
460 }
461}
462
463#[derive(Debug, Clone)]
465pub struct BoundedSubscriptions {
466 guard: Arc<Semaphore>,
467 max: u32,
468}
469
470impl BoundedSubscriptions {
471 pub fn new(max_subscriptions: u32) -> Self {
473 Self { guard: Arc::new(Semaphore::new(max_subscriptions as usize)), max: max_subscriptions }
474 }
475
476 pub fn acquire(&self) -> Option<SubscriptionPermit> {
480 Arc::clone(&self.guard).try_acquire_owned().ok()
481 }
482
483 pub const fn max(&self) -> u32 {
485 self.max
486 }
487}
488
489#[derive(Debug)]
490pub struct SubscriptionState<'a> {
492 pub conn_id: ConnectionId,
494 pub id_provider: &'a dyn IdProvider,
496 pub subscription_permit: SubscriptionPermit,
498}
499
500pub(crate) fn sub_message_to_json(msg: SubscriptionMessage, sub_id: &SubscriptionId, method: &str) -> Box<RawValue> {
501 match msg.0 {
502 SubscriptionMessageInner::Complete(msg) => msg,
503 SubscriptionMessageInner::NeedsData(result) => serde_json::value::to_raw_value(&SubscriptionResponse::new(
504 method.into(),
505 SubscriptionPayload { subscription: sub_id.clone(), result },
506 ))
507 .expect("Serialize infallible; qed"),
508 }
509}
510
511pub(crate) fn sub_err_to_json(error: SubscriptionError, sub_id: SubscriptionId, method: &str) -> Box<RawValue> {
512 serde_json::value::to_raw_value(&jsonrpsee_types::response::SubscriptionError::new(
513 method.into(),
514 SubscriptionPayloadError { subscription: sub_id, error },
515 ))
516 .expect("Serialize infallible; qed")
517}