use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use astrid_core::SessionId;
use teloxide::types::ChatId;
use tokio::sync::RwLock;
pub struct ChatSession {
pub session_id: SessionId,
pub turn_in_progress: bool,
}
pub enum TurnStartResult {
Started(SessionId),
TurnBusy,
NoSession,
}
struct Inner {
sessions: HashMap<ChatId, ChatSession>,
creating: HashSet<ChatId>,
}
#[derive(Clone)]
pub struct SessionMap {
inner: Arc<RwLock<Inner>>,
}
impl Default for SessionMap {
fn default() -> Self {
Self::new()
}
}
impl SessionMap {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(Inner {
sessions: HashMap::new(),
creating: HashSet::new(),
})),
}
}
pub async fn get_session_id(&self, chat_id: ChatId) -> Option<SessionId> {
self.inner
.read()
.await
.sessions
.get(&chat_id)
.map(|s| s.session_id.clone())
}
pub async fn insert(&self, chat_id: ChatId, session_id: SessionId) {
let mut guard = self.inner.write().await;
guard.creating.remove(&chat_id);
guard.sessions.insert(
chat_id,
ChatSession {
session_id,
turn_in_progress: false,
},
);
}
pub async fn try_start_existing_turn(&self, chat_id: ChatId) -> TurnStartResult {
let mut guard = self.inner.write().await;
if guard.creating.contains(&chat_id) {
return TurnStartResult::TurnBusy;
}
match guard.sessions.get_mut(&chat_id) {
Some(session) if session.turn_in_progress => TurnStartResult::TurnBusy,
Some(session) => {
session.turn_in_progress = true;
TurnStartResult::Started(session.session_id.clone())
},
None => TurnStartResult::NoSession,
}
}
pub async fn try_claim_creation(&self, chat_id: ChatId) -> bool {
let mut guard = self.inner.write().await;
if guard.sessions.contains_key(&chat_id) || guard.creating.contains(&chat_id) {
false
} else {
guard.creating.insert(chat_id);
true
}
}
pub async fn finish_creation(&self, chat_id: ChatId, session_id: SessionId) {
let mut guard = self.inner.write().await;
guard.creating.remove(&chat_id);
guard.sessions.insert(
chat_id,
ChatSession {
session_id,
turn_in_progress: false,
},
);
}
pub async fn finish_creation_and_start_turn(
&self,
chat_id: ChatId,
session_id: SessionId,
) -> SessionId {
let mut guard = self.inner.write().await;
guard.creating.remove(&chat_id);
guard.sessions.insert(
chat_id,
ChatSession {
session_id: session_id.clone(),
turn_in_progress: true,
},
);
session_id
}
pub async fn cancel_creation(&self, chat_id: ChatId) {
self.inner.write().await.creating.remove(&chat_id);
}
pub async fn remove(&self, chat_id: ChatId) -> Option<SessionId> {
let mut guard = self.inner.write().await;
guard.creating.remove(&chat_id);
guard.sessions.remove(&chat_id).map(|s| s.session_id)
}
pub async fn try_start_turn(&self, chat_id: ChatId) -> bool {
let mut guard = self.inner.write().await;
if let Some(session) = guard.sessions.get_mut(&chat_id) {
if session.turn_in_progress {
false
} else {
session.turn_in_progress = true;
true
}
} else {
false
}
}
pub async fn is_turn_in_progress(&self, chat_id: ChatId) -> bool {
self.inner
.read()
.await
.sessions
.get(&chat_id)
.is_some_and(|s| s.turn_in_progress)
}
pub async fn set_turn_in_progress(&self, chat_id: ChatId, in_progress: bool) {
if let Some(session) = self.inner.write().await.sessions.get_mut(&chat_id) {
session.turn_in_progress = in_progress;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn chat(id: i64) -> ChatId {
ChatId(id)
}
#[tokio::test]
async fn empty_map_returns_none() {
let map = SessionMap::new();
assert!(map.get_session_id(chat(1)).await.is_none());
}
#[tokio::test]
async fn insert_and_get() {
let map = SessionMap::new();
let sid = SessionId::new();
map.insert(chat(42), sid.clone()).await;
assert_eq!(map.get_session_id(chat(42)).await, Some(sid));
assert!(map.get_session_id(chat(99)).await.is_none());
}
#[tokio::test]
async fn remove_returns_session_and_clears() {
let map = SessionMap::new();
let sid = SessionId::new();
map.insert(chat(1), sid.clone()).await;
let removed = map.remove(chat(1)).await;
assert_eq!(removed, Some(sid));
assert!(map.get_session_id(chat(1)).await.is_none());
}
#[tokio::test]
async fn remove_nonexistent_returns_none() {
let map = SessionMap::new();
assert!(map.remove(chat(1)).await.is_none());
}
#[tokio::test]
async fn turn_in_progress_defaults_to_false() {
let map = SessionMap::new();
map.insert(chat(1), SessionId::new()).await;
assert!(!map.is_turn_in_progress(chat(1)).await);
}
#[tokio::test]
async fn turn_in_progress_toggle() {
let map = SessionMap::new();
map.insert(chat(1), SessionId::new()).await;
map.set_turn_in_progress(chat(1), true).await;
assert!(map.is_turn_in_progress(chat(1)).await);
map.set_turn_in_progress(chat(1), false).await;
assert!(!map.is_turn_in_progress(chat(1)).await);
}
#[tokio::test]
async fn try_start_turn_atomic() {
let map = SessionMap::new();
map.insert(chat(1), SessionId::new()).await;
assert!(map.try_start_turn(chat(1)).await);
assert!(map.is_turn_in_progress(chat(1)).await);
assert!(!map.try_start_turn(chat(1)).await);
map.set_turn_in_progress(chat(1), false).await;
assert!(map.try_start_turn(chat(1)).await);
}
#[tokio::test]
async fn try_start_turn_no_session_returns_false() {
let map = SessionMap::new();
assert!(!map.try_start_turn(chat(999)).await);
}
#[tokio::test]
async fn turn_in_progress_for_unknown_chat_is_false() {
let map = SessionMap::new();
assert!(!map.is_turn_in_progress(chat(999)).await);
}
#[tokio::test]
async fn set_turn_on_unknown_chat_is_noop() {
let map = SessionMap::new();
map.set_turn_in_progress(chat(999), true).await;
assert!(!map.is_turn_in_progress(chat(999)).await);
}
#[tokio::test]
async fn multiple_chats_independent() {
let map = SessionMap::new();
let sid1 = SessionId::new();
let sid2 = SessionId::new();
map.insert(chat(1), sid1.clone()).await;
map.insert(chat(2), sid2.clone()).await;
map.set_turn_in_progress(chat(1), true).await;
assert!(map.is_turn_in_progress(chat(1)).await);
assert!(!map.is_turn_in_progress(chat(2)).await);
assert_eq!(map.get_session_id(chat(1)).await, Some(sid1));
assert_eq!(map.get_session_id(chat(2)).await, Some(sid2));
}
#[tokio::test]
async fn insert_overwrites_existing() {
let map = SessionMap::new();
let sid1 = SessionId::new();
let sid2 = SessionId::new();
map.insert(chat(1), sid1).await;
map.set_turn_in_progress(chat(1), true).await;
map.insert(chat(1), sid2.clone()).await;
assert_eq!(map.get_session_id(chat(1)).await, Some(sid2));
assert!(!map.is_turn_in_progress(chat(1)).await);
}
#[tokio::test]
async fn clone_shares_state() {
let map1 = SessionMap::new();
let map2 = map1.clone();
let sid = SessionId::new();
map1.insert(chat(1), sid.clone()).await;
assert_eq!(map2.get_session_id(chat(1)).await, Some(sid));
}
#[tokio::test]
async fn try_claim_creation_succeeds_when_no_session() {
let map = SessionMap::new();
assert!(map.try_claim_creation(chat(1)).await);
}
#[tokio::test]
async fn try_claim_creation_fails_when_already_creating() {
let map = SessionMap::new();
assert!(map.try_claim_creation(chat(1)).await);
assert!(!map.try_claim_creation(chat(1)).await);
}
#[tokio::test]
async fn try_claim_creation_fails_when_session_exists() {
let map = SessionMap::new();
map.insert(chat(1), SessionId::new()).await;
assert!(!map.try_claim_creation(chat(1)).await);
}
#[tokio::test]
async fn finish_creation_inserts_session_and_clears_lock() {
let map = SessionMap::new();
assert!(map.try_claim_creation(chat(1)).await);
let sid = SessionId::new();
map.finish_creation(chat(1), sid.clone()).await;
assert_eq!(map.get_session_id(chat(1)).await, Some(sid));
assert!(!map.try_claim_creation(chat(1)).await);
}
#[tokio::test]
async fn cancel_creation_clears_lock() {
let map = SessionMap::new();
assert!(map.try_claim_creation(chat(1)).await);
map.cancel_creation(chat(1)).await;
assert!(map.try_claim_creation(chat(1)).await);
}
#[tokio::test]
async fn creating_blocks_try_start_existing_turn() {
let map = SessionMap::new();
assert!(map.try_claim_creation(chat(1)).await);
assert!(matches!(
map.try_start_existing_turn(chat(1)).await,
TurnStartResult::TurnBusy
));
}
}