gewe_session/
lib.rs

1use async_trait::async_trait;
2use gewe_core::{AppId, BotContext};
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, VecDeque};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub type BotRegistry = Arc<RwLock<HashMap<AppId, BotContext>>>;
9
10#[async_trait]
11pub trait SessionStore: Send + Sync {
12    async fn get_session(&self, app_id: &AppId) -> Option<BotContext>;
13    async fn put_session(&self, context: BotContext);
14    /// returns true if this message id is first seen
15    async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool;
16}
17
18#[derive(Clone, Default)]
19pub struct InMemorySessionStore {
20    inner: Arc<RwLock<HashMap<AppId, StoredEntry>>>,
21}
22
23#[derive(Clone, Serialize, Deserialize)]
24struct StoredEntry {
25    context: BotContext,
26    #[serde(default)]
27    seen: VecDeque<i64>,
28}
29
30#[async_trait]
31impl SessionStore for InMemorySessionStore {
32    async fn get_session(&self, app_id: &AppId) -> Option<BotContext> {
33        let map: tokio::sync::RwLockReadGuard<'_, HashMap<AppId, StoredEntry>> =
34            self.inner.read().await;
35        map.get(app_id).map(|entry| entry.context.clone())
36    }
37
38    async fn put_session(&self, context: BotContext) {
39        let mut map: tokio::sync::RwLockWriteGuard<'_, HashMap<AppId, StoredEntry>> =
40            self.inner.write().await;
41        map.insert(
42            context.app_id.clone(),
43            StoredEntry {
44                context,
45                seen: VecDeque::new(),
46            },
47        );
48    }
49
50    async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool {
51        let mut map: tokio::sync::RwLockWriteGuard<'_, HashMap<AppId, StoredEntry>> =
52            self.inner.write().await;
53        let entry = match map.get_mut(app_id) {
54            Some(entry) => entry,
55            None => return true,
56        };
57        if entry.seen.contains(&new_msg_id) {
58            return false;
59        }
60
61        entry.seen.push_back(new_msg_id);
62        // 防止无限增长,简单裁剪
63        const MAX_SEEN: usize = 1024;
64        if entry.seen.len() > MAX_SEEN {
65            entry.seen.pop_front();
66        }
67        true
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    fn create_test_context(app_id: &str) -> BotContext {
76        BotContext {
77            app_id: AppId(app_id.to_string()),
78            token: format!("token_{}", app_id),
79            webhook_secret: None,
80            description: None,
81        }
82    }
83
84    #[tokio::test]
85    async fn test_in_memory_store_default() {
86        let store = InMemorySessionStore::default();
87        let app_id = AppId("test_app".to_string());
88        let result = store.get_session(&app_id).await;
89        assert!(result.is_none());
90    }
91
92    #[tokio::test]
93    async fn test_in_memory_store_put_and_get_session() {
94        let store = InMemorySessionStore::default();
95        let ctx = create_test_context("app123");
96
97        store.put_session(ctx.clone()).await;
98
99        let result = store.get_session(&ctx.app_id).await;
100        assert!(result.is_some());
101        let retrieved = result.unwrap();
102        assert_eq!(retrieved.app_id.0, "app123");
103        assert_eq!(retrieved.token, "token_app123");
104    }
105
106    #[tokio::test]
107    async fn test_in_memory_store_overwrite_session() {
108        let store = InMemorySessionStore::default();
109        let ctx1 = BotContext {
110            app_id: AppId("app123".to_string()),
111            token: "token1".to_string(),
112            webhook_secret: None,
113            description: None,
114        };
115        let ctx2 = BotContext {
116            app_id: AppId("app123".to_string()),
117            token: "token2".to_string(),
118            webhook_secret: Some("secret".to_string()),
119            description: Some("updated".to_string()),
120        };
121
122        store.put_session(ctx1).await;
123        store.put_session(ctx2).await;
124
125        let result = store.get_session(&AppId("app123".to_string())).await;
126        assert!(result.is_some());
127        let retrieved = result.unwrap();
128        assert_eq!(retrieved.token, "token2");
129        assert_eq!(retrieved.webhook_secret, Some("secret".to_string()));
130    }
131
132    #[tokio::test]
133    async fn test_in_memory_store_multiple_sessions() {
134        let store = InMemorySessionStore::default();
135        let ctx1 = create_test_context("app1");
136        let ctx2 = create_test_context("app2");
137        let ctx3 = create_test_context("app3");
138
139        store.put_session(ctx1).await;
140        store.put_session(ctx2).await;
141        store.put_session(ctx3).await;
142
143        assert!(store
144            .get_session(&AppId("app1".to_string()))
145            .await
146            .is_some());
147        assert!(store
148            .get_session(&AppId("app2".to_string()))
149            .await
150            .is_some());
151        assert!(store
152            .get_session(&AppId("app3".to_string()))
153            .await
154            .is_some());
155        assert!(store
156            .get_session(&AppId("app4".to_string()))
157            .await
158            .is_none());
159    }
160
161    #[tokio::test]
162    async fn test_mark_message_seen_no_session() {
163        let store = InMemorySessionStore::default();
164        let app_id = AppId("nonexistent".to_string());
165
166        // Should return true for unknown app_id
167        let result = store.mark_message_seen(&app_id, 12345).await;
168        assert!(result);
169    }
170
171    #[tokio::test]
172    async fn test_mark_message_seen_first_time() {
173        let store = InMemorySessionStore::default();
174        let ctx = create_test_context("app123");
175        store.put_session(ctx.clone()).await;
176
177        // First time seeing this message
178        let result = store.mark_message_seen(&ctx.app_id, 12345).await;
179        assert!(result);
180    }
181
182    #[tokio::test]
183    async fn test_mark_message_seen_duplicate() {
184        let store = InMemorySessionStore::default();
185        let ctx = create_test_context("app123");
186        store.put_session(ctx.clone()).await;
187
188        // First time - should return true
189        let result1 = store.mark_message_seen(&ctx.app_id, 12345).await;
190        assert!(result1);
191
192        // Second time - should return false (duplicate)
193        let result2 = store.mark_message_seen(&ctx.app_id, 12345).await;
194        assert!(!result2);
195    }
196
197    #[tokio::test]
198    async fn test_mark_message_seen_different_messages() {
199        let store = InMemorySessionStore::default();
200        let ctx = create_test_context("app123");
201        store.put_session(ctx.clone()).await;
202
203        let result1 = store.mark_message_seen(&ctx.app_id, 1).await;
204        let result2 = store.mark_message_seen(&ctx.app_id, 2).await;
205        let result3 = store.mark_message_seen(&ctx.app_id, 3).await;
206
207        assert!(result1);
208        assert!(result2);
209        assert!(result3);
210
211        // Check duplicates
212        assert!(!store.mark_message_seen(&ctx.app_id, 1).await);
213        assert!(!store.mark_message_seen(&ctx.app_id, 2).await);
214        assert!(!store.mark_message_seen(&ctx.app_id, 3).await);
215    }
216
217    #[tokio::test]
218    async fn test_mark_message_seen_max_capacity() {
219        let store = InMemorySessionStore::default();
220        let ctx = create_test_context("app123");
221        store.put_session(ctx.clone()).await;
222
223        // Add more than MAX_SEEN (1024) messages
224        for i in 0..1030 {
225            store.mark_message_seen(&ctx.app_id, i).await;
226        }
227
228        // The oldest messages should have been evicted
229        // Message 0-5 should be gone (evicted to make room)
230        // Message 6+ should still be there
231        let result_old = store.mark_message_seen(&ctx.app_id, 0).await;
232        assert!(result_old); // Should return true as if first seen (was evicted)
233
234        let result_recent = store.mark_message_seen(&ctx.app_id, 1029).await;
235        assert!(!result_recent); // Should return false (still in cache)
236    }
237
238    #[tokio::test]
239    async fn test_mark_message_seen_different_apps() {
240        let store = InMemorySessionStore::default();
241        let ctx1 = create_test_context("app1");
242        let ctx2 = create_test_context("app2");
243        store.put_session(ctx1.clone()).await;
244        store.put_session(ctx2.clone()).await;
245
246        // Same message ID for different apps should both be first seen
247        let result1 = store.mark_message_seen(&ctx1.app_id, 12345).await;
248        let result2 = store.mark_message_seen(&ctx2.app_id, 12345).await;
249
250        assert!(result1);
251        assert!(result2);
252
253        // But duplicates within same app should be caught
254        assert!(!store.mark_message_seen(&ctx1.app_id, 12345).await);
255        assert!(!store.mark_message_seen(&ctx2.app_id, 12345).await);
256    }
257
258    #[tokio::test]
259    async fn test_in_memory_store_clone() {
260        let store1 = InMemorySessionStore::default();
261        let ctx = create_test_context("app123");
262        store1.put_session(ctx.clone()).await;
263
264        let store2 = store1.clone();
265
266        // Both stores should share the same data
267        let result1 = store1.get_session(&ctx.app_id).await;
268        let result2 = store2.get_session(&ctx.app_id).await;
269
270        assert!(result1.is_some());
271        assert!(result2.is_some());
272        assert_eq!(result1.unwrap().token, result2.unwrap().token);
273    }
274
275    #[tokio::test]
276    async fn test_stored_entry_serialize_deserialize() {
277        let entry = StoredEntry {
278            context: create_test_context("app123"),
279            seen: VecDeque::from([1, 2, 3]),
280        };
281
282        let json = serde_json::to_string(&entry).unwrap();
283        let deserialized: StoredEntry = serde_json::from_str(&json).unwrap();
284
285        assert_eq!(deserialized.context.app_id.0, "app123");
286        assert_eq!(deserialized.seen.len(), 3);
287        assert!(deserialized.seen.contains(&1));
288        assert!(deserialized.seen.contains(&2));
289        assert!(deserialized.seen.contains(&3));
290    }
291
292    #[tokio::test]
293    async fn test_stored_entry_default_seen() {
294        // Test that seen defaults to empty when deserializing without it
295        let json = r#"{"context":{"appId":"app123","token":"token"}}"#;
296        let entry: StoredEntry = serde_json::from_str(json).unwrap();
297        assert!(entry.seen.is_empty());
298    }
299
300    #[test]
301    fn test_bot_registry_type() {
302        // Test that BotRegistry can be created
303        let registry: BotRegistry = Arc::new(RwLock::new(HashMap::new()));
304        assert!(Arc::strong_count(&registry) == 1);
305    }
306
307    #[tokio::test]
308    async fn test_bot_registry_operations() {
309        let registry: BotRegistry = Arc::new(RwLock::new(HashMap::new()));
310        let ctx = create_test_context("app123");
311
312        // Insert
313        {
314            let mut map = registry.write().await;
315            map.insert(ctx.app_id.clone(), ctx.clone());
316        }
317
318        // Read
319        {
320            let map = registry.read().await;
321            let retrieved = map.get(&ctx.app_id);
322            assert!(retrieved.is_some());
323            assert_eq!(retrieved.unwrap().token, "token_app123");
324        }
325    }
326
327    #[tokio::test]
328    async fn test_concurrent_access() {
329        use std::sync::Arc;
330
331        let store = Arc::new(InMemorySessionStore::default());
332        let ctx = create_test_context("app123");
333        store.put_session(ctx.clone()).await;
334
335        let store1 = Arc::clone(&store);
336        let store2 = Arc::clone(&store);
337        let app_id1 = ctx.app_id.clone();
338        let app_id2 = ctx.app_id.clone();
339
340        let handle1 = tokio::spawn(async move {
341            for i in 0..100 {
342                store1.mark_message_seen(&app_id1, i).await;
343            }
344        });
345
346        let handle2 = tokio::spawn(async move {
347            for i in 100..200 {
348                store2.mark_message_seen(&app_id2, i).await;
349            }
350        });
351
352        handle1.await.unwrap();
353        handle2.await.unwrap();
354
355        // Both ranges should have been processed
356        assert!(!store.mark_message_seen(&ctx.app_id, 50).await);
357        assert!(!store.mark_message_seen(&ctx.app_id, 150).await);
358    }
359}
360
361#[cfg(feature = "sqlite")]
362pub mod sqlite_store {
363    use super::{AppId, BotContext, SessionStore, StoredEntry};
364    use async_trait::async_trait;
365    use serde_json;
366    use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
367    use std::collections::VecDeque;
368    use std::time::Duration;
369
370    #[derive(Clone)]
371    pub struct SqliteSessionStore {
372        pool: SqlitePool,
373    }
374
375    impl SqliteSessionStore {
376        pub async fn connect(database_url: &str) -> sqlx::Result<Self> {
377            let pool = SqlitePoolOptions::new()
378                .max_connections(5)
379                .acquire_timeout(Duration::from_secs(5))
380                .connect(database_url)
381                .await?;
382            sqlx::query(
383                r#"
384CREATE TABLE IF NOT EXISTS sessions (
385    app_id TEXT PRIMARY KEY,
386    payload TEXT NOT NULL
387);
388"#,
389            )
390            .execute(&pool)
391            .await?;
392            Ok(Self { pool })
393        }
394
395        async fn load_entry(&self, app_id: &AppId) -> Option<StoredEntry> {
396            let row: Option<(String,)> =
397                sqlx::query_as("SELECT payload FROM sessions WHERE app_id = ?")
398                    .bind(&app_id.0)
399                    .fetch_optional(&self.pool)
400                    .await
401                    .ok()?;
402            row.and_then(|(payload,)| serde_json::from_str::<StoredEntry>(&payload).ok())
403        }
404    }
405
406    #[async_trait]
407    impl SessionStore for SqliteSessionStore {
408        async fn get_session(&self, app_id: &AppId) -> Option<BotContext> {
409            self.load_entry(app_id).await.map(|entry| entry.context)
410        }
411
412        async fn put_session(&self, context: BotContext) {
413            let entry = StoredEntry {
414                context,
415                seen: VecDeque::new(),
416            };
417            let payload = match serde_json::to_string(&entry) {
418                Ok(p) => p,
419                Err(err) => {
420                    tracing::warn!(?err, "failed to serialize session");
421                    return;
422                }
423            };
424            let _ = sqlx::query("INSERT OR REPLACE INTO sessions (app_id, payload) VALUES (?, ?)")
425                .bind(&entry.context.app_id.0)
426                .bind(payload)
427                .execute(&self.pool)
428                .await;
429        }
430
431        async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool {
432            let mut entry = match self.load_entry(app_id).await {
433                Some(entry) => entry,
434                None => return true,
435            };
436            if entry.seen.contains(&new_msg_id) {
437                return false;
438            }
439            entry.seen.push_back(new_msg_id);
440            const MAX_SEEN: usize = 1024;
441            if entry.seen.len() > MAX_SEEN {
442                entry.seen.pop_front();
443            }
444            let payload = match serde_json::to_string(&entry) {
445                Ok(p) => p,
446                Err(_) => return true,
447            };
448            let _ = sqlx::query("INSERT OR REPLACE INTO sessions (app_id, payload) VALUES (?, ?)")
449                .bind(&entry.context.app_id.0)
450                .bind(payload)
451                .execute(&self.pool)
452                .await;
453            true
454        }
455    }
456}
457
458#[cfg(feature = "redis-store")]
459pub mod redis_store {
460    use super::{AppId, BotContext, SessionStore, StoredEntry};
461    use async_trait::async_trait;
462    use redis::{AsyncCommands, Client};
463    use serde_json;
464    use std::collections::VecDeque;
465
466    #[derive(Clone)]
467    pub struct RedisSessionStore {
468        client: Client,
469        prefix: String,
470    }
471
472    impl RedisSessionStore {
473        pub fn new(url: &str, prefix: impl Into<String>) -> redis::RedisResult<Self> {
474            Ok(Self {
475                client: Client::open(url)?,
476                prefix: prefix.into(),
477            })
478        }
479
480        fn key(&self, app_id: &AppId) -> String {
481            format!("{}:{}", self.prefix, app_id.0)
482        }
483
484        async fn load_entry(&self, app_id: &AppId) -> Option<StoredEntry> {
485            let mut conn = self.client.get_multiplexed_async_connection().await.ok()?;
486            let payload: Option<String> = conn.get(self.key(app_id)).await.ok()?;
487            payload.and_then(|p| serde_json::from_str::<StoredEntry>(&p).ok())
488        }
489    }
490
491    #[async_trait]
492    impl SessionStore for RedisSessionStore {
493        async fn get_session(&self, app_id: &AppId) -> Option<BotContext> {
494            self.load_entry(app_id).await.map(|entry| entry.context)
495        }
496
497        async fn put_session(&self, context: BotContext) {
498            let entry = StoredEntry {
499                context,
500                seen: VecDeque::new(),
501            };
502            if let Ok(payload) = serde_json::to_string(&entry) {
503                if let Ok(mut conn) = self.client.get_multiplexed_async_connection().await {
504                    let _: redis::RedisResult<()> =
505                        conn.set(self.key(&entry.context.app_id), payload).await;
506                }
507            }
508        }
509
510        async fn mark_message_seen(&self, app_id: &AppId, new_msg_id: i64) -> bool {
511            // Fetch and update atomically best-effort; simple get/set for now.
512            let mut entry = match self.load_entry(app_id).await {
513                Some(entry) => entry,
514                None => return true,
515            };
516            if entry.seen.contains(&new_msg_id) {
517                return false;
518            }
519            entry.seen.push_back(new_msg_id);
520            const MAX_SEEN: usize = 1024;
521            if entry.seen.len() > MAX_SEEN {
522                entry.seen.pop_front();
523            }
524
525            if let Ok(payload) = serde_json::to_string(&entry) {
526                if let Ok(mut conn) = self.client.get_multiplexed_async_connection().await {
527                    let _: redis::RedisResult<()> = conn.set(self.key(app_id), payload).await;
528                }
529            }
530            true
531        }
532    }
533}