use std::collections::HashMap;
use core_types::{SessionId, Timestamp, TransportDomain};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ReconnectBackoffPolicy {
pub base_delay_ms: u64,
pub max_delay_ms: u64,
}
impl Default for ReconnectBackoffPolicy {
fn default() -> Self {
Self {
base_delay_ms: 200,
max_delay_ms: 5_000,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum SessionLifecycle {
Connecting,
Connected,
Reconnecting,
Closed,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SessionRecord {
pub id: SessionId,
pub domain: TransportDomain,
pub target: String,
pub opened_at: Timestamp,
pub lifecycle: SessionLifecycle,
pub reconnect_attempts: u32,
pub reconnect_backoff_ms: u64,
pub next_retry_at: Option<Timestamp>,
pub last_error: Option<String>,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct SessionLifecycleCounts {
pub connecting: usize,
pub connected: usize,
pub reconnecting: usize,
pub closed: usize,
}
pub trait SessionManager {
fn open_session(&mut self, domain: TransportDomain, target: impl Into<String>) -> SessionId;
fn get_session(&self, id: SessionId) -> Option<&SessionRecord>;
fn set_backoff_policy(&mut self, policy: ReconnectBackoffPolicy);
fn mark_connected(&mut self, id: SessionId);
fn mark_reconnecting(&mut self, id: SessionId, reason: impl Into<String>);
fn mark_closed(&mut self, id: SessionId);
}
#[derive(Default)]
pub struct SimpleSessionManager {
next_id: u64,
policy: ReconnectBackoffPolicy,
sessions: HashMap<SessionId, SessionRecord>,
}
impl SimpleSessionManager {
pub fn lifecycle_counts(&self) -> SessionLifecycleCounts {
let mut out = SessionLifecycleCounts::default();
for record in self.sessions.values() {
match record.lifecycle {
SessionLifecycle::Connecting => out.connecting += 1,
SessionLifecycle::Connected => out.connected += 1,
SessionLifecycle::Reconnecting => out.reconnecting += 1,
SessionLifecycle::Closed => out.closed += 1,
}
}
out
}
}
impl SessionManager for SimpleSessionManager {
fn open_session(&mut self, domain: TransportDomain, target: impl Into<String>) -> SessionId {
self.next_id += 1;
let id = SessionId::new(self.next_id);
self.sessions.insert(
id,
SessionRecord {
id,
domain,
target: target.into(),
opened_at: Timestamp::now(),
lifecycle: SessionLifecycle::Connecting,
reconnect_attempts: 0,
reconnect_backoff_ms: 0,
next_retry_at: None,
last_error: None,
},
);
id
}
fn get_session(&self, id: SessionId) -> Option<&SessionRecord> {
self.sessions.get(&id)
}
fn set_backoff_policy(&mut self, policy: ReconnectBackoffPolicy) {
self.policy = policy;
}
fn mark_connected(&mut self, id: SessionId) {
if let Some(record) = self.sessions.get_mut(&id) {
record.lifecycle = SessionLifecycle::Connected;
record.reconnect_backoff_ms = 0;
record.next_retry_at = None;
record.last_error = None;
}
}
fn mark_reconnecting(&mut self, id: SessionId, reason: impl Into<String>) {
let policy = self.policy;
if let Some(record) = self.sessions.get_mut(&id) {
record.lifecycle = SessionLifecycle::Reconnecting;
record.reconnect_attempts += 1;
let shift = record.reconnect_attempts.saturating_sub(1).min(16);
let factor = 1u64 << shift;
let backoff = policy
.base_delay_ms
.saturating_mul(factor)
.min(policy.max_delay_ms);
record.reconnect_backoff_ms = backoff;
record.next_retry_at = Some(Timestamp(
Timestamp::now().0 + (backoff as u128) * 1_000_000,
));
record.last_error = Some(reason.into());
}
}
fn mark_closed(&mut self, id: SessionId) {
if let Some(record) = self.sessions.get_mut(&id) {
record.lifecycle = SessionLifecycle::Closed;
}
}
}