use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum SessionPolicy {
KickOld,
RejectNew,
AllowMultiple,
}
impl Default for SessionPolicy {
fn default() -> Self {
SessionPolicy::KickOld
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SessionEvent {
Connect { existing_connection_count: u32 },
TtlTick {
suspended_at_ms: u64,
now_ms: u64,
ttl_ms: u64,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SessionState {
pub policy: SessionPolicy,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum SessionEffect {
AcceptConnection,
AcceptAndKickExisting,
RejectConnection,
ExpireSession,
RetainSession,
}
pub fn session_step(state: &SessionState, event: &SessionEvent) -> SessionEffect {
match event {
SessionEvent::Connect {
existing_connection_count,
} => match state.policy {
SessionPolicy::AllowMultiple => SessionEffect::AcceptConnection,
SessionPolicy::RejectNew => {
if *existing_connection_count > 0 {
SessionEffect::RejectConnection
} else {
SessionEffect::AcceptConnection
}
}
SessionPolicy::KickOld => {
if *existing_connection_count > 0 {
SessionEffect::AcceptAndKickExisting
} else {
SessionEffect::AcceptConnection
}
}
},
SessionEvent::TtlTick {
suspended_at_ms,
now_ms,
ttl_ms,
} => {
let elapsed = now_ms.saturating_sub(*suspended_at_ms);
if elapsed >= *ttl_ms {
SessionEffect::ExpireSession
} else {
SessionEffect::RetainSession
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kick_old_on_first_connection_accepts() {
let s = SessionState {
policy: SessionPolicy::KickOld,
};
let e = SessionEvent::Connect {
existing_connection_count: 0,
};
assert_eq!(session_step(&s, &e), SessionEffect::AcceptConnection);
}
#[test]
fn kick_old_on_second_connection_kicks() {
let s = SessionState {
policy: SessionPolicy::KickOld,
};
let e = SessionEvent::Connect {
existing_connection_count: 1,
};
assert_eq!(session_step(&s, &e), SessionEffect::AcceptAndKickExisting);
}
#[test]
fn reject_new_rejects_when_occupied() {
let s = SessionState {
policy: SessionPolicy::RejectNew,
};
assert_eq!(
session_step(
&s,
&SessionEvent::Connect {
existing_connection_count: 1
}
),
SessionEffect::RejectConnection
);
assert_eq!(
session_step(
&s,
&SessionEvent::Connect {
existing_connection_count: 0
}
),
SessionEffect::AcceptConnection
);
}
#[test]
fn allow_multiple_always_accepts() {
let s = SessionState {
policy: SessionPolicy::AllowMultiple,
};
for n in [0, 1, 2, 10, 100] {
assert_eq!(
session_step(
&s,
&SessionEvent::Connect {
existing_connection_count: n
}
),
SessionEffect::AcceptConnection
);
}
}
#[test]
fn ttl_not_yet_elapsed_retains() {
let s = SessionState::default();
let e = SessionEvent::TtlTick {
suspended_at_ms: 1_000,
now_ms: 1_500,
ttl_ms: 1_000,
};
assert_eq!(session_step(&s, &e), SessionEffect::RetainSession);
}
#[test]
fn ttl_elapsed_expires() {
let s = SessionState::default();
let e = SessionEvent::TtlTick {
suspended_at_ms: 1_000,
now_ms: 2_001,
ttl_ms: 1_000,
};
assert_eq!(session_step(&s, &e), SessionEffect::ExpireSession);
}
#[test]
fn ttl_clock_skew_does_not_underflow() {
let s = SessionState::default();
let e = SessionEvent::TtlTick {
suspended_at_ms: 2_000,
now_ms: 1_000,
ttl_ms: 1_000,
};
assert_eq!(session_step(&s, &e), SessionEffect::RetainSession);
}
#[test]
fn json_roundtrip_for_cross_host_transport() {
let state = SessionState {
policy: SessionPolicy::KickOld,
};
let event = SessionEvent::Connect {
existing_connection_count: 2,
};
let state_s = serde_json::to_string(&state).unwrap();
let event_s = serde_json::to_string(&event).unwrap();
let state2: SessionState = serde_json::from_str(&state_s).unwrap();
let event2: SessionEvent = serde_json::from_str(&event_s).unwrap();
assert_eq!(
session_step(&state2, &event2),
SessionEffect::AcceptAndKickExisting
);
}
}