use crate::error::ValidationError;
use crate::message::{SessionId, Timestamp};
use crate::QueueError;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::Instant;
#[cfg(test)]
#[path = "sessions_tests.rs"]
mod tests;
pub trait SessionKeyExtractor {
fn get_metadata(&self, key: &str) -> Option<String>;
fn list_metadata_keys(&self) -> Vec<String> {
Vec::new()
}
fn get_all_metadata(&self) -> HashMap<String, String> {
self.list_metadata_keys()
.into_iter()
.filter_map(|key| self.get_metadata(&key).map(|value| (key, value)))
.collect()
}
}
pub trait SessionKeyGenerator: Send + Sync {
fn generate_key(&self, extractor: &dyn SessionKeyExtractor) -> Option<SessionId>;
}
pub struct CompositeKeyStrategy {
fields: Vec<String>,
separator: String,
}
impl CompositeKeyStrategy {
pub fn new(fields: Vec<String>, separator: &str) -> Self {
Self {
fields,
separator: separator.to_string(),
}
}
}
impl SessionKeyGenerator for CompositeKeyStrategy {
fn generate_key(&self, extractor: &dyn SessionKeyExtractor) -> Option<SessionId> {
if self.fields.is_empty() {
return None;
}
let values: Vec<String> = self
.fields
.iter()
.filter_map(|field| extractor.get_metadata(field))
.collect();
if values.len() != self.fields.len() {
return None;
}
let key = values.join(&self.separator);
SessionId::new(key).ok()
}
}
pub struct SingleFieldStrategy {
field_name: String,
prefix: Option<String>,
}
impl SingleFieldStrategy {
pub fn new(field_name: &str, prefix: Option<&str>) -> Self {
Self {
field_name: field_name.to_string(),
prefix: prefix.map(|s| s.to_string()),
}
}
}
impl SessionKeyGenerator for SingleFieldStrategy {
fn generate_key(&self, extractor: &dyn SessionKeyExtractor) -> Option<SessionId> {
let value = extractor.get_metadata(&self.field_name)?;
let key = if let Some(ref prefix) = self.prefix {
format!("{}-{}", prefix, value)
} else {
value
};
SessionId::new(key).ok()
}
}
pub struct NoOrderingStrategy;
impl SessionKeyGenerator for NoOrderingStrategy {
fn generate_key(&self, _extractor: &dyn SessionKeyExtractor) -> Option<SessionId> {
None
}
}
pub struct FallbackStrategy {
strategies: Vec<Box<dyn SessionKeyGenerator>>,
}
impl FallbackStrategy {
pub fn new(strategies: Vec<Box<dyn SessionKeyGenerator>>) -> Self {
Self { strategies }
}
}
impl SessionKeyGenerator for FallbackStrategy {
fn generate_key(&self, extractor: &dyn SessionKeyExtractor) -> Option<SessionId> {
for strategy in &self.strategies {
if let Some(session_id) = strategy.generate_key(extractor) {
return Some(session_id);
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct SessionLock {
session_id: SessionId,
owner: String,
acquired_at: Instant,
expires_at: Instant,
lock_duration: Duration,
}
impl SessionLock {
pub fn new(session_id: SessionId, owner: String, lock_duration: Duration) -> Self {
let now = Instant::now();
Self {
session_id,
owner,
acquired_at: now,
expires_at: now + lock_duration,
lock_duration,
}
}
pub fn session_id(&self) -> &SessionId {
&self.session_id
}
pub fn owner(&self) -> &str {
&self.owner
}
pub fn acquired_at(&self) -> Instant {
self.acquired_at
}
pub fn expires_at(&self) -> Instant {
self.expires_at
}
pub fn lock_duration(&self) -> Duration {
self.lock_duration
}
pub fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
pub fn time_remaining(&self) -> Duration {
let now = Instant::now();
if now >= self.expires_at {
Duration::ZERO
} else {
self.expires_at - now
}
}
pub fn renew(&self, extension: Duration) -> Self {
Self {
session_id: self.session_id.clone(),
owner: self.owner.clone(),
acquired_at: self.acquired_at,
expires_at: Instant::now() + extension,
lock_duration: extension,
}
}
}
pub struct SessionLockManager {
locks: Arc<RwLock<HashMap<SessionId, SessionLock>>>,
default_lock_duration: Duration,
}
impl SessionLockManager {
pub fn new(default_lock_duration: Duration) -> Self {
Self {
locks: Arc::new(RwLock::new(HashMap::new())),
default_lock_duration,
}
}
pub async fn try_acquire_lock(
&self,
session_id: SessionId,
owner: String,
) -> Result<SessionLock, QueueError> {
let mut locks = self.locks.write().await;
if let Some(existing_lock) = locks.get(&session_id) {
if !existing_lock.is_expired() {
if existing_lock.owner() != owner {
return Err(QueueError::SessionLocked {
session_id: session_id.to_string(),
locked_until: Timestamp::now(),
});
}
return Ok(existing_lock.clone());
}
}
let lock = SessionLock::new(session_id.clone(), owner, self.default_lock_duration);
locks.insert(session_id, lock.clone());
Ok(lock)
}
pub async fn acquire_lock(
&self,
session_id: SessionId,
owner: String,
) -> Result<SessionLock, QueueError> {
self.try_acquire_lock(session_id, owner).await
}
pub async fn renew_lock(
&self,
session_id: &SessionId,
owner: &str,
extension: Option<Duration>,
) -> Result<SessionLock, QueueError> {
let mut locks = self.locks.write().await;
let existing_lock = locks
.get(session_id)
.ok_or_else(|| QueueError::SessionNotFound {
session_id: session_id.to_string(),
})?;
if existing_lock.owner() != owner {
return Err(QueueError::SessionLocked {
session_id: session_id.to_string(),
locked_until: Timestamp::now(),
});
}
let renewed_lock = existing_lock.renew(extension.unwrap_or(self.default_lock_duration));
locks.insert(session_id.clone(), renewed_lock.clone());
Ok(renewed_lock)
}
pub async fn release_lock(
&self,
session_id: &SessionId,
owner: &str,
) -> Result<(), QueueError> {
let mut locks = self.locks.write().await;
let existing_lock = locks
.get(session_id)
.ok_or_else(|| QueueError::SessionNotFound {
session_id: session_id.to_string(),
})?;
if existing_lock.owner() != owner {
return Err(QueueError::SessionLocked {
session_id: session_id.to_string(),
locked_until: Timestamp::now(),
});
}
locks.remove(session_id);
Ok(())
}
pub async fn is_locked(&self, session_id: &SessionId) -> bool {
let locks = self.locks.read().await;
locks
.get(session_id)
.map(|lock| !lock.is_expired())
.unwrap_or(false)
}
pub async fn get_lock(&self, session_id: &SessionId) -> Option<SessionLock> {
let locks = self.locks.read().await;
locks
.get(session_id)
.filter(|lock| !lock.is_expired())
.cloned()
}
pub async fn cleanup_expired_locks(&self) -> usize {
let mut locks = self.locks.write().await;
let expired: Vec<SessionId> = locks
.iter()
.filter(|(_, lock)| lock.is_expired())
.map(|(id, _)| id.clone())
.collect();
let count = expired.len();
for session_id in expired {
locks.remove(&session_id);
}
count
}
pub async fn lock_count(&self) -> usize {
let locks = self.locks.read().await;
locks.len()
}
pub async fn active_lock_count(&self) -> usize {
let locks = self.locks.read().await;
locks.values().filter(|lock| !lock.is_expired()).count()
}
}
#[derive(Debug, Clone)]
pub struct SessionAffinity {
session_id: SessionId,
consumer_id: String,
assigned_at: Instant,
expires_at: Instant,
affinity_duration: Duration,
last_activity: Instant,
}
impl SessionAffinity {
pub fn new(session_id: SessionId, consumer_id: String, affinity_duration: Duration) -> Self {
let now = Instant::now();
Self {
session_id,
consumer_id,
assigned_at: now,
expires_at: now + affinity_duration,
affinity_duration,
last_activity: now,
}
}
pub fn session_id(&self) -> &SessionId {
&self.session_id
}
pub fn consumer_id(&self) -> &str {
&self.consumer_id
}
pub fn affinity_duration(&self) -> Duration {
self.affinity_duration
}
pub fn assigned_at(&self) -> Instant {
self.assigned_at
}
pub fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
pub fn time_remaining(&self) -> Duration {
let now = Instant::now();
if now >= self.expires_at {
Duration::ZERO
} else {
self.expires_at - now
}
}
pub fn touch(&mut self) {
self.last_activity = Instant::now();
}
pub fn idle_time(&self) -> Duration {
Instant::now().duration_since(self.last_activity)
}
pub fn extend(&self, additional_duration: Duration) -> Self {
let mut extended = self.clone();
extended.expires_at = Instant::now() + additional_duration;
extended
}
}
#[derive(Clone)]
pub struct SessionAffinityTracker {
affinities: Arc<RwLock<HashMap<SessionId, SessionAffinity>>>,
default_affinity_duration: Duration,
}
impl SessionAffinityTracker {
pub fn new(default_affinity_duration: Duration) -> Self {
Self {
affinities: Arc::new(RwLock::new(HashMap::new())),
default_affinity_duration,
}
}
pub async fn assign_session(
&self,
session_id: SessionId,
consumer_id: String,
) -> Result<SessionAffinity, QueueError> {
let mut affinities = self.affinities.write().await;
if let Some(existing) = affinities.get(&session_id) {
if !existing.is_expired() {
if existing.consumer_id() != consumer_id {
return Err(QueueError::SessionLocked {
session_id: session_id.to_string(),
locked_until: Timestamp::now(), });
}
return Ok(existing.clone());
}
}
let affinity = SessionAffinity::new(
session_id.clone(),
consumer_id,
self.default_affinity_duration,
);
affinities.insert(session_id, affinity.clone());
Ok(affinity)
}
pub async fn get_consumer(&self, session_id: &SessionId) -> Option<String> {
let affinities = self.affinities.read().await;
affinities
.get(session_id)
.filter(|affinity| !affinity.is_expired())
.map(|affinity| affinity.consumer_id().to_string())
}
pub async fn get_affinity(&self, session_id: &SessionId) -> Option<SessionAffinity> {
let affinities = self.affinities.read().await;
affinities
.get(session_id)
.filter(|affinity| !affinity.is_expired())
.cloned()
}
pub async fn has_affinity(&self, session_id: &SessionId) -> bool {
self.get_consumer(session_id).await.is_some()
}
pub async fn touch_session(&self, session_id: &SessionId) -> Result<(), QueueError> {
let mut affinities = self.affinities.write().await;
if let Some(affinity) = affinities.get_mut(session_id) {
if !affinity.is_expired() {
affinity.touch();
return Ok(());
}
}
Err(QueueError::SessionNotFound {
session_id: session_id.to_string(),
})
}
pub async fn release_session(
&self,
session_id: &SessionId,
consumer_id: &str,
) -> Result<(), QueueError> {
let mut affinities = self.affinities.write().await;
if let Some(affinity) = affinities.get(session_id) {
if affinity.consumer_id() != consumer_id {
return Err(QueueError::ValidationError(
ValidationError::InvalidFormat {
field: "consumer_id".to_string(),
message: format!(
"Session owned by {}, cannot release from {}",
affinity.consumer_id(),
consumer_id
),
},
));
}
}
affinities.remove(session_id);
Ok(())
}
pub async fn extend_affinity(
&self,
session_id: &SessionId,
consumer_id: &str,
additional_duration: Duration,
) -> Result<SessionAffinity, QueueError> {
let mut affinities = self.affinities.write().await;
if let Some(affinity) = affinities.get(session_id) {
if affinity.consumer_id() != consumer_id {
return Err(QueueError::ValidationError(
ValidationError::InvalidFormat {
field: "consumer_id".to_string(),
message: format!(
"Session owned by {}, cannot extend from {}",
affinity.consumer_id(),
consumer_id
),
},
));
}
let extended = affinity.extend(additional_duration);
affinities.insert(session_id.clone(), extended.clone());
return Ok(extended);
}
Err(QueueError::SessionNotFound {
session_id: session_id.to_string(),
})
}
pub async fn get_consumer_sessions(&self, consumer_id: &str) -> Vec<SessionId> {
let affinities = self.affinities.read().await;
affinities
.iter()
.filter(|(_, affinity)| !affinity.is_expired() && affinity.consumer_id() == consumer_id)
.map(|(session_id, _)| session_id.clone())
.collect()
}
pub async fn cleanup_expired(&self) -> usize {
let mut affinities = self.affinities.write().await;
let expired: Vec<SessionId> = affinities
.iter()
.filter(|(_, affinity)| affinity.is_expired())
.map(|(session_id, _)| session_id.clone())
.collect();
let count = expired.len();
for session_id in expired {
affinities.remove(&session_id);
}
count
}
pub async fn affinity_count(&self) -> usize {
let affinities = self.affinities.read().await;
affinities.len()
}
pub async fn active_affinity_count(&self) -> usize {
let affinities = self.affinities.read().await;
affinities
.values()
.filter(|affinity| !affinity.is_expired())
.count()
}
}
#[derive(Debug, Clone)]
pub struct SessionInfo {
session_id: SessionId,
consumer_id: String,
started_at: Instant,
last_activity: Instant,
message_count: u32,
}
impl SessionInfo {
pub fn new(session_id: SessionId, consumer_id: String) -> Self {
let now = Instant::now();
Self {
session_id,
consumer_id,
started_at: now,
last_activity: now,
message_count: 0,
}
}
pub fn session_id(&self) -> &SessionId {
&self.session_id
}
pub fn consumer_id(&self) -> &str {
&self.consumer_id
}
pub fn started_at(&self) -> Instant {
self.started_at
}
pub fn last_activity(&self) -> Instant {
self.last_activity
}
pub fn message_count(&self) -> u32 {
self.message_count
}
pub fn duration(&self) -> Duration {
Instant::now().saturating_duration_since(self.started_at)
}
pub fn idle_time(&self) -> Duration {
Instant::now().saturating_duration_since(self.last_activity)
}
pub fn increment_message_count(&mut self) {
self.message_count += 1;
self.last_activity = Instant::now();
}
pub fn touch(&mut self) {
self.last_activity = Instant::now();
}
}
#[derive(Debug, Clone)]
pub struct SessionLifecycleConfig {
pub max_session_duration: Duration,
pub max_messages_per_session: u32,
pub session_timeout: Duration,
}
impl Default for SessionLifecycleConfig {
fn default() -> Self {
Self {
max_session_duration: Duration::from_secs(2 * 60 * 60), max_messages_per_session: 1000,
session_timeout: Duration::from_secs(30 * 60), }
}
}
#[derive(Debug, Clone)]
pub struct SessionLifecycleManager {
active_sessions: Arc<RwLock<HashMap<SessionId, SessionInfo>>>,
config: SessionLifecycleConfig,
}
impl SessionLifecycleManager {
pub fn new(config: SessionLifecycleConfig) -> Self {
Self {
active_sessions: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub async fn start_session(
&self,
session_id: SessionId,
consumer_id: String,
) -> Result<(), QueueError> {
let mut sessions = self.active_sessions.write().await;
if sessions.contains_key(&session_id) {
return Err(QueueError::ValidationError(
ValidationError::InvalidFormat {
field: "session_id".to_string(),
message: format!("Session {} is already active", session_id),
},
));
}
sessions.insert(
session_id.clone(),
SessionInfo::new(session_id, consumer_id),
);
Ok(())
}
pub async fn stop_session(&self, session_id: &SessionId) -> Result<(), QueueError> {
let mut sessions = self.active_sessions.write().await;
if sessions.remove(session_id).is_none() {
return Err(QueueError::SessionNotFound {
session_id: session_id.to_string(),
});
}
Ok(())
}
pub async fn record_message(&self, session_id: &SessionId) -> Result<(), QueueError> {
let mut sessions = self.active_sessions.write().await;
let session_info =
sessions
.get_mut(session_id)
.ok_or_else(|| QueueError::SessionNotFound {
session_id: session_id.to_string(),
})?;
session_info.increment_message_count();
Ok(())
}
pub async fn touch_session(&self, session_id: &SessionId) -> Result<(), QueueError> {
let mut sessions = self.active_sessions.write().await;
let session_info =
sessions
.get_mut(session_id)
.ok_or_else(|| QueueError::SessionNotFound {
session_id: session_id.to_string(),
})?;
session_info.touch();
Ok(())
}
pub async fn get_session_info(&self, session_id: &SessionId) -> Option<SessionInfo> {
let sessions = self.active_sessions.read().await;
sessions.get(session_id).cloned()
}
pub async fn should_close_session(&self, session_id: &SessionId) -> bool {
let sessions = self.active_sessions.read().await;
if let Some(session_info) = sessions.get(session_id) {
if session_info.duration() > self.config.max_session_duration {
return true;
}
if session_info.message_count > self.config.max_messages_per_session {
return true;
}
if session_info.idle_time() > self.config.session_timeout {
return true;
}
false
} else {
false
}
}
pub async fn get_sessions_to_close(&self) -> Vec<SessionId> {
let sessions = self.active_sessions.read().await;
sessions
.iter()
.filter(|(_session_id, session_info)| {
session_info.duration() > self.config.max_session_duration
|| session_info.message_count > self.config.max_messages_per_session
|| session_info.idle_time() > self.config.session_timeout
})
.map(|(session_id, _)| session_id.clone())
.collect()
}
pub async fn cleanup_expired_sessions(&self) -> Vec<SessionId> {
let expired_sessions = self.get_sessions_to_close().await;
if !expired_sessions.is_empty() {
let mut sessions = self.active_sessions.write().await;
for session_id in &expired_sessions {
sessions.remove(session_id);
}
}
expired_sessions
}
pub async fn session_count(&self) -> usize {
let sessions = self.active_sessions.read().await;
sessions.len()
}
pub async fn get_active_sessions(&self) -> Vec<SessionId> {
let sessions = self.active_sessions.read().await;
sessions.keys().cloned().collect()
}
pub async fn get_consumer_sessions(&self, consumer_id: &str) -> Vec<SessionId> {
let sessions = self.active_sessions.read().await;
sessions
.iter()
.filter(|(_, info)| info.consumer_id() == consumer_id)
.map(|(session_id, _)| session_id.clone())
.collect()
}
}