use super::helpers::MethodSink;
use super::{MethodResponse, MethodsError, ResponsePayload};
use crate::server::LOG_TARGET;
use crate::server::error::{DisconnectError, PendingSubscriptionAcceptError, SendTimeoutError, TrySendError};
use crate::server::rpc_module::ConnectionId;
use crate::{error::SubscriptionError, traits::IdProvider};
use jsonrpsee_types::SubscriptionPayload;
use jsonrpsee_types::response::SubscriptionPayloadError;
use jsonrpsee_types::{ErrorObjectOwned, Id, SubscriptionId, SubscriptionResponse};
use parking_lot::Mutex;
use rustc_hash::FxHashMap;
use serde::{Serialize, de::DeserializeOwned};
use serde_json::value::RawValue;
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
pub type Subscribers = Arc<Mutex<FxHashMap<SubscriptionKey, (MethodSink, mpsc::Receiver<()>)>>>;
pub type SubscriptionPermit = OwnedSemaphorePermit;
pub trait IntoSubscriptionCloseResponse {
fn into_response(self) -> SubscriptionCloseResponse;
}
#[derive(Debug)]
pub enum SubscriptionCloseResponse {
None,
Notif(SubscriptionMessage),
NotifErr(SubscriptionError),
}
impl IntoSubscriptionCloseResponse for Result<(), SubscriptionError> {
fn into_response(self) -> SubscriptionCloseResponse {
match self {
Ok(()) => SubscriptionCloseResponse::None,
Err(e) => SubscriptionCloseResponse::NotifErr(e),
}
}
}
impl IntoSubscriptionCloseResponse for () {
fn into_response(self) -> SubscriptionCloseResponse {
SubscriptionCloseResponse::None
}
}
impl IntoSubscriptionCloseResponse for SubscriptionCloseResponse {
fn into_response(self) -> Self {
self
}
}
#[derive(Debug, Clone)]
pub enum SubscriptionMessageInner {
Complete(Box<RawValue>),
NeedsData(Box<RawValue>),
}
#[derive(Debug, Clone)]
pub struct SubscriptionMessage(pub(crate) SubscriptionMessageInner);
impl From<Box<RawValue>> for SubscriptionMessage {
fn from(json: Box<RawValue>) -> Self {
Self(SubscriptionMessageInner::NeedsData(json))
}
}
impl SubscriptionMessage {
pub fn new(method: &str, subscription: SubscriptionId, result: &impl Serialize) -> Result<Self, serde_json::Error> {
let json = serde_json::value::to_raw_value(&SubscriptionResponse::new(
method.into(),
SubscriptionPayload { subscription, result },
))?;
Ok(Self::from_complete_message(json))
}
pub(crate) fn from_complete_message(msg: Box<RawValue>) -> Self {
SubscriptionMessage(SubscriptionMessageInner::Complete(msg))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct SubscriptionKey {
pub(crate) conn_id: ConnectionId,
pub(crate) sub_id: SubscriptionId<'static>,
}
#[derive(Debug, Clone)]
pub struct IsUnsubscribed(mpsc::Sender<()>);
impl IsUnsubscribed {
pub fn is_unsubscribed(&self) -> bool {
self.0.is_closed()
}
pub async fn unsubscribed(&self) {
self.0.closed().await;
}
}
#[derive(Debug)]
#[must_use = "PendingSubscriptionSink does nothing unless `accept` or `reject` is called"]
pub struct PendingSubscriptionSink {
pub(crate) inner: MethodSink,
pub(crate) method: &'static str,
pub(crate) subscribers: Subscribers,
pub(crate) uniq_sub: SubscriptionKey,
pub(crate) id: Id<'static>,
pub(crate) subscribe: oneshot::Sender<MethodResponse>,
pub(crate) permit: OwnedSemaphorePermit,
}
impl PendingSubscriptionSink {
pub async fn reject(self, err: impl Into<ErrorObjectOwned>) {
let err = MethodResponse::subscription_error(self.id, err.into());
_ = self.inner.send(err.to_json()).await;
_ = self.subscribe.send(err);
}
pub async fn accept(self) -> Result<SubscriptionSink, PendingSubscriptionAcceptError> {
let response = MethodResponse::subscription_response(
self.id,
ResponsePayload::success_borrowed(&self.uniq_sub.sub_id),
self.inner.max_response_size() as usize,
);
let success = response.is_success();
self.inner.send(response.to_json()).await.map_err(|_| PendingSubscriptionAcceptError)?;
self.subscribe.send(response).map_err(|_| PendingSubscriptionAcceptError)?;
if success {
let (tx, rx) = mpsc::channel(1);
self.subscribers.lock().insert(self.uniq_sub.clone(), (self.inner.clone(), rx));
Ok(SubscriptionSink {
inner: self.inner,
method: self.method,
subscribers: self.subscribers,
uniq_sub: self.uniq_sub,
unsubscribe: IsUnsubscribed(tx),
_permit: Arc::new(self.permit),
})
} else {
panic!(
"The subscription response was too big; adjust the `max_response_size` or change Subscription ID generation"
);
}
}
pub fn connection_id(&self) -> ConnectionId {
self.uniq_sub.conn_id
}
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
pub fn max_capacity(&self) -> usize {
self.inner.max_capacity()
}
pub fn method_name(&self) -> &str {
self.method
}
}
#[derive(Debug, Clone)]
pub struct SubscriptionSink {
inner: MethodSink,
method: &'static str,
subscribers: Subscribers,
uniq_sub: SubscriptionKey,
unsubscribe: IsUnsubscribed,
_permit: Arc<SubscriptionPermit>,
}
impl SubscriptionSink {
pub fn subscription_id(&self) -> SubscriptionId<'static> {
self.uniq_sub.sub_id.clone()
}
pub fn method_name(&self) -> &str {
self.method
}
pub fn connection_id(&self) -> ConnectionId {
self.uniq_sub.conn_id
}
pub async fn send(&self, msg: impl Into<SubscriptionMessage>) -> Result<(), DisconnectError> {
let msg = msg.into();
if self.is_closed() {
return Err(DisconnectError(msg));
}
let json = sub_message_to_json(msg, &self.uniq_sub.sub_id, self.method);
self.inner.send(json).await
}
pub async fn send_timeout(
&self,
msg: impl Into<SubscriptionMessage>,
timeout: Duration,
) -> Result<(), SendTimeoutError> {
let msg = msg.into();
if self.is_closed() {
return Err(SendTimeoutError::Closed(msg));
}
let json = sub_message_to_json(msg, &self.uniq_sub.sub_id, self.method);
self.inner.send_timeout(json, timeout).await
}
pub fn try_send(&mut self, msg: impl Into<SubscriptionMessage>) -> Result<(), TrySendError> {
let msg = msg.into();
if self.is_closed() {
return Err(TrySendError::Closed(msg));
}
let json = sub_message_to_json(msg, &self.uniq_sub.sub_id, self.method);
self.inner.try_send(json)
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed() || !self.is_active_subscription()
}
pub async fn closed(&self) {
tokio::select! {
_ = self.inner.closed() => (),
_ = self.unsubscribe.unsubscribed() => (),
}
}
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
pub fn max_capacity(&self) -> usize {
self.inner.max_capacity()
}
fn is_active_subscription(&self) -> bool {
!self.unsubscribe.is_unsubscribed()
}
}
impl Drop for SubscriptionSink {
fn drop(&mut self) {
if self.is_active_subscription() {
self.subscribers.lock().remove(&self.uniq_sub);
}
}
}
#[derive(Debug)]
pub struct Subscription {
pub(crate) rx: mpsc::Receiver<Box<RawValue>>,
pub(crate) sub_id: SubscriptionId<'static>,
}
impl Subscription {
pub fn close(&mut self) {
tracing::trace!(target: LOG_TARGET, "[Subscription::close] Notifying");
self.rx.close();
}
pub fn subscription_id(&self) -> &SubscriptionId {
&self.sub_id
}
pub async fn next<T: DeserializeOwned>(&mut self) -> Option<Result<(T, SubscriptionId<'static>), MethodsError>> {
let raw = self.rx.recv().await?;
tracing::debug!(target: LOG_TARGET, "[Subscription::next]: rx {}", raw);
#[allow(clippy::let_and_return)]
let res = match serde_json::from_str::<SubscriptionResponse<T>>(raw.get()) {
Ok(r) => Some(Ok((r.params.result, r.params.subscription.into_owned()))),
Err(e) => {
match serde_json::from_str::<jsonrpsee_types::response::SubscriptionError<&RawValue>>(raw.get()) {
Ok(_) => None,
Err(_) => Some(Err(e.into())),
}
}
};
res
}
}
impl Drop for Subscription {
fn drop(&mut self) {
self.close();
}
}
#[derive(Debug, Clone)]
pub struct BoundedSubscriptions {
guard: Arc<Semaphore>,
max: u32,
}
impl BoundedSubscriptions {
pub fn new(max_subscriptions: u32) -> Self {
Self { guard: Arc::new(Semaphore::new(max_subscriptions as usize)), max: max_subscriptions }
}
pub fn acquire(&self) -> Option<SubscriptionPermit> {
Arc::clone(&self.guard).try_acquire_owned().ok()
}
pub const fn max(&self) -> u32 {
self.max
}
}
#[derive(Debug)]
pub struct SubscriptionState<'a> {
pub conn_id: ConnectionId,
pub id_provider: &'a dyn IdProvider,
pub subscription_permit: SubscriptionPermit,
}
pub(crate) fn sub_message_to_json(msg: SubscriptionMessage, sub_id: &SubscriptionId, method: &str) -> Box<RawValue> {
match msg.0 {
SubscriptionMessageInner::Complete(msg) => msg,
SubscriptionMessageInner::NeedsData(result) => serde_json::value::to_raw_value(&SubscriptionResponse::new(
method.into(),
SubscriptionPayload { subscription: sub_id.clone(), result },
))
.expect("Serialize infallible; qed"),
}
}
pub(crate) fn sub_err_to_json(error: SubscriptionError, sub_id: SubscriptionId, method: &str) -> Box<RawValue> {
serde_json::value::to_raw_value(&jsonrpsee_types::response::SubscriptionError::new(
method.into(),
SubscriptionPayloadError { subscription: sub_id, error },
))
.expect("Serialize infallible; qed")
}