Skip to main content

agent_phone/
session.rs

1//! Session: multiplexes streams over a transport, handles RPC + backpressure.
2
3use crate::envelope::Envelope;
4use crate::error::{Error, Result};
5use futures_util::Stream;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use tokio::sync::{mpsc, oneshot, Mutex, Notify};
12
13/// A handler runs on the responder when an RPC method is invoked. It may
14/// return a single value (unary) or a stream of values (server-streaming).
15pub enum HandlerOutput {
16    Unary(Value),
17    Stream(Pin<Box<dyn Stream<Item = Value> + Send>>),
18}
19
20pub type Handler = Arc<
21    dyn Fn(Value) -> Pin<Box<dyn std::future::Future<Output = Result<HandlerOutput>> + Send>>
22        + Send
23        + Sync,
24>;
25
26#[derive(Clone, Copy, Debug)]
27pub enum Role {
28    Initiator,
29    Responder,
30}
31
32/// A transport that the Session uses to send + receive envelopes.
33#[derive(Clone)]
34pub struct SessionTransport {
35    pub tx: mpsc::UnboundedSender<Envelope>,
36}
37
38impl SessionTransport {
39    pub fn send(&self, env: Envelope) -> Result<()> {
40        self.tx.send(env).map_err(|_| Error::Closed)
41    }
42}
43
44struct PendingUnary {
45    tx: oneshot::Sender<Result<Value>>,
46}
47
48struct ClientStreamState {
49    chunk_tx: mpsc::UnboundedSender<Value>,
50    end_notify: Arc<Notify>,
51    ended: bool,
52    granted: u64,
53    emitted: u64,
54    initial_credits: u64,
55}
56
57struct ServerStreamCtl {
58    grant_tx: mpsc::UnboundedSender<u64>,
59    cancel_tx: mpsc::UnboundedSender<()>,
60}
61
62struct Inner {
63    role: Role,
64    transport: SessionTransport,
65    next_stream_id: u64,
66    pending: HashMap<u64, PendingUnary>,
67    handlers: HashMap<String, Handler>,
68    client_streams: HashMap<u64, ClientStreamState>,
69    server_streams: HashMap<u64, ServerStreamCtl>,
70}
71
72#[derive(Clone)]
73pub struct Session {
74    inner: Arc<Mutex<Inner>>,
75}
76
77impl Session {
78    pub fn new(transport: SessionTransport, role: Role) -> Self {
79        let next = match role {
80            Role::Initiator => 1,
81            Role::Responder => 2,
82        };
83        Self {
84            inner: Arc::new(Mutex::new(Inner {
85                role,
86                transport,
87                next_stream_id: next,
88                pending: HashMap::new(),
89                handlers: HashMap::new(),
90                client_streams: HashMap::new(),
91                server_streams: HashMap::new(),
92            })),
93        }
94    }
95
96    pub async fn handle(&self, method: impl Into<String>, h: Handler) {
97        let mut g = self.inner.lock().await;
98        g.handlers.insert(method.into(), h);
99    }
100
101    fn next_stream_id_locked(inner: &mut Inner) -> u64 {
102        let sid = inner.next_stream_id;
103        inner.next_stream_id += 2;
104        sid
105    }
106
107    pub async fn call(&self, method: &str, params: Option<Value>) -> Result<Value> {
108        let (sid, rx) = {
109            let mut g = self.inner.lock().await;
110            let sid = Self::next_stream_id_locked(&mut g);
111            let (tx, rx) = oneshot::channel();
112            g.pending.insert(sid, PendingUnary { tx });
113            let mut env = Envelope::new();
114            env.insert("stream_id".into(), Value::from(sid));
115            env.insert("type".into(), Value::from("req"));
116            env.insert("seq".into(), Value::from(0));
117            env.insert("method".into(), Value::from(method));
118            if let Some(p) = params {
119                env.insert("params".into(), p);
120            }
121            g.transport.send(env)?;
122            (sid, rx)
123        };
124        let _ = sid;
125        rx.await.map_err(|_| Error::Closed)?
126    }
127
128    pub async fn stream(
129        &self,
130        method: &str,
131        params: Option<Value>,
132        credits: u64,
133    ) -> Result<ClientStream> {
134        let (sid, chunk_rx, end_notify) = {
135            let mut g = self.inner.lock().await;
136            let sid = Self::next_stream_id_locked(&mut g);
137            let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
138            let end_notify = Arc::new(Notify::new());
139            g.client_streams.insert(
140                sid,
141                ClientStreamState {
142                    chunk_tx,
143                    end_notify: end_notify.clone(),
144                    ended: false,
145                    granted: credits,
146                    emitted: 0,
147                    initial_credits: credits,
148                },
149            );
150            let mut env = Envelope::new();
151            env.insert("stream_id".into(), Value::from(sid));
152            env.insert("type".into(), Value::from("req"));
153            env.insert("seq".into(), Value::from(0));
154            env.insert("method".into(), Value::from(method));
155            env.insert("credits".into(), Value::from(credits));
156            if let Some(p) = params {
157                env.insert("params".into(), p);
158            }
159            g.transport.send(env)?;
160            (sid, chunk_rx, end_notify)
161        };
162        Ok(ClientStream {
163            session: self.clone(),
164            sid,
165            chunk_rx,
166            end_notify,
167            initial_credits: credits,
168        })
169    }
170
171    /// Handle a single inbound envelope. Long-running work (handler invocation,
172    /// server stream pumping) is spawned as separate tokio tasks.
173    pub async fn dispatch(&self, env: Envelope) -> Result<()> {
174        let sid = env
175            .get("stream_id")
176            .and_then(|v| v.as_u64())
177            .ok_or_else(|| Error::InvalidEnvelope("missing stream_id".into()))?;
178        let t = env
179            .get("type")
180            .and_then(|v| v.as_str())
181            .ok_or_else(|| Error::InvalidEnvelope("missing type".into()))?
182            .to_string();
183
184        match t.as_str() {
185            "req" => {
186                let method = env
187                    .get("method")
188                    .and_then(|v| v.as_str())
189                    .unwrap_or("")
190                    .to_string();
191                let params = env.get("params").cloned().unwrap_or(Value::Null);
192                let initial_credits = env.get("credits").and_then(|v| v.as_u64()).unwrap_or(0);
193
194                let (handler, transport) = {
195                    let g = self.inner.lock().await;
196                    (g.handlers.get(&method).cloned(), g.transport.clone())
197                };
198                let handler = match handler {
199                    Some(h) => h,
200                    None => {
201                        let mut err = Envelope::new();
202                        err.insert("stream_id".into(), Value::from(sid));
203                        err.insert("type".into(), Value::from("error"));
204                        err.insert("seq".into(), Value::from(0));
205                        err.insert(
206                            "error".into(),
207                            serde_json::json!({
208                                "code": -32601,
209                                "message": format!("method not found: {method}"),
210                            }),
211                        );
212                        let _ = transport.send(err);
213                        return Ok(());
214                    }
215                };
216
217                let session = self.clone();
218                tokio::spawn(async move {
219                    match handler(params).await {
220                        Ok(HandlerOutput::Unary(value)) => {
221                            let mut env = Envelope::new();
222                            env.insert("stream_id".into(), Value::from(sid));
223                            env.insert("type".into(), Value::from("res"));
224                            env.insert("seq".into(), Value::from(0));
225                            env.insert("result".into(), value);
226                            let _ = transport.send(env);
227                        }
228                        Ok(HandlerOutput::Stream(stream)) => {
229                            session
230                                .run_server_stream(sid, stream, initial_credits)
231                                .await;
232                        }
233                        Err(e) => {
234                            let mut env = Envelope::new();
235                            env.insert("stream_id".into(), Value::from(sid));
236                            env.insert("type".into(), Value::from("error"));
237                            env.insert("seq".into(), Value::from(0));
238                            env.insert(
239                                "error".into(),
240                                serde_json::json!({
241                                    "code": -32000,
242                                    "message": e.to_string(),
243                                }),
244                            );
245                            let _ = transport.send(env);
246                        }
247                    }
248                });
249            }
250            "res" => {
251                let mut g = self.inner.lock().await;
252                // Credit grant for an active server stream takes priority.
253                if let Some(ctl) = g.server_streams.get(&sid) {
254                    let n = env.get("credits").and_then(|v| v.as_u64()).unwrap_or(0);
255                    let _ = ctl.grant_tx.send(n);
256                    return Ok(());
257                }
258                if let Some(p) = g.pending.remove(&sid) {
259                    let result = env.get("result").cloned().unwrap_or(Value::Null);
260                    let _ = p.tx.send(Ok(result));
261                }
262            }
263            "error" => {
264                let mut g = self.inner.lock().await;
265                if let Some(p) = g.pending.remove(&sid) {
266                    let err = env.get("error").cloned().unwrap_or(Value::Null);
267                    let code = err.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
268                    let msg = err
269                        .get("message")
270                        .and_then(|v| v.as_str())
271                        .unwrap_or("unknown error")
272                        .to_string();
273                    let _ = p.tx.send(Err(Error::Rpc { code, message: msg }));
274                }
275            }
276            "stream_chunk" => {
277                let mut g = self.inner.lock().await;
278                if let Some(s) = g.client_streams.get_mut(&sid) {
279                    let result = env.get("result").cloned().unwrap_or(Value::Null);
280                    let _ = s.chunk_tx.send(result);
281                }
282            }
283            "cancel" => {
284                let g = self.inner.lock().await;
285                if let Some(ctl) = g.server_streams.get(&sid) {
286                    let _ = ctl.cancel_tx.send(());
287                }
288            }
289            "stream_end" => {
290                let mut g = self.inner.lock().await;
291                if let Some(s) = g.client_streams.remove(&sid) {
292                    s.end_notify.notify_waiters();
293                    // chunk_tx is dropped → receiver side will see "no more".
294                    drop(s);
295                }
296            }
297            _ => {}
298        }
299        Ok(())
300    }
301
302    async fn run_server_stream(
303        &self,
304        sid: u64,
305        mut src: Pin<Box<dyn Stream<Item = Value> + Send>>,
306        initial_credits: u64,
307    ) {
308        use futures_util::StreamExt;
309
310        let (grant_tx, mut grant_rx) = mpsc::unbounded_channel::<u64>();
311        let (cancel_tx, mut cancel_rx) = mpsc::unbounded_channel::<()>();
312        let transport = {
313            let mut g = self.inner.lock().await;
314            g.server_streams.insert(
315                sid,
316                ServerStreamCtl {
317                    grant_tx,
318                    cancel_tx,
319                },
320            );
321            g.transport.clone()
322        };
323
324        let mut granted = initial_credits;
325        let mut seq: u64 = 0;
326        let mut cancelled = false;
327
328        'outer: loop {
329            while granted == 0 && !cancelled {
330                tokio::select! {
331                    Some(n) = grant_rx.recv() => { granted += n; }
332                    Some(_) = cancel_rx.recv() => { cancelled = true; break; }
333                    else => { break 'outer; }
334                }
335            }
336            if cancelled {
337                break;
338            }
339            // Pull next item — but also watch for cancel concurrently.
340            tokio::select! {
341                next = src.next() => {
342                    let value = match next { Some(v) => v, None => break };
343                    granted = granted.saturating_sub(1);
344                    let mut env = Envelope::new();
345                    env.insert("stream_id".into(), Value::from(sid));
346                    env.insert("type".into(), Value::from("stream_chunk"));
347                    env.insert("seq".into(), Value::from(seq));
348                    env.insert("result".into(), value);
349                    if transport.send(env).is_err() { break; }
350                    seq += 1;
351                }
352                Some(n) = grant_rx.recv() => { granted += n; }
353                Some(_) = cancel_rx.recv() => { cancelled = true; break; }
354            }
355        }
356
357        let mut end = Envelope::new();
358        end.insert("stream_id".into(), Value::from(sid));
359        end.insert("type".into(), Value::from("stream_end"));
360        end.insert("seq".into(), Value::from(seq));
361        end.insert(
362            "reason".into(),
363            Value::from(if cancelled { "cancelled" } else { "ok" }),
364        );
365        let _ = transport.send(end);
366
367        let mut g = self.inner.lock().await;
368        g.server_streams.remove(&sid);
369    }
370
371    async fn cancel_client_stream(&self, sid: u64) {
372        let transport = {
373            let mut g = self.inner.lock().await;
374            g.client_streams.remove(&sid);
375            g.transport.clone()
376        };
377        let mut env = Envelope::new();
378        env.insert("stream_id".into(), Value::from(sid));
379        env.insert("type".into(), Value::from("cancel"));
380        env.insert("seq".into(), Value::from(0));
381        let _ = transport.send(env);
382    }
383}
384
385impl Drop for Session {
386    fn drop(&mut self) {
387        // Channels close naturally when last Session reference goes away.
388    }
389}
390
391/// Drives the chunk-receiving side of a server-sent stream. Implements
392/// `futures::Stream<Item = serde_json::Value>` and emits credit grants
393/// automatically as the client consumes chunks.
394pub struct ClientStream {
395    session: Session,
396    sid: u64,
397    chunk_rx: mpsc::UnboundedReceiver<Value>,
398    end_notify: Arc<Notify>,
399    initial_credits: u64,
400}
401
402impl ClientStream {
403    pub async fn next(&mut self) -> Option<Value> {
404        // Auto-refresh credits halfway through the grant.
405        {
406            let mut g = self.session.inner.lock().await;
407            if let Some(s) = g.client_streams.get_mut(&self.sid) {
408                if !s.ended && s.emitted + 1 >= s.granted.saturating_sub(s.initial_credits / 2) {
409                    s.granted += s.initial_credits;
410                    let mut env = Envelope::new();
411                    env.insert("stream_id".into(), Value::from(self.sid));
412                    env.insert("type".into(), Value::from("res"));
413                    env.insert("seq".into(), Value::from(0));
414                    env.insert("credits".into(), Value::from(s.initial_credits));
415                    let _ = g.transport.send(env);
416                }
417            }
418        }
419        let v = self.chunk_rx.recv().await;
420        if v.is_some() {
421            let mut g = self.session.inner.lock().await;
422            if let Some(s) = g.client_streams.get_mut(&self.sid) {
423                s.emitted += 1;
424            }
425        }
426        v
427    }
428
429    pub async fn cancel(&mut self) {
430        self.session.cancel_client_stream(self.sid).await;
431    }
432}
433
434impl Drop for ClientStream {
435    fn drop(&mut self) {
436        // best-effort cancel; ignore if session already gone.
437        let session = self.session.clone();
438        let sid = self.sid;
439        tokio::spawn(async move {
440            session.cancel_client_stream(sid).await;
441        });
442    }
443}
444
445impl Stream for ClientStream {
446    type Item = Value;
447    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
448        // Light-weight implementation that doesn't auto-refill credits via the
449        // poll path. Callers using `.next().await` get the full behaviour via
450        // the inherent method above. This impl exists for `StreamExt` users.
451        let this = self.get_mut();
452        this.chunk_rx.poll_recv(cx)
453    }
454}
455
456// Unused fields silence warnings.
457#[allow(dead_code)]
458fn _force_use(role: Role, n: &Notify) {
459    let _ = role;
460    let _ = n;
461}