1use crate::error::{redis_error, serde_error};
2use crate::model::{Cas, Session, SessionKey};
3use crate::store::SessionStore;
4use greentic_types::GResult;
5use redis::{Commands, Connection, RedisResult, Script, Value};
6use serde::{Deserialize, Serialize};
7use time::OffsetDateTime;
8
9const DEFAULT_NAMESPACE: &str = "greentic:session";
10const LOOKUP_SEGMENT: &str = "lookup";
11
12static UPDATE_LUA: &str = r#"
13local key = KEYS[1]
14local expected = tonumber(ARGV[1])
15local payload = ARGV[2]
16local ttl = tonumber(ARGV[3])
17local new_cas = tonumber(ARGV[4])
18local existing = redis.call("GET", key)
19if not existing then
20 return {0, 0}
21end
22local doc = cjson.decode(existing)
23local current = tonumber(doc.cas or 0)
24if current ~= expected then
25 return {1, current or 0}
26end
27redis.call("SET", key, payload)
28if ttl and ttl > 0 then
29 redis.call("EXPIRE", key, ttl)
30else
31 redis.call("PERSIST", key)
32end
33return {2, new_cas}
34"#;
35
36#[derive(Clone)]
37pub struct RedisSessionStore {
38 client: redis::Client,
39 namespace: String,
40 update_script: Script,
41}
42
43impl RedisSessionStore {
44 pub fn new(client: redis::Client) -> Self {
46 Self::with_namespace(client, DEFAULT_NAMESPACE)
47 }
48
49 pub fn with_namespace(client: redis::Client, namespace: impl Into<String>) -> Self {
51 Self {
52 client,
53 namespace: namespace.into(),
54 update_script: Script::new(UPDATE_LUA),
55 }
56 }
57
58 fn connection(&self) -> GResult<Connection> {
59 self.client.get_connection().map_err(redis_error)
60 }
61
62 fn data_key(&self, tenant_id: &str, key: &SessionKey) -> String {
63 format!("{}:{}:{}", self.namespace, tenant_id, key.as_str())
64 }
65
66 fn lookup_key(&self, key: &SessionKey) -> String {
67 format!("{}:{}:{}", self.namespace, LOOKUP_SEGMENT, key.as_str())
68 }
69
70 fn resolve_tenant(&self, conn: &mut Connection, key: &SessionKey) -> GResult<Option<String>> {
71 let lookup_key = self.lookup_key(key);
72 conn.get(&lookup_key).map_err(redis_error)
73 }
74
75 fn load_envelope(&self, conn: &mut Connection, key: &str) -> GResult<Option<SessionEnvelope>> {
76 let payload: Option<String> = conn.get(key).map_err(redis_error)?;
77 let envelope = payload
78 .map(|raw| serde_json::from_str(&raw))
79 .transpose()
80 .map_err(serde_error)?;
81 Ok(envelope)
82 }
83
84 fn serialize_envelope(envelope: &SessionEnvelope) -> GResult<String> {
85 serde_json::to_string(envelope).map_err(serde_error)
86 }
87
88 fn ttl_arg(session: &Session) -> i64 {
89 i64::from(session.ttl_secs)
90 }
91
92 fn set_payload(
93 conn: &mut Connection,
94 key: &str,
95 payload: &str,
96 ttl_secs: u32,
97 ) -> RedisResult<()> {
98 if ttl_secs > 0 {
99 let _: Value = redis::cmd("SET")
100 .arg(key)
101 .arg(payload)
102 .arg("EX")
103 .arg(ttl_secs)
104 .query(conn)?;
105 } else {
106 let _: Value = redis::cmd("SET").arg(key).arg(payload).query(conn)?;
107 }
108 Ok(())
109 }
110
111 fn sync_lookup(
112 &self,
113 conn: &mut Connection,
114 key: &SessionKey,
115 tenant_id: &str,
116 ttl_secs: u32,
117 ) -> RedisResult<()> {
118 let lookup_key = self.lookup_key(key);
119 if ttl_secs > 0 {
120 let _: Value = redis::cmd("SET")
121 .arg(&lookup_key)
122 .arg(tenant_id)
123 .arg("EX")
124 .arg(ttl_secs)
125 .query(conn)?;
126 } else {
127 let _: Value = redis::cmd("SET")
128 .arg(&lookup_key)
129 .arg(tenant_id)
130 .query(conn)?;
131 }
132 Ok(())
133 }
134
135 fn touch_lookup(
136 &self,
137 conn: &mut Connection,
138 key: &SessionKey,
139 ttl_secs: u32,
140 ) -> RedisResult<()> {
141 let lookup_key = self.lookup_key(key);
142 if ttl_secs > 0 {
143 let _: i64 = redis::cmd("EXPIRE")
144 .arg(&lookup_key)
145 .arg(ttl_secs)
146 .query(conn)?;
147 } else {
148 let _: i64 = redis::cmd("PERSIST").arg(&lookup_key).query(conn)?;
149 }
150 Ok(())
151 }
152
153 fn purge_lookup(&self, conn: &mut Connection, key: &SessionKey) {
154 let lookup_key = self.lookup_key(key);
155 let _ = redis::cmd("DEL").arg(&lookup_key).query::<i64>(conn);
156 }
157}
158
159impl SessionStore for RedisSessionStore {
160 fn get(&self, key: &SessionKey) -> GResult<Option<(Session, Cas)>> {
161 let mut conn = self.connection()?;
162 let Some(tenant_id) = self.resolve_tenant(&mut conn, key)? else {
163 return Ok(None);
164 };
165 let redis_key = self.data_key(&tenant_id, key);
166 if let Some(envelope) = self.load_envelope(&mut conn, &redis_key)? {
167 return Ok(Some((envelope.session, Cas::from(envelope.cas))));
168 } else {
169 self.purge_lookup(&mut conn, key);
170 }
171 Ok(None)
172 }
173
174 fn put(&self, mut session: Session) -> GResult<Cas> {
175 let mut conn = self.connection()?;
176 let tenant_id = session.tenant_id().to_owned();
177 let redis_key = self.data_key(&tenant_id, &session.key);
178 let now = OffsetDateTime::now_utc();
179 session.updated_at = now;
180 session.normalize();
181
182 let existing_cas = self
183 .load_envelope(&mut conn, &redis_key)?
184 .map(|envelope| Cas::from(envelope.cas).next());
185 let cas = existing_cas.unwrap_or_else(Cas::initial);
186 let envelope = SessionEnvelope::new(session, cas);
187 let payload = Self::serialize_envelope(&envelope)?;
188 Self::set_payload(&mut conn, &redis_key, &payload, envelope.session.ttl_secs)
189 .map_err(redis_error)?;
190 self.sync_lookup(
191 &mut conn,
192 &envelope.session.key,
193 &tenant_id,
194 envelope.session.ttl_secs,
195 )
196 .map_err(redis_error)?;
197 Ok(cas)
198 }
199
200 fn update_cas(&self, mut session: Session, expected: Cas) -> GResult<Result<Cas, Cas>> {
201 let mut conn = self.connection()?;
202 let tenant_id = session.tenant_id().to_owned();
203 let redis_key = self.data_key(&tenant_id, &session.key);
204 let now = OffsetDateTime::now_utc();
205 session.updated_at = now;
206 session.normalize();
207
208 let new_cas = expected.next();
209 let envelope = SessionEnvelope::new(session, new_cas);
210 let payload = Self::serialize_envelope(&envelope)?;
211 let ttl = Self::ttl_arg(&envelope.session);
212
213 let (status, cas_value): (i64, u64) = self
214 .update_script
215 .key(redis_key.clone())
216 .arg(expected.value() as i64)
217 .arg(payload)
218 .arg(ttl)
219 .arg(new_cas.value() as i64)
220 .invoke(&mut conn)
221 .map_err(redis_error)?;
222
223 match status {
224 0 => {
225 self.purge_lookup(&mut conn, &envelope.session.key);
226 Ok(Err(Cas::none()))
227 }
228 1 => Ok(Err(Cas::from(cas_value))),
229 2 => {
230 self.touch_lookup(&mut conn, &envelope.session.key, envelope.session.ttl_secs)
231 .map_err(redis_error)?;
232 Ok(Ok(new_cas))
233 }
234 _ => Ok(Err(Cas::none())),
235 }
236 }
237
238 fn delete(&self, key: &SessionKey) -> GResult<bool> {
239 let mut conn = self.connection()?;
240 let Some(tenant_id) = self.resolve_tenant(&mut conn, key)? else {
241 return Ok(false);
242 };
243 let redis_key = self.data_key(&tenant_id, key);
244 let lookup_key = self.lookup_key(key);
245 let removed: i64 = conn.del(&redis_key).map_err(redis_error)?;
246 conn.del::<_, i64>(&lookup_key).map_err(redis_error)?;
247 Ok(removed > 0)
248 }
249
250 fn touch(&self, key: &SessionKey, ttl_secs: Option<u32>) -> GResult<bool> {
251 let mut conn = self.connection()?;
252 let Some(tenant_id) = self.resolve_tenant(&mut conn, key)? else {
253 return Ok(false);
254 };
255 let redis_key = self.data_key(&tenant_id, key);
256 let Some(mut envelope) = self.load_envelope(&mut conn, &redis_key)? else {
257 self.purge_lookup(&mut conn, key);
258 return Ok(false);
259 };
260
261 let now = OffsetDateTime::now_utc();
262 envelope.session.updated_at = now;
263 if let Some(ttl) = ttl_secs {
264 envelope.session.ttl_secs = ttl;
265 }
266
267 let payload = Self::serialize_envelope(&envelope)?;
268 let ttl = Self::ttl_arg(&envelope.session);
269
270 let (status, _): (i64, u64) = self
271 .update_script
272 .key(redis_key.clone())
273 .arg(envelope.cas)
274 .arg(payload)
275 .arg(ttl)
276 .arg(envelope.cas)
277 .invoke(&mut conn)
278 .map_err(redis_error)?;
279
280 if status == 2 {
281 self.touch_lookup(&mut conn, key, envelope.session.ttl_secs)
282 .map_err(redis_error)?;
283 Ok(true)
284 } else {
285 if status == 0 {
286 self.purge_lookup(&mut conn, key);
287 }
288 Ok(false)
289 }
290 }
291}
292
293#[derive(Serialize, Deserialize)]
294struct SessionEnvelope {
295 cas: u64,
296 session: Session,
297}
298
299impl SessionEnvelope {
300 fn new(mut session: Session, cas: Cas) -> Self {
301 session.normalize();
302 Self {
303 cas: cas.value(),
304 session,
305 }
306 }
307}