Skip to main content

astrid_telegram/
session.rs

1//! Session mapping: Telegram `ChatId` → daemon `SessionId`.
2
3use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use astrid_core::SessionId;
7use teloxide::types::ChatId;
8use tokio::sync::RwLock;
9
10/// Per-chat session state.
11pub struct ChatSession {
12    /// Daemon session ID.
13    pub session_id: SessionId,
14    /// Whether a turn is currently in progress (prevents double-send).
15    pub turn_in_progress: bool,
16}
17
18/// Result of attempting to start a turn for a chat.
19pub enum TurnStartResult {
20    /// Turn started successfully; contains the session ID.
21    Started(SessionId),
22    /// A turn is already in progress (or a session is being created).
23    TurnBusy,
24    /// No session exists for this chat.
25    NoSession,
26}
27
28/// Interior state guarded by a single `RwLock`.
29struct Inner {
30    sessions: HashMap<ChatId, ChatSession>,
31    /// Chats that are currently creating a session (prevents duplicate
32    /// `create_session` calls when concurrent messages race).
33    creating: HashSet<ChatId>,
34}
35
36/// Maps Telegram chat IDs to daemon sessions.
37#[derive(Clone)]
38pub struct SessionMap {
39    inner: Arc<RwLock<Inner>>,
40}
41
42impl Default for SessionMap {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl SessionMap {
49    /// Create an empty session map.
50    pub fn new() -> Self {
51        Self {
52            inner: Arc::new(RwLock::new(Inner {
53                sessions: HashMap::new(),
54                creating: HashSet::new(),
55            })),
56        }
57    }
58
59    /// Get the session ID for a chat, if one exists.
60    pub async fn get_session_id(&self, chat_id: ChatId) -> Option<SessionId> {
61        self.inner
62            .read()
63            .await
64            .sessions
65            .get(&chat_id)
66            .map(|s| s.session_id.clone())
67    }
68
69    /// Insert a new session mapping.
70    ///
71    /// Also clears any in-progress creation lock for this chat to keep
72    /// internal invariants consistent.
73    pub async fn insert(&self, chat_id: ChatId, session_id: SessionId) {
74        let mut guard = self.inner.write().await;
75        guard.creating.remove(&chat_id);
76        guard.sessions.insert(
77            chat_id,
78            ChatSession {
79                session_id,
80                turn_in_progress: false,
81            },
82        );
83    }
84
85    /// Atomically check if a session exists and start a turn.
86    ///
87    /// Also returns `TurnBusy` if a session is currently being created for
88    /// this chat (prevents the caller from starting a duplicate creation).
89    pub async fn try_start_existing_turn(&self, chat_id: ChatId) -> TurnStartResult {
90        let mut guard = self.inner.write().await;
91        if guard.creating.contains(&chat_id) {
92            return TurnStartResult::TurnBusy;
93        }
94        match guard.sessions.get_mut(&chat_id) {
95            Some(session) if session.turn_in_progress => TurnStartResult::TurnBusy,
96            Some(session) => {
97                session.turn_in_progress = true;
98                TurnStartResult::Started(session.session_id.clone())
99            },
100            None => TurnStartResult::NoSession,
101        }
102    }
103
104    /// Atomically claim the right to create a session for this chat.
105    ///
106    /// Returns `true` if the caller should proceed with `create_session`.
107    /// Returns `false` if a session already exists or another task is
108    /// already creating one.
109    pub async fn try_claim_creation(&self, chat_id: ChatId) -> bool {
110        let mut guard = self.inner.write().await;
111        if guard.sessions.contains_key(&chat_id) || guard.creating.contains(&chat_id) {
112            false
113        } else {
114            guard.creating.insert(chat_id);
115            true
116        }
117    }
118
119    /// Complete session creation: insert the session and clear the creation
120    /// lock.
121    pub async fn finish_creation(&self, chat_id: ChatId, session_id: SessionId) {
122        let mut guard = self.inner.write().await;
123        guard.creating.remove(&chat_id);
124        guard.sessions.insert(
125            chat_id,
126            ChatSession {
127                session_id,
128                turn_in_progress: false,
129            },
130        );
131    }
132
133    /// Atomically complete session creation and start a turn in one lock
134    /// acquisition. Prevents a race where another message starts the turn
135    /// between `finish_creation` and `try_start_existing_turn`.
136    pub async fn finish_creation_and_start_turn(
137        &self,
138        chat_id: ChatId,
139        session_id: SessionId,
140    ) -> SessionId {
141        let mut guard = self.inner.write().await;
142        guard.creating.remove(&chat_id);
143        guard.sessions.insert(
144            chat_id,
145            ChatSession {
146                session_id: session_id.clone(),
147                turn_in_progress: true,
148            },
149        );
150        session_id
151    }
152
153    /// Cancel session creation (on failure) and clear the creation lock.
154    pub async fn cancel_creation(&self, chat_id: ChatId) {
155        self.inner.write().await.creating.remove(&chat_id);
156    }
157
158    /// Remove a session mapping.
159    ///
160    /// Also clears any in-progress creation lock for this chat so a
161    /// concurrent `finish_creation_and_start_turn` doesn't silently
162    /// re-insert the session after a `/reset`.
163    pub async fn remove(&self, chat_id: ChatId) -> Option<SessionId> {
164        let mut guard = self.inner.write().await;
165        guard.creating.remove(&chat_id);
166        guard.sessions.remove(&chat_id).map(|s| s.session_id)
167    }
168
169    /// Atomically check and start a turn for this chat.
170    ///
171    /// Returns `true` if the turn was started (was not already in progress).
172    /// Returns `false` if a turn is already in progress or no session exists.
173    pub async fn try_start_turn(&self, chat_id: ChatId) -> bool {
174        let mut guard = self.inner.write().await;
175        if let Some(session) = guard.sessions.get_mut(&chat_id) {
176            if session.turn_in_progress {
177                false
178            } else {
179                session.turn_in_progress = true;
180                true
181            }
182        } else {
183            false
184        }
185    }
186
187    /// Check if a turn is currently in progress for this chat.
188    pub async fn is_turn_in_progress(&self, chat_id: ChatId) -> bool {
189        self.inner
190            .read()
191            .await
192            .sessions
193            .get(&chat_id)
194            .is_some_and(|s| s.turn_in_progress)
195    }
196
197    /// Mark a turn as finished for this chat.
198    pub async fn set_turn_in_progress(&self, chat_id: ChatId, in_progress: bool) {
199        if let Some(session) = self.inner.write().await.sessions.get_mut(&chat_id) {
200            session.turn_in_progress = in_progress;
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    fn chat(id: i64) -> ChatId {
210        ChatId(id)
211    }
212
213    #[tokio::test]
214    async fn empty_map_returns_none() {
215        let map = SessionMap::new();
216        assert!(map.get_session_id(chat(1)).await.is_none());
217    }
218
219    #[tokio::test]
220    async fn insert_and_get() {
221        let map = SessionMap::new();
222        let sid = SessionId::new();
223        map.insert(chat(42), sid.clone()).await;
224
225        assert_eq!(map.get_session_id(chat(42)).await, Some(sid));
226        assert!(map.get_session_id(chat(99)).await.is_none());
227    }
228
229    #[tokio::test]
230    async fn remove_returns_session_and_clears() {
231        let map = SessionMap::new();
232        let sid = SessionId::new();
233        map.insert(chat(1), sid.clone()).await;
234
235        let removed = map.remove(chat(1)).await;
236        assert_eq!(removed, Some(sid));
237        assert!(map.get_session_id(chat(1)).await.is_none());
238    }
239
240    #[tokio::test]
241    async fn remove_nonexistent_returns_none() {
242        let map = SessionMap::new();
243        assert!(map.remove(chat(1)).await.is_none());
244    }
245
246    #[tokio::test]
247    async fn turn_in_progress_defaults_to_false() {
248        let map = SessionMap::new();
249        map.insert(chat(1), SessionId::new()).await;
250        assert!(!map.is_turn_in_progress(chat(1)).await);
251    }
252
253    #[tokio::test]
254    async fn turn_in_progress_toggle() {
255        let map = SessionMap::new();
256        map.insert(chat(1), SessionId::new()).await;
257
258        map.set_turn_in_progress(chat(1), true).await;
259        assert!(map.is_turn_in_progress(chat(1)).await);
260
261        map.set_turn_in_progress(chat(1), false).await;
262        assert!(!map.is_turn_in_progress(chat(1)).await);
263    }
264
265    #[tokio::test]
266    async fn try_start_turn_atomic() {
267        let map = SessionMap::new();
268        map.insert(chat(1), SessionId::new()).await;
269
270        // First call succeeds and sets in_progress.
271        assert!(map.try_start_turn(chat(1)).await);
272        assert!(map.is_turn_in_progress(chat(1)).await);
273
274        // Second call fails because already in progress.
275        assert!(!map.try_start_turn(chat(1)).await);
276
277        // After clearing, can start again.
278        map.set_turn_in_progress(chat(1), false).await;
279        assert!(map.try_start_turn(chat(1)).await);
280    }
281
282    #[tokio::test]
283    async fn try_start_turn_no_session_returns_false() {
284        let map = SessionMap::new();
285        assert!(!map.try_start_turn(chat(999)).await);
286    }
287
288    #[tokio::test]
289    async fn turn_in_progress_for_unknown_chat_is_false() {
290        let map = SessionMap::new();
291        assert!(!map.is_turn_in_progress(chat(999)).await);
292    }
293
294    #[tokio::test]
295    async fn set_turn_on_unknown_chat_is_noop() {
296        let map = SessionMap::new();
297        // Should not panic.
298        map.set_turn_in_progress(chat(999), true).await;
299        assert!(!map.is_turn_in_progress(chat(999)).await);
300    }
301
302    #[tokio::test]
303    async fn multiple_chats_independent() {
304        let map = SessionMap::new();
305        let sid1 = SessionId::new();
306        let sid2 = SessionId::new();
307        map.insert(chat(1), sid1.clone()).await;
308        map.insert(chat(2), sid2.clone()).await;
309
310        map.set_turn_in_progress(chat(1), true).await;
311        assert!(map.is_turn_in_progress(chat(1)).await);
312        assert!(!map.is_turn_in_progress(chat(2)).await);
313
314        assert_eq!(map.get_session_id(chat(1)).await, Some(sid1));
315        assert_eq!(map.get_session_id(chat(2)).await, Some(sid2));
316    }
317
318    #[tokio::test]
319    async fn insert_overwrites_existing() {
320        let map = SessionMap::new();
321        let sid1 = SessionId::new();
322        let sid2 = SessionId::new();
323
324        map.insert(chat(1), sid1).await;
325        map.set_turn_in_progress(chat(1), true).await;
326
327        // Overwrite with new session.
328        map.insert(chat(1), sid2.clone()).await;
329
330        assert_eq!(map.get_session_id(chat(1)).await, Some(sid2));
331        // turn_in_progress should be reset to false.
332        assert!(!map.is_turn_in_progress(chat(1)).await);
333    }
334
335    #[tokio::test]
336    async fn clone_shares_state() {
337        let map1 = SessionMap::new();
338        let map2 = map1.clone();
339        let sid = SessionId::new();
340
341        map1.insert(chat(1), sid.clone()).await;
342        assert_eq!(map2.get_session_id(chat(1)).await, Some(sid));
343    }
344
345    // --- creation lock ---
346
347    #[tokio::test]
348    async fn try_claim_creation_succeeds_when_no_session() {
349        let map = SessionMap::new();
350        assert!(map.try_claim_creation(chat(1)).await);
351    }
352
353    #[tokio::test]
354    async fn try_claim_creation_fails_when_already_creating() {
355        let map = SessionMap::new();
356        assert!(map.try_claim_creation(chat(1)).await);
357        // Second call for same chat should fail.
358        assert!(!map.try_claim_creation(chat(1)).await);
359    }
360
361    #[tokio::test]
362    async fn try_claim_creation_fails_when_session_exists() {
363        let map = SessionMap::new();
364        map.insert(chat(1), SessionId::new()).await;
365        assert!(!map.try_claim_creation(chat(1)).await);
366    }
367
368    #[tokio::test]
369    async fn finish_creation_inserts_session_and_clears_lock() {
370        let map = SessionMap::new();
371        assert!(map.try_claim_creation(chat(1)).await);
372
373        let sid = SessionId::new();
374        map.finish_creation(chat(1), sid.clone()).await;
375
376        assert_eq!(map.get_session_id(chat(1)).await, Some(sid));
377        // Creation lock should be cleared — can claim again if needed.
378        // (In practice, session exists so claim would fail for a different reason.)
379        assert!(!map.try_claim_creation(chat(1)).await);
380    }
381
382    #[tokio::test]
383    async fn cancel_creation_clears_lock() {
384        let map = SessionMap::new();
385        assert!(map.try_claim_creation(chat(1)).await);
386        map.cancel_creation(chat(1)).await;
387        // Lock is cleared, can try again.
388        assert!(map.try_claim_creation(chat(1)).await);
389    }
390
391    #[tokio::test]
392    async fn creating_blocks_try_start_existing_turn() {
393        let map = SessionMap::new();
394        assert!(map.try_claim_creation(chat(1)).await);
395        // While creating, try_start_existing_turn should return TurnBusy.
396        assert!(matches!(
397            map.try_start_existing_turn(chat(1)).await,
398            TurnStartResult::TurnBusy
399        ));
400    }
401}