use crate::envelope::Envelope;
use crate::error::{Error, Result};
use futures_util::Stream;
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot, Mutex, Notify};
pub enum HandlerOutput {
Unary(Value),
Stream(Pin<Box<dyn Stream<Item = Value> + Send>>),
}
pub type Handler = Arc<
dyn Fn(Value) -> Pin<Box<dyn std::future::Future<Output = Result<HandlerOutput>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone, Copy, Debug)]
pub enum Role {
Initiator,
Responder,
}
#[derive(Clone)]
pub struct SessionTransport {
pub tx: mpsc::UnboundedSender<Envelope>,
}
impl SessionTransport {
pub fn send(&self, env: Envelope) -> Result<()> {
self.tx.send(env).map_err(|_| Error::Closed)
}
}
struct PendingUnary {
tx: oneshot::Sender<Result<Value>>,
}
struct ClientStreamState {
chunk_tx: mpsc::UnboundedSender<Value>,
end_notify: Arc<Notify>,
ended: bool,
granted: u64,
emitted: u64,
initial_credits: u64,
}
struct ServerStreamCtl {
grant_tx: mpsc::UnboundedSender<u64>,
cancel_tx: mpsc::UnboundedSender<()>,
}
struct Inner {
role: Role,
transport: SessionTransport,
next_stream_id: u64,
pending: HashMap<u64, PendingUnary>,
handlers: HashMap<String, Handler>,
client_streams: HashMap<u64, ClientStreamState>,
server_streams: HashMap<u64, ServerStreamCtl>,
}
#[derive(Clone)]
pub struct Session {
inner: Arc<Mutex<Inner>>,
}
impl Session {
pub fn new(transport: SessionTransport, role: Role) -> Self {
let next = match role {
Role::Initiator => 1,
Role::Responder => 2,
};
Self {
inner: Arc::new(Mutex::new(Inner {
role,
transport,
next_stream_id: next,
pending: HashMap::new(),
handlers: HashMap::new(),
client_streams: HashMap::new(),
server_streams: HashMap::new(),
})),
}
}
pub async fn handle(&self, method: impl Into<String>, h: Handler) {
let mut g = self.inner.lock().await;
g.handlers.insert(method.into(), h);
}
fn next_stream_id_locked(inner: &mut Inner) -> u64 {
let sid = inner.next_stream_id;
inner.next_stream_id += 2;
sid
}
pub async fn call(&self, method: &str, params: Option<Value>) -> Result<Value> {
let (sid, rx) = {
let mut g = self.inner.lock().await;
let sid = Self::next_stream_id_locked(&mut g);
let (tx, rx) = oneshot::channel();
g.pending.insert(sid, PendingUnary { tx });
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(sid));
env.insert("type".into(), Value::from("req"));
env.insert("seq".into(), Value::from(0));
env.insert("method".into(), Value::from(method));
if let Some(p) = params {
env.insert("params".into(), p);
}
g.transport.send(env)?;
(sid, rx)
};
let _ = sid;
rx.await.map_err(|_| Error::Closed)?
}
pub async fn stream(
&self,
method: &str,
params: Option<Value>,
credits: u64,
) -> Result<ClientStream> {
let (sid, chunk_rx, end_notify) = {
let mut g = self.inner.lock().await;
let sid = Self::next_stream_id_locked(&mut g);
let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
let end_notify = Arc::new(Notify::new());
g.client_streams.insert(
sid,
ClientStreamState {
chunk_tx,
end_notify: end_notify.clone(),
ended: false,
granted: credits,
emitted: 0,
initial_credits: credits,
},
);
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(sid));
env.insert("type".into(), Value::from("req"));
env.insert("seq".into(), Value::from(0));
env.insert("method".into(), Value::from(method));
env.insert("credits".into(), Value::from(credits));
if let Some(p) = params {
env.insert("params".into(), p);
}
g.transport.send(env)?;
(sid, chunk_rx, end_notify)
};
Ok(ClientStream {
session: self.clone(),
sid,
chunk_rx,
end_notify,
initial_credits: credits,
})
}
pub async fn dispatch(&self, env: Envelope) -> Result<()> {
let sid = env
.get("stream_id")
.and_then(|v| v.as_u64())
.ok_or_else(|| Error::InvalidEnvelope("missing stream_id".into()))?;
let t = env
.get("type")
.and_then(|v| v.as_str())
.ok_or_else(|| Error::InvalidEnvelope("missing type".into()))?
.to_string();
match t.as_str() {
"req" => {
let method = env
.get("method")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let params = env.get("params").cloned().unwrap_or(Value::Null);
let initial_credits = env.get("credits").and_then(|v| v.as_u64()).unwrap_or(0);
let (handler, transport) = {
let g = self.inner.lock().await;
(g.handlers.get(&method).cloned(), g.transport.clone())
};
let handler = match handler {
Some(h) => h,
None => {
let mut err = Envelope::new();
err.insert("stream_id".into(), Value::from(sid));
err.insert("type".into(), Value::from("error"));
err.insert("seq".into(), Value::from(0));
err.insert(
"error".into(),
serde_json::json!({
"code": -32601,
"message": format!("method not found: {method}"),
}),
);
let _ = transport.send(err);
return Ok(());
}
};
let session = self.clone();
tokio::spawn(async move {
match handler(params).await {
Ok(HandlerOutput::Unary(value)) => {
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(sid));
env.insert("type".into(), Value::from("res"));
env.insert("seq".into(), Value::from(0));
env.insert("result".into(), value);
let _ = transport.send(env);
}
Ok(HandlerOutput::Stream(stream)) => {
session
.run_server_stream(sid, stream, initial_credits)
.await;
}
Err(e) => {
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(sid));
env.insert("type".into(), Value::from("error"));
env.insert("seq".into(), Value::from(0));
env.insert(
"error".into(),
serde_json::json!({
"code": -32000,
"message": e.to_string(),
}),
);
let _ = transport.send(env);
}
}
});
}
"res" => {
let mut g = self.inner.lock().await;
if let Some(ctl) = g.server_streams.get(&sid) {
let n = env.get("credits").and_then(|v| v.as_u64()).unwrap_or(0);
let _ = ctl.grant_tx.send(n);
return Ok(());
}
if let Some(p) = g.pending.remove(&sid) {
let result = env.get("result").cloned().unwrap_or(Value::Null);
let _ = p.tx.send(Ok(result));
}
}
"error" => {
let mut g = self.inner.lock().await;
if let Some(p) = g.pending.remove(&sid) {
let err = env.get("error").cloned().unwrap_or(Value::Null);
let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
let msg = err
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error")
.to_string();
let _ = p.tx.send(Err(Error::Rpc { code, message: msg }));
}
}
"stream_chunk" => {
let mut g = self.inner.lock().await;
if let Some(s) = g.client_streams.get_mut(&sid) {
let result = env.get("result").cloned().unwrap_or(Value::Null);
let _ = s.chunk_tx.send(result);
}
}
"cancel" => {
let g = self.inner.lock().await;
if let Some(ctl) = g.server_streams.get(&sid) {
let _ = ctl.cancel_tx.send(());
}
}
"stream_end" => {
let mut g = self.inner.lock().await;
if let Some(s) = g.client_streams.remove(&sid) {
s.end_notify.notify_waiters();
drop(s);
}
}
_ => {}
}
Ok(())
}
async fn run_server_stream(
&self,
sid: u64,
mut src: Pin<Box<dyn Stream<Item = Value> + Send>>,
initial_credits: u64,
) {
use futures_util::StreamExt;
let (grant_tx, mut grant_rx) = mpsc::unbounded_channel::<u64>();
let (cancel_tx, mut cancel_rx) = mpsc::unbounded_channel::<()>();
let transport = {
let mut g = self.inner.lock().await;
g.server_streams.insert(
sid,
ServerStreamCtl {
grant_tx,
cancel_tx,
},
);
g.transport.clone()
};
let mut granted = initial_credits;
let mut seq: u64 = 0;
let mut cancelled = false;
'outer: loop {
while granted == 0 && !cancelled {
tokio::select! {
Some(n) = grant_rx.recv() => { granted += n; }
Some(_) = cancel_rx.recv() => { cancelled = true; break; }
else => { break 'outer; }
}
}
if cancelled {
break;
}
tokio::select! {
next = src.next() => {
let value = match next { Some(v) => v, None => break };
granted = granted.saturating_sub(1);
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(sid));
env.insert("type".into(), Value::from("stream_chunk"));
env.insert("seq".into(), Value::from(seq));
env.insert("result".into(), value);
if transport.send(env).is_err() { break; }
seq += 1;
}
Some(n) = grant_rx.recv() => { granted += n; }
Some(_) = cancel_rx.recv() => { cancelled = true; break; }
}
}
let mut end = Envelope::new();
end.insert("stream_id".into(), Value::from(sid));
end.insert("type".into(), Value::from("stream_end"));
end.insert("seq".into(), Value::from(seq));
end.insert(
"reason".into(),
Value::from(if cancelled { "cancelled" } else { "ok" }),
);
let _ = transport.send(end);
let mut g = self.inner.lock().await;
g.server_streams.remove(&sid);
}
async fn cancel_client_stream(&self, sid: u64) {
let transport = {
let mut g = self.inner.lock().await;
g.client_streams.remove(&sid);
g.transport.clone()
};
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(sid));
env.insert("type".into(), Value::from("cancel"));
env.insert("seq".into(), Value::from(0));
let _ = transport.send(env);
}
}
impl Drop for Session {
fn drop(&mut self) {
}
}
pub struct ClientStream {
session: Session,
sid: u64,
chunk_rx: mpsc::UnboundedReceiver<Value>,
end_notify: Arc<Notify>,
initial_credits: u64,
}
impl ClientStream {
pub async fn next(&mut self) -> Option<Value> {
{
let mut g = self.session.inner.lock().await;
if let Some(s) = g.client_streams.get_mut(&self.sid) {
if !s.ended && s.emitted + 1 >= s.granted.saturating_sub(s.initial_credits / 2) {
s.granted += s.initial_credits;
let mut env = Envelope::new();
env.insert("stream_id".into(), Value::from(self.sid));
env.insert("type".into(), Value::from("res"));
env.insert("seq".into(), Value::from(0));
env.insert("credits".into(), Value::from(s.initial_credits));
let _ = g.transport.send(env);
}
}
}
let v = self.chunk_rx.recv().await;
if v.is_some() {
let mut g = self.session.inner.lock().await;
if let Some(s) = g.client_streams.get_mut(&self.sid) {
s.emitted += 1;
}
}
v
}
pub async fn cancel(&mut self) {
self.session.cancel_client_stream(self.sid).await;
}
}
impl Drop for ClientStream {
fn drop(&mut self) {
let session = self.session.clone();
let sid = self.sid;
tokio::spawn(async move {
session.cancel_client_stream(sid).await;
});
}
}
impl Stream for ClientStream {
type Item = Value;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.chunk_rx.poll_recv(cx)
}
}
#[allow(dead_code)]
fn _force_use(role: Role, n: &Notify) {
let _ = role;
let _ = n;
}