use std::convert::Infallible;
use axum::response::sse::Event;
use futures::stream::{self, Stream};
use serde_json::json;
use super::error::ProtocolError;
use super::state::SessionRuntime;
use super::wire::Frame;
use super::ws::{apply_outgoing, next_outgoing_frame, PeerRole};
const SSE_EVENT_SEND: &str = "axon.send";
const SSE_EVENT_SELECT: &str = "axon.select";
const SSE_EVENT_END: &str = "axon.end";
const SSE_EVENT_ERROR: &str = "axon.error";
pub(super) fn frame_to_sse_event(frame: &Frame) -> Event {
match frame {
Frame::Send { payload_type, data } => Event::default()
.event(SSE_EVENT_SEND)
.data(json!({ "payload_type": payload_type, "data": data }).to_string()),
Frame::Select { label } => Event::default()
.event(SSE_EVENT_SELECT)
.data(json!({ "label": label }).to_string()),
Frame::End => Event::default()
.event(SSE_EVENT_END)
.data("{}".to_string()),
Frame::Error { code, detail } => Event::default()
.event(SSE_EVENT_ERROR)
.data(json!({ "code": code, "detail": detail }).to_string()),
}
}
pub fn drive_sse_producer(
runtime: SessionRuntime,
) -> impl Stream<Item = Result<Event, Infallible>> + Send + 'static {
let init = WalkerState { runtime, done: false };
let preflight_error = preflight_polarity_check(&init.runtime);
let init_with_preflight = (init, preflight_error);
stream::unfold(init_with_preflight, |(mut state, preflight)| async move {
if state.done {
return None;
}
if let Some(err) = preflight {
state.done = true;
return Some((
Ok(error_event(&err)),
(state, None),
));
}
match step_runtime(&mut state.runtime) {
StepOutcome::Event(event, becomes_done) => {
state.done = becomes_done;
Some((Ok(event), (state, None)))
}
StepOutcome::Error(err) => {
state.done = true;
Some((Ok(error_event(&err)), (state, None)))
}
StepOutcome::Done => None,
}
})
}
#[derive(Debug)]
struct WalkerState {
runtime: SessionRuntime,
done: bool,
}
#[derive(Debug)]
enum StepOutcome {
Event(Event, bool),
Error(ProtocolError),
Done,
}
fn preflight_polarity_check(runtime: &SessionRuntime) -> Option<ProtocolError> {
if runtime.schema().projects_to_sse() {
None
} else {
Some(ProtocolError::UnexpectedFrame {
cursor_kind: "non-sse-polarity-schema",
frame_kind: "sse-projection-requested",
})
}
}
fn step_runtime(runtime: &mut SessionRuntime) -> StepOutcome {
if runtime.is_complete() {
return StepOutcome::Event(frame_to_sse_event(&Frame::End), true);
}
let Some(frame) = next_outgoing_frame(runtime) else {
return StepOutcome::Done;
};
if let Err(e) = futures::executor::block_on(apply_outgoing(runtime, &frame, PeerRole::Server)) {
return StepOutcome::Error(e);
}
StepOutcome::Event(frame_to_sse_event(&frame), false)
}
fn error_event(err: &ProtocolError) -> Event {
let frame = Frame::Error {
code: err.code().to_string(),
detail: err.to_string(),
};
frame_to_sse_event(&frame)
}
#[cfg(test)]
mod tests {
use super::*;
use axon_frontend::session::SessionType;
use futures::StreamExt;
#[tokio::test]
async fn producer_fragment_stream_emits_one_event_per_step_then_closes() {
let schema = SessionType::send("A", SessionType::End);
let mut stream = Box::pin(drive_sse_producer(SessionRuntime::new(schema, None)));
assert!(stream.next().await.expect("first event").is_ok());
assert!(stream.next().await.expect("second event").is_ok());
assert!(stream.next().await.is_none(), "stream should close after End");
}
#[tokio::test]
async fn recursive_token_stream_emits_indefinitely() {
let schema = SessionType::rec(
"X",
SessionType::send("Token", SessionType::var("X")),
);
let mut stream = Box::pin(drive_sse_producer(SessionRuntime::new(schema, None)));
for i in 0..16 {
assert!(
stream.next().await.expect("token #").is_ok(),
"token #{i} should arrive"
);
}
}
#[tokio::test]
async fn non_sse_polarity_short_circuits_with_one_error_event() {
let schema = SessionType::send("Q", SessionType::recv("Ack", SessionType::End));
let mut stream = Box::pin(drive_sse_producer(SessionRuntime::new(schema, None)));
assert!(stream.next().await.expect("error event").is_ok());
assert!(stream.next().await.is_none(), "stream should close after preflight error");
}
#[tokio::test]
async fn credit_exhaustion_at_runtime_emits_one_event_then_one_error_then_closes() {
let schema = SessionType::send("A", SessionType::send("B", SessionType::End));
let mut stream = Box::pin(drive_sse_producer(SessionRuntime::new(schema, Some(1))));
assert!(stream.next().await.expect("send A").is_ok());
assert!(stream.next().await.expect("error after credit exhaustion").is_ok());
assert!(stream.next().await.is_none(), "stream should close after error");
}
#[test]
fn frame_to_sse_event_is_total_over_the_closed_frame_catalog() {
let cases = vec![
Frame::Send { payload_type: "T".into(), data: serde_json::json!(null) },
Frame::Select { label: "a".into() },
Frame::End,
Frame::Error { code: "c".into(), detail: "d".into() },
];
for c in cases {
let _ = frame_to_sse_event(&c); }
}
}