use axum::extract::ws::{CloseFrame, Message, WebSocket};
use super::error::ProtocolError;
use super::state::SessionRuntime;
use super::wire::Frame;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeerRole {
Server,
Client,
}
impl PeerRole {
pub fn flip(self) -> Self {
match self {
PeerRole::Server => PeerRole::Client,
PeerRole::Client => PeerRole::Server,
}
}
}
const CLOSE_NORMAL: u16 = 1000;
const CLOSE_PROTOCOL_ERROR: u16 = 1002;
const CLOSE_INTERNAL_ERROR: u16 = 1011;
pub async fn drive(
mut ws: WebSocket,
mut runtime: SessionRuntime,
role: PeerRole,
) -> Result<(), ProtocolError> {
loop {
if runtime.is_complete() {
send_frame(&mut ws, &Frame::End).await?;
close_normal(&mut ws).await;
return Ok(());
}
if let Some(out) = next_outgoing_frame(&runtime) {
if let Err(e) = apply_outgoing(&mut runtime, &out, role).await {
report_and_close(&mut ws, &e).await;
return Err(e);
}
send_frame(&mut ws, &out).await?;
continue;
}
let msg = match ws.recv().await {
Some(Ok(msg)) => msg,
Some(Err(e)) => {
let _ = close_internal(&mut ws).await;
return Err(ProtocolError::Transport(e.to_string()));
}
None => {
return Err(ProtocolError::Transport("peer closed mid-protocol".into()));
}
};
match msg {
Message::Text(text) => {
let frame = match Frame::from_wire(&text) {
Ok(f) => f,
Err(e) => {
report_and_close(&mut ws, &e).await;
return Err(e);
}
};
if let Err(e) = apply_incoming(&mut runtime, frame, role) {
report_and_close(&mut ws, &e).await;
return Err(e);
}
}
Message::Binary(_) => {
let e = ProtocolError::MalformedFrame(
"binary frame received on a text-only session-typed channel".into(),
);
report_and_close(&mut ws, &e).await;
return Err(e);
}
Message::Ping(p) => {
let _ = ws.send(Message::Pong(p)).await;
}
Message::Pong(_) => {}
Message::Close(_) => {
if runtime.is_complete() {
return Ok(());
}
return Err(ProtocolError::Transport("peer closed mid-protocol".into()));
}
}
}
}
pub(super) fn next_outgoing_frame(runtime: &SessionRuntime) -> Option<Frame> {
use axon_frontend::session::SessionType;
match runtime.cursor() {
SessionType::Send { payload, .. } => Some(Frame::Send {
payload_type: payload.to_string(),
data: serde_json::Value::Null, }),
SessionType::Select(arms) => {
let label = arms.keys().next()?.clone();
Some(Frame::Select { label })
}
SessionType::End => None, _ => None, }
}
pub(super) async fn apply_outgoing(
runtime: &mut SessionRuntime,
frame: &Frame,
_role: PeerRole,
) -> Result<(), ProtocolError> {
match frame {
Frame::Send { payload_type, .. } => runtime.try_send(payload_type),
Frame::Select { label } => runtime.try_select(label),
Frame::End => runtime.try_end(),
Frame::Error { .. } => Ok(()), }
}
fn apply_incoming(
runtime: &mut SessionRuntime,
frame: Frame,
_role: PeerRole,
) -> Result<(), ProtocolError> {
match frame {
Frame::Send { payload_type, .. } => runtime.try_recv(&payload_type),
Frame::Select { label } => runtime.try_offer(&label),
Frame::End => runtime.try_end(),
Frame::Error { code, detail } => Err(ProtocolError::Transport(format!(
"peer reported `{code}`: {detail}"
))),
}
}
async fn send_frame(ws: &mut WebSocket, frame: &Frame) -> Result<(), ProtocolError> {
ws.send(Message::Text(frame.to_wire().into()))
.await
.map_err(|e| ProtocolError::Transport(e.to_string()))
}
async fn report_and_close(ws: &mut WebSocket, err: &ProtocolError) {
let frame = Frame::Error {
code: err.code().to_string(),
detail: err.to_string(),
};
let _ = ws.send(Message::Text(frame.to_wire().into())).await;
let _ = close_with(ws, CLOSE_PROTOCOL_ERROR, err.code()).await;
}
async fn close_normal(ws: &mut WebSocket) {
let _ = close_with(ws, CLOSE_NORMAL, "session_end").await;
}
async fn close_internal(ws: &mut WebSocket) -> Result<(), ProtocolError> {
close_with(ws, CLOSE_INTERNAL_ERROR, "internal").await
}
async fn close_with(ws: &mut WebSocket, code: u16, reason: &str) -> Result<(), ProtocolError> {
let frame = CloseFrame {
code,
reason: reason.to_string().into(),
};
ws.send(Message::Close(Some(frame)))
.await
.map_err(|e| ProtocolError::Transport(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn peer_role_flip_is_involutive() {
assert_eq!(PeerRole::Server.flip(), PeerRole::Client);
assert_eq!(PeerRole::Client.flip(), PeerRole::Server);
assert_eq!(PeerRole::Server.flip().flip(), PeerRole::Server);
}
#[test]
fn next_outgoing_frame_for_send_cursor() {
use axon_frontend::session::SessionType;
let r = SessionRuntime::new(SessionType::send("Msg", SessionType::End), None);
match next_outgoing_frame(&r) {
Some(Frame::Send { payload_type, .. }) => assert_eq!(payload_type, "Msg"),
other => panic!("expected Send frame for Send cursor, got {other:?}"),
}
}
#[test]
fn next_outgoing_frame_for_recv_cursor_is_none() {
use axon_frontend::session::SessionType;
let r = SessionRuntime::new(SessionType::recv("Msg", SessionType::End), None);
assert!(next_outgoing_frame(&r).is_none());
}
#[test]
fn next_outgoing_frame_for_select_picks_first_label() {
use axon_frontend::session::SessionType;
let r = SessionRuntime::new(
SessionType::select([
("zeta".into(), SessionType::End),
("alpha".into(), SessionType::End),
]),
None,
);
match next_outgoing_frame(&r) {
Some(Frame::Select { label }) => assert_eq!(label, "alpha"),
other => panic!("expected Select(alpha), got {other:?}"),
}
}
}