use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Instant;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub enum SessionSubscriptionType {
Instance { instance_id: String },
All,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum WireMode {
#[default]
BinaryJson,
Jsonl,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionState {
Connected,
Ready,
Authenticated,
Closing,
}
pub struct Session {
pub id: String,
pub remote_addr: SocketAddr,
state: SessionState,
wire_mode: WireMode,
protocol_version: u16,
client_name: Option<String>,
features: HashSet<String>,
auth_required: bool,
authenticated: AtomicBool,
request_count: AtomicU64,
created_at: Instant,
last_activity: std::sync::Mutex<Instant>,
subscriptions: std::sync::Mutex<HashMap<String, SessionSubscriptionType>>,
}
impl Session {
pub fn new(remote_addr: SocketAddr, auth_required: bool) -> Self {
Self {
id: Uuid::new_v4().to_string(),
remote_addr,
state: SessionState::Connected,
wire_mode: WireMode::default(),
protocol_version: 0,
client_name: None,
features: HashSet::new(),
auth_required,
authenticated: AtomicBool::new(!auth_required),
request_count: AtomicU64::new(0),
created_at: Instant::now(),
last_activity: std::sync::Mutex::new(Instant::now()),
subscriptions: std::sync::Mutex::new(HashMap::new()),
}
}
pub fn state(&self) -> SessionState {
self.state
}
pub fn set_state(&mut self, state: SessionState) {
self.state = state;
}
pub fn wire_mode(&self) -> WireMode {
self.wire_mode
}
pub fn protocol_version(&self) -> u16 {
self.protocol_version
}
pub fn client_name(&self) -> Option<&str> {
self.client_name.as_deref()
}
pub fn is_authenticated(&self) -> bool {
self.authenticated.load(Ordering::Acquire)
}
pub fn set_authenticated(&self, authenticated: bool) {
self.authenticated.store(authenticated, Ordering::Release);
}
pub fn complete_handshake(
&mut self,
protocol_version: u16,
wire_mode: WireMode,
client_name: Option<String>,
features: HashSet<String>,
) {
self.protocol_version = protocol_version;
self.wire_mode = wire_mode;
self.client_name = client_name;
self.features = features;
self.state = if self.auth_required && !self.is_authenticated() {
SessionState::Ready
} else {
SessionState::Authenticated
};
}
pub fn record_request(&self) {
self.request_count.fetch_add(1, Ordering::Relaxed);
*self.last_activity.lock().unwrap() = Instant::now();
}
pub fn request_count(&self) -> u64 {
self.request_count.load(Ordering::Relaxed)
}
pub fn idle_duration(&self) -> std::time::Duration {
self.last_activity.lock().unwrap().elapsed()
}
pub fn age(&self) -> std::time::Duration {
self.created_at.elapsed()
}
pub fn has_feature(&self, feature: &str) -> bool {
self.features.contains(feature)
}
pub fn add_subscription(&self, subscription_id: String) {
self.subscriptions.lock().unwrap().insert(
subscription_id,
SessionSubscriptionType::All, );
}
pub fn add_instance_subscription(&self, subscription_id: String, instance_id: String) {
self.subscriptions.lock().unwrap().insert(
subscription_id,
SessionSubscriptionType::Instance { instance_id },
);
}
pub fn add_all_subscription(&self, subscription_id: String) {
self.subscriptions
.lock()
.unwrap()
.insert(subscription_id, SessionSubscriptionType::All);
}
pub fn remove_subscription(&self, subscription_id: &str) -> Option<SessionSubscriptionType> {
self.subscriptions.lock().unwrap().remove(subscription_id)
}
pub fn get_subscription_instance(&self, subscription_id: &str) -> Option<String> {
match self.subscriptions.lock().unwrap().get(subscription_id) {
Some(SessionSubscriptionType::Instance { instance_id }) => Some(instance_id.clone()),
_ => None,
}
}
pub fn subscriptions(&self) -> Vec<String> {
self.subscriptions.lock().unwrap().keys().cloned().collect()
}
pub fn subscription_count(&self) -> usize {
self.subscriptions.lock().unwrap().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn test_addr() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345)
}
#[test]
fn test_session_creation() {
let session = Session::new(test_addr(), false);
assert_eq!(session.state(), SessionState::Connected);
assert!(session.is_authenticated()); }
#[test]
fn test_session_with_auth() {
let session = Session::new(test_addr(), true);
assert!(!session.is_authenticated());
}
#[test]
fn test_session_handshake() {
let mut session = Session::new(test_addr(), false);
session.complete_handshake(
1,
WireMode::BinaryJson,
Some("test-client".to_string()),
HashSet::from(["idempotency".to_string()]),
);
assert_eq!(session.state(), SessionState::Authenticated);
assert_eq!(session.protocol_version(), 1);
assert_eq!(session.client_name(), Some("test-client"));
assert!(session.has_feature("idempotency"));
}
#[test]
fn test_session_subscriptions() {
let session = Session::new(test_addr(), false);
session.add_instance_subscription("sub-1".to_string(), "instance-1".to_string());
session.add_all_subscription("sub-2".to_string());
assert_eq!(session.subscriptions().len(), 2);
assert_eq!(session.subscription_count(), 2);
assert_eq!(
session.get_subscription_instance("sub-1"),
Some("instance-1".to_string())
);
assert_eq!(session.get_subscription_instance("sub-2"), None);
assert!(session.remove_subscription("sub-1").is_some());
assert_eq!(session.subscriptions().len(), 1);
assert!(session.remove_subscription("sub-1").is_none());
}
}