1use super::helpers::MethodSink;
30use super::{MethodResponse, MethodsError, ResponsePayload};
31use crate::server::error::{DisconnectError, PendingSubscriptionAcceptError, SendTimeoutError, TrySendError};
32use crate::server::rpc_module::ConnectionId;
33use crate::server::LOG_TARGET;
34use crate::{error::StringError, traits::IdProvider};
35use jsonrpsee_types::SubscriptionPayload;
36use jsonrpsee_types::{response::SubscriptionError, ErrorObjectOwned, Id, SubscriptionId, SubscriptionResponse};
37use parking_lot::Mutex;
38use rustc_hash::FxHashMap;
39use serde::{de::DeserializeOwned, Serialize};
40use std::{sync::Arc, time::Duration};
41use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
42
43pub type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, mpsc::Receiver<()>)>>>;
45pub type SubscriptionPermit = OwnedSemaphorePermit;
47
48pub trait IntoSubscriptionCloseResponse {
51 fn into_response(self) -> SubscriptionCloseResponse;
53}
54
55#[derive(Debug)]
57pub enum SubscriptionCloseResponse {
58 None,
60 Notif(SubscriptionMessage),
76 NotifErr(SubscriptionMessage),
92}
93
94impl IntoSubscriptionCloseResponse for Result<(), StringError> {
95 fn into_response(self) -> SubscriptionCloseResponse {
96 match self {
97 Ok(()) => SubscriptionCloseResponse::None,
98 Err(e) => SubscriptionCloseResponse::NotifErr(e.0.into()),
99 }
100 }
101}
102
103impl IntoSubscriptionCloseResponse for () {
104 fn into_response(self) -> SubscriptionCloseResponse {
105 SubscriptionCloseResponse::None
106 }
107}
108
109impl IntoSubscriptionCloseResponse for SubscriptionCloseResponse {
110 fn into_response(self) -> Self {
111 self
112 }
113}
114
115#[derive(Debug, Clone)]
117pub enum SubscriptionMessageInner {
118 Complete(String),
120 NeedsData(String),
122}
123
124#[derive(Debug, Clone)]
126pub struct SubscriptionMessage(pub(crate) SubscriptionMessageInner);
127
128impl SubscriptionMessage {
129 pub fn from_json(t: &impl Serialize) -> Result<Self, serde_json::Error> {
133 serde_json::to_string(t).map(|json| SubscriptionMessage(SubscriptionMessageInner::NeedsData(json)))
134 }
135
136 pub fn new(method: &str, subscription: SubscriptionId, result: &impl Serialize) -> Result<Self, serde_json::Error> {
141 let json = serde_json::to_string(&SubscriptionResponse::new(
142 method.into(),
143 SubscriptionPayload { subscription, result },
144 ))?;
145 Ok(Self::from_complete_message(json))
146 }
147
148 pub(crate) fn from_complete_message(msg: String) -> Self {
149 SubscriptionMessage(SubscriptionMessageInner::Complete(msg))
150 }
151
152 pub(crate) fn empty() -> Self {
153 Self::from_complete_message(String::new())
154 }
155}
156
157impl<T> From<T> for SubscriptionMessage
158where
159 T: AsRef<str>,
160{
161 fn from(s: T) -> Self {
162 let json_str = {
164 let s = s.as_ref();
165 let mut res = String::with_capacity(s.len() + 2);
166 res.push('"');
167 res.push_str(s);
168 res.push('"');
169 res
170 };
171
172 SubscriptionMessage(SubscriptionMessageInner::NeedsData(json_str))
173 }
174}
175
176#[derive(Clone, Debug, PartialEq, Eq, Hash)]
178pub struct SubscriptionKey {
179 pub(crate) conn_id: ConnectionId,
180 pub(crate) sub_id: SubscriptionId<'static>,
181}
182
183#[derive(Debug, Clone, Copy)]
184pub(crate) enum SubNotifResultOrError {
185 Result,
186 Error,
187}
188
189impl SubNotifResultOrError {
190 pub(crate) const fn as_str(&self) -> &str {
191 match self {
192 Self::Result => "result",
193 Self::Error => "error",
194 }
195 }
196}
197
198#[derive(Debug, Clone)]
203pub struct IsUnsubscribed(mpsc::Sender<()>);
204
205impl IsUnsubscribed {
206 pub fn is_unsubscribed(&self) -> bool {
211 self.0.is_closed()
212 }
213
214 pub async fn unsubscribed(&self) {
221 self.0.closed().await;
222 }
223}
224
225#[derive(Debug)]
232#[must_use = "PendingSubscriptionSink does nothing unless `accept` or `reject` is called"]
233pub struct PendingSubscriptionSink {
234 pub(crate) inner: MethodSink,
236 pub(crate) method: &'static str,
238 pub(crate) subscribers: Subscribers,
240 pub(crate) uniq_sub: SubscriptionKey,
242 pub(crate) id: Id<'static>,
245 pub(crate) subscribe: oneshot::Sender<MethodResponse>,
247 pub(crate) permit: OwnedSemaphorePermit,
249}
250
251impl PendingSubscriptionSink {
252 pub async fn reject(self, err: impl Into<ErrorObjectOwned>) {
261 let err = MethodResponse::subscription_error(self.id, err.into());
262 _ = self.inner.send(err.to_result()).await;
263 _ = self.subscribe.send(err);
264 }
265
266 pub async fn accept(self) -> Result<SubscriptionSink, PendingSubscriptionAcceptError> {
272 let response = MethodResponse::subscription_response(
273 self.id,
274 ResponsePayload::success_borrowed(&self.uniq_sub.sub_id),
275 self.inner.max_response_size() as usize,
276 );
277 let success = response.is_success();
278
279 self.inner.send(response.to_result()).await.map_err(|_| PendingSubscriptionAcceptError)?;
286 self.subscribe.send(response).map_err(|_| PendingSubscriptionAcceptError)?;
287
288 if success {
289 let (tx, rx) = mpsc::channel(1);
290 self.subscribers.lock().insert(self.uniq_sub.clone(), (self.inner.clone(), rx));
291 Ok(SubscriptionSink {
292 inner: self.inner,
293 method: self.method,
294 subscribers: self.subscribers,
295 uniq_sub: self.uniq_sub,
296 unsubscribe: IsUnsubscribed(tx),
297 _permit: Arc::new(self.permit),
298 })
299 } else {
300 panic!("The subscription response was too big; adjust the `max_response_size` or change Subscription ID generation");
301 }
302 }
303
304 pub fn connection_id(&self) -> ConnectionId {
306 self.uniq_sub.conn_id
307 }
308
309 pub fn capacity(&self) -> usize {
311 self.inner.capacity()
312 }
313
314 pub fn max_capacity(&self) -> usize {
316 self.inner.max_capacity()
317 }
318
319 pub fn method_name(&self) -> &str {
321 self.method
322 }
323}
324
325#[derive(Debug, Clone)]
327pub struct SubscriptionSink {
328 inner: MethodSink,
330 method: &'static str,
332 subscribers: Subscribers,
334 uniq_sub: SubscriptionKey,
336 unsubscribe: IsUnsubscribed,
338 _permit: Arc<SubscriptionPermit>,
340}
341
342impl SubscriptionSink {
343 pub fn subscription_id(&self) -> SubscriptionId<'static> {
345 self.uniq_sub.sub_id.clone()
346 }
347
348 pub fn method_name(&self) -> &str {
350 self.method
351 }
352
353 pub fn connection_id(&self) -> ConnectionId {
355 self.uniq_sub.conn_id
356 }
357
358 pub async fn send(&self, msg: SubscriptionMessage) -> Result<(), DisconnectError> {
369 if self.is_closed() {
371 return Err(DisconnectError(msg));
372 }
373
374 let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &self.uniq_sub.sub_id, self.method);
375 self.inner.send(json).await.map_err(Into::into)
376 }
377
378 pub async fn send_timeout(&self, msg: SubscriptionMessage, timeout: Duration) -> Result<(), SendTimeoutError> {
380 if self.is_closed() {
382 return Err(SendTimeoutError::Closed(msg));
383 }
384
385 let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &self.uniq_sub.sub_id, self.method);
386 self.inner.send_timeout(json, timeout).await.map_err(Into::into)
387 }
388
389 pub fn try_send(&mut self, msg: SubscriptionMessage) -> Result<(), TrySendError> {
396 if self.is_closed() {
398 return Err(TrySendError::Closed(msg));
399 }
400
401 let json = sub_message_to_json(msg, SubNotifResultOrError::Result, &self.uniq_sub.sub_id, self.method);
402 self.inner.try_send(json).map_err(Into::into)
403 }
404
405 pub fn is_closed(&self) -> bool {
407 self.inner.is_closed() || !self.is_active_subscription()
408 }
409
410 pub async fn closed(&self) {
412 tokio::select! {
414 _ = self.inner.closed() => (),
415 _ = self.unsubscribe.unsubscribed() => (),
416 }
417 }
418
419 pub fn capacity(&self) -> usize {
421 self.inner.capacity()
422 }
423
424 pub fn max_capacity(&self) -> usize {
426 self.inner.max_capacity()
427 }
428
429 fn is_active_subscription(&self) -> bool {
430 !self.unsubscribe.is_unsubscribed()
431 }
432}
433
434impl Drop for SubscriptionSink {
435 fn drop(&mut self) {
436 if self.is_active_subscription() {
437 self.subscribers.lock().remove(&self.uniq_sub);
438 }
439 }
440}
441
442#[derive(Debug)]
444pub struct Subscription {
445 pub(crate) rx: mpsc::Receiver<String>,
446 pub(crate) sub_id: SubscriptionId<'static>,
447}
448
449impl Subscription {
450 pub fn close(&mut self) {
452 tracing::trace!(target: LOG_TARGET, "[Subscription::close] Notifying");
453 self.rx.close();
454 }
455
456 pub fn subscription_id(&self) -> &SubscriptionId {
458 &self.sub_id
459 }
460
461 pub async fn next<T: DeserializeOwned>(&mut self) -> Option<Result<(T, SubscriptionId<'static>), MethodsError>> {
463 let raw = self.rx.recv().await?;
464
465 tracing::debug!(target: LOG_TARGET, "[Subscription::next]: rx {}", raw);
466
467 #[allow(clippy::let_and_return)]
469 let res = match serde_json::from_str::<SubscriptionResponse<T>>(&raw) {
470 Ok(r) => Some(Ok((r.params.result, r.params.subscription.into_owned()))),
471 Err(e) => match serde_json::from_str::<SubscriptionError<serde_json::Value>>(&raw) {
472 Ok(_) => None,
473 Err(_) => Some(Err(e.into())),
474 },
475 };
476 res
477 }
478}
479
480impl Drop for Subscription {
481 fn drop(&mut self) {
482 self.close();
483 }
484}
485
486#[derive(Debug, Clone)]
488pub struct BoundedSubscriptions {
489 guard: Arc<Semaphore>,
490 max: u32,
491}
492
493impl BoundedSubscriptions {
494 pub fn new(max_subscriptions: u32) -> Self {
496 Self { guard: Arc::new(Semaphore::new(max_subscriptions as usize)), max: max_subscriptions }
497 }
498
499 pub fn acquire(&self) -> Option<SubscriptionPermit> {
503 Arc::clone(&self.guard).try_acquire_owned().ok()
504 }
505
506 pub const fn max(&self) -> u32 {
508 self.max
509 }
510}
511
512#[derive(Debug)]
513pub struct SubscriptionState<'a> {
515 pub conn_id: ConnectionId,
517 pub id_provider: &'a dyn IdProvider,
519 pub subscription_permit: SubscriptionPermit,
521}
522
523pub(crate) fn sub_message_to_json(
524 msg: SubscriptionMessage,
525 result_or_err: SubNotifResultOrError,
526 sub_id: &SubscriptionId,
527 method: &str,
528) -> String {
529 let result_or_err = result_or_err.as_str();
530
531 match msg.0 {
532 SubscriptionMessageInner::Complete(msg) => msg,
533 SubscriptionMessageInner::NeedsData(result) => {
534 let sub_id = serde_json::to_string(&sub_id).expect("valid JSON; qed");
535 format!(
536 r#"{{"jsonrpc":"2.0","method":"{method}","params":{{"subscription":{sub_id},"{result_or_err}":{result}}}}}"#,
537 )
538 }
539 }
540}