claw-guard 0.1.0

Security and policy engine for ClawDB
Documentation
use std::sync::Arc;

use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use sqlx::{Row, SqlitePool};
use uuid::Uuid;

use crate::config::GuardConfig;
use crate::error::{GuardError, GuardResult};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ClawSession {
    pub session_id: Uuid,
    pub agent_id: Uuid,
    pub role: String,
    pub scopes: Vec<String>,
    pub expires_at: DateTime<Utc>,
    pub token: String,
}

#[derive(Debug, Clone, PartialEq)]
pub struct SessionInfo {
    pub session_id: Uuid,
    pub agent_id: Uuid,
    pub role: String,
    pub scopes: Vec<String>,
    pub created_at: DateTime<Utc>,
    pub expires_at: DateTime<Utc>,
}

#[derive(Debug, Clone, PartialEq)]
pub struct PaginatedSessions {
    pub sessions: Vec<SessionInfo>,
    pub next_offset: Option<u64>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct SessionClaims {
    session_id: String,
    agent_id: String,
    role: String,
    scopes: Vec<String>,
    exp: usize,
}

#[derive(Clone)]
pub struct SessionManager {
    pool: SqlitePool,
    config: Arc<GuardConfig>,
}

impl SessionManager {
    pub fn new(pool: SqlitePool, config: Arc<GuardConfig>) -> Self {
        Self { pool, config }
    }

    pub async fn create_session(
        &self,
        agent_id: Uuid,
        role: &str,
        scopes: Vec<String>,
        ttl_secs: i64,
    ) -> GuardResult<ClawSession> {
        let session_id = Uuid::new_v4();
        let role_id = self.ensure_role(role, &scopes).await?;
        let created_at = Utc::now();
        let expires_at = created_at + Duration::seconds(ttl_secs);
        let claims = SessionClaims {
            session_id: session_id.to_string(),
            agent_id: agent_id.to_string(),
            role: role.to_owned(),
            scopes: scopes.clone(),
            exp: expires_at.timestamp() as usize,
        };
        let token = encode(
            &Header::new(Algorithm::HS256),
            &claims,
            &EncodingKey::from_secret(self.config.jwt_secret.as_bytes()),
        )?;

        sqlx::query(
            "INSERT INTO sessions (id, agent_id, role_id, scopes, created_at, expires_at, revoked)
             VALUES (?1, ?2, ?3, ?4, ?5, ?6, 0)",
        )
        .bind(session_id.to_string())
        .bind(agent_id.to_string())
        .bind(role_id.to_string())
        .bind(serde_json::to_string(&scopes)?)
        .bind(created_at)
        .bind(expires_at)
        .execute(&self.pool)
        .await?;

        Ok(ClawSession {
            session_id,
            agent_id,
            role: role.to_owned(),
            scopes,
            expires_at,
            token,
        })
    }

    pub async fn validate_session(&self, token: &str) -> GuardResult<ClawSession> {
        let mut validation = Validation::new(Algorithm::HS256);
        validation.validate_exp = true;
        let token_data = decode::<SessionClaims>(
            token,
            &DecodingKey::from_secret(self.config.jwt_secret.as_bytes()),
            &validation,
        )?;
        let claims = token_data.claims;
        let session_id = Uuid::parse_str(&claims.session_id)?;
        let row = sqlx::query(
            "SELECT s.agent_id, s.scopes, s.expires_at, s.revoked, r.name AS role_name
             FROM sessions s
             JOIN roles r ON r.id = s.role_id
             WHERE s.id = ?1",
        )
        .bind(session_id.to_string())
        .fetch_optional(&self.pool)
        .await?
        .ok_or(GuardError::SessionNotFound(session_id))?;

        let revoked: bool = row.try_get("revoked")?;
        if revoked {
            return Err(GuardError::SessionRevoked(session_id));
        }

        let expires_at: DateTime<Utc> = row.try_get("expires_at")?;
        if expires_at < Utc::now() {
            return Err(GuardError::SessionExpired {
                session_id,
                expired_at: expires_at,
            });
        }

        let scopes: String = row.try_get("scopes")?;
        let scopes = serde_json::from_str(&scopes)?;
        let agent_id = Uuid::parse_str(&row.try_get::<String, _>("agent_id")?)?;
        let role: String = row.try_get("role_name")?;

        Ok(ClawSession {
            session_id,
            agent_id,
            role,
            scopes,
            expires_at,
            token: token.to_owned(),
        })
    }

    pub async fn revoke_session(&self, session_id: Uuid) -> GuardResult<()> {
        let result = sqlx::query("UPDATE sessions SET revoked = 1 WHERE id = ?1")
            .bind(session_id.to_string())
            .execute(&self.pool)
            .await?;

        if result.rows_affected() == 0 {
            return Err(GuardError::SessionNotFound(session_id));
        }

        Ok(())
    }

    pub async fn list_active_sessions(
        &self,
        agent_id: Uuid,
        limit: u64,
        offset: u64,
    ) -> GuardResult<PaginatedSessions> {
        let rows = sqlx::query(
            "SELECT s.id, s.agent_id, s.scopes, s.created_at, s.expires_at, r.name AS role_name
             FROM sessions s
             JOIN roles r ON r.id = s.role_id
             WHERE s.agent_id = ?1 AND s.revoked = 0 AND s.expires_at > CURRENT_TIMESTAMP
             ORDER BY s.created_at DESC
             LIMIT ?2 OFFSET ?3",
        )
        .bind(agent_id.to_string())
        .bind(limit as i64)
        .bind(offset as i64)
        .fetch_all(&self.pool)
        .await?;

        let sessions = rows
            .into_iter()
            .map(|row| {
                let session_id = Uuid::parse_str(&row.try_get::<String, _>("id")?)?;
                let scopes: String = row.try_get("scopes")?;
                Ok(SessionInfo {
                    session_id,
                    agent_id: Uuid::parse_str(&row.try_get::<String, _>("agent_id")?)?,
                    role: row.try_get("role_name")?,
                    scopes: serde_json::from_str(&scopes)?,
                    created_at: row.try_get("created_at")?,
                    expires_at: row.try_get("expires_at")?,
                })
            })
            .collect::<GuardResult<Vec<_>>>()?;

        let next_offset = (!sessions.is_empty() && sessions.len() as u64 == limit).then_some(offset + limit);
        Ok(PaginatedSessions { sessions, next_offset })
    }

    async fn ensure_role(&self, role: &str, scopes: &[String]) -> GuardResult<Uuid> {
        if let Some(existing) = sqlx::query("SELECT id FROM roles WHERE name = ?1")
            .bind(role)
            .fetch_optional(&self.pool)
            .await?
        {
            let id: String = existing.try_get("id")?;
            return Ok(Uuid::parse_str(&id)?);
        }

        let role_id = Uuid::new_v4();
        sqlx::query("INSERT INTO roles (id, name, description, scopes) VALUES (?1, ?2, NULL, ?3)")
            .bind(role_id.to_string())
            .bind(role)
            .bind(serde_json::to_string(scopes)?)
            .execute(&self.pool)
            .await?;
        Ok(role_id)
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;

    use sqlx::{sqlite::SqlitePoolOptions, Executor};

    use super::*;
    use crate::config::GuardConfig;

    async fn setup() -> (SessionManager, SqlitePool) {
        let pool = SqlitePoolOptions::new()
            .max_connections(1)
            .connect("sqlite::memory:")
            .await
            .expect("in-memory db should connect");
        pool.execute(
            "CREATE TABLE roles (id UUID PRIMARY KEY, name TEXT UNIQUE, description TEXT, scopes TEXT[]);
             CREATE TABLE sessions (id UUID PRIMARY KEY, agent_id UUID, role_id UUID, scopes TEXT[], created_at TIMESTAMPTZ, expires_at TIMESTAMPTZ, revoked BOOL);",
        )
        .await
        .expect("tables should be created");
        let config = Arc::new(GuardConfig {
            db_path: ":memory:".to_owned(),
            jwt_secret: crate::config::ZeroizeString::new("secret"),
            policy_dir: "policies".into(),
            tls_cert_path: "certs/server.crt".into(),
            tls_key_path: "certs/server.key".into(),
            risk_thresholds: crate::config::RiskThresholds::default(),
            sensitive_resources: Vec::new(),
            audit_flush_interval_ms: 100,
            audit_batch_size: 500,
        });
        (SessionManager::new(pool.clone(), config), pool)
    }

    #[tokio::test]
    async fn session_jwt_round_trip() {
        let (manager, _) = setup().await;
        let agent_id = Uuid::new_v4();
        let session = manager
            .create_session(agent_id, "analyst", vec!["tool:*".to_owned()], 60)
            .await
            .expect("session should be created");
        let validated = manager
            .validate_session(&session.token)
            .await
            .expect("session should validate");
        assert_eq!(validated.session_id, session.session_id);

        manager
            .revoke_session(session.session_id)
            .await
            .expect("session should be revoked");
        assert!(matches!(
            manager.validate_session(&session.token).await,
            Err(GuardError::SessionRevoked(_))
        ));
    }
}