arbit 0.7.0

Security proxy for MCP (Model Context Protocol) — auth, rate limiting, payload filtering, and audit logging between AI agents and MCP servers
Documentation
use super::{AuditEntry, AuditLog, Outcome};
use async_trait::async_trait;
use rusqlite::{Connection, params};
use std::sync::{Arc, Mutex};
use std::time::SystemTime;
use tokio::sync::mpsc;

pub struct SqliteAudit {
    tx: Arc<Mutex<Option<mpsc::UnboundedSender<Arc<AuditEntry>>>>>,
    handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
}

impl SqliteAudit {
    pub fn new(path: &str) -> anyhow::Result<Self> {
        Self::with_rotation(path, None, None)
    }

    pub fn with_rotation(
        path: &str,
        max_entries: Option<usize>,
        max_age_days: Option<u64>,
    ) -> anyhow::Result<Self> {
        let conn = Connection::open(path)?;
        conn.execute_batch(
            "CREATE TABLE IF NOT EXISTS audit_log (
                id        INTEGER PRIMARY KEY AUTOINCREMENT,
                ts        INTEGER NOT NULL,
                agent_id  TEXT    NOT NULL,
                method    TEXT    NOT NULL,
                tool      TEXT,
                arguments TEXT,
                outcome   TEXT    NOT NULL,
                reason    TEXT
            );",
        )?;
        // Migrate existing databases that don't have the arguments column yet.
        let _ = conn.execute_batch("ALTER TABLE audit_log ADD COLUMN arguments TEXT;");
        let conn = Arc::new(Mutex::new(conn));
        let (tx, mut rx) = mpsc::unbounded_channel::<Arc<AuditEntry>>();

        let handle = tokio::spawn(async move {
            while let Some(entry) = rx.recv().await {
                let ts = entry
                    .ts
                    .duration_since(SystemTime::UNIX_EPOCH)
                    .unwrap_or_default()
                    .as_secs() as i64;

                let (outcome_str, reason) = match &entry.outcome {
                    Outcome::Allowed => ("allowed".to_string(), None),
                    Outcome::Blocked(r) => ("blocked".to_string(), Some(r.clone())),
                    Outcome::Forwarded => ("forwarded".to_string(), None),
                    Outcome::Shadowed => ("shadowed".to_string(), None),
                };

                // Emit structured log for every persisted entry
                match &entry.outcome {
                    Outcome::Allowed => tracing::info!(
                        outcome = "allowed",
                        agent = %entry.agent_id,
                        method = %entry.method,
                        tool = entry.tool.as_deref().unwrap_or("-"),
                    ),
                    Outcome::Blocked(r) => tracing::info!(
                        outcome = "blocked",
                        agent = %entry.agent_id,
                        method = %entry.method,
                        tool = entry.tool.as_deref().unwrap_or("-"),
                        reason = %r,
                    ),
                    Outcome::Forwarded => tracing::info!(
                        outcome = "forwarded",
                        agent = %entry.agent_id,
                        method = %entry.method,
                    ),
                    Outcome::Shadowed => tracing::info!(
                        outcome = "shadowed",
                        agent = %entry.agent_id,
                        method = %entry.method,
                        tool = entry.tool.as_deref().unwrap_or("-"),
                    ),
                }

                let conn = conn.clone();
                tokio::task::spawn_blocking(move || {
                    if let Ok(c) = conn.lock() {
                        let args_json = entry
                            .arguments
                            .as_ref()
                            .and_then(|v| serde_json::to_string(v).ok());
                        if let Err(e) = c.execute(
                            "INSERT INTO audit_log (ts, agent_id, method, tool, arguments, outcome, reason)
                             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
                            params![
                                ts,
                                entry.agent_id,
                                entry.method,
                                entry.tool,
                                args_json,
                                outcome_str,
                                reason
                            ],
                        ) {
                            tracing::error!(
                                error = %e,
                                agent = %entry.agent_id,
                                "audit insert failed"
                            );
                        }

                        // Rotate by entry count — keep only the newest N rows
                        if let Some(max) = max_entries
                            && let Err(e) = c.execute(
                                "DELETE FROM audit_log WHERE id NOT IN \
                                 (SELECT id FROM audit_log ORDER BY id DESC LIMIT ?1)",
                                params![max as i64],
                            )
                        {
                            tracing::warn!(error = %e, "audit rotation (max_entries) failed");
                        }
                        // Rotate by age — purge entries older than max_age_days
                        if let Some(days) = max_age_days {
                            let cutoff = SystemTime::now()
                                .duration_since(SystemTime::UNIX_EPOCH)
                                .unwrap_or_default()
                                .as_secs() as i64
                                - (days as i64 * 86400);
                            if let Err(e) =
                                c.execute("DELETE FROM audit_log WHERE ts < ?1", params![cutoff])
                            {
                                tracing::warn!(error = %e, "audit rotation (max_age_days) failed");
                            }
                        }
                    }
                })
                .await
                .ok();
            }
        });

        Ok(Self {
            tx: Arc::new(Mutex::new(Some(tx))),
            handle: Arc::new(Mutex::new(Some(handle))),
        })
    }
}

#[async_trait]
impl AuditLog for SqliteAudit {
    fn record(&self, entry: Arc<AuditEntry>) {
        if let Ok(guard) = self.tx.lock()
            && let Some(tx) = guard.as_ref()
        {
            let _ = tx.send(entry);
        }
    }

    async fn flush(&self) {
        // Drop sender to signal EOF to the worker task
        {
            let mut guard = self.tx.lock().unwrap();
            *guard = None;
        }
        // Await the worker to finish processing all queued entries
        let handle = {
            let mut guard = self.handle.lock().unwrap();
            guard.take()
        };
        if let Some(h) = handle {
            let _ = h.await;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::audit::Outcome;
    use rusqlite::Connection;
    use std::time::{Duration, UNIX_EPOCH};
    use tempfile::NamedTempFile;

    fn entry(outcome: Outcome) -> Arc<AuditEntry> {
        Arc::new(AuditEntry {
            ts: UNIX_EPOCH + Duration::from_secs(1_000_000),
            agent_id: "agent-x".to_string(),
            method: "tools/call".to_string(),
            tool: Some("read_file".to_string()),
            arguments: Some(serde_json::json!({"path": "/tmp/foo"})),
            outcome,
            request_id: "req-1".to_string(),
        })
    }

    fn count_rows(path: &str) -> usize {
        let conn = Connection::open(path).unwrap();
        conn.query_row("SELECT COUNT(*) FROM audit_log", [], |r| r.get::<_, i64>(0))
            .unwrap() as usize
    }

    fn fetch_outcomes(path: &str) -> Vec<String> {
        let conn = Connection::open(path).unwrap();
        let mut stmt = conn
            .prepare("SELECT outcome FROM audit_log ORDER BY id")
            .unwrap();
        stmt.query_map([], |r| r.get::<_, String>(0))
            .unwrap()
            .map(|r| r.unwrap())
            .collect()
    }

    fn fetch_reasons(path: &str) -> Vec<Option<String>> {
        let conn = Connection::open(path).unwrap();
        let mut stmt = conn
            .prepare("SELECT reason FROM audit_log ORDER BY id")
            .unwrap();
        stmt.query_map([], |r| r.get::<_, Option<String>>(0))
            .unwrap()
            .map(|r| r.unwrap())
            .collect()
    }

    #[tokio::test]
    async fn records_are_persisted() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::new(path).unwrap();
        audit.record(entry(Outcome::Allowed));
        audit.flush().await;
        assert_eq!(count_rows(path), 1);
    }

    #[tokio::test]
    async fn outcome_strings_are_correct() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::new(path).unwrap();
        audit.record(entry(Outcome::Allowed));
        audit.record(entry(Outcome::Forwarded));
        audit.record(entry(Outcome::Shadowed));
        audit.record(entry(Outcome::Blocked("denied".to_string())));
        audit.flush().await;
        let outcomes = fetch_outcomes(path);
        assert_eq!(outcomes, ["allowed", "forwarded", "shadowed", "blocked"]);
    }

    #[tokio::test]
    async fn null_reason_stored_for_non_blocked_outcomes() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::new(path).unwrap();
        audit.record(entry(Outcome::Allowed));
        audit.record(entry(Outcome::Forwarded));
        audit.record(entry(Outcome::Shadowed));
        audit.flush().await;
        let reasons = fetch_reasons(path);
        assert!(
            reasons.iter().all(|r| r.is_none()),
            "expected all NULL reasons"
        );
    }

    #[tokio::test]
    async fn blocked_reason_stored() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::new(path).unwrap();
        audit.record(entry(Outcome::Blocked("rate limit".to_string())));
        audit.flush().await;
        let reasons = fetch_reasons(path);
        assert_eq!(reasons, [Some("rate limit".to_string())]);
    }

    #[tokio::test]
    async fn max_entries_rotation_keeps_newest() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::with_rotation(path, Some(3), None).unwrap();
        for _ in 0..6 {
            audit.record(entry(Outcome::Allowed));
        }
        audit.flush().await;
        assert_eq!(count_rows(path), 3, "rotation should keep only 3 rows");
    }

    #[tokio::test]
    async fn max_age_days_rotation_purges_old() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        // max_age_days: 1 — entries at UNIX_EPOCH (1970) are way older than 1 day
        let audit = SqliteAudit::with_rotation(path, None, Some(1)).unwrap();
        audit.record(entry(Outcome::Allowed)); // ts is 1970 — will be purged
        audit.flush().await;
        assert_eq!(count_rows(path), 0, "old entry should have been purged");
    }

    #[tokio::test]
    async fn flush_is_idempotent() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::new(path).unwrap();
        audit.record(entry(Outcome::Allowed));
        audit.flush().await;
        // Second flush should be a no-op and not panic
        audit.flush().await;
        assert_eq!(count_rows(path), 1);
    }

    #[tokio::test]
    async fn multiple_entries_all_persisted() {
        let f = NamedTempFile::new().unwrap();
        let path = f.path().to_str().unwrap();
        let audit = SqliteAudit::new(path).unwrap();
        for _ in 0..10 {
            audit.record(entry(Outcome::Forwarded));
        }
        audit.flush().await;
        assert_eq!(count_rows(path), 10);
    }
}