1use std::sync::Arc;
2
3use chrono::{DateTime, Duration, Utc};
4use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
5use secrecy::ExposeSecret;
6use serde::{Deserialize, Serialize};
7use sqlx::{Row, SqlitePool};
8use uuid::Uuid;
9
10use crate::config::GuardConfig;
11use crate::error::{GuardError, GuardResult};
12use crate::types::GuardSession;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct ListOptions {
17 pub limit: u64,
19 pub offset: u64,
21}
22
23impl Default for ListOptions {
24 fn default() -> Self {
25 Self {
26 limit: 50,
27 offset: 0,
28 }
29 }
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct ListPage<T> {
35 pub items: Vec<T>,
37 pub next_offset: Option<u64>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct SessionRecord {
44 pub id: Uuid,
46 pub agent_id: Uuid,
48 pub workspace_id: Uuid,
50 pub role: String,
52 pub scopes: Vec<String>,
54 pub created_at: DateTime<Utc>,
56 pub expires_at: DateTime<Utc>,
58 pub revoked: bool,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63struct SessionClaims {
64 sub: String,
65 agent_id: String,
66 workspace_id: String,
67 role: String,
68 scopes: Vec<String>,
69 exp: usize,
70}
71
72#[derive(Clone)]
74pub struct SessionManager {
75 pool: SqlitePool,
76 config: Arc<GuardConfig>,
77}
78
79impl SessionManager {
80 pub fn new(pool: SqlitePool, config: Arc<GuardConfig>) -> Self {
82 Self { pool, config }
83 }
84
85 pub async fn create_session(
87 &self,
88 agent_id: Uuid,
89 workspace_id: Uuid,
90 role: &str,
91 scopes: Vec<String>,
92 ttl_secs: u64,
93 ) -> GuardResult<GuardSession> {
94 let id = Uuid::new_v4();
95 let created_at = Utc::now();
96 let expires_at = created_at + Duration::seconds(ttl_secs as i64);
97 let claims = SessionClaims {
98 sub: id.to_string(),
99 agent_id: agent_id.to_string(),
100 workspace_id: workspace_id.to_string(),
101 role: role.to_owned(),
102 scopes: scopes.clone(),
103 exp: expires_at.timestamp() as usize,
104 };
105 let token = encode(
106 &Header::new(Algorithm::HS256),
107 &claims,
108 &EncodingKey::from_secret(self.config.jwt_secret.expose_secret().as_bytes()),
109 )?;
110
111 sqlx::query(
112 "INSERT INTO sessions (id, agent_id, workspace_id, role, scopes, created_at, expires_at, revoked)
113 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 0)",
114 )
115 .bind(id.to_string())
116 .bind(agent_id.to_string())
117 .bind(workspace_id.to_string())
118 .bind(role)
119 .bind(serde_json::to_string(&scopes)?)
120 .bind(created_at.timestamp_millis())
121 .bind(expires_at.timestamp_millis())
122 .execute(&self.pool)
123 .await?;
124
125 Ok(GuardSession {
126 id,
127 agent_id,
128 workspace_id,
129 role: role.to_owned(),
130 scopes,
131 expires_at,
132 token,
133 })
134 }
135
136 pub async fn validate_session(&self, token: &str) -> GuardResult<GuardSession> {
138 let mut validation = Validation::new(Algorithm::HS256);
139 validation.validate_exp = true;
140 let token_data = decode::<SessionClaims>(
141 token,
142 &DecodingKey::from_secret(self.config.jwt_secret.expose_secret().as_bytes()),
143 &validation,
144 )
145 .map_err(|_| GuardError::InvalidToken)?;
146
147 let claims = token_data.claims;
148 let session_id = Uuid::parse_str(&claims.sub)?;
149 let row = sqlx::query(
150 "SELECT id, agent_id, workspace_id, role, scopes, created_at, expires_at, revoked
151 FROM sessions WHERE id = ?1",
152 )
153 .bind(session_id.to_string())
154 .fetch_optional(&self.pool)
155 .await?
156 .ok_or(GuardError::InvalidToken)?;
157
158 let record = row_to_session_record(&row)?;
159 if record.revoked {
160 return Err(GuardError::SessionRevoked);
161 }
162 if record.expires_at <= Utc::now() {
163 return Err(GuardError::SessionExpired);
164 }
165
166 Ok(GuardSession {
167 id: record.id,
168 agent_id: record.agent_id,
169 workspace_id: record.workspace_id,
170 role: record.role,
171 scopes: record.scopes,
172 expires_at: record.expires_at,
173 token: token.to_owned(),
174 })
175 }
176
177 pub async fn revoke_session(&self, session_id: Uuid) -> GuardResult<()> {
179 sqlx::query("UPDATE sessions SET revoked = 1 WHERE id = ?1")
180 .bind(session_id.to_string())
181 .execute(&self.pool)
182 .await?;
183 Ok(())
184 }
185
186 pub async fn list_active_sessions(
188 &self,
189 agent_id: Uuid,
190 opts: Option<ListOptions>,
191 ) -> GuardResult<ListPage<SessionRecord>> {
192 let opts = opts.unwrap_or_default();
193 let rows = sqlx::query(
194 "SELECT id, agent_id, workspace_id, role, scopes, created_at, expires_at, revoked
195 FROM sessions
196 WHERE agent_id = ?1 AND revoked = 0 AND expires_at > ?2
197 ORDER BY created_at DESC
198 LIMIT ?3 OFFSET ?4",
199 )
200 .bind(agent_id.to_string())
201 .bind(Utc::now().timestamp_millis())
202 .bind(opts.limit as i64)
203 .bind(opts.offset as i64)
204 .fetch_all(&self.pool)
205 .await?;
206 let items = rows
207 .iter()
208 .map(row_to_session_record)
209 .collect::<GuardResult<Vec<_>>>()?;
210 let next_offset = (items.len() as u64 == opts.limit).then_some(opts.offset + opts.limit);
211
212 Ok(ListPage { items, next_offset })
213 }
214
215 pub async fn assert_session_active(&self, session: &GuardSession) -> GuardResult<()> {
217 let row = sqlx::query("SELECT revoked, expires_at FROM sessions WHERE id = ?1")
218 .bind(session.id.to_string())
219 .fetch_optional(&self.pool)
220 .await?
221 .ok_or(GuardError::InvalidToken)?;
222 let revoked = row.try_get::<i64, _>("revoked")? != 0;
223 if revoked {
224 return Err(GuardError::SessionRevoked);
225 }
226 let expires_at = from_ms(row.try_get("expires_at")?)?;
227 if expires_at <= Utc::now() {
228 return Err(GuardError::SessionExpired);
229 }
230 Ok(())
231 }
232}
233
234fn row_to_session_record(row: &sqlx::sqlite::SqliteRow) -> GuardResult<SessionRecord> {
235 Ok(SessionRecord {
236 id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
237 agent_id: Uuid::parse_str(&row.try_get::<String, _>("agent_id")?)?,
238 workspace_id: Uuid::parse_str(&row.try_get::<String, _>("workspace_id")?)?,
239 role: row.try_get("role")?,
240 scopes: serde_json::from_str(&row.try_get::<String, _>("scopes")?)?,
241 created_at: from_ms(row.try_get("created_at")?)?,
242 expires_at: from_ms(row.try_get("expires_at")?)?,
243 revoked: row.try_get::<i64, _>("revoked")? != 0,
244 })
245}
246
247fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
248 DateTime::from_timestamp_millis(value)
249 .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
250}