Skip to main content

greentic_session/
inmemory.rs

1use crate::ReplyScope;
2use crate::error::SessionResult;
3use crate::error::{GreenticError, invalid_argument, not_found};
4use crate::store::SessionStore;
5use greentic_types::{EnvId, SessionData, SessionKey, TeamId, TenantCtx, TenantId, UserId};
6use parking_lot::RwLock;
7use std::collections::{HashMap, HashSet};
8use std::time::{Duration, Instant};
9use uuid::Uuid;
10
11/// Simple in-memory implementation backed by hash maps.
12pub struct InMemorySessionStore {
13    sessions: RwLock<HashMap<SessionKey, SessionEntry>>,
14    user_waits: RwLock<HashMap<UserLookupKey, HashSet<SessionKey>>>,
15    scope_index: RwLock<HashMap<ScopeLookupKey, ScopeEntry>>,
16}
17
18impl Default for InMemorySessionStore {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl InMemorySessionStore {
25    /// Constructs an empty store.
26    pub fn new() -> Self {
27        Self {
28            sessions: RwLock::new(HashMap::new()),
29            user_waits: RwLock::new(HashMap::new()),
30            scope_index: RwLock::new(HashMap::new()),
31        }
32    }
33
34    fn next_key() -> SessionKey {
35        SessionKey::new(Uuid::new_v4().to_string())
36    }
37
38    fn normalize_team(ctx: &TenantCtx) -> Option<&TeamId> {
39        ctx.team_id.as_ref().or(ctx.team.as_ref())
40    }
41
42    fn normalize_user(ctx: &TenantCtx) -> Option<&UserId> {
43        ctx.user_id.as_ref().or(ctx.user.as_ref())
44    }
45
46    fn ttl_deadline(ttl: Option<Duration>) -> Option<Instant> {
47        ttl.map(|value| Instant::now() + value)
48    }
49
50    fn is_expired(deadline: Option<Instant>) -> bool {
51        deadline
52            .map(|value| Instant::now() >= value)
53            .unwrap_or(false)
54    }
55
56    fn ctx_mismatch(expected: &TenantCtx, provided: &TenantCtx, reason: &str) -> GreenticError {
57        let expected_team = Self::normalize_team(expected)
58            .map(|t| t.as_str())
59            .unwrap_or("-");
60        let provided_team = Self::normalize_team(provided)
61            .map(|t| t.as_str())
62            .unwrap_or("-");
63        let expected_user_presence = if Self::normalize_user(expected).is_some() {
64            "present"
65        } else {
66            "missing"
67        };
68        let provided_user_presence = if Self::normalize_user(provided).is_some() {
69            "present"
70        } else {
71            "missing"
72        };
73        invalid_argument(format!(
74            "tenant context mismatch ({reason}): expected env={}, tenant={}, team={}, user={}, got env={}, tenant={}, team={}, user={}",
75            expected.env.as_str(),
76            expected.tenant_id.as_str(),
77            expected_team,
78            expected_user_presence,
79            provided.env.as_str(),
80            provided.tenant_id.as_str(),
81            provided_team,
82            provided_user_presence
83        ))
84    }
85
86    fn ensure_alignment(ctx: &TenantCtx, data: &SessionData) -> SessionResult<()> {
87        let stored = &data.tenant_ctx;
88        if ctx.env != stored.env || ctx.tenant_id != stored.tenant_id {
89            return Err(Self::ctx_mismatch(stored, ctx, "env/tenant must match"));
90        }
91        if Self::normalize_team(ctx) != Self::normalize_team(stored) {
92            return Err(Self::ctx_mismatch(stored, ctx, "team must match"));
93        }
94        if let Some(stored_user) = Self::normalize_user(stored) {
95            let Some(provided_user) = Self::normalize_user(ctx) else {
96                return Err(Self::ctx_mismatch(
97                    stored,
98                    ctx,
99                    "user required by session but missing in caller context",
100                ));
101            };
102            if stored_user != provided_user {
103                return Err(Self::ctx_mismatch(
104                    stored,
105                    ctx,
106                    "user must match stored session",
107                ));
108            }
109        }
110        Ok(())
111    }
112
113    fn ensure_ctx_preserved(existing: &TenantCtx, candidate: &TenantCtx) -> SessionResult<()> {
114        if existing.env != candidate.env || existing.tenant_id != candidate.tenant_id {
115            return Err(Self::ctx_mismatch(
116                existing,
117                candidate,
118                "env/tenant cannot change for an existing session",
119            ));
120        }
121        if Self::normalize_team(existing) != Self::normalize_team(candidate) {
122            return Err(Self::ctx_mismatch(
123                existing,
124                candidate,
125                "team cannot change for an existing session",
126            ));
127        }
128        match (
129            Self::normalize_user(existing),
130            Self::normalize_user(candidate),
131        ) {
132            (Some(a), Some(b)) if a == b => {}
133            (Some(_), Some(_)) | (Some(_), None) => {
134                return Err(Self::ctx_mismatch(
135                    existing,
136                    candidate,
137                    "user cannot change for an existing session",
138                ));
139            }
140            (None, Some(_)) => {
141                return Err(Self::ctx_mismatch(
142                    existing,
143                    candidate,
144                    "user cannot be introduced when none was stored",
145                ));
146            }
147            (None, None) => {}
148        }
149        Ok(())
150    }
151
152    fn ensure_user_matches(
153        ctx: &TenantCtx,
154        user: &UserId,
155        data: &SessionData,
156    ) -> SessionResult<()> {
157        if let Some(ctx_user) = Self::normalize_user(ctx)
158            && ctx_user != user
159        {
160            return Err(invalid_argument(
161                "user must match tenant context when registering a wait",
162            ));
163        }
164        if let Some(stored_user) = Self::normalize_user(&data.tenant_ctx) {
165            if stored_user != user {
166                return Err(invalid_argument(
167                    "user must match session data when registering a wait",
168                ));
169            }
170        } else {
171            return Err(invalid_argument(
172                "user required by wait but missing in session data",
173            ));
174        }
175        Ok(())
176    }
177
178    fn user_lookup_key(ctx: &TenantCtx, user: &UserId) -> UserLookupKey {
179        UserLookupKey::from_ctx(ctx, user)
180    }
181
182    fn scope_lookup_key(ctx: &TenantCtx, user: &UserId, scope: &ReplyScope) -> ScopeLookupKey {
183        ScopeLookupKey::from_ctx(ctx, user, scope)
184    }
185
186    fn remove_from_user_waits(&self, lookup: &UserLookupKey, key: &SessionKey) {
187        let mut waits = self.user_waits.write();
188        if let Some(entries) = waits.get_mut(lookup) {
189            entries.remove(key);
190            if entries.is_empty() {
191                waits.remove(lookup);
192            }
193        }
194    }
195
196    fn remove_scope_entry(&self, scope_key: &ScopeLookupKey) -> Option<ScopeEntry> {
197        self.scope_index.write().remove(scope_key)
198    }
199
200    fn purge_expired_session(&self, key: &SessionKey, entry: SessionEntry) {
201        if let Some(user_lookup) = &entry.wait_user {
202            self.remove_from_user_waits(user_lookup, key);
203        }
204        if let Some(scope_key) = &entry.scope_key {
205            self.remove_scope_entry(scope_key);
206        }
207    }
208}
209
210impl SessionStore for InMemorySessionStore {
211    fn create_session(&self, ctx: &TenantCtx, data: SessionData) -> SessionResult<SessionKey> {
212        Self::ensure_alignment(ctx, &data)?;
213        let key = Self::next_key();
214        let entry = SessionEntry {
215            data: data.clone(),
216            expires_at: None,
217            wait_user: None,
218            scope_key: None,
219        };
220        self.sessions.write().insert(key.clone(), entry);
221        Ok(key)
222    }
223
224    fn get_session(&self, key: &SessionKey) -> SessionResult<Option<SessionData>> {
225        let mut sessions = self.sessions.write();
226        let Some(entry) = sessions.get(key).cloned() else {
227            return Ok(None);
228        };
229        if Self::is_expired(entry.expires_at) {
230            sessions.remove(key);
231            drop(sessions);
232            self.purge_expired_session(key, entry);
233            return Ok(None);
234        }
235        Ok(Some(entry.data))
236    }
237
238    fn update_session(&self, key: &SessionKey, data: SessionData) -> SessionResult<()> {
239        let mut sessions = self.sessions.write();
240        let Some(previous) = sessions.get(key).cloned() else {
241            return Err(not_found(key));
242        };
243        Self::ensure_ctx_preserved(&previous.data.tenant_ctx, &data.tenant_ctx)?;
244        let entry = SessionEntry {
245            data: data.clone(),
246            expires_at: previous.expires_at,
247            wait_user: previous.wait_user.clone(),
248            scope_key: previous.scope_key.clone(),
249        };
250        sessions.insert(key.clone(), entry);
251        Ok(())
252    }
253
254    fn remove_session(&self, key: &SessionKey) -> SessionResult<()> {
255        if let Some(old) = self.sessions.write().remove(key) {
256            self.purge_expired_session(key, old);
257            Ok(())
258        } else {
259            Err(not_found(key))
260        }
261    }
262
263    fn register_wait(
264        &self,
265        ctx: &TenantCtx,
266        user_id: &UserId,
267        scope: &ReplyScope,
268        session_key: &SessionKey,
269        data: SessionData,
270        ttl: Option<Duration>,
271    ) -> SessionResult<()> {
272        Self::ensure_alignment(ctx, &data)?;
273        Self::ensure_user_matches(ctx, user_id, &data)?;
274        let user_lookup = Self::user_lookup_key(ctx, user_id);
275        let scope_key = Self::scope_lookup_key(ctx, user_id, scope);
276        let expires_at = Self::ttl_deadline(ttl);
277
278        let existing = self.sessions.read().get(session_key).cloned();
279        if let Some(existing) = &existing {
280            Self::ensure_ctx_preserved(&existing.data.tenant_ctx, &data.tenant_ctx)?;
281            if let Some(existing_user) = &existing.wait_user {
282                self.remove_from_user_waits(existing_user, session_key);
283            }
284            if let Some(existing_scope) = &existing.scope_key {
285                self.remove_scope_entry(existing_scope);
286            }
287        }
288        let entry = SessionEntry {
289            data,
290            expires_at,
291            wait_user: Some(user_lookup.clone()),
292            scope_key: Some(scope_key.clone()),
293        };
294        self.sessions.write().insert(session_key.clone(), entry);
295
296        let mut waits = self.user_waits.write();
297        waits
298            .entry(user_lookup.clone())
299            .or_default()
300            .insert(session_key.clone());
301        drop(waits);
302
303        let mut scopes = self.scope_index.write();
304        if let Some(existing) = scopes.get(&scope_key)
305            && existing.session_key != *session_key
306        {
307            self.remove_from_user_waits(&user_lookup, &existing.session_key);
308        }
309        scopes.insert(
310            scope_key,
311            ScopeEntry {
312                session_key: session_key.clone(),
313                expires_at,
314            },
315        );
316        Ok(())
317    }
318
319    fn find_wait_by_scope(
320        &self,
321        ctx: &TenantCtx,
322        user_id: &UserId,
323        scope: &ReplyScope,
324    ) -> SessionResult<Option<SessionKey>> {
325        let scope_key = Self::scope_lookup_key(ctx, user_id, scope);
326        let entry = self.scope_index.read().get(&scope_key).cloned();
327        let Some(entry) = entry else {
328            return Ok(None);
329        };
330        if Self::is_expired(entry.expires_at) {
331            self.remove_scope_entry(&scope_key);
332            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
333            let removed = self.sessions.write().remove(&entry.session_key);
334            if let Some(session_entry) = removed {
335                self.purge_expired_session(&entry.session_key, session_entry);
336            }
337            return Ok(None);
338        }
339        let Some(session) = self.get_session(&entry.session_key)? else {
340            self.remove_scope_entry(&scope_key);
341            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
342            return Ok(None);
343        };
344        let stored_ctx = &session.tenant_ctx;
345        if stored_ctx.env != ctx.env
346            || stored_ctx.tenant_id != ctx.tenant_id
347            || Self::normalize_team(stored_ctx) != Self::normalize_team(ctx)
348        {
349            self.remove_scope_entry(&scope_key);
350            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
351            return Ok(None);
352        }
353        if let Some(stored_user) = Self::normalize_user(stored_ctx)
354            && stored_user != user_id
355        {
356            self.remove_scope_entry(&scope_key);
357            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
358            return Ok(None);
359        }
360        Ok(Some(entry.session_key))
361    }
362
363    fn list_waits_for_user(
364        &self,
365        ctx: &TenantCtx,
366        user_id: &UserId,
367    ) -> SessionResult<Vec<SessionKey>> {
368        let lookup = UserLookupKey::from_ctx(ctx, user_id);
369        let keys: Vec<SessionKey> = self
370            .user_waits
371            .read()
372            .get(&lookup)
373            .map(|set| set.iter().cloned().collect())
374            .unwrap_or_default();
375        let mut available = Vec::new();
376        for key in keys {
377            let Some(data) = self.get_session(&key)? else {
378                self.remove_from_user_waits(&lookup, &key);
379                continue;
380            };
381            let stored_ctx = &data.tenant_ctx;
382            if stored_ctx.env != ctx.env
383                || stored_ctx.tenant_id != ctx.tenant_id
384                || Self::normalize_team(stored_ctx) != Self::normalize_team(ctx)
385            {
386                self.remove_from_user_waits(&lookup, &key);
387                continue;
388            }
389            if let Some(stored_user) = Self::normalize_user(stored_ctx)
390                && stored_user != user_id
391            {
392                self.remove_from_user_waits(&lookup, &key);
393                continue;
394            }
395            available.push(key);
396        }
397        Ok(available)
398    }
399
400    fn clear_wait(
401        &self,
402        ctx: &TenantCtx,
403        user_id: &UserId,
404        scope: &ReplyScope,
405    ) -> SessionResult<()> {
406        let scope_key = Self::scope_lookup_key(ctx, user_id, scope);
407        let entry = self.scope_index.write().remove(&scope_key);
408        if let Some(entry) = entry {
409            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
410            self.sessions.write().remove(&entry.session_key);
411        }
412        Ok(())
413    }
414
415    fn find_by_user(
416        &self,
417        ctx: &TenantCtx,
418        user: &UserId,
419    ) -> SessionResult<Option<(SessionKey, SessionData)>> {
420        let waits = self.list_waits_for_user(ctx, user)?;
421        match waits.len() {
422            0 => Ok(None),
423            1 => {
424                let key = waits.into_iter().next().expect("single wait entry");
425                let data = self.get_session(&key)?.ok_or_else(|| not_found(&key))?;
426                Ok(Some((key, data)))
427            }
428            _ => Err(invalid_argument(
429                "multiple waits exist for user; use scope-based routing instead",
430            )),
431        }
432    }
433}
434
435#[derive(Clone)]
436struct SessionEntry {
437    data: SessionData,
438    expires_at: Option<Instant>,
439    wait_user: Option<UserLookupKey>,
440    scope_key: Option<ScopeLookupKey>,
441}
442
443#[derive(Clone)]
444struct ScopeEntry {
445    session_key: SessionKey,
446    expires_at: Option<Instant>,
447}
448
449#[derive(Clone, Eq, PartialEq, Hash)]
450struct ScopeLookupKey {
451    env: EnvId,
452    tenant: TenantId,
453    team: Option<TeamId>,
454    user: UserId,
455    scope_hash: String,
456}
457
458impl ScopeLookupKey {
459    fn from_ctx(ctx: &TenantCtx, user: &UserId, scope: &ReplyScope) -> Self {
460        Self {
461            env: ctx.env.clone(),
462            tenant: ctx.tenant_id.clone(),
463            team: ctx.team_id.clone().or_else(|| ctx.team.clone()),
464            user: user.clone(),
465            scope_hash: scope.scope_hash(),
466        }
467    }
468}
469
470#[derive(Clone, Eq, PartialEq, Hash)]
471struct UserLookupKey {
472    env: EnvId,
473    tenant: TenantId,
474    team: Option<TeamId>,
475    user: UserId,
476}
477
478impl UserLookupKey {
479    fn from_ctx(ctx: &TenantCtx, user: &UserId) -> Self {
480        Self {
481            env: ctx.env.clone(),
482            tenant: ctx.tenant_id.clone(),
483            team: ctx.team_id.clone().or_else(|| ctx.team.clone()),
484            user: user.clone(),
485        }
486    }
487}