use crate::{rpc::RpcResponse, shared_client::SharedClient, JsonRpcError, LightClientRpcError};
use futures::{stream::StreamExt, FutureExt};
use serde_json::value::RawValue;
use smoldot_light::platform::PlatformRef;
use std::{collections::HashMap, str::FromStr};
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::UnboundedReceiverStream;
const LOG_TARGET: &str = "subxt-light-client-background-task";
pub type MethodResponse = Result<Box<RawValue>, LightClientRpcError>;
pub type SubscriptionResponse = Result<
(SubscriptionId, mpsc::UnboundedReceiver<Result<Box<RawValue>, JsonRpcError>>),
LightClientRpcError,
>;
pub type SubscriptionId = String;
#[derive(Debug)]
enum Message {
Request {
method: String,
params: Option<Box<RawValue>>,
sender: oneshot::Sender<MethodResponse>,
},
Subscription {
method: String,
unsubscribe_method: String,
params: Option<Box<RawValue>>,
sender: oneshot::Sender<SubscriptionResponse>,
},
}
#[derive(Clone, Debug)]
pub struct BackgroundTaskHandle {
to_backend: mpsc::UnboundedSender<Message>,
}
impl BackgroundTaskHandle {
pub async fn request(&self, method: String, params: Option<Box<RawValue>>) -> MethodResponse {
let (tx, rx) = oneshot::channel();
self.to_backend
.send(Message::Request { method, params, sender: tx })
.map_err(|_e| LightClientRpcError::BackgroundTaskDropped)?;
match rx.await {
Err(_e) => Err(LightClientRpcError::BackgroundTaskDropped),
Ok(response) => response,
}
}
pub async fn subscribe(
&self,
method: String,
params: Option<Box<RawValue>>,
unsubscribe_method: String,
) -> SubscriptionResponse {
let (tx, rx) = oneshot::channel();
self.to_backend
.send(Message::Subscription { method, params, unsubscribe_method, sender: tx })
.map_err(|_e| LightClientRpcError::BackgroundTaskDropped)?;
match rx.await {
Err(_e) => Err(LightClientRpcError::BackgroundTaskDropped),
Ok(response) => response,
}
}
}
#[allow(clippy::type_complexity)]
pub struct BackgroundTask<TPlatform: PlatformRef, TChain> {
channels: BackgroundTaskChannels<TPlatform>,
data: BackgroundTaskData<TPlatform, TChain>,
}
impl<TPlatform: PlatformRef, TChain> BackgroundTask<TPlatform, TChain> {
pub(crate) fn new(
client: SharedClient<TPlatform, TChain>,
chain_id: smoldot_light::ChainId,
from_back: smoldot_light::JsonRpcResponses<TPlatform>,
) -> (BackgroundTask<TPlatform, TChain>, BackgroundTaskHandle) {
let (tx, rx) = mpsc::unbounded_channel();
let bg_task = BackgroundTask {
channels: BackgroundTaskChannels {
from_front: UnboundedReceiverStream::new(rx),
from_back,
},
data: BackgroundTaskData {
client,
chain_id,
last_request_id: 0,
pending_subscriptions: HashMap::new(),
requests: HashMap::new(),
subscriptions: HashMap::new(),
},
};
let bg_handle = BackgroundTaskHandle { to_backend: tx };
(bg_task, bg_handle)
}
pub async fn run(self) {
let chain_id = self.data.chain_id;
let mut channels = self.channels;
let mut data = self.data;
loop {
tokio::pin! {
let from_front_fut = channels.from_front.next().fuse();
let from_back_fut = channels.from_back.next().fuse();
}
futures::select! {
front_message = from_front_fut => {
let Some(message) = front_message else {
tracing::trace!(target: LOG_TARGET, "Subxt channel closed");
break;
};
tracing::trace!(
target: LOG_TARGET,
"Received register message {:?}",
message
);
data.handle_requests(message).await;
},
back_message = from_back_fut => {
let Some(back_message) = back_message else {
tracing::trace!(target: LOG_TARGET, "Smoldot RPC responses channel closed");
break;
};
tracing::trace!(
target: LOG_TARGET,
"Received smoldot RPC chain {chain_id:?} result {}",
trim_message(&back_message),
);
data.handle_rpc_response(back_message);
}
}
}
tracing::trace!(target: LOG_TARGET, "Task closed");
}
}
struct BackgroundTaskChannels<TPlatform: PlatformRef> {
from_front: UnboundedReceiverStream<Message>,
from_back: smoldot_light::JsonRpcResponses<TPlatform>,
}
struct BackgroundTaskData<TPlatform: PlatformRef, TChain> {
client: SharedClient<TPlatform, TChain>,
chain_id: smoldot_light::ChainId,
last_request_id: usize,
requests: HashMap<usize, oneshot::Sender<MethodResponse>>,
pending_subscriptions: HashMap<usize, PendingSubscription>,
subscriptions: HashMap<String, ActiveSubscription>,
}
struct PendingSubscription {
response_sender: oneshot::Sender<SubscriptionResponse>,
unsubscribe_method: String,
}
struct ActiveSubscription {
notification_sender: mpsc::UnboundedSender<Result<Box<RawValue>, JsonRpcError>>,
unsubscribe_method: String,
}
fn trim_message(s: &str) -> &str {
const MAX_SIZE: usize = 512;
if s.len() < MAX_SIZE {
return s;
}
match s.char_indices().nth(MAX_SIZE) {
None => s,
Some((idx, _)) => &s[..idx],
}
}
impl<TPlatform: PlatformRef, TChain> BackgroundTaskData<TPlatform, TChain> {
fn next_id(&mut self) -> usize {
self.last_request_id = self.last_request_id.wrapping_add(1);
self.last_request_id
}
async fn handle_requests(&mut self, message: Message) {
match message {
Message::Request { method, params, sender } => {
let id = self.next_id();
let chain_id = self.chain_id;
let params = match ¶ms {
Some(params) => params.get(),
None => "null",
};
let request = format!(
r#"{{"jsonrpc":"2.0","id":"{id}", "method":"{method}","params":{params}}}"#
);
self.requests.insert(id, sender);
tracing::trace!(target: LOG_TARGET, "Tracking request id={id} chain={chain_id:?}");
let result = self.client.json_rpc_request(request, chain_id);
if let Err(err) = result {
tracing::warn!(
target: LOG_TARGET,
"Cannot send RPC request to lightclient {:?}",
err.to_string()
);
let sender = self.requests.remove(&id).expect("Channel is inserted above; qed");
if sender.send(Err(LightClientRpcError::SmoldotError(err.to_string()))).is_err()
{
tracing::warn!(
target: LOG_TARGET,
"Cannot send RPC request error to id={id}",
);
}
} else {
tracing::trace!(target: LOG_TARGET, "Submitted to smoldot request with id={id}");
}
},
Message::Subscription { method, unsubscribe_method, params, sender } => {
let id = self.next_id();
let chain_id = self.chain_id;
let params = match ¶ms {
Some(params) => params.get(),
None => "null",
};
let request = format!(
r#"{{"jsonrpc":"2.0","id":"{id}", "method":"{method}","params":{params}}}"#
);
tracing::trace!(target: LOG_TARGET, "Tracking subscription request id={id} chain={chain_id:?}");
let pending_subscription =
PendingSubscription { response_sender: sender, unsubscribe_method };
self.pending_subscriptions.insert(id, pending_subscription);
let result = self.client.json_rpc_request(request, chain_id);
if let Err(err) = result {
tracing::warn!(
target: LOG_TARGET,
"Cannot send RPC request to lightclient {:?}",
err.to_string()
);
let subscription_id_state = self
.pending_subscriptions
.remove(&id)
.expect("Channels are inserted above; qed");
if subscription_id_state
.response_sender
.send(Err(LightClientRpcError::SmoldotError(err.to_string())))
.is_err()
{
tracing::warn!(
target: LOG_TARGET,
"Cannot send RPC request error to id={id}",
);
}
} else {
tracing::trace!(target: LOG_TARGET, "Submitted to smoldot subscription request with id={id}");
}
},
};
}
fn handle_rpc_response(&mut self, response: String) {
let chain_id = self.chain_id;
tracing::trace!(target: LOG_TARGET, "Received from smoldot response='{}' chain={chain_id:?}", trim_message(&response));
match RpcResponse::from_str(&response) {
Ok(RpcResponse::Method { id, result }) => {
let Ok(id) = id.parse::<usize>() else {
tracing::warn!(target: LOG_TARGET, "Cannot send response. Id={id} chain={chain_id:?} is not a valid number");
return;
};
if let Some(sender) = self.requests.remove(&id) {
if sender.send(Ok(result)).is_err() {
tracing::warn!(
target: LOG_TARGET,
"Cannot send method response to id={id} chain={chain_id:?}",
);
}
} else if let Some(pending_subscription) = self.pending_subscriptions.remove(&id) {
let Ok(sub_id) = serde_json::from_str::<SubscriptionId>(result.get()) else {
tracing::warn!(
target: LOG_TARGET,
"Subscription id='{result}' chain={chain_id:?} is not a valid string",
);
return;
};
tracing::trace!(target: LOG_TARGET, "Received subscription id={sub_id} chain={chain_id:?}");
let (sub_tx, sub_rx) = mpsc::unbounded_channel();
if pending_subscription
.response_sender
.send(Ok((sub_id.clone(), sub_rx)))
.is_err()
{
tracing::warn!(
target: LOG_TARGET,
"Cannot send subscription ID response to id={id} chain={chain_id:?}",
);
return;
}
self.subscriptions.insert(
sub_id,
ActiveSubscription {
notification_sender: sub_tx,
unsubscribe_method: pending_subscription.unsubscribe_method,
},
);
} else {
tracing::warn!(
target: LOG_TARGET,
"Response id={id} chain={chain_id:?} is not tracked",
);
}
},
Ok(RpcResponse::MethodError { id, error }) => {
let Ok(id) = id.parse::<usize>() else {
tracing::warn!(target: LOG_TARGET, "Cannot send error. Id={id} chain={chain_id:?} is not a valid number");
return;
};
if let Some(sender) = self.requests.remove(&id) {
if sender
.send(Err(LightClientRpcError::JsonRpcError(JsonRpcError(error))))
.is_err()
{
tracing::warn!(
target: LOG_TARGET,
"Cannot send method response to id={id} chain={chain_id:?}",
);
}
} else if let Some(subscription_id_state) = self.pending_subscriptions.remove(&id) {
if subscription_id_state
.response_sender
.send(Err(LightClientRpcError::JsonRpcError(JsonRpcError(error))))
.is_err()
{
tracing::warn!(
target: LOG_TARGET,
"Cannot send method response to id {id} chain={chain_id:?}",
);
}
}
},
Ok(RpcResponse::Notification { method, subscription_id, result }) => {
let Some(active_subscription) = self.subscriptions.get_mut(&subscription_id) else {
tracing::warn!(
target: LOG_TARGET,
"Subscription response id={subscription_id} chain={chain_id:?} method={method} is not tracked",
);
return;
};
if active_subscription.notification_sender.send(Ok(result)).is_err() {
self.unsubscribe(&subscription_id, chain_id);
}
},
Ok(RpcResponse::NotificationError { method, subscription_id, error }) => {
let Some(active_subscription) = self.subscriptions.get_mut(&subscription_id) else {
tracing::warn!(
target: LOG_TARGET,
"Subscription error id={subscription_id} chain={chain_id:?} method={method} is not tracked",
);
return;
};
if active_subscription.notification_sender.send(Err(JsonRpcError(error))).is_err() {
self.unsubscribe(&subscription_id, chain_id);
}
},
Err(err) => {
tracing::warn!(target: LOG_TARGET, "cannot decode RPC response {:?}", err);
},
}
}
fn unsubscribe(&mut self, subscription_id: &str, chain_id: smoldot_light::ChainId) {
let Some(active_subscription) = self.subscriptions.remove(subscription_id) else {
return;
};
let unsub_id = self.next_id();
let request = format!(
r#"{{"jsonrpc":"2.0","id":"{}", "method":"{}","params":["{}"]}}"#,
unsub_id, active_subscription.unsubscribe_method, subscription_id
);
if let Err(err) = self.client.json_rpc_request(request, chain_id) {
tracing::warn!(
target: LOG_TARGET,
"Failed to unsubscribe id={subscription_id} chain={chain_id:?} method={:?} err={err:?}", active_subscription.unsubscribe_method
);
} else {
tracing::debug!(target: LOG_TARGET,"Unsubscribe id={subscription_id} chain={chain_id:?} method={:?}", active_subscription.unsubscribe_method);
}
}
}