Skip to main content

claw_guard/
session.rs

1use std::sync::Arc;
2
3use chrono::{DateTime, Duration, Utc};
4use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
5use secrecy::ExposeSecret;
6use serde::{Deserialize, Serialize};
7use sqlx::{Row, SqlitePool};
8use uuid::Uuid;
9
10use crate::config::GuardConfig;
11use crate::error::{GuardError, GuardResult};
12use crate::types::GuardSession;
13
14/// Pagination options for list APIs.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct ListOptions {
17    /// Maximum number of items to return.
18    pub limit: u64,
19    /// Number of items to skip.
20    pub offset: u64,
21}
22
23impl Default for ListOptions {
24    fn default() -> Self {
25        Self {
26            limit: 50,
27            offset: 0,
28        }
29    }
30}
31
32/// Page of list results.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ListPage<T> {
35    /// Returned items.
36    pub items: Vec<T>,
37    /// Offset for the next page if more results exist.
38    pub next_offset: Option<u64>,
39}
40
41/// Session row returned by list operations.
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct SessionRecord {
44    /// Session identifier.
45    pub id: Uuid,
46    /// Agent identifier.
47    pub agent_id: Uuid,
48    /// Workspace identifier.
49    pub workspace_id: Uuid,
50    /// Role associated with the session.
51    pub role: String,
52    /// Session scopes.
53    pub scopes: Vec<String>,
54    /// Session creation time.
55    pub created_at: DateTime<Utc>,
56    /// Session expiration time.
57    pub expires_at: DateTime<Utc>,
58    /// Revocation flag.
59    pub revoked: bool,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63struct SessionClaims {
64    sub: String,
65    agent_id: String,
66    workspace_id: String,
67    role: String,
68    scopes: Vec<String>,
69    exp: usize,
70}
71
72/// Issues, validates, revokes, and lists guard sessions.
73#[derive(Clone)]
74pub struct SessionManager {
75    pool: SqlitePool,
76    config: Arc<GuardConfig>,
77}
78
79impl SessionManager {
80    /// Creates a session manager backed by the provided database pool.
81    pub fn new(pool: SqlitePool, config: Arc<GuardConfig>) -> Self {
82        Self { pool, config }
83    }
84
85    /// Creates a new JWT-backed session and stores it in SQLite.
86    pub async fn create_session(
87        &self,
88        agent_id: Uuid,
89        workspace_id: Uuid,
90        role: &str,
91        scopes: Vec<String>,
92        ttl_secs: u64,
93    ) -> GuardResult<GuardSession> {
94        let id = Uuid::new_v4();
95        let created_at = Utc::now();
96        let expires_at = created_at + Duration::seconds(ttl_secs as i64);
97        let claims = SessionClaims {
98            sub: id.to_string(),
99            agent_id: agent_id.to_string(),
100            workspace_id: workspace_id.to_string(),
101            role: role.to_owned(),
102            scopes: scopes.clone(),
103            exp: expires_at.timestamp() as usize,
104        };
105        let token = encode(
106            &Header::new(Algorithm::HS256),
107            &claims,
108            &EncodingKey::from_secret(self.config.jwt_secret.expose_secret().as_bytes()),
109        )?;
110
111        sqlx::query(
112            "INSERT INTO sessions (id, agent_id, workspace_id, role, scopes, created_at, expires_at, revoked)
113             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 0)",
114        )
115        .bind(id.to_string())
116        .bind(agent_id.to_string())
117        .bind(workspace_id.to_string())
118        .bind(role)
119        .bind(serde_json::to_string(&scopes)?)
120        .bind(created_at.timestamp_millis())
121        .bind(expires_at.timestamp_millis())
122        .execute(&self.pool)
123        .await?;
124
125        Ok(GuardSession {
126            id,
127            agent_id,
128            workspace_id,
129            role: role.to_owned(),
130            scopes,
131            expires_at,
132            token,
133        })
134    }
135
136    /// Validates a session JWT and confirms the backing row remains active.
137    pub async fn validate_session(&self, token: &str) -> GuardResult<GuardSession> {
138        let mut validation = Validation::new(Algorithm::HS256);
139        validation.validate_exp = true;
140        let token_data = decode::<SessionClaims>(
141            token,
142            &DecodingKey::from_secret(self.config.jwt_secret.expose_secret().as_bytes()),
143            &validation,
144        )
145        .map_err(|_| GuardError::InvalidToken)?;
146
147        let claims = token_data.claims;
148        let session_id = Uuid::parse_str(&claims.sub)?;
149        let row = sqlx::query(
150            "SELECT id, agent_id, workspace_id, role, scopes, created_at, expires_at, revoked
151             FROM sessions WHERE id = ?1",
152        )
153        .bind(session_id.to_string())
154        .fetch_optional(&self.pool)
155        .await?
156        .ok_or(GuardError::InvalidToken)?;
157
158        let record = row_to_session_record(&row)?;
159        if record.revoked {
160            return Err(GuardError::SessionRevoked);
161        }
162        if record.expires_at <= Utc::now() {
163            return Err(GuardError::SessionExpired);
164        }
165
166        Ok(GuardSession {
167            id: record.id,
168            agent_id: record.agent_id,
169            workspace_id: record.workspace_id,
170            role: record.role,
171            scopes: record.scopes,
172            expires_at: record.expires_at,
173            token: token.to_owned(),
174        })
175    }
176
177    /// Revokes a session row.
178    pub async fn revoke_session(&self, session_id: Uuid) -> GuardResult<()> {
179        sqlx::query("UPDATE sessions SET revoked = 1 WHERE id = ?1")
180            .bind(session_id.to_string())
181            .execute(&self.pool)
182            .await?;
183        Ok(())
184    }
185
186    /// Lists active sessions for an agent.
187    pub async fn list_active_sessions(
188        &self,
189        agent_id: Uuid,
190        opts: Option<ListOptions>,
191    ) -> GuardResult<ListPage<SessionRecord>> {
192        let opts = opts.unwrap_or_default();
193        let rows = sqlx::query(
194            "SELECT id, agent_id, workspace_id, role, scopes, created_at, expires_at, revoked
195             FROM sessions
196             WHERE agent_id = ?1 AND revoked = 0 AND expires_at > ?2
197             ORDER BY created_at DESC
198             LIMIT ?3 OFFSET ?4",
199        )
200        .bind(agent_id.to_string())
201        .bind(Utc::now().timestamp_millis())
202        .bind(opts.limit as i64)
203        .bind(opts.offset as i64)
204        .fetch_all(&self.pool)
205        .await?;
206        let items = rows
207            .iter()
208            .map(row_to_session_record)
209            .collect::<GuardResult<Vec<_>>>()?;
210        let next_offset = (items.len() as u64 == opts.limit).then_some(opts.offset + opts.limit);
211
212        Ok(ListPage { items, next_offset })
213    }
214
215    /// Confirms that a session row is still active.
216    pub async fn assert_session_active(&self, session: &GuardSession) -> GuardResult<()> {
217        let row = sqlx::query("SELECT revoked, expires_at FROM sessions WHERE id = ?1")
218            .bind(session.id.to_string())
219            .fetch_optional(&self.pool)
220            .await?
221            .ok_or(GuardError::InvalidToken)?;
222        let revoked = row.try_get::<i64, _>("revoked")? != 0;
223        if revoked {
224            return Err(GuardError::SessionRevoked);
225        }
226        let expires_at = from_ms(row.try_get("expires_at")?)?;
227        if expires_at <= Utc::now() {
228            return Err(GuardError::SessionExpired);
229        }
230        Ok(())
231    }
232}
233
234fn row_to_session_record(row: &sqlx::sqlite::SqliteRow) -> GuardResult<SessionRecord> {
235    Ok(SessionRecord {
236        id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
237        agent_id: Uuid::parse_str(&row.try_get::<String, _>("agent_id")?)?,
238        workspace_id: Uuid::parse_str(&row.try_get::<String, _>("workspace_id")?)?,
239        role: row.try_get("role")?,
240        scopes: serde_json::from_str(&row.try_get::<String, _>("scopes")?)?,
241        created_at: from_ms(row.try_get("created_at")?)?,
242        expires_at: from_ms(row.try_get("expires_at")?)?,
243        revoked: row.try_get::<i64, _>("revoked")? != 0,
244    })
245}
246
247fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
248    DateTime::from_timestamp_millis(value)
249        .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
250}