Skip to main content

greentic_session/
inmemory.rs

1use crate::ReplyScope;
2use crate::error::SessionResult;
3use crate::error::{GreenticError, invalid_argument, not_found};
4use crate::store::SessionStore;
5use greentic_types::{EnvId, SessionData, SessionKey, TeamId, TenantCtx, TenantId, UserId};
6use parking_lot::RwLock;
7use std::collections::{HashMap, HashSet};
8use std::time::{Duration, Instant};
9use uuid::Uuid;
10
11/// Simple in-memory implementation backed by hash maps.
12pub struct InMemorySessionStore {
13    sessions: RwLock<HashMap<SessionKey, SessionEntry>>,
14    user_waits: RwLock<HashMap<UserLookupKey, HashSet<SessionKey>>>,
15    scope_index: RwLock<HashMap<ScopeLookupKey, ScopeEntry>>,
16}
17
18impl Default for InMemorySessionStore {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl InMemorySessionStore {
25    /// Constructs an empty store.
26    pub fn new() -> Self {
27        Self {
28            sessions: RwLock::new(HashMap::new()),
29            user_waits: RwLock::new(HashMap::new()),
30            scope_index: RwLock::new(HashMap::new()),
31        }
32    }
33
34    fn next_key() -> SessionKey {
35        SessionKey::new(Uuid::new_v4().to_string())
36    }
37
38    fn normalize_team(ctx: &TenantCtx) -> Option<&TeamId> {
39        ctx.team_id.as_ref().or(ctx.team.as_ref())
40    }
41
42    fn normalize_user(ctx: &TenantCtx) -> Option<&UserId> {
43        ctx.user_id.as_ref().or(ctx.user.as_ref())
44    }
45
46    fn ttl_deadline(ttl: Option<Duration>) -> Option<Instant> {
47        ttl.map(|value| Instant::now() + value)
48    }
49
50    fn is_expired(deadline: Option<Instant>) -> bool {
51        deadline
52            .map(|value| Instant::now() >= value)
53            .unwrap_or(false)
54    }
55
56    fn ctx_mismatch(expected: &TenantCtx, provided: &TenantCtx, reason: &str) -> GreenticError {
57        let expected_team = Self::normalize_team(expected)
58            .map(|t| t.as_str())
59            .unwrap_or("-");
60        let provided_team = Self::normalize_team(provided)
61            .map(|t| t.as_str())
62            .unwrap_or("-");
63        let expected_user = Self::normalize_user(expected)
64            .map(|u| u.as_str())
65            .unwrap_or("-");
66        let provided_user = Self::normalize_user(provided)
67            .map(|u| u.as_str())
68            .unwrap_or("-");
69        invalid_argument(format!(
70            "tenant context mismatch ({reason}): expected env={}, tenant={}, team={}, user={}, got env={}, tenant={}, team={}, user={}",
71            expected.env.as_str(),
72            expected.tenant_id.as_str(),
73            expected_team,
74            expected_user,
75            provided.env.as_str(),
76            provided.tenant_id.as_str(),
77            provided_team,
78            provided_user
79        ))
80    }
81
82    fn ensure_alignment(ctx: &TenantCtx, data: &SessionData) -> SessionResult<()> {
83        let stored = &data.tenant_ctx;
84        if ctx.env != stored.env || ctx.tenant_id != stored.tenant_id {
85            return Err(Self::ctx_mismatch(stored, ctx, "env/tenant must match"));
86        }
87        if Self::normalize_team(ctx) != Self::normalize_team(stored) {
88            return Err(Self::ctx_mismatch(stored, ctx, "team must match"));
89        }
90        if let Some(stored_user) = Self::normalize_user(stored) {
91            let Some(provided_user) = Self::normalize_user(ctx) else {
92                return Err(Self::ctx_mismatch(
93                    stored,
94                    ctx,
95                    "user required by session but missing in caller context",
96                ));
97            };
98            if stored_user != provided_user {
99                return Err(Self::ctx_mismatch(
100                    stored,
101                    ctx,
102                    "user must match stored session",
103                ));
104            }
105        }
106        Ok(())
107    }
108
109    fn ensure_ctx_preserved(existing: &TenantCtx, candidate: &TenantCtx) -> SessionResult<()> {
110        if existing.env != candidate.env || existing.tenant_id != candidate.tenant_id {
111            return Err(Self::ctx_mismatch(
112                existing,
113                candidate,
114                "env/tenant cannot change for an existing session",
115            ));
116        }
117        if Self::normalize_team(existing) != Self::normalize_team(candidate) {
118            return Err(Self::ctx_mismatch(
119                existing,
120                candidate,
121                "team cannot change for an existing session",
122            ));
123        }
124        match (
125            Self::normalize_user(existing),
126            Self::normalize_user(candidate),
127        ) {
128            (Some(a), Some(b)) if a == b => {}
129            (Some(_), Some(_)) | (Some(_), None) => {
130                return Err(Self::ctx_mismatch(
131                    existing,
132                    candidate,
133                    "user cannot change for an existing session",
134                ));
135            }
136            (None, Some(_)) => {
137                return Err(Self::ctx_mismatch(
138                    existing,
139                    candidate,
140                    "user cannot be introduced when none was stored",
141                ));
142            }
143            (None, None) => {}
144        }
145        Ok(())
146    }
147
148    fn ensure_user_matches(
149        ctx: &TenantCtx,
150        user: &UserId,
151        data: &SessionData,
152    ) -> SessionResult<()> {
153        if let Some(ctx_user) = Self::normalize_user(ctx)
154            && ctx_user != user
155        {
156            return Err(invalid_argument(
157                "user must match tenant context when registering a wait",
158            ));
159        }
160        if let Some(stored_user) = Self::normalize_user(&data.tenant_ctx) {
161            if stored_user != user {
162                return Err(invalid_argument(
163                    "user must match session data when registering a wait",
164                ));
165            }
166        } else {
167            return Err(invalid_argument(
168                "user required by wait but missing in session data",
169            ));
170        }
171        Ok(())
172    }
173
174    fn user_lookup_key(ctx: &TenantCtx, user: &UserId) -> UserLookupKey {
175        UserLookupKey::from_ctx(ctx, user)
176    }
177
178    fn scope_lookup_key(ctx: &TenantCtx, user: &UserId, scope: &ReplyScope) -> ScopeLookupKey {
179        ScopeLookupKey::from_ctx(ctx, user, scope)
180    }
181
182    fn remove_from_user_waits(&self, lookup: &UserLookupKey, key: &SessionKey) {
183        let mut waits = self.user_waits.write();
184        if let Some(entries) = waits.get_mut(lookup) {
185            entries.remove(key);
186            if entries.is_empty() {
187                waits.remove(lookup);
188            }
189        }
190    }
191
192    fn remove_scope_entry(&self, scope_key: &ScopeLookupKey) -> Option<ScopeEntry> {
193        self.scope_index.write().remove(scope_key)
194    }
195
196    fn purge_expired_session(&self, key: &SessionKey, entry: SessionEntry) {
197        if let Some(user_lookup) = &entry.wait_user {
198            self.remove_from_user_waits(user_lookup, key);
199        }
200        if let Some(scope_key) = &entry.scope_key {
201            self.remove_scope_entry(scope_key);
202        }
203    }
204}
205
206impl SessionStore for InMemorySessionStore {
207    fn create_session(&self, ctx: &TenantCtx, data: SessionData) -> SessionResult<SessionKey> {
208        Self::ensure_alignment(ctx, &data)?;
209        let key = Self::next_key();
210        let entry = SessionEntry {
211            data: data.clone(),
212            expires_at: None,
213            wait_user: None,
214            scope_key: None,
215        };
216        self.sessions.write().insert(key.clone(), entry);
217        Ok(key)
218    }
219
220    fn get_session(&self, key: &SessionKey) -> SessionResult<Option<SessionData>> {
221        let mut sessions = self.sessions.write();
222        let Some(entry) = sessions.get(key).cloned() else {
223            return Ok(None);
224        };
225        if Self::is_expired(entry.expires_at) {
226            sessions.remove(key);
227            drop(sessions);
228            self.purge_expired_session(key, entry);
229            return Ok(None);
230        }
231        Ok(Some(entry.data))
232    }
233
234    fn update_session(&self, key: &SessionKey, data: SessionData) -> SessionResult<()> {
235        let mut sessions = self.sessions.write();
236        let Some(previous) = sessions.get(key).cloned() else {
237            return Err(not_found(key));
238        };
239        Self::ensure_ctx_preserved(&previous.data.tenant_ctx, &data.tenant_ctx)?;
240        let entry = SessionEntry {
241            data: data.clone(),
242            expires_at: previous.expires_at,
243            wait_user: previous.wait_user.clone(),
244            scope_key: previous.scope_key.clone(),
245        };
246        sessions.insert(key.clone(), entry);
247        Ok(())
248    }
249
250    fn remove_session(&self, key: &SessionKey) -> SessionResult<()> {
251        if let Some(old) = self.sessions.write().remove(key) {
252            self.purge_expired_session(key, old);
253            Ok(())
254        } else {
255            Err(not_found(key))
256        }
257    }
258
259    fn register_wait(
260        &self,
261        ctx: &TenantCtx,
262        user_id: &UserId,
263        scope: &ReplyScope,
264        session_key: &SessionKey,
265        data: SessionData,
266        ttl: Option<Duration>,
267    ) -> SessionResult<()> {
268        Self::ensure_alignment(ctx, &data)?;
269        Self::ensure_user_matches(ctx, user_id, &data)?;
270        let user_lookup = Self::user_lookup_key(ctx, user_id);
271        let scope_key = Self::scope_lookup_key(ctx, user_id, scope);
272        let expires_at = Self::ttl_deadline(ttl);
273
274        let existing = self.sessions.read().get(session_key).cloned();
275        if let Some(existing) = &existing {
276            Self::ensure_ctx_preserved(&existing.data.tenant_ctx, &data.tenant_ctx)?;
277            if let Some(existing_user) = &existing.wait_user {
278                self.remove_from_user_waits(existing_user, session_key);
279            }
280            if let Some(existing_scope) = &existing.scope_key {
281                self.remove_scope_entry(existing_scope);
282            }
283        }
284        let entry = SessionEntry {
285            data,
286            expires_at,
287            wait_user: Some(user_lookup.clone()),
288            scope_key: Some(scope_key.clone()),
289        };
290        self.sessions.write().insert(session_key.clone(), entry);
291
292        let mut waits = self.user_waits.write();
293        waits
294            .entry(user_lookup.clone())
295            .or_default()
296            .insert(session_key.clone());
297        drop(waits);
298
299        let mut scopes = self.scope_index.write();
300        if let Some(existing) = scopes.get(&scope_key)
301            && existing.session_key != *session_key
302        {
303            self.remove_from_user_waits(&user_lookup, &existing.session_key);
304        }
305        scopes.insert(
306            scope_key,
307            ScopeEntry {
308                session_key: session_key.clone(),
309                expires_at,
310            },
311        );
312        Ok(())
313    }
314
315    fn find_wait_by_scope(
316        &self,
317        ctx: &TenantCtx,
318        user_id: &UserId,
319        scope: &ReplyScope,
320    ) -> SessionResult<Option<SessionKey>> {
321        let scope_key = Self::scope_lookup_key(ctx, user_id, scope);
322        let entry = self.scope_index.read().get(&scope_key).cloned();
323        let Some(entry) = entry else {
324            return Ok(None);
325        };
326        if Self::is_expired(entry.expires_at) {
327            self.remove_scope_entry(&scope_key);
328            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
329            let removed = self.sessions.write().remove(&entry.session_key);
330            if let Some(session_entry) = removed {
331                self.purge_expired_session(&entry.session_key, session_entry);
332            }
333            return Ok(None);
334        }
335        let Some(session) = self.get_session(&entry.session_key)? else {
336            self.remove_scope_entry(&scope_key);
337            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
338            return Ok(None);
339        };
340        let stored_ctx = &session.tenant_ctx;
341        if stored_ctx.env != ctx.env
342            || stored_ctx.tenant_id != ctx.tenant_id
343            || Self::normalize_team(stored_ctx) != Self::normalize_team(ctx)
344        {
345            self.remove_scope_entry(&scope_key);
346            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
347            return Ok(None);
348        }
349        if let Some(stored_user) = Self::normalize_user(stored_ctx)
350            && stored_user != user_id
351        {
352            self.remove_scope_entry(&scope_key);
353            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
354            return Ok(None);
355        }
356        Ok(Some(entry.session_key))
357    }
358
359    fn list_waits_for_user(
360        &self,
361        ctx: &TenantCtx,
362        user_id: &UserId,
363    ) -> SessionResult<Vec<SessionKey>> {
364        let lookup = UserLookupKey::from_ctx(ctx, user_id);
365        let keys: Vec<SessionKey> = self
366            .user_waits
367            .read()
368            .get(&lookup)
369            .map(|set| set.iter().cloned().collect())
370            .unwrap_or_default();
371        let mut available = Vec::new();
372        for key in keys {
373            let Some(data) = self.get_session(&key)? else {
374                self.remove_from_user_waits(&lookup, &key);
375                continue;
376            };
377            let stored_ctx = &data.tenant_ctx;
378            if stored_ctx.env != ctx.env
379                || stored_ctx.tenant_id != ctx.tenant_id
380                || Self::normalize_team(stored_ctx) != Self::normalize_team(ctx)
381            {
382                self.remove_from_user_waits(&lookup, &key);
383                continue;
384            }
385            if let Some(stored_user) = Self::normalize_user(stored_ctx)
386                && stored_user != user_id
387            {
388                self.remove_from_user_waits(&lookup, &key);
389                continue;
390            }
391            available.push(key);
392        }
393        Ok(available)
394    }
395
396    fn clear_wait(
397        &self,
398        ctx: &TenantCtx,
399        user_id: &UserId,
400        scope: &ReplyScope,
401    ) -> SessionResult<()> {
402        let scope_key = Self::scope_lookup_key(ctx, user_id, scope);
403        let entry = self.scope_index.write().remove(&scope_key);
404        if let Some(entry) = entry {
405            self.remove_from_user_waits(&UserLookupKey::from_ctx(ctx, user_id), &entry.session_key);
406            self.sessions.write().remove(&entry.session_key);
407        }
408        Ok(())
409    }
410
411    fn find_by_user(
412        &self,
413        ctx: &TenantCtx,
414        user: &UserId,
415    ) -> SessionResult<Option<(SessionKey, SessionData)>> {
416        let waits = self.list_waits_for_user(ctx, user)?;
417        match waits.len() {
418            0 => Ok(None),
419            1 => {
420                let key = waits.into_iter().next().expect("single wait entry");
421                let data = self.get_session(&key)?.ok_or_else(|| not_found(&key))?;
422                Ok(Some((key, data)))
423            }
424            _ => Err(invalid_argument(
425                "multiple waits exist for user; use scope-based routing instead",
426            )),
427        }
428    }
429}
430
431#[derive(Clone)]
432struct SessionEntry {
433    data: SessionData,
434    expires_at: Option<Instant>,
435    wait_user: Option<UserLookupKey>,
436    scope_key: Option<ScopeLookupKey>,
437}
438
439#[derive(Clone)]
440struct ScopeEntry {
441    session_key: SessionKey,
442    expires_at: Option<Instant>,
443}
444
445#[derive(Clone, Eq, PartialEq, Hash)]
446struct ScopeLookupKey {
447    env: EnvId,
448    tenant: TenantId,
449    team: Option<TeamId>,
450    user: UserId,
451    scope_hash: String,
452}
453
454impl ScopeLookupKey {
455    fn from_ctx(ctx: &TenantCtx, user: &UserId, scope: &ReplyScope) -> Self {
456        Self {
457            env: ctx.env.clone(),
458            tenant: ctx.tenant_id.clone(),
459            team: ctx.team_id.clone().or_else(|| ctx.team.clone()),
460            user: user.clone(),
461            scope_hash: scope.scope_hash(),
462        }
463    }
464}
465
466#[derive(Clone, Eq, PartialEq, Hash)]
467struct UserLookupKey {
468    env: EnvId,
469    tenant: TenantId,
470    team: Option<TeamId>,
471    user: UserId,
472}
473
474impl UserLookupKey {
475    fn from_ctx(ctx: &TenantCtx, user: &UserId) -> Self {
476        Self {
477            env: ctx.env.clone(),
478            tenant: ctx.tenant_id.clone(),
479            team: ctx.team_id.clone().or_else(|| ctx.team.clone()),
480            user: user.clone(),
481        }
482    }
483}