greentic_session/
redis_store.rs

1use crate::error::{redis_error, serde_error};
2use crate::model::{Cas, Session, SessionKey};
3use crate::store::SessionStore;
4use greentic_types::GResult;
5use redis::{Commands, Connection, RedisResult, Script, Value};
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9const DEFAULT_NAMESPACE: &str = "greentic:session";
10const LOOKUP_SEGMENT: &str = "lookup";
11
12static UPDATE_LUA: &str = r#"
13local key = KEYS[1]
14local expected = tonumber(ARGV[1])
15local payload = ARGV[2]
16local ttl = tonumber(ARGV[3])
17local new_cas = tonumber(ARGV[4])
18local existing = redis.call("GET", key)
19if not existing then
20  return {0, 0}
21end
22local doc = cjson.decode(existing)
23local current = tonumber(doc.cas or 0)
24if current ~= expected then
25  return {1, current or 0}
26end
27redis.call("SET", key, payload)
28if ttl and ttl > 0 then
29  redis.call("EXPIRE", key, ttl)
30else
31  redis.call("PERSIST", key)
32end
33return {2, new_cas}
34"#;
35
36#[derive(Clone)]
37pub struct RedisSessionStore {
38    client: redis::Client,
39    namespace: String,
40    update_script: Script,
41}
42
43impl RedisSessionStore {
44    /// Creates a store using the default namespace prefix.
45    pub fn new(client: redis::Client) -> Self {
46        Self::with_namespace(client, DEFAULT_NAMESPACE)
47    }
48
49    /// Creates a store with a custom namespace prefix.
50    pub fn with_namespace(client: redis::Client, namespace: impl Into<String>) -> Self {
51        Self {
52            client,
53            namespace: namespace.into(),
54            update_script: Script::new(UPDATE_LUA),
55        }
56    }
57
58    fn connection(&self) -> GResult<Connection> {
59        self.client.get_connection().map_err(redis_error)
60    }
61
62    fn data_key(&self, tenant_id: &str, key: &SessionKey) -> String {
63        format!("{}:{}:{}", self.namespace, tenant_id, key.as_str())
64    }
65
66    fn lookup_key(&self, key: &SessionKey) -> String {
67        format!("{}:{}:{}", self.namespace, LOOKUP_SEGMENT, key.as_str())
68    }
69
70    fn resolve_tenant(&self, conn: &mut Connection, key: &SessionKey) -> GResult<Option<String>> {
71        let lookup_key = self.lookup_key(key);
72        conn.get(&lookup_key).map_err(redis_error)
73    }
74
75    fn load_envelope(&self, conn: &mut Connection, key: &str) -> GResult<Option<SessionEnvelope>> {
76        let payload: Option<String> = conn.get(key).map_err(redis_error)?;
77        let envelope = payload
78            .map(|raw| serde_json::from_str(&raw))
79            .transpose()
80            .map_err(serde_error)?;
81        Ok(envelope)
82    }
83
84    fn serialize_envelope(envelope: &SessionEnvelope) -> GResult<String> {
85        serde_json::to_string(envelope).map_err(serde_error)
86    }
87
88    fn ttl_arg(session: &Session) -> i64 {
89        i64::from(session.ttl_secs)
90    }
91
92    fn set_payload(
93        conn: &mut Connection,
94        key: &str,
95        payload: &str,
96        ttl_secs: u32,
97    ) -> RedisResult<()> {
98        if ttl_secs > 0 {
99            let _: Value = redis::cmd("SET")
100                .arg(key)
101                .arg(payload)
102                .arg("EX")
103                .arg(ttl_secs)
104                .query(conn)?;
105        } else {
106            let _: Value = redis::cmd("SET").arg(key).arg(payload).query(conn)?;
107        }
108        Ok(())
109    }
110
111    fn sync_lookup(
112        &self,
113        conn: &mut Connection,
114        key: &SessionKey,
115        tenant_id: &str,
116        ttl_secs: u32,
117    ) -> RedisResult<()> {
118        let lookup_key = self.lookup_key(key);
119        if ttl_secs > 0 {
120            let _: Value = redis::cmd("SET")
121                .arg(&lookup_key)
122                .arg(tenant_id)
123                .arg("EX")
124                .arg(ttl_secs)
125                .query(conn)?;
126        } else {
127            let _: Value = redis::cmd("SET")
128                .arg(&lookup_key)
129                .arg(tenant_id)
130                .query(conn)?;
131        }
132        Ok(())
133    }
134
135    fn touch_lookup(
136        &self,
137        conn: &mut Connection,
138        key: &SessionKey,
139        ttl_secs: u32,
140    ) -> RedisResult<()> {
141        let lookup_key = self.lookup_key(key);
142        if ttl_secs > 0 {
143            let _: i64 = redis::cmd("EXPIRE")
144                .arg(&lookup_key)
145                .arg(ttl_secs)
146                .query(conn)?;
147        } else {
148            let _: i64 = redis::cmd("PERSIST").arg(&lookup_key).query(conn)?;
149        }
150        Ok(())
151    }
152
153    fn purge_lookup(&self, conn: &mut Connection, key: &SessionKey) {
154        let lookup_key = self.lookup_key(key);
155        let _ = redis::cmd("DEL").arg(&lookup_key).query::<i64>(conn);
156    }
157}
158
159impl SessionStore for RedisSessionStore {
160    fn get(&self, key: &SessionKey) -> GResult<Option<(Session, Cas)>> {
161        let mut conn = self.connection()?;
162        let Some(tenant_id) = self.resolve_tenant(&mut conn, key)? else {
163            return Ok(None);
164        };
165        let redis_key = self.data_key(&tenant_id, key);
166        if let Some(envelope) = self.load_envelope(&mut conn, &redis_key)? {
167            return Ok(Some((envelope.session, Cas::from(envelope.cas))));
168        } else {
169            self.purge_lookup(&mut conn, key);
170        }
171        Ok(None)
172    }
173
174    fn put(&self, mut session: Session) -> GResult<Cas> {
175        let mut conn = self.connection()?;
176        let tenant_id = session.tenant_id().to_owned();
177        let redis_key = self.data_key(&tenant_id, &session.key);
178        let now = OffsetDateTime::now_utc();
179        session.updated_at = now;
180        session.normalize();
181
182        let existing_cas = self
183            .load_envelope(&mut conn, &redis_key)?
184            .map(|envelope| Cas::from(envelope.cas).next());
185        let cas = existing_cas.unwrap_or_else(Cas::initial);
186        let envelope = SessionEnvelope::new(session, cas);
187        let payload = Self::serialize_envelope(&envelope)?;
188        Self::set_payload(&mut conn, &redis_key, &payload, envelope.session.ttl_secs)
189            .map_err(redis_error)?;
190        self.sync_lookup(
191            &mut conn,
192            &envelope.session.key,
193            &tenant_id,
194            envelope.session.ttl_secs,
195        )
196        .map_err(redis_error)?;
197        Ok(cas)
198    }
199
200    fn update_cas(&self, mut session: Session, expected: Cas) -> GResult<Result<Cas, Cas>> {
201        let mut conn = self.connection()?;
202        let tenant_id = session.tenant_id().to_owned();
203        let redis_key = self.data_key(&tenant_id, &session.key);
204        let now = OffsetDateTime::now_utc();
205        session.updated_at = now;
206        session.normalize();
207
208        let new_cas = expected.next();
209        let envelope = SessionEnvelope::new(session, new_cas);
210        let payload = Self::serialize_envelope(&envelope)?;
211        let ttl = Self::ttl_arg(&envelope.session);
212
213        let (status, cas_value): (i64, u64) = self
214            .update_script
215            .key(redis_key.clone())
216            .arg(expected.value() as i64)
217            .arg(payload)
218            .arg(ttl)
219            .arg(new_cas.value() as i64)
220            .invoke(&mut conn)
221            .map_err(redis_error)?;
222
223        match status {
224            0 => {
225                self.purge_lookup(&mut conn, &envelope.session.key);
226                Ok(Err(Cas::none()))
227            }
228            1 => Ok(Err(Cas::from(cas_value))),
229            2 => {
230                self.touch_lookup(&mut conn, &envelope.session.key, envelope.session.ttl_secs)
231                    .map_err(redis_error)?;
232                Ok(Ok(new_cas))
233            }
234            _ => Ok(Err(Cas::none())),
235        }
236    }
237
238    fn delete(&self, key: &SessionKey) -> GResult<bool> {
239        let mut conn = self.connection()?;
240        let Some(tenant_id) = self.resolve_tenant(&mut conn, key)? else {
241            return Ok(false);
242        };
243        let redis_key = self.data_key(&tenant_id, key);
244        let lookup_key = self.lookup_key(key);
245        let removed: i64 = conn.del(&redis_key).map_err(redis_error)?;
246        conn.del::<_, i64>(&lookup_key).map_err(redis_error)?;
247        Ok(removed > 0)
248    }
249
250    fn touch(&self, key: &SessionKey, ttl_secs: Option<u32>) -> GResult<bool> {
251        let mut conn = self.connection()?;
252        let Some(tenant_id) = self.resolve_tenant(&mut conn, key)? else {
253            return Ok(false);
254        };
255        let redis_key = self.data_key(&tenant_id, key);
256        let Some(mut envelope) = self.load_envelope(&mut conn, &redis_key)? else {
257            self.purge_lookup(&mut conn, key);
258            return Ok(false);
259        };
260
261        let now = OffsetDateTime::now_utc();
262        envelope.session.updated_at = now;
263        if let Some(ttl) = ttl_secs {
264            envelope.session.ttl_secs = ttl;
265        }
266
267        let payload = Self::serialize_envelope(&envelope)?;
268        let ttl = Self::ttl_arg(&envelope.session);
269
270        let (status, _): (i64, u64) = self
271            .update_script
272            .key(redis_key.clone())
273            .arg(envelope.cas)
274            .arg(payload)
275            .arg(ttl)
276            .arg(envelope.cas)
277            .invoke(&mut conn)
278            .map_err(redis_error)?;
279
280        if status == 2 {
281            self.touch_lookup(&mut conn, key, envelope.session.ttl_secs)
282                .map_err(redis_error)?;
283            Ok(true)
284        } else {
285            if status == 0 {
286                self.purge_lookup(&mut conn, key);
287            }
288            Ok(false)
289        }
290    }
291}
292
293#[derive(Serialize, Deserialize)]
294struct SessionEnvelope {
295    cas: u64,
296    session: Session,
297}
298
299impl SessionEnvelope {
300    fn new(mut session: Session, cas: Cas) -> Self {
301        session.normalize();
302        Self {
303            cas: cas.value(),
304            session,
305        }
306    }
307}