1use async_trait::async_trait;
2use gewe_core::{AppId, BotContext};
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, VecDeque};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub type BotRegistry = Arc<RwLock<HashMap<AppId, BotContext>>>;
9
10#[async_trait]
11pub trait SessionStore: Send + Sync {
12 async fn get_session(&self, app_id: &AppId) -> Option<BotContext>;
13 async fn put_session(&self, context: BotContext);
14 async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool;
16}
17
18#[derive(Clone, Default)]
19pub struct InMemorySessionStore {
20 inner: Arc<RwLock<HashMap<AppId, StoredEntry>>>,
21}
22
23#[derive(Clone, Serialize, Deserialize)]
24struct StoredEntry {
25 context: BotContext,
26 #[serde(default)]
27 seen: VecDeque<i64>,
28}
29
30#[async_trait]
31impl SessionStore for InMemorySessionStore {
32 async fn get_session(&self, app_id: &AppId) -> Option<BotContext> {
33 let map: tokio::sync::RwLockReadGuard<'_, HashMap<AppId, StoredEntry>> =
34 self.inner.read().await;
35 map.get(app_id).map(|entry| entry.context.clone())
36 }
37
38 async fn put_session(&self, context: BotContext) {
39 let mut map: tokio::sync::RwLockWriteGuard<'_, HashMap<AppId, StoredEntry>> =
40 self.inner.write().await;
41 map.insert(
42 context.app_id.clone(),
43 StoredEntry {
44 context,
45 seen: VecDeque::new(),
46 },
47 );
48 }
49
50 async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool {
51 let mut map: tokio::sync::RwLockWriteGuard<'_, HashMap<AppId, StoredEntry>> =
52 self.inner.write().await;
53 let entry = match map.get_mut(app_id) {
54 Some(entry) => entry,
55 None => return true,
56 };
57 if entry.seen.contains(&new_msg_id) {
58 return false;
59 }
60
61 entry.seen.push_back(new_msg_id);
62 const MAX_SEEN: usize = 1024;
64 if entry.seen.len() > MAX_SEEN {
65 entry.seen.pop_front();
66 }
67 true
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 fn create_test_context(app_id: &str) -> BotContext {
76 BotContext {
77 app_id: AppId(app_id.to_string()),
78 token: format!("token_{}", app_id),
79 webhook_secret: None,
80 description: None,
81 }
82 }
83
84 #[tokio::test]
85 async fn test_in_memory_store_default() {
86 let store = InMemorySessionStore::default();
87 let app_id = AppId("test_app".to_string());
88 let result = store.get_session(&app_id).await;
89 assert!(result.is_none());
90 }
91
92 #[tokio::test]
93 async fn test_in_memory_store_put_and_get_session() {
94 let store = InMemorySessionStore::default();
95 let ctx = create_test_context("app123");
96
97 store.put_session(ctx.clone()).await;
98
99 let result = store.get_session(&ctx.app_id).await;
100 assert!(result.is_some());
101 let retrieved = result.unwrap();
102 assert_eq!(retrieved.app_id.0, "app123");
103 assert_eq!(retrieved.token, "token_app123");
104 }
105
106 #[tokio::test]
107 async fn test_in_memory_store_overwrite_session() {
108 let store = InMemorySessionStore::default();
109 let ctx1 = BotContext {
110 app_id: AppId("app123".to_string()),
111 token: "token1".to_string(),
112 webhook_secret: None,
113 description: None,
114 };
115 let ctx2 = BotContext {
116 app_id: AppId("app123".to_string()),
117 token: "token2".to_string(),
118 webhook_secret: Some("secret".to_string()),
119 description: Some("updated".to_string()),
120 };
121
122 store.put_session(ctx1).await;
123 store.put_session(ctx2).await;
124
125 let result = store.get_session(&AppId("app123".to_string())).await;
126 assert!(result.is_some());
127 let retrieved = result.unwrap();
128 assert_eq!(retrieved.token, "token2");
129 assert_eq!(retrieved.webhook_secret, Some("secret".to_string()));
130 }
131
132 #[tokio::test]
133 async fn test_in_memory_store_multiple_sessions() {
134 let store = InMemorySessionStore::default();
135 let ctx1 = create_test_context("app1");
136 let ctx2 = create_test_context("app2");
137 let ctx3 = create_test_context("app3");
138
139 store.put_session(ctx1).await;
140 store.put_session(ctx2).await;
141 store.put_session(ctx3).await;
142
143 assert!(store
144 .get_session(&AppId("app1".to_string()))
145 .await
146 .is_some());
147 assert!(store
148 .get_session(&AppId("app2".to_string()))
149 .await
150 .is_some());
151 assert!(store
152 .get_session(&AppId("app3".to_string()))
153 .await
154 .is_some());
155 assert!(store
156 .get_session(&AppId("app4".to_string()))
157 .await
158 .is_none());
159 }
160
161 #[tokio::test]
162 async fn test_mark_message_seen_no_session() {
163 let store = InMemorySessionStore::default();
164 let app_id = AppId("nonexistent".to_string());
165
166 let result = store.mark_message_seen(&app_id, 12345).await;
168 assert!(result);
169 }
170
171 #[tokio::test]
172 async fn test_mark_message_seen_first_time() {
173 let store = InMemorySessionStore::default();
174 let ctx = create_test_context("app123");
175 store.put_session(ctx.clone()).await;
176
177 let result = store.mark_message_seen(&ctx.app_id, 12345).await;
179 assert!(result);
180 }
181
182 #[tokio::test]
183 async fn test_mark_message_seen_duplicate() {
184 let store = InMemorySessionStore::default();
185 let ctx = create_test_context("app123");
186 store.put_session(ctx.clone()).await;
187
188 let result1 = store.mark_message_seen(&ctx.app_id, 12345).await;
190 assert!(result1);
191
192 let result2 = store.mark_message_seen(&ctx.app_id, 12345).await;
194 assert!(!result2);
195 }
196
197 #[tokio::test]
198 async fn test_mark_message_seen_different_messages() {
199 let store = InMemorySessionStore::default();
200 let ctx = create_test_context("app123");
201 store.put_session(ctx.clone()).await;
202
203 let result1 = store.mark_message_seen(&ctx.app_id, 1).await;
204 let result2 = store.mark_message_seen(&ctx.app_id, 2).await;
205 let result3 = store.mark_message_seen(&ctx.app_id, 3).await;
206
207 assert!(result1);
208 assert!(result2);
209 assert!(result3);
210
211 assert!(!store.mark_message_seen(&ctx.app_id, 1).await);
213 assert!(!store.mark_message_seen(&ctx.app_id, 2).await);
214 assert!(!store.mark_message_seen(&ctx.app_id, 3).await);
215 }
216
217 #[tokio::test]
218 async fn test_mark_message_seen_max_capacity() {
219 let store = InMemorySessionStore::default();
220 let ctx = create_test_context("app123");
221 store.put_session(ctx.clone()).await;
222
223 for i in 0..1030 {
225 store.mark_message_seen(&ctx.app_id, i).await;
226 }
227
228 let result_old = store.mark_message_seen(&ctx.app_id, 0).await;
232 assert!(result_old); let result_recent = store.mark_message_seen(&ctx.app_id, 1029).await;
235 assert!(!result_recent); }
237
238 #[tokio::test]
239 async fn test_mark_message_seen_different_apps() {
240 let store = InMemorySessionStore::default();
241 let ctx1 = create_test_context("app1");
242 let ctx2 = create_test_context("app2");
243 store.put_session(ctx1.clone()).await;
244 store.put_session(ctx2.clone()).await;
245
246 let result1 = store.mark_message_seen(&ctx1.app_id, 12345).await;
248 let result2 = store.mark_message_seen(&ctx2.app_id, 12345).await;
249
250 assert!(result1);
251 assert!(result2);
252
253 assert!(!store.mark_message_seen(&ctx1.app_id, 12345).await);
255 assert!(!store.mark_message_seen(&ctx2.app_id, 12345).await);
256 }
257
258 #[tokio::test]
259 async fn test_in_memory_store_clone() {
260 let store1 = InMemorySessionStore::default();
261 let ctx = create_test_context("app123");
262 store1.put_session(ctx.clone()).await;
263
264 let store2 = store1.clone();
265
266 let result1 = store1.get_session(&ctx.app_id).await;
268 let result2 = store2.get_session(&ctx.app_id).await;
269
270 assert!(result1.is_some());
271 assert!(result2.is_some());
272 assert_eq!(result1.unwrap().token, result2.unwrap().token);
273 }
274
275 #[tokio::test]
276 async fn test_stored_entry_serialize_deserialize() {
277 let entry = StoredEntry {
278 context: create_test_context("app123"),
279 seen: VecDeque::from([1, 2, 3]),
280 };
281
282 let json = serde_json::to_string(&entry).unwrap();
283 let deserialized: StoredEntry = serde_json::from_str(&json).unwrap();
284
285 assert_eq!(deserialized.context.app_id.0, "app123");
286 assert_eq!(deserialized.seen.len(), 3);
287 assert!(deserialized.seen.contains(&1));
288 assert!(deserialized.seen.contains(&2));
289 assert!(deserialized.seen.contains(&3));
290 }
291
292 #[tokio::test]
293 async fn test_stored_entry_default_seen() {
294 let json = r#"{"context":{"appId":"app123","token":"token"}}"#;
296 let entry: StoredEntry = serde_json::from_str(json).unwrap();
297 assert!(entry.seen.is_empty());
298 }
299
300 #[test]
301 fn test_bot_registry_type() {
302 let registry: BotRegistry = Arc::new(RwLock::new(HashMap::new()));
304 assert!(Arc::strong_count(®istry) == 1);
305 }
306
307 #[tokio::test]
308 async fn test_bot_registry_operations() {
309 let registry: BotRegistry = Arc::new(RwLock::new(HashMap::new()));
310 let ctx = create_test_context("app123");
311
312 {
314 let mut map = registry.write().await;
315 map.insert(ctx.app_id.clone(), ctx.clone());
316 }
317
318 {
320 let map = registry.read().await;
321 let retrieved = map.get(&ctx.app_id);
322 assert!(retrieved.is_some());
323 assert_eq!(retrieved.unwrap().token, "token_app123");
324 }
325 }
326
327 #[tokio::test]
328 async fn test_concurrent_access() {
329 use std::sync::Arc;
330
331 let store = Arc::new(InMemorySessionStore::default());
332 let ctx = create_test_context("app123");
333 store.put_session(ctx.clone()).await;
334
335 let store1 = Arc::clone(&store);
336 let store2 = Arc::clone(&store);
337 let app_id1 = ctx.app_id.clone();
338 let app_id2 = ctx.app_id.clone();
339
340 let handle1 = tokio::spawn(async move {
341 for i in 0..100 {
342 store1.mark_message_seen(&app_id1, i).await;
343 }
344 });
345
346 let handle2 = tokio::spawn(async move {
347 for i in 100..200 {
348 store2.mark_message_seen(&app_id2, i).await;
349 }
350 });
351
352 handle1.await.unwrap();
353 handle2.await.unwrap();
354
355 assert!(!store.mark_message_seen(&ctx.app_id, 50).await);
357 assert!(!store.mark_message_seen(&ctx.app_id, 150).await);
358 }
359}
360
361#[cfg(feature = "sqlite")]
362pub mod sqlite_store {
363 use super::{AppId, BotContext, SessionStore, StoredEntry};
364 use async_trait::async_trait;
365 use serde_json;
366 use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
367 use std::collections::VecDeque;
368 use std::time::Duration;
369
370 #[derive(Clone)]
371 pub struct SqliteSessionStore {
372 pool: SqlitePool,
373 }
374
375 impl SqliteSessionStore {
376 pub async fn connect(database_url: &str) -> sqlx::Result<Self> {
377 let pool = SqlitePoolOptions::new()
378 .max_connections(5)
379 .acquire_timeout(Duration::from_secs(5))
380 .connect(database_url)
381 .await?;
382 sqlx::query(
383 r#"
384CREATE TABLE IF NOT EXISTS sessions (
385 app_id TEXT PRIMARY KEY,
386 payload TEXT NOT NULL
387);
388"#,
389 )
390 .execute(&pool)
391 .await?;
392 Ok(Self { pool })
393 }
394
395 async fn load_entry(&self, app_id: &AppId) -> Option<StoredEntry> {
396 let row: Option<(String,)> =
397 sqlx::query_as("SELECT payload FROM sessions WHERE app_id = ?")
398 .bind(&app_id.0)
399 .fetch_optional(&self.pool)
400 .await
401 .ok()?;
402 row.and_then(|(payload,)| serde_json::from_str::<StoredEntry>(&payload).ok())
403 }
404 }
405
406 #[async_trait]
407 impl SessionStore for SqliteSessionStore {
408 async fn get_session(&self, app_id: &AppId) -> Option<BotContext> {
409 self.load_entry(app_id).await.map(|entry| entry.context)
410 }
411
412 async fn put_session(&self, context: BotContext) {
413 let entry = StoredEntry {
414 context,
415 seen: VecDeque::new(),
416 };
417 let payload = match serde_json::to_string(&entry) {
418 Ok(p) => p,
419 Err(err) => {
420 tracing::warn!(?err, "failed to serialize session");
421 return;
422 }
423 };
424 let _ = sqlx::query("INSERT OR REPLACE INTO sessions (app_id, payload) VALUES (?, ?)")
425 .bind(&entry.context.app_id.0)
426 .bind(payload)
427 .execute(&self.pool)
428 .await;
429 }
430
431 async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool {
432 let mut entry = match self.load_entry(app_id).await {
433 Some(entry) => entry,
434 None => return true,
435 };
436 if entry.seen.contains(&new_msg_id) {
437 return false;
438 }
439 entry.seen.push_back(new_msg_id);
440 const MAX_SEEN: usize = 1024;
441 if entry.seen.len() > MAX_SEEN {
442 entry.seen.pop_front();
443 }
444 let payload = match serde_json::to_string(&entry) {
445 Ok(p) => p,
446 Err(_) => return true,
447 };
448 let _ = sqlx::query("INSERT OR REPLACE INTO sessions (app_id, payload) VALUES (?, ?)")
449 .bind(&entry.context.app_id.0)
450 .bind(payload)
451 .execute(&self.pool)
452 .await;
453 true
454 }
455 }
456}
457
458#[cfg(feature = "redis-store")]
459pub mod redis_store {
460 use super::{AppId, BotContext, SessionStore, StoredEntry};
461 use async_trait::async_trait;
462 use redis::{AsyncCommands, Client};
463 use serde_json;
464 use std::collections::VecDeque;
465
466 #[derive(Clone)]
467 pub struct RedisSessionStore {
468 client: Client,
469 prefix: String,
470 }
471
472 impl RedisSessionStore {
473 pub fn new(url: &str, prefix: impl Into<String>) -> redis::RedisResult<Self> {
474 Ok(Self {
475 client: Client::open(url)?,
476 prefix: prefix.into(),
477 })
478 }
479
480 fn key(&self, app_id: &AppId) -> String {
481 format!("{}:{}", self.prefix, app_id.0)
482 }
483
484 async fn load_entry(&self, app_id: &AppId) -> Option<StoredEntry> {
485 let mut conn = self.client.get_multiplexed_async_connection().await.ok()?;
486 let payload: Option<String> = conn.get(self.key(app_id)).await.ok()?;
487 payload.and_then(|p| serde_json::from_str::<StoredEntry>(&p).ok())
488 }
489 }
490
491 #[async_trait]
492 impl SessionStore for RedisSessionStore {
493 async fn get_session(&self, app_id: &AppId) -> Option<BotContext> {
494 self.load_entry(app_id).await.map(|entry| entry.context)
495 }
496
497 async fn put_session(&self, context: BotContext) {
498 let entry = StoredEntry {
499 context,
500 seen: VecDeque::new(),
501 };
502 if let Ok(payload) = serde_json::to_string(&entry) {
503 if let Ok(mut conn) = self.client.get_multiplexed_async_connection().await {
504 let _: redis::RedisResult<()> =
505 conn.set(self.key(&entry.context.app_id), payload).await;
506 }
507 }
508 }
509
510 async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool {
511 let mut entry = match self.load_entry(app_id).await {
513 Some(entry) => entry,
514 None => return true,
515 };
516 if entry.seen.contains(&new_msg_id) {
517 return false;
518 }
519 entry.seen.push_back(new_msg_id);
520 const MAX_SEEN: usize = 1024;
521 if entry.seen.len() > MAX_SEEN {
522 entry.seen.pop_front();
523 }
524
525 if let Ok(payload) = serde_json::to_string(&entry) {
526 if let Ok(mut conn) = self.client.get_multiplexed_async_connection().await {
527 let _: redis::RedisResult<()> = conn.set(self.key(app_id), payload).await;
528 }
529 }
530 true
531 }
532 }
533}