use ahash::HashMap;
use jsonrpsee::{
ConnectionId, MethodResponse, MethodSink,
server::{
IntoSubscriptionCloseResponse, MethodCallback, Methods, RegisterMethodError,
ResponsePayload,
},
types::{ErrorObjectOwned, Id, Params, error::ErrorCode},
};
use parking_lot::Mutex;
use serde_json::value::{RawValue, to_raw_value};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::{mpsc, oneshot};
use super::error::ServerError;
pub const NOTIF_METHOD_NAME: &str = "xrpc.ch.val";
pub const CANCEL_METHOD_NAME: &str = "xrpc.cancel";
pub type ChannelId = u64;
pub type Subscribers =
Arc<Mutex<HashMap<(ConnectionId, Id<'static>), (MethodSink, mpsc::Receiver<()>, ChannelId)>>>;
#[derive(Debug)]
#[must_use = "PendingSubscriptionSink does nothing unless `accept` or `reject` is called"]
pub struct PendingSubscriptionSink {
pub(crate) inner: MethodSink,
pub(crate) method: &'static str,
pub(crate) subscribers: Subscribers,
pub(crate) id: Id<'static>,
pub(crate) subscribe: oneshot::Sender<MethodResponse>,
pub(crate) channel_id: ChannelId,
pub(crate) connection_id: ConnectionId,
}
impl PendingSubscriptionSink {
pub async fn accept(self) -> Result<SubscriptionSink, String> {
let channel_id = self.channel_id();
let id = self.id.clone();
let response = MethodResponse::subscription_response(
self.id,
ResponsePayload::success_borrowed(&channel_id),
self.inner.max_response_size() as usize,
);
let success = response.is_success();
self.inner
.send(response.to_json())
.await
.map_err(|e| e.to_string())?;
self.subscribe
.send(response)
.map_err(|e| format!("accept error: {}", e.as_json()))?;
if success {
let (tx, rx) = mpsc::channel(1);
self.subscribers.lock().insert(
(self.connection_id, id),
(self.inner.clone(), rx, self.channel_id),
);
tracing::debug!(
"Accepting subscription (conn_id={}, chann_id={})",
self.connection_id.0,
self.channel_id
);
Ok(SubscriptionSink {
inner: self.inner,
method: self.method,
unsubscribe: IsUnsubscribed(tx),
channel_id: self.channel_id,
})
} else {
panic!(
"The subscription response was too big; adjust the `max_response_size` or change Subscription ID generation"
);
}
}
pub fn channel_id(&self) -> ChannelId {
self.channel_id
}
}
#[derive(Debug, Clone)]
pub struct IsUnsubscribed(mpsc::Sender<()>);
impl IsUnsubscribed {
pub async fn unsubscribed(&self) {
self.0.closed().await;
}
}
#[derive(Debug, Clone)]
pub struct SubscriptionSink {
inner: MethodSink,
method: &'static str,
unsubscribe: IsUnsubscribed,
channel_id: ChannelId,
}
impl SubscriptionSink {
pub fn method_name(&self) -> &str {
self.method
}
pub fn channel_id(&self) -> ChannelId {
self.channel_id
}
pub async fn send(&self, msg: Box<serde_json::value::RawValue>) -> Result<(), String> {
if self.is_closed() {
return Err(format!("disconnect error: {msg}"));
}
self.inner.send(msg).await.map_err(|e| e.to_string())
}
pub fn is_closed(&self) -> bool {
self.inner.is_closed()
}
pub async fn closed(&self) {
tokio::select! {
_ = self.inner.closed() => (),
_ = self.unsubscribe.unsubscribed() => (),
}
}
}
fn create_notif_message(
sink: &SubscriptionSink,
result: &impl serde::Serialize,
) -> anyhow::Result<Box<RawValue>> {
let method = sink.method_name();
let channel_id = sink.channel_id();
let result = serde_json::to_value(result)?;
let msg = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": [channel_id, result]
});
tracing::debug!("Sending notification: {}", msg);
Ok(to_raw_value(&msg)?)
}
fn close_payload(channel_id: ChannelId) -> serde_json::Value {
serde_json::json!({
"jsonrpc":"2.0",
"method":"xrpc.ch.close",
"params":[channel_id]
})
}
fn close_channel_response(channel_id: ChannelId) -> MethodResponse {
MethodResponse::response(
Id::Null,
ResponsePayload::success(close_payload(channel_id)),
1024,
)
}
#[derive(Debug, Clone)]
pub struct RpcModule {
id_provider: Arc<AtomicU64>,
channels: Subscribers,
methods: Methods,
}
impl From<RpcModule> for Methods {
fn from(module: RpcModule) -> Methods {
module.methods
}
}
impl Default for RpcModule {
fn default() -> Self {
let mut methods = Methods::default();
let channels = Subscribers::default();
methods
.verify_and_insert(
CANCEL_METHOD_NAME,
MethodCallback::Unsubscription(Arc::new({
let channels = channels.clone();
move |id,
params: Params,
connection_id: ConnectionId,
_max_response,
_extensions| {
let cb = || {
let arr: [Id<'_>; 1] = params.parse()?;
let sub_id = arr[0].clone().into_owned();
tracing::debug!("Got cancel request (id={sub_id})");
let opt = channels.lock().remove(&(connection_id, sub_id));
match opt {
Some((_, _, channel_id)) => {
Ok::<ChannelId, ServerError>(channel_id)
}
None => Err::<ChannelId, ServerError>(ServerError::from(
anyhow::anyhow!("channel not found"),
)),
}
};
let result = cb();
match result {
Ok(channel_id) => {
let resp = close_channel_response(channel_id);
tracing::debug!("Sending close message: {}", resp.as_json());
resp
}
Err(e) => {
let error: ErrorObjectOwned = e.into();
MethodResponse::error(id, error)
}
}
}
})),
)
.expect("Inserting a method into an empty methods map is infallible.");
Self {
id_provider: Arc::new(AtomicU64::new(0)),
channels,
methods,
}
}
}
impl RpcModule {
pub fn register_channel<R, F>(
&mut self,
subscribe_method_name: &'static str,
callback: F,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
F: (Fn(Params) -> tokio::sync::broadcast::Receiver<R>) + Send + Sync + 'static,
R: serde::Serialize + Clone + Send + 'static,
{
self.register_channel_raw(subscribe_method_name, {
move |params, pending| {
let mut receiver = callback(params);
tokio::spawn(async move {
let sink = pending.accept().await.unwrap();
tracing::debug!("Channel created: chann_id={}", sink.channel_id);
loop {
tokio::select! {
action = receiver.recv() => {
match action {
Ok(msg) => {
match create_notif_message(&sink, &msg) {
Ok(msg) => {
if let Err(e) = sink.send(msg).await {
tracing::error!("Failed to send message: {:?}", e);
break;
}
}
Err(e) => {
tracing::error!("Failed to serialize channel message: {:?}", e);
break;
}
}
}
Err(RecvError::Closed) => {
if let Ok(payload) = to_raw_value(&close_payload(sink.channel_id())) {
let _ = sink.send(payload).await;
}
break;
}
Err(RecvError::Lagged(_)) => {
}
}
},
_ = sink.closed() => {
break;
}
}
}
tracing::debug!("Send notification task ended (chann_id={})", sink.channel_id);
});
}
})
}
fn register_channel_raw<R, F>(
&mut self,
subscribe_method_name: &'static str,
callback: F,
) -> Result<&mut MethodCallback, RegisterMethodError>
where
F: (Fn(Params, PendingSubscriptionSink) -> R) + Send + Sync + 'static,
R: IntoSubscriptionCloseResponse,
{
self.methods.verify_method_name(subscribe_method_name)?;
let subscribers = self.channels.clone();
self.methods.verify_and_insert(
subscribe_method_name,
MethodCallback::Subscription(Arc::new({
let id_provider = self.id_provider.clone();
move |id, params, method_sink, conn, _extensions| {
let channel_id = id_provider.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
let sink = PendingSubscriptionSink {
inner: method_sink.clone(),
method: NOTIF_METHOD_NAME,
subscribers: subscribers.clone(),
id: id.clone().into_owned(),
subscribe: tx,
channel_id,
connection_id: conn.conn_id,
};
callback(params, sink);
let id = id.clone().into_owned();
Box::pin(async move {
match rx.await {
Ok(rp) => rp,
Err(_) => MethodResponse::error(id, ErrorCode::InternalError),
}
})
}
})),
)
}
}