astrid_telegram/
session.rs1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use astrid_core::SessionId;
7use teloxide::types::ChatId;
8use tokio::sync::RwLock;
9
10pub struct ChatSession {
12 pub session_id: SessionId,
14 pub turn_in_progress: bool,
16}
17
18pub enum TurnStartResult {
20 Started(SessionId),
22 TurnBusy,
24 NoSession,
26}
27
28struct Inner {
30 sessions: HashMap<ChatId, ChatSession>,
31 creating: HashSet<ChatId>,
34}
35
36#[derive(Clone)]
38pub struct SessionMap {
39 inner: Arc<RwLock<Inner>>,
40}
41
42impl Default for SessionMap {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl SessionMap {
49 pub fn new() -> Self {
51 Self {
52 inner: Arc::new(RwLock::new(Inner {
53 sessions: HashMap::new(),
54 creating: HashSet::new(),
55 })),
56 }
57 }
58
59 pub async fn get_session_id(&self, chat_id: ChatId) -> Option<SessionId> {
61 self.inner
62 .read()
63 .await
64 .sessions
65 .get(&chat_id)
66 .map(|s| s.session_id.clone())
67 }
68
69 pub async fn insert(&self, chat_id: ChatId, session_id: SessionId) {
74 let mut guard = self.inner.write().await;
75 guard.creating.remove(&chat_id);
76 guard.sessions.insert(
77 chat_id,
78 ChatSession {
79 session_id,
80 turn_in_progress: false,
81 },
82 );
83 }
84
85 pub async fn try_start_existing_turn(&self, chat_id: ChatId) -> TurnStartResult {
90 let mut guard = self.inner.write().await;
91 if guard.creating.contains(&chat_id) {
92 return TurnStartResult::TurnBusy;
93 }
94 match guard.sessions.get_mut(&chat_id) {
95 Some(session) if session.turn_in_progress => TurnStartResult::TurnBusy,
96 Some(session) => {
97 session.turn_in_progress = true;
98 TurnStartResult::Started(session.session_id.clone())
99 },
100 None => TurnStartResult::NoSession,
101 }
102 }
103
104 pub async fn try_claim_creation(&self, chat_id: ChatId) -> bool {
110 let mut guard = self.inner.write().await;
111 if guard.sessions.contains_key(&chat_id) || guard.creating.contains(&chat_id) {
112 false
113 } else {
114 guard.creating.insert(chat_id);
115 true
116 }
117 }
118
119 pub async fn finish_creation(&self, chat_id: ChatId, session_id: SessionId) {
122 let mut guard = self.inner.write().await;
123 guard.creating.remove(&chat_id);
124 guard.sessions.insert(
125 chat_id,
126 ChatSession {
127 session_id,
128 turn_in_progress: false,
129 },
130 );
131 }
132
133 pub async fn finish_creation_and_start_turn(
137 &self,
138 chat_id: ChatId,
139 session_id: SessionId,
140 ) -> SessionId {
141 let mut guard = self.inner.write().await;
142 guard.creating.remove(&chat_id);
143 guard.sessions.insert(
144 chat_id,
145 ChatSession {
146 session_id: session_id.clone(),
147 turn_in_progress: true,
148 },
149 );
150 session_id
151 }
152
153 pub async fn cancel_creation(&self, chat_id: ChatId) {
155 self.inner.write().await.creating.remove(&chat_id);
156 }
157
158 pub async fn remove(&self, chat_id: ChatId) -> Option<SessionId> {
164 let mut guard = self.inner.write().await;
165 guard.creating.remove(&chat_id);
166 guard.sessions.remove(&chat_id).map(|s| s.session_id)
167 }
168
169 pub async fn try_start_turn(&self, chat_id: ChatId) -> bool {
174 let mut guard = self.inner.write().await;
175 if let Some(session) = guard.sessions.get_mut(&chat_id) {
176 if session.turn_in_progress {
177 false
178 } else {
179 session.turn_in_progress = true;
180 true
181 }
182 } else {
183 false
184 }
185 }
186
187 pub async fn is_turn_in_progress(&self, chat_id: ChatId) -> bool {
189 self.inner
190 .read()
191 .await
192 .sessions
193 .get(&chat_id)
194 .is_some_and(|s| s.turn_in_progress)
195 }
196
197 pub async fn set_turn_in_progress(&self, chat_id: ChatId, in_progress: bool) {
199 if let Some(session) = self.inner.write().await.sessions.get_mut(&chat_id) {
200 session.turn_in_progress = in_progress;
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn chat(id: i64) -> ChatId {
210 ChatId(id)
211 }
212
213 #[tokio::test]
214 async fn empty_map_returns_none() {
215 let map = SessionMap::new();
216 assert!(map.get_session_id(chat(1)).await.is_none());
217 }
218
219 #[tokio::test]
220 async fn insert_and_get() {
221 let map = SessionMap::new();
222 let sid = SessionId::new();
223 map.insert(chat(42), sid.clone()).await;
224
225 assert_eq!(map.get_session_id(chat(42)).await, Some(sid));
226 assert!(map.get_session_id(chat(99)).await.is_none());
227 }
228
229 #[tokio::test]
230 async fn remove_returns_session_and_clears() {
231 let map = SessionMap::new();
232 let sid = SessionId::new();
233 map.insert(chat(1), sid.clone()).await;
234
235 let removed = map.remove(chat(1)).await;
236 assert_eq!(removed, Some(sid));
237 assert!(map.get_session_id(chat(1)).await.is_none());
238 }
239
240 #[tokio::test]
241 async fn remove_nonexistent_returns_none() {
242 let map = SessionMap::new();
243 assert!(map.remove(chat(1)).await.is_none());
244 }
245
246 #[tokio::test]
247 async fn turn_in_progress_defaults_to_false() {
248 let map = SessionMap::new();
249 map.insert(chat(1), SessionId::new()).await;
250 assert!(!map.is_turn_in_progress(chat(1)).await);
251 }
252
253 #[tokio::test]
254 async fn turn_in_progress_toggle() {
255 let map = SessionMap::new();
256 map.insert(chat(1), SessionId::new()).await;
257
258 map.set_turn_in_progress(chat(1), true).await;
259 assert!(map.is_turn_in_progress(chat(1)).await);
260
261 map.set_turn_in_progress(chat(1), false).await;
262 assert!(!map.is_turn_in_progress(chat(1)).await);
263 }
264
265 #[tokio::test]
266 async fn try_start_turn_atomic() {
267 let map = SessionMap::new();
268 map.insert(chat(1), SessionId::new()).await;
269
270 assert!(map.try_start_turn(chat(1)).await);
272 assert!(map.is_turn_in_progress(chat(1)).await);
273
274 assert!(!map.try_start_turn(chat(1)).await);
276
277 map.set_turn_in_progress(chat(1), false).await;
279 assert!(map.try_start_turn(chat(1)).await);
280 }
281
282 #[tokio::test]
283 async fn try_start_turn_no_session_returns_false() {
284 let map = SessionMap::new();
285 assert!(!map.try_start_turn(chat(999)).await);
286 }
287
288 #[tokio::test]
289 async fn turn_in_progress_for_unknown_chat_is_false() {
290 let map = SessionMap::new();
291 assert!(!map.is_turn_in_progress(chat(999)).await);
292 }
293
294 #[tokio::test]
295 async fn set_turn_on_unknown_chat_is_noop() {
296 let map = SessionMap::new();
297 map.set_turn_in_progress(chat(999), true).await;
299 assert!(!map.is_turn_in_progress(chat(999)).await);
300 }
301
302 #[tokio::test]
303 async fn multiple_chats_independent() {
304 let map = SessionMap::new();
305 let sid1 = SessionId::new();
306 let sid2 = SessionId::new();
307 map.insert(chat(1), sid1.clone()).await;
308 map.insert(chat(2), sid2.clone()).await;
309
310 map.set_turn_in_progress(chat(1), true).await;
311 assert!(map.is_turn_in_progress(chat(1)).await);
312 assert!(!map.is_turn_in_progress(chat(2)).await);
313
314 assert_eq!(map.get_session_id(chat(1)).await, Some(sid1));
315 assert_eq!(map.get_session_id(chat(2)).await, Some(sid2));
316 }
317
318 #[tokio::test]
319 async fn insert_overwrites_existing() {
320 let map = SessionMap::new();
321 let sid1 = SessionId::new();
322 let sid2 = SessionId::new();
323
324 map.insert(chat(1), sid1).await;
325 map.set_turn_in_progress(chat(1), true).await;
326
327 map.insert(chat(1), sid2.clone()).await;
329
330 assert_eq!(map.get_session_id(chat(1)).await, Some(sid2));
331 assert!(!map.is_turn_in_progress(chat(1)).await);
333 }
334
335 #[tokio::test]
336 async fn clone_shares_state() {
337 let map1 = SessionMap::new();
338 let map2 = map1.clone();
339 let sid = SessionId::new();
340
341 map1.insert(chat(1), sid.clone()).await;
342 assert_eq!(map2.get_session_id(chat(1)).await, Some(sid));
343 }
344
345 #[tokio::test]
348 async fn try_claim_creation_succeeds_when_no_session() {
349 let map = SessionMap::new();
350 assert!(map.try_claim_creation(chat(1)).await);
351 }
352
353 #[tokio::test]
354 async fn try_claim_creation_fails_when_already_creating() {
355 let map = SessionMap::new();
356 assert!(map.try_claim_creation(chat(1)).await);
357 assert!(!map.try_claim_creation(chat(1)).await);
359 }
360
361 #[tokio::test]
362 async fn try_claim_creation_fails_when_session_exists() {
363 let map = SessionMap::new();
364 map.insert(chat(1), SessionId::new()).await;
365 assert!(!map.try_claim_creation(chat(1)).await);
366 }
367
368 #[tokio::test]
369 async fn finish_creation_inserts_session_and_clears_lock() {
370 let map = SessionMap::new();
371 assert!(map.try_claim_creation(chat(1)).await);
372
373 let sid = SessionId::new();
374 map.finish_creation(chat(1), sid.clone()).await;
375
376 assert_eq!(map.get_session_id(chat(1)).await, Some(sid));
377 assert!(!map.try_claim_creation(chat(1)).await);
380 }
381
382 #[tokio::test]
383 async fn cancel_creation_clears_lock() {
384 let map = SessionMap::new();
385 assert!(map.try_claim_creation(chat(1)).await);
386 map.cancel_creation(chat(1)).await;
387 assert!(map.try_claim_creation(chat(1)).await);
389 }
390
391 #[tokio::test]
392 async fn creating_blocks_try_start_existing_turn() {
393 let map = SessionMap::new();
394 assert!(map.try_claim_creation(chat(1)).await);
395 assert!(matches!(
397 map.try_start_existing_turn(chat(1)).await,
398 TurnStartResult::TurnBusy
399 ));
400 }
401}