use std::collections::BTreeMap;
use axon_frontend::session::{Payload, SessionType};
use serde::{Deserialize, Serialize};
use super::error::ProtocolError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct CreditWindow {
pub budget: u64,
pub available: u64,
}
impl CreditWindow {
pub fn new(budget: u64) -> Self {
Self { budget, available: budget }
}
fn try_consume(&mut self) -> Option<u64> {
if self.available == 0 {
None
} else {
self.available -= 1;
Some(self.available)
}
}
fn refill(&mut self) {
if self.available < self.budget {
self.available += 1;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionRuntime {
schema: SessionType,
cursor: SessionType,
credit: Option<CreditWindow>,
}
impl SessionRuntime {
pub fn new(schema: SessionType, budget: Option<u64>) -> Self {
let cursor = schema.unfold_head();
Self {
schema,
cursor,
credit: budget.map(CreditWindow::new),
}
}
pub fn schema(&self) -> &SessionType {
&self.schema
}
pub fn cursor(&self) -> &SessionType {
&self.cursor
}
pub fn credit(&self) -> Option<CreditWindow> {
self.credit
}
pub fn is_complete(&self) -> bool {
matches!(self.cursor, SessionType::End)
}
pub fn seal(&self) -> Option<SealedRuntime> {
if self.is_complete() {
return None;
}
Some(SealedRuntime {
version: SEALED_RUNTIME_VERSION,
schema: self.schema.clone(),
cursor: self.cursor.clone(),
credit: self.credit,
})
}
pub fn resume(
sealed: SealedRuntime,
declared_schema: &SessionType,
) -> Result<Self, ResumeError> {
if sealed.version != SEALED_RUNTIME_VERSION {
return Err(ResumeError::UnsupportedVersion(sealed.version));
}
if !sealed.schema.equiv(declared_schema) {
return Err(ResumeError::SchemaMismatch);
}
Ok(SessionRuntime {
schema: sealed.schema,
cursor: sealed.cursor,
credit: sealed.credit,
})
}
pub fn try_send(&mut self, got: &str) -> Result<(), ProtocolError> {
if self.is_complete() {
return Err(ProtocolError::AlreadyComplete { frame_kind: "send" });
}
let (expected_payload, cont) = match &self.cursor {
SessionType::Send { payload, cont, .. } => (payload.clone(), cont.clone()),
other => {
return Err(ProtocolError::UnexpectedFrame {
cursor_kind: kind_of(other),
frame_kind: "send",
});
}
};
let got_payload = Payload::new(got);
if expected_payload != got_payload {
return Err(ProtocolError::PayloadMismatch {
expected: expected_payload,
got: got_payload,
});
}
if let Some(w) = self.credit.as_mut() {
if w.try_consume().is_none() {
return Err(ProtocolError::CreditExhausted {
payload: expected_payload,
budget: w.budget,
});
}
}
self.advance(*cont);
Ok(())
}
pub fn try_recv(&mut self, got: &str) -> Result<(), ProtocolError> {
if self.is_complete() {
return Err(ProtocolError::AlreadyComplete { frame_kind: "send" });
}
let (expected_payload, cont) = match &self.cursor {
SessionType::Recv { payload, cont, .. } => (payload.clone(), cont.clone()),
other => {
return Err(ProtocolError::UnexpectedFrame {
cursor_kind: kind_of(other),
frame_kind: "send",
});
}
};
let got_payload = Payload::new(got);
if expected_payload != got_payload {
return Err(ProtocolError::PayloadMismatch {
expected: expected_payload,
got: got_payload,
});
}
if let Some(w) = self.credit.as_mut() {
w.refill();
}
self.advance(*cont);
Ok(())
}
pub fn try_select(&mut self, label: &str) -> Result<(), ProtocolError> {
self.advance_into_arm(label, true)
}
pub fn try_offer(&mut self, label: &str) -> Result<(), ProtocolError> {
self.advance_into_arm(label, false)
}
pub fn try_end(&mut self) -> Result<(), ProtocolError> {
match &self.cursor {
SessionType::End => Ok(()),
other => Err(ProtocolError::UnexpectedFrame {
cursor_kind: kind_of(other),
frame_kind: "end",
}),
}
}
fn advance(&mut self, next: SessionType) {
self.cursor = next.unfold_head();
}
fn advance_into_arm(&mut self, label: &str, internal: bool) -> Result<(), ProtocolError> {
if self.is_complete() {
return Err(ProtocolError::AlreadyComplete {
frame_kind: if internal { "select" } else { "branch" },
});
}
let arms = match (&self.cursor, internal) {
(SessionType::Select(m), true) => m.clone(),
(SessionType::Branch(m), false) => m.clone(),
(other, _) => {
return Err(ProtocolError::UnexpectedFrame {
cursor_kind: kind_of(other),
frame_kind: if internal { "select" } else { "select" },
});
}
};
match arms.get(label) {
Some(cont) => {
let cont = cont.clone();
self.advance(cont);
Ok(())
}
None => Err(ProtocolError::UnknownLabel {
label: label.to_string(),
expected: keys_of(&arms),
}),
}
}
}
fn kind_of(t: &SessionType) -> &'static str {
match t {
SessionType::End => "end",
SessionType::Send { .. } => "send",
SessionType::Recv { .. } => "recv",
SessionType::Select(_) => "select",
SessionType::Branch(_) => "branch",
SessionType::Rec(_, _) => "rec", SessionType::Var(_) => "var", }
}
fn keys_of(m: &BTreeMap<String, SessionType>) -> Vec<String> {
m.keys().cloned().collect()
}
pub const SEALED_RUNTIME_VERSION: u8 = 1;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SealedRuntime {
pub version: u8,
pub schema: SessionType,
pub cursor: SessionType,
pub credit: Option<CreditWindow>,
}
impl SealedRuntime {
pub fn to_bytes(&self) -> Vec<u8> {
serde_json::to_vec(self).expect("SealedRuntime ⇒ JSON is total")
}
pub fn from_bytes(b: &[u8]) -> Result<Self, ResumeError> {
serde_json::from_slice(b).map_err(|e| ResumeError::Malformed(e.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResumeError {
UnsupportedVersion(u8),
SchemaMismatch,
Malformed(String),
}
impl std::fmt::Display for ResumeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResumeError::UnsupportedVersion(v) => write!(
f,
"sealed runtime envelope version {v} is newer than this runtime supports \
(current = {SEALED_RUNTIME_VERSION})"
),
ResumeError::SchemaMismatch => f.write_str(
"sealed runtime's declared protocol does not match the live socket's session type \
— the protocol drifted between seal and resume",
),
ResumeError::Malformed(detail) => write!(f, "sealed runtime envelope is malformed: {detail}"),
}
}
}
impl std::error::Error for ResumeError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn credit_window_decrements_and_refills_within_budget() {
let mut w = CreditWindow::new(2);
assert_eq!(w.try_consume(), Some(1));
assert_eq!(w.try_consume(), Some(0));
assert!(w.try_consume().is_none()); w.refill();
assert_eq!(w.available, 1);
w.refill();
assert_eq!(w.available, 2);
w.refill();
assert_eq!(w.available, 2);
}
#[test]
fn try_send_advances_on_matching_payload() {
let schema = SessionType::send("Msg", SessionType::End);
let mut r = SessionRuntime::new(schema, None);
r.try_send("Msg").expect("step");
assert!(r.is_complete());
}
#[test]
fn try_send_rejects_wrong_payload() {
let schema = SessionType::send("Msg", SessionType::End);
let mut r = SessionRuntime::new(schema, None);
match r.try_send("WrongType") {
Err(ProtocolError::PayloadMismatch { expected, got }) => {
assert_eq!(expected, Payload::new("Msg"));
assert_eq!(got, Payload::new("WrongType"));
}
other => panic!("expected PayloadMismatch, got {other:?}"),
}
assert!(matches!(r.cursor(), SessionType::Send { .. }));
}
#[test]
fn try_recv_rejects_when_cursor_is_send() {
let schema = SessionType::send("Msg", SessionType::End);
let mut r = SessionRuntime::new(schema, None);
match r.try_recv("Msg") {
Err(ProtocolError::UnexpectedFrame { cursor_kind: "send", .. }) => {}
other => panic!("expected UnexpectedFrame(send→send), got {other:?}"),
}
}
#[test]
fn credit_exhaustion_blocks_send_at_zero() {
let schema = SessionType::send("A", SessionType::send("B", SessionType::End));
let mut r = SessionRuntime::new(schema, Some(1));
r.try_send("A").expect("first send");
assert_eq!(r.credit().unwrap().available, 0);
match r.try_send("B") {
Err(ProtocolError::CreditExhausted { payload, budget: 1 }) => {
assert_eq!(payload, Payload::new("B"));
}
other => panic!("expected CreditExhausted, got {other:?}"),
}
}
#[test]
fn recv_refills_credit_capped_at_budget() {
let schema = SessionType::send(
"A",
SessionType::recv("Ack", SessionType::send("B", SessionType::End)),
);
let mut r = SessionRuntime::new(schema, Some(1));
r.try_send("A").expect("send A");
assert_eq!(r.credit().unwrap().available, 0);
r.try_recv("Ack").expect("recv Ack refills");
assert_eq!(r.credit().unwrap().available, 1);
r.try_send("B").expect("send B uses refilled credit");
assert!(r.is_complete());
}
#[test]
fn select_advances_into_named_arm() {
let schema = SessionType::select([
("ask".into(), SessionType::send("Q", SessionType::End)),
("quit".into(), SessionType::End),
]);
let mut r = SessionRuntime::new(schema, None);
r.try_select("ask").expect("select ask");
assert!(matches!(r.cursor(), SessionType::Send { .. }));
r.try_send("Q").expect("send Q");
assert!(r.is_complete());
}
#[test]
fn select_rejects_unknown_label() {
let schema = SessionType::select([
("ask".into(), SessionType::End),
("quit".into(), SessionType::End),
]);
let mut r = SessionRuntime::new(schema, None);
match r.try_select("nope") {
Err(ProtocolError::UnknownLabel { label, expected }) => {
assert_eq!(label, "nope");
assert_eq!(expected, vec!["ask".to_string(), "quit".to_string()]);
}
other => panic!("expected UnknownLabel, got {other:?}"),
}
}
#[test]
fn offer_advances_into_peer_selected_arm() {
let schema = SessionType::branch([
("ack".into(), SessionType::End),
("err".into(), SessionType::End),
]);
let mut r = SessionRuntime::new(schema, None);
r.try_offer("ack").expect("offer ack");
assert!(r.is_complete());
}
#[test]
fn recursion_unfolds_one_step_at_a_time() {
let schema = SessionType::rec(
"X",
SessionType::send("A", SessionType::recv("Ack", SessionType::var("X"))),
);
let mut r = SessionRuntime::new(schema, Some(1));
for _ in 0..5 {
r.try_send("A").expect("send");
r.try_recv("Ack").expect("recv");
}
assert!(matches!(r.cursor(), SessionType::Send { .. }));
assert!(!r.is_complete());
}
#[test]
fn post_end_traffic_is_rejected() {
let mut r = SessionRuntime::new(SessionType::End, None);
r.try_end().expect("end on End is OK");
match r.try_send("X") {
Err(ProtocolError::AlreadyComplete { frame_kind: "send" }) => {}
other => panic!("expected AlreadyComplete, got {other:?}"),
}
}
#[test]
fn seal_returns_none_at_end_and_some_otherwise() {
let schema = SessionType::send("A", SessionType::End);
let r = SessionRuntime::new(schema.clone(), None);
let sealed = r.seal().expect("snapshot at non-End cursor");
assert_eq!(sealed.version, SEALED_RUNTIME_VERSION);
assert_eq!(sealed.schema, schema);
assert!(matches!(sealed.cursor, SessionType::Send { .. }));
let mut r = SessionRuntime::new(schema, None);
r.try_send("A").unwrap();
assert!(r.is_complete());
assert!(r.seal().is_none(), "no snapshot once cursor is at End");
}
#[test]
fn seal_carries_live_credit_window_not_just_budget() {
let schema = SessionType::send("A", SessionType::send("B", SessionType::End));
let mut r = SessionRuntime::new(schema, Some(2));
r.try_send("A").unwrap();
let sealed = r.seal().expect("snapshot mid-protocol");
assert_eq!(sealed.credit, Some(CreditWindow { budget: 2, available: 1 }));
}
#[test]
fn resume_round_trips_through_seal_then_unbinds_to_the_same_cursor() {
let schema = SessionType::recv(
"Msg",
SessionType::send("Ack", SessionType::End),
);
let r0 = SessionRuntime::new(schema.clone(), None);
let sealed = r0.seal().expect("snapshot before recv");
let bytes = sealed.to_bytes();
let recovered = SealedRuntime::from_bytes(&bytes).expect("parse");
assert_eq!(recovered, sealed);
let r1 = SessionRuntime::resume(recovered, &schema).expect("resume");
assert_eq!(r1.cursor(), r0.cursor());
assert_eq!(r1.credit(), r0.credit());
}
#[test]
fn resume_after_partial_progress_continues_from_the_residual() {
let schema = SessionType::send("A", SessionType::send("B", SessionType::End));
let mut r0 = SessionRuntime::new(schema.clone(), Some(2));
r0.try_send("A").unwrap();
let bytes = r0.seal().unwrap().to_bytes();
let recovered = SealedRuntime::from_bytes(&bytes).unwrap();
let mut r1 = SessionRuntime::resume(recovered, &schema).unwrap();
assert_eq!(r1.credit().unwrap().available, 1);
r1.try_send("B").expect("send B from resumed cursor");
assert!(r1.is_complete());
}
#[test]
fn resume_rejects_a_schema_mismatch() {
let schema_a = SessionType::send("A", SessionType::End);
let schema_b = SessionType::send("B", SessionType::End);
let r0 = SessionRuntime::new(schema_a.clone(), None);
let sealed = r0.seal().unwrap();
assert_eq!(
SessionRuntime::resume(sealed.clone(), &schema_b).err(),
Some(ResumeError::SchemaMismatch)
);
assert!(SessionRuntime::resume(sealed, &schema_a).is_ok());
}
#[test]
fn resume_rejects_a_future_envelope_version() {
let schema = SessionType::send("A", SessionType::End);
let r = SessionRuntime::new(schema.clone(), None);
let mut sealed = r.seal().unwrap();
sealed.version = SEALED_RUNTIME_VERSION + 7;
assert_eq!(
SessionRuntime::resume(sealed, &schema).err(),
Some(ResumeError::UnsupportedVersion(SEALED_RUNTIME_VERSION + 7))
);
}
#[test]
fn resume_rejects_malformed_envelope_bytes() {
let garbage = b"{not valid JSON";
assert!(matches!(
SealedRuntime::from_bytes(garbage),
Err(ResumeError::Malformed(_))
));
}
#[test]
fn resume_accepts_alpha_equivalent_schemas() {
let schema_x = SessionType::rec("X", SessionType::send("T", SessionType::var("X")));
let schema_y = SessionType::rec("Y", SessionType::send("T", SessionType::var("Y")));
let r = SessionRuntime::new(schema_x.clone(), None);
let sealed = r.seal().unwrap();
assert!(SessionRuntime::resume(sealed, &schema_y).is_ok());
}
#[test]
fn sealed_runtime_is_json_compatible_with_serde_roundtrip() {
let schema = SessionType::send("X", SessionType::End);
let r = SessionRuntime::new(schema, None);
let bytes = r.seal().unwrap().to_bytes();
let value: serde_json::Value =
serde_json::from_slice(&bytes).expect("envelope is well-formed JSON");
assert!(value.get("version").is_some());
assert!(value.get("schema").is_some());
assert!(value.get("cursor").is_some());
assert!(value.get("credit").is_some());
}
#[test]
fn realistic_chat_dialogue_runs_to_completion() {
let schema = SessionType::rec(
"X",
SessionType::select([
(
"ask".into(),
SessionType::send(
"Utterance",
SessionType::branch([
("token".into(), SessionType::recv("Token", SessionType::var("X"))),
("done".into(), SessionType::End),
]),
),
),
("cancel".into(), SessionType::End),
]),
);
let mut client = SessionRuntime::new(schema, Some(4));
client.try_select("ask").unwrap();
client.try_send("Utterance").unwrap();
client.try_offer("token").unwrap();
client.try_recv("Token").unwrap();
client.try_select("ask").unwrap();
client.try_send("Utterance").unwrap();
client.try_offer("done").unwrap();
client.try_end().unwrap();
assert!(client.is_complete());
}
}