1use crate::envelope::Envelope;
4use crate::error::{Error, Result};
5use futures_util::Stream;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::sync::{mpsc, oneshot, Mutex, Notify};
12
13pub enum HandlerOutput {
16 Unary(Value),
17 Stream(Pin<Box<dyn Stream<Item = Value> + Send>>),
18}
19
20pub type Handler = Arc<
21 dyn Fn(Value) -> Pin<Box<dyn std::future::Future<Output = Result<HandlerOutput>> + Send>>
22 + Send
23 + Sync,
24>;
25
26#[derive(Clone, Copy, Debug)]
27pub enum Role {
28 Initiator,
29 Responder,
30}
31
32#[derive(Clone)]
34pub struct SessionTransport {
35 pub tx: mpsc::UnboundedSender<Envelope>,
36}
37
38impl SessionTransport {
39 pub fn send(&self, env: Envelope) -> Result<()> {
40 self.tx.send(env).map_err(|_| Error::Closed)
41 }
42}
43
44struct PendingUnary {
45 tx: oneshot::Sender<Result<Value>>,
46}
47
48struct ClientStreamState {
49 chunk_tx: mpsc::UnboundedSender<Value>,
50 end_notify: Arc<Notify>,
51 ended: bool,
52 granted: u64,
53 emitted: u64,
54 initial_credits: u64,
55}
56
57struct ServerStreamCtl {
58 grant_tx: mpsc::UnboundedSender<u64>,
59 cancel_tx: mpsc::UnboundedSender<()>,
60}
61
62struct Inner {
63 role: Role,
64 transport: SessionTransport,
65 next_stream_id: u64,
66 pending: HashMap<u64, PendingUnary>,
67 handlers: HashMap<String, Handler>,
68 client_streams: HashMap<u64, ClientStreamState>,
69 server_streams: HashMap<u64, ServerStreamCtl>,
70}
71
72#[derive(Clone)]
73pub struct Session {
74 inner: Arc<Mutex<Inner>>,
75}
76
77impl Session {
78 pub fn new(transport: SessionTransport, role: Role) -> Self {
79 let next = match role {
80 Role::Initiator => 1,
81 Role::Responder => 2,
82 };
83 Self {
84 inner: Arc::new(Mutex::new(Inner {
85 role,
86 transport,
87 next_stream_id: next,
88 pending: HashMap::new(),
89 handlers: HashMap::new(),
90 client_streams: HashMap::new(),
91 server_streams: HashMap::new(),
92 })),
93 }
94 }
95
96 pub async fn handle(&self, method: impl Into<String>, h: Handler) {
97 let mut g = self.inner.lock().await;
98 g.handlers.insert(method.into(), h);
99 }
100
101 fn next_stream_id_locked(inner: &mut Inner) -> u64 {
102 let sid = inner.next_stream_id;
103 inner.next_stream_id += 2;
104 sid
105 }
106
107 pub async fn call(&self, method: &str, params: Option<Value>) -> Result<Value> {
108 let (sid, rx) = {
109 let mut g = self.inner.lock().await;
110 let sid = Self::next_stream_id_locked(&mut g);
111 let (tx, rx) = oneshot::channel();
112 g.pending.insert(sid, PendingUnary { tx });
113 let mut env = Envelope::new();
114 env.insert("stream_id".into(), Value::from(sid));
115 env.insert("type".into(), Value::from("req"));
116 env.insert("seq".into(), Value::from(0));
117 env.insert("method".into(), Value::from(method));
118 if let Some(p) = params {
119 env.insert("params".into(), p);
120 }
121 g.transport.send(env)?;
122 (sid, rx)
123 };
124 let _ = sid;
125 rx.await.map_err(|_| Error::Closed)?
126 }
127
128 pub async fn stream(
129 &self,
130 method: &str,
131 params: Option<Value>,
132 credits: u64,
133 ) -> Result<ClientStream> {
134 let (sid, chunk_rx, end_notify) = {
135 let mut g = self.inner.lock().await;
136 let sid = Self::next_stream_id_locked(&mut g);
137 let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
138 let end_notify = Arc::new(Notify::new());
139 g.client_streams.insert(
140 sid,
141 ClientStreamState {
142 chunk_tx,
143 end_notify: end_notify.clone(),
144 ended: false,
145 granted: credits,
146 emitted: 0,
147 initial_credits: credits,
148 },
149 );
150 let mut env = Envelope::new();
151 env.insert("stream_id".into(), Value::from(sid));
152 env.insert("type".into(), Value::from("req"));
153 env.insert("seq".into(), Value::from(0));
154 env.insert("method".into(), Value::from(method));
155 env.insert("credits".into(), Value::from(credits));
156 if let Some(p) = params {
157 env.insert("params".into(), p);
158 }
159 g.transport.send(env)?;
160 (sid, chunk_rx, end_notify)
161 };
162 Ok(ClientStream {
163 session: self.clone(),
164 sid,
165 chunk_rx,
166 end_notify,
167 initial_credits: credits,
168 })
169 }
170
171 pub async fn dispatch(&self, env: Envelope) -> Result<()> {
174 let sid = env
175 .get("stream_id")
176 .and_then(|v| v.as_u64())
177 .ok_or_else(|| Error::InvalidEnvelope("missing stream_id".into()))?;
178 let t = env
179 .get("type")
180 .and_then(|v| v.as_str())
181 .ok_or_else(|| Error::InvalidEnvelope("missing type".into()))?
182 .to_string();
183
184 match t.as_str() {
185 "req" => {
186 let method = env
187 .get("method")
188 .and_then(|v| v.as_str())
189 .unwrap_or("")
190 .to_string();
191 let params = env.get("params").cloned().unwrap_or(Value::Null);
192 let initial_credits = env.get("credits").and_then(|v| v.as_u64()).unwrap_or(0);
193
194 let (handler, transport) = {
195 let g = self.inner.lock().await;
196 (g.handlers.get(&method).cloned(), g.transport.clone())
197 };
198 let handler = match handler {
199 Some(h) => h,
200 None => {
201 let mut err = Envelope::new();
202 err.insert("stream_id".into(), Value::from(sid));
203 err.insert("type".into(), Value::from("error"));
204 err.insert("seq".into(), Value::from(0));
205 err.insert(
206 "error".into(),
207 serde_json::json!({
208 "code": -32601,
209 "message": format!("method not found: {method}"),
210 }),
211 );
212 let _ = transport.send(err);
213 return Ok(());
214 }
215 };
216
217 let session = self.clone();
218 tokio::spawn(async move {
219 match handler(params).await {
220 Ok(HandlerOutput::Unary(value)) => {
221 let mut env = Envelope::new();
222 env.insert("stream_id".into(), Value::from(sid));
223 env.insert("type".into(), Value::from("res"));
224 env.insert("seq".into(), Value::from(0));
225 env.insert("result".into(), value);
226 let _ = transport.send(env);
227 }
228 Ok(HandlerOutput::Stream(stream)) => {
229 session
230 .run_server_stream(sid, stream, initial_credits)
231 .await;
232 }
233 Err(e) => {
234 let mut env = Envelope::new();
235 env.insert("stream_id".into(), Value::from(sid));
236 env.insert("type".into(), Value::from("error"));
237 env.insert("seq".into(), Value::from(0));
238 env.insert(
239 "error".into(),
240 serde_json::json!({
241 "code": -32000,
242 "message": e.to_string(),
243 }),
244 );
245 let _ = transport.send(env);
246 }
247 }
248 });
249 }
250 "res" => {
251 let mut g = self.inner.lock().await;
252 if let Some(ctl) = g.server_streams.get(&sid) {
254 let n = env.get("credits").and_then(|v| v.as_u64()).unwrap_or(0);
255 let _ = ctl.grant_tx.send(n);
256 return Ok(());
257 }
258 if let Some(p) = g.pending.remove(&sid) {
259 let result = env.get("result").cloned().unwrap_or(Value::Null);
260 let _ = p.tx.send(Ok(result));
261 }
262 }
263 "error" => {
264 let mut g = self.inner.lock().await;
265 if let Some(p) = g.pending.remove(&sid) {
266 let err = env.get("error").cloned().unwrap_or(Value::Null);
267 let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
268 let msg = err
269 .get("message")
270 .and_then(|v| v.as_str())
271 .unwrap_or("unknown error")
272 .to_string();
273 let _ = p.tx.send(Err(Error::Rpc { code, message: msg }));
274 }
275 }
276 "stream_chunk" => {
277 let mut g = self.inner.lock().await;
278 if let Some(s) = g.client_streams.get_mut(&sid) {
279 let result = env.get("result").cloned().unwrap_or(Value::Null);
280 let _ = s.chunk_tx.send(result);
281 }
282 }
283 "cancel" => {
284 let g = self.inner.lock().await;
285 if let Some(ctl) = g.server_streams.get(&sid) {
286 let _ = ctl.cancel_tx.send(());
287 }
288 }
289 "stream_end" => {
290 let mut g = self.inner.lock().await;
291 if let Some(s) = g.client_streams.remove(&sid) {
292 s.end_notify.notify_waiters();
293 drop(s);
295 }
296 }
297 _ => {}
298 }
299 Ok(())
300 }
301
302 async fn run_server_stream(
303 &self,
304 sid: u64,
305 mut src: Pin<Box<dyn Stream<Item = Value> + Send>>,
306 initial_credits: u64,
307 ) {
308 use futures_util::StreamExt;
309
310 let (grant_tx, mut grant_rx) = mpsc::unbounded_channel::<u64>();
311 let (cancel_tx, mut cancel_rx) = mpsc::unbounded_channel::<()>();
312 let transport = {
313 let mut g = self.inner.lock().await;
314 g.server_streams.insert(
315 sid,
316 ServerStreamCtl {
317 grant_tx,
318 cancel_tx,
319 },
320 );
321 g.transport.clone()
322 };
323
324 let mut granted = initial_credits;
325 let mut seq: u64 = 0;
326 let mut cancelled = false;
327
328 'outer: loop {
329 while granted == 0 && !cancelled {
330 tokio::select! {
331 Some(n) = grant_rx.recv() => { granted += n; }
332 Some(_) = cancel_rx.recv() => { cancelled = true; break; }
333 else => { break 'outer; }
334 }
335 }
336 if cancelled {
337 break;
338 }
339 tokio::select! {
341 next = src.next() => {
342 let value = match next { Some(v) => v, None => break };
343 granted = granted.saturating_sub(1);
344 let mut env = Envelope::new();
345 env.insert("stream_id".into(), Value::from(sid));
346 env.insert("type".into(), Value::from("stream_chunk"));
347 env.insert("seq".into(), Value::from(seq));
348 env.insert("result".into(), value);
349 if transport.send(env).is_err() { break; }
350 seq += 1;
351 }
352 Some(n) = grant_rx.recv() => { granted += n; }
353 Some(_) = cancel_rx.recv() => { cancelled = true; break; }
354 }
355 }
356
357 let mut end = Envelope::new();
358 end.insert("stream_id".into(), Value::from(sid));
359 end.insert("type".into(), Value::from("stream_end"));
360 end.insert("seq".into(), Value::from(seq));
361 end.insert(
362 "reason".into(),
363 Value::from(if cancelled { "cancelled" } else { "ok" }),
364 );
365 let _ = transport.send(end);
366
367 let mut g = self.inner.lock().await;
368 g.server_streams.remove(&sid);
369 }
370
371 async fn cancel_client_stream(&self, sid: u64) {
372 let transport = {
373 let mut g = self.inner.lock().await;
374 g.client_streams.remove(&sid);
375 g.transport.clone()
376 };
377 let mut env = Envelope::new();
378 env.insert("stream_id".into(), Value::from(sid));
379 env.insert("type".into(), Value::from("cancel"));
380 env.insert("seq".into(), Value::from(0));
381 let _ = transport.send(env);
382 }
383}
384
385impl Drop for Session {
386 fn drop(&mut self) {
387 }
389}
390
391pub struct ClientStream {
395 session: Session,
396 sid: u64,
397 chunk_rx: mpsc::UnboundedReceiver<Value>,
398 end_notify: Arc<Notify>,
399 initial_credits: u64,
400}
401
402impl ClientStream {
403 pub async fn next(&mut self) -> Option<Value> {
404 {
406 let mut g = self.session.inner.lock().await;
407 if let Some(s) = g.client_streams.get_mut(&self.sid) {
408 if !s.ended && s.emitted + 1 >= s.granted.saturating_sub(s.initial_credits / 2) {
409 s.granted += s.initial_credits;
410 let mut env = Envelope::new();
411 env.insert("stream_id".into(), Value::from(self.sid));
412 env.insert("type".into(), Value::from("res"));
413 env.insert("seq".into(), Value::from(0));
414 env.insert("credits".into(), Value::from(s.initial_credits));
415 let _ = g.transport.send(env);
416 }
417 }
418 }
419 let v = self.chunk_rx.recv().await;
420 if v.is_some() {
421 let mut g = self.session.inner.lock().await;
422 if let Some(s) = g.client_streams.get_mut(&self.sid) {
423 s.emitted += 1;
424 }
425 }
426 v
427 }
428
429 pub async fn cancel(&mut self) {
430 self.session.cancel_client_stream(self.sid).await;
431 }
432}
433
434impl Drop for ClientStream {
435 fn drop(&mut self) {
436 let session = self.session.clone();
438 let sid = self.sid;
439 tokio::spawn(async move {
440 session.cancel_client_stream(sid).await;
441 });
442 }
443}
444
445impl Stream for ClientStream {
446 type Item = Value;
447 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
448 let this = self.get_mut();
452 this.chunk_rx.poll_recv(cx)
453 }
454}
455
456#[allow(dead_code)]
458fn _force_use(role: Role, n: &Notify) {
459 let _ = role;
460 let _ = n;
461}