use crate::{packet::*, service::*};
use async_lock::Mutex;
use downcast::{AnySync, DowncastSync};
use futures::{
channel::mpsc::{channel, Receiver, Sender},
StreamExt,
};
use log_error::LogError;
use parking_lot::RwLock;
use rmp::encode;
use serde::{Deserializer, Serialize, Serializer};
use std::future::Future;
#[cfg(not(target_arch = "wasm32"))]
use std::io;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::{collections::HashMap, sync::Arc};
pub struct Session {
pub adapter: Arc<dyn Adapter>,
pub service: Arc<dyn Service>,
pub log_error: AtomicBool,
recving: Mutex<()>,
id_counter: AtomicU32,
#[doc(hidden)]
#[cfg(debug_assertions)]
pub pending_method: RwLock<HashMap<u32, Method>>,
pending: RwLock<HashMap<u32, Sender<Packet>>>,
#[cfg(not(target_arch = "wasm32"))]
pub tokio: Option<Box<tokio::runtime::Handle>>,
}
impl Session {
pub fn from<A: Adapter, S: Service>(adapter: A, service: S) -> Self {
Self::new(Arc::new(adapter), Arc::new(service))
}
pub fn new(adapter: Arc<dyn Adapter>, service: Arc<dyn Service>) -> Self {
Self {
adapter,
service,
log_error: true.into(),
recving: Mutex::new(()),
id_counter: AtomicU32::new(1),
#[cfg(debug_assertions)]
pending_method: RwLock::new(Default::default()),
pending: RwLock::new(Default::default()),
#[cfg(not(target_arch = "wasm32"))]
tokio: tokio::runtime::Handle::try_current().ok().map(Into::into),
}
}
#[inline]
pub fn downcast_adapter<A: Adapter>(&self) -> Option<Arc<A>> {
self.adapter.clone().downcast_arc().ok()
}
#[cfg(not(target_arch = "wasm32"))]
#[inline]
pub fn downcast_service<S: Service>(&self) -> Option<Arc<S>> {
self.service.clone().downcast_arc().ok()
}
#[cfg(not(target_arch = "wasm32"))]
#[inline]
pub fn mapped_service<T: Send + Sync + 'static>(&self) -> Arc<MappedService<T>> {
self.downcast_service().expect("MappedService")
}
}
impl Session {
#[inline(always)]
async fn send_pack(&self, pack: Vec<u8>) -> anyhow::Result<()> {
self.adapter.send(pack).await
}
#[inline(always)]
pub async fn notify<A: Serialize, M: Serialize>(
&self,
method: M,
args: A,
) -> anyhow::Result<()> {
self.send_pack(rmp_serde::to_vec_named(&(NOTIFY, method, args))?)
.await?;
Ok(())
}
pub async fn notify_de<'de, A: Deserializer<'de>, M: Deserializer<'de>>(
&self,
method: M,
args: A,
) -> anyhow::Result<()> {
let mut data = Vec::with_capacity(128);
encode::write_array_len(&mut data, 3)?;
encode::write_u32(&mut data, NOTIFY)?;
let mut ser = rmp_serde::Serializer::new(&mut data);
serde_transcode::transcode(method, &mut ser)?;
serde_transcode::transcode(args, &mut ser)?;
self.send_pack(data).await?;
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn notify_sync<A: Serialize, M: Serialize>(
self: &Self,
method: M,
args: A,
) -> anyhow::Result<()> {
self.tokio
.as_ref()
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?
.block_on(self.notify(method, args))
}
pub async fn request<A: Serialize, M: Serialize, R: FromPacket>(
self: &Arc<Self>,
method: M,
args: A,
) -> anyhow::Result<R> {
let req_id = self.next_id();
#[cfg(debug_assertions)]
self.pending_method
.write()
.insert(req_id, serde_value::to_value(&method)?);
self.send_and_wait_response(
req_id,
rmp_serde::to_vec_named(&(REQUEST, req_id, method, args))?,
)
.await
.and_then(R::from_packet)
}
pub async fn request_recver<A: Serialize, M: Serialize>(
self: &Arc<Self>,
method: M,
args: A,
) -> anyhow::Result<StreamReceiver> {
let req_id = self.next_id();
#[cfg(debug_assertions)]
self.pending_method
.write()
.insert(req_id, serde_value::to_value(&method)?);
self.send_and_receive(
req_id,
rmp_serde::to_vec_named(&(REQUEST, req_id, method, args))?,
)
.await
}
pub async fn request_de<'de, A: Deserializer<'de>, M: Deserializer<'de>, R: Serializer>(
self: &Arc<Self>,
method: M,
args: A,
result: R,
) -> anyhow::Result<R::Ok> {
let mut data = Vec::with_capacity(128);
encode::write_array_len(&mut data, 4)?;
encode::write_u32(&mut data, REQUEST)?;
let req_id = self.next_id();
encode::write_u32(&mut data, req_id)?;
let mut ser = rmp_serde::Serializer::new(&mut data);
serde_transcode::transcode(method, &mut ser)?;
serde_transcode::transcode(args, &mut ser)?;
let p = self.send_and_wait_response(req_id, data).await?;
if let Some(err) = p.error() {
anyhow::bail!("{err}")
} else {
p.decode_to(result)
.map_err(|err| anyhow::anyhow!("{err:?}"))
}
}
pub async fn request_recver_de<'de, A: Deserializer<'de>, M: Deserializer<'de>>(
self: &Arc<Self>,
method: M,
args: A,
) -> anyhow::Result<StreamReceiver> {
let mut data = Vec::with_capacity(128);
encode::write_array_len(&mut data, 4)?;
encode::write_u32(&mut data, REQUEST)?;
let req_id = self.next_id();
encode::write_u32(&mut data, req_id)?;
let mut ser = rmp_serde::Serializer::new(&mut data);
serde_transcode::transcode(method, &mut ser)?;
serde_transcode::transcode(args, &mut ser)?;
self.send_and_receive(req_id, data).await
}
#[cfg(not(target_arch = "wasm32"))]
pub fn request_sync<A: Serialize, M: Serialize, R: FromPacket>(
self: &Arc<Self>,
method: M,
args: A,
) -> anyhow::Result<R> {
self.tokio
.as_ref()
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?
.block_on(self.request(method, args))
}
#[inline(always)]
pub async fn request_packet<A: Serialize, M: Serialize>(
self: &Arc<Self>,
method: M,
args: A,
) -> anyhow::Result<Packet> {
self.request(method, args).await
}
pub async fn request_with<
A: Serialize,
M: Serialize,
FUT: Future<Output = ()> + NeedSend,
F: FnOnce(Packet) -> FUT + Send + 'static,
>(
&self,
method: M,
args: A,
callback: F,
) -> anyhow::Result<()> {
let req_id = self.next_id();
#[cfg(debug_assertions)]
self.pending_method
.write()
.insert(req_id, serde_value::to_value(&method)?);
self.send_pack(rmp_serde::to_vec_named(&(REQUEST, req_id, method, args))?)
.await?;
let (sender, mut recver) = channel(2);
self.pending.write().insert(req_id, sender);
let task = async move {
let pack = recver.next().await;
callback(pack.unwrap()).await;
};
#[cfg(target_arch = "wasm32")]
wasm_bindgen_futures::spawn_local(task);
#[cfg(not(target_arch = "wasm32"))]
tokio::spawn(task);
Ok(())
}
pub async fn loop_dispatch(self: &Arc<Self>) -> anyhow::Result<()> {
loop {
if let Err(err) = self.dispatch().await {
if let Some(_) = err.downcast_ref::<crate::error::Disconnect>() {
break Ok(());
} else {
break Err(err);
}
}
}
}
pub async fn dispatch(self: &Arc<Self>) -> anyhow::Result<()> {
use std::time::Duration;
if let Some(pack) = self.try_recv().await.transpose()? {
self.handle_packet(pack).await;
} else {
log::debug!("dispatch warning: recv lock failed");
#[cfg(target_arch = "wasm32")]
gloo_timers::future::sleep(Duration::from_millis(1)).await;
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(Duration::from_millis(1)).await;
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn dispatch_sync(self: &Arc<Self>) -> anyhow::Result<()> {
self.tokio
.as_ref()
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?
.block_on(self.dispatch())
}
pub async fn try_recv(&self) -> Option<anyhow::Result<Packet>> {
let _guard = self.recving.try_lock()?;
self.adapter.recv().await.and_then(Packet::from_pack).into()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn try_recv_sync(&self) -> Option<anyhow::Result<Packet>> {
self.tokio.as_ref()?.block_on(self.try_recv())
}
#[inline(always)]
fn prepare_request<M: Serialize>(&self, method: &M) -> anyhow::Result<(u32, Vec<u8>)> {
let mut pack: Vec<u8> = Vec::with_capacity(0x30);
let req_id = self.next_id();
encode::write_array_len(&mut pack, 4)?;
encode::write_u32(&mut pack, REQUEST)?;
encode::write_u32(&mut pack, req_id)?;
rmp_serde::encode::write_named(&mut pack, method)?;
Ok((req_id, pack))
}
fn prepare_response(
&self,
req_id: u32,
error: Option<&str>,
data: &[u8],
) -> anyhow::Result<Vec<u8>> {
let mut pack: Vec<u8> = Vec::new();
encode::write_array_len(&mut pack, 4)?;
encode::write_u32(&mut pack, RESPONSE)?;
encode::write_u32(&mut pack, req_id)?;
if let Some(err) = error {
encode::write_str(&mut pack, err)?;
encode::write_nil(&mut pack)?;
} else {
encode::write_nil(&mut pack)?;
pack.extend_from_slice(data);
}
Ok(pack)
}
#[inline(always)]
fn prepare_notify<M: Serialize>(&self, method: &M) -> anyhow::Result<Vec<u8>> {
let mut pack: Vec<u8> = Vec::new();
encode::write_array_len(&mut pack, 3)?;
encode::write_u32(&mut pack, NOTIFY)?;
rmp_serde::encode::write_named(&mut pack, method)?;
Ok(pack)
}
#[inline(always)]
fn prepare_stream(&self, req_id: u32, end: bool) -> anyhow::Result<Vec<u8>> {
let mut pack: Vec<u8> = Vec::new();
encode::write_array_len(&mut pack, 4)?;
encode::write_u32(&mut pack, RESPONSE_STREAM)?;
encode::write_u32(&mut pack, req_id)?;
encode::write_bool(&mut pack, end)?;
Ok(pack)
}
pub async fn transfer_packet(
self: &Arc<Self>,
from: &Self,
packet: &Packet,
) -> anyhow::Result<()> {
match &packet.meta {
PackMeta::Notify { method } => {
let mut pack = self.prepare_notify(method)?;
pack.extend_from_slice(packet.data());
self.send_pack(pack).await?;
}
PackMeta::Request { method, req } => {
let origin_req = *req;
let (req, mut pack) = self.prepare_request(method)?;
pack.extend_from_slice(packet.data());
let res = self.send_and_wait_response(req, pack).await;
from.send_pack(match res {
Err(err) => {
from.prepare_response(origin_req, Some(&format!("transfer: {err:?}")), &[])
}
Ok(p) => from.prepare_response(origin_req, p.error(), p.data()),
}?)
.await?;
}
PackMeta::Response { .. } => unreachable!("can not transfer a response"),
PackMeta::Stream { req, finished } => {
if *finished {
self.stream_packet_end(*req).await?;
} else {
let mut pack = self.prepare_stream(*req, false)?;
pack.extend_from_slice(packet.data());
self.send_pack(pack).await?;
}
}
};
Ok(())
}
#[inline(always)]
pub async fn response(&self, req_id: u32, args: impl Serialize) -> anyhow::Result<()> {
self.send_pack(rmp_serde::to_vec_named(&(RESPONSE, req_id, (), args))?)
.await
}
pub async fn response_de<'de, A: Deserializer<'de>>(
&self,
req_id: u32,
args: A,
) -> anyhow::Result<()> {
let mut data = Vec::with_capacity(128);
encode::write_array_len(&mut data, 3)?;
encode::write_u32(&mut data, RESPONSE)?;
encode::write_u32(&mut data, req_id)?;
let mut ser = rmp_serde::Serializer::new(&mut data);
serde_transcode::transcode(args, &mut ser)?;
self.send_pack(data).await?;
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn response_sync(&self, req_id: u32, args: impl Serialize) -> anyhow::Result<()> {
self.tokio
.as_ref()
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?
.block_on(self.response(req_id, args))
}
pub async fn response_error(&self, req_id: u32, err: impl AsRef<str>) -> anyhow::Result<()> {
let err = err.as_ref();
if self.log_error.load(Ordering::SeqCst) && !err.is_empty() {
log::error!("response#{req_id}: {err}")
}
self.send_pack(rmp_serde::to_vec_named(&(RESPONSE, req_id, err, ()))?)
.await
}
pub async fn stream_packet(&self, req_id: u32, args: impl Serialize) -> anyhow::Result<()> {
self.send_pack(rmp_serde::to_vec_named(&(
RESPONSE_STREAM,
req_id,
false,
args,
))?)
.await
}
pub async fn stream_packet_end(&self, req_id: u32) -> anyhow::Result<()> {
self.send_pack(rmp_serde::to_vec_named(&(
RESPONSE_STREAM,
req_id,
true,
(),
))?)
.await
}
#[cfg(not(target_arch = "wasm32"))]
pub fn response_error_sync(&self, req_id: u32, err: impl AsRef<str>) -> anyhow::Result<()> {
self.tokio
.as_ref()
.ok_or(io::Error::from(io::ErrorKind::Unsupported))?
.block_on(self.response_error(req_id, err))
}
fn next_id(&self) -> u32 {
self.id_counter.fetch_add(1, Ordering::SeqCst)
}
async fn handle_packet(self: &Arc<Self>, pack: Packet) {
match &pack.meta {
PackMeta::Notify { .. } => {
let m = pack.method().cloned();
if let Err(err) = self.service.handle(self, pack).await {
log::error!("[urpc] handle {m:?}: {err:?}");
}
}
PackMeta::Request { req, .. } => {
let req_id = *req;
if let Err(err) = self.service.handle(self, pack).await {
self.response_error(req_id, &format!("{err:?}"))
.await
.log_error("response error");
}
}
PackMeta::Response { req, .. } => {
self.pending
.write()
.remove(req)
.map(|mut sender| sender.try_send(pack).map_err(|_| panic!("")));
}
PackMeta::Stream { req, finished } => {
let mut pending = self.pending.write();
if *finished {
pending
.remove(req)
.map(|mut sender| sender.try_send(pack).map_err(|_| panic!("")));
} else {
pending
.get_mut(req)
.map(|sender| sender.try_send(pack).map_err(|_| panic!("")));
}
}
};
}
async fn send_and_receive(
self: &Arc<Self>,
req_id: u32,
pack: Vec<u8>,
) -> anyhow::Result<StreamReceiver> {
let (sender, recver) = channel(2);
self.pending.write().insert(req_id, sender);
self.send_pack(pack).await?;
Ok(StreamReceiver {
session: self.clone(),
recver,
})
}
async fn send_and_wait_response(
self: &Arc<Self>,
req_id: u32,
pack: Vec<u8>,
) -> anyhow::Result<Packet> {
Ok(self
.send_and_receive(req_id, pack)
.await?
.recv()
.await?
.ok_or_else(|| anyhow::anyhow!("no response received"))?)
}
}
#[async_trait::async_trait]
pub trait Adapter: Send + Sync + AnySync {
async fn send(&self, pack: Vec<u8>) -> anyhow::Result<()>;
async fn recv(&self) -> anyhow::Result<Vec<u8>>;
}
downcast::impl_downcast_sync!(dyn Adapter);
pub struct StreamReceiver {
session: Arc<Session>,
recver: Receiver<Packet>,
}
impl StreamReceiver {
pub async fn recv(&mut self) -> anyhow::Result<Option<Packet>> {
loop {
if let Ok(Some(res)) = self.recver.try_next() {
break Ok(Some(res));
}
match self.session.try_recv().await {
None => break Ok(self.recver.next().await),
Some(Ok(pack)) => {
self.session.handle_packet(pack).await;
}
Some(Err(err)) => break Err(err),
}
}
}
pub async fn recv_decode<R: FromPacket>(&mut self) -> anyhow::Result<Option<R>> {
self.recv().await?.map(R::from_packet).transpose()
}
}
pub struct RpcContext {
pub session: Arc<Session>,
pub packet: Packet,
pub(crate) responsed: AtomicBool,
}
pub type ArcRpcContext = Arc<RpcContext>;
impl RpcContext {
pub fn into_stream(self: Arc<Self>) -> StreamSender {
self.into()
}
pub fn responsed(&self) -> bool {
self.responsed.load(Ordering::SeqCst)
}
pub async fn response<C, R: MethodReturn<C>>(&self, res: R) -> anyhow::Result<()> {
if self.responsed.swap(true, Ordering::SeqCst) {
return Ok(());
}
if let Some(req) = self.packet.request_id() {
match res.into_return() {
Ok(res) => self.session.response(req, res).await,
Err(err) => self.session.response_error(req, format!("{err:?}")).await,
}
} else {
Ok(())
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn response_sync<C, R: MethodReturn<C>>(&self, res: R) {
if self.responsed() {
return;
}
self.responsed.store(true, Ordering::SeqCst);
if let Some(req) = self.packet.request_id() {
match res.into_return() {
Ok(res) => self.session.response_sync(req, res),
Err(err) => self.session.response_error_sync(req, format!("{err:?}")),
}
.ok();
}
}
pub async fn response_error(&self, err: impl ToString) {
if self.responsed() {
return;
}
self.responsed.store(true, Ordering::SeqCst);
if let Some(req) = self.packet.request_id() {
self.session.response_error(req, err.to_string()).await.ok();
}
}
}
pub struct StreamSender {
pub session: Arc<Session>,
pub req_id: u32,
}
impl From<ArcRpcContext> for StreamSender {
fn from(ctx: ArcRpcContext) -> Self {
ctx.responsed.store(true, Ordering::SeqCst);
Self {
session: ctx.session.clone(),
req_id: ctx.packet.request_id().expect("reqid"),
}
}
}
impl StreamSender {
pub async fn send<R: Serialize>(&self, res: R) -> anyhow::Result<()> {
self.session.stream_packet(self.req_id, res).await
}
pub async fn end(mut self) -> anyhow::Result<()> {
let result = self.session.stream_packet_end(self.req_id).await;
self.req_id = 0;
result
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Drop for StreamSender {
fn drop(&mut self) {
let req_id = self.req_id;
if req_id > 0 {
let session = self.session.clone();
tokio::spawn(async move {
session
.stream_packet_end(req_id)
.await
.log_error("stream end");
});
}
}
}