Skip to main content

agentox_core/checks/
runner.rs

1//! Check runner — orchestrates check execution against an MCP session.
2
3use crate::checks::types::{CheckCategory, CheckResult};
4use crate::client::{HttpSseTransport, McpSession, StdioTransport};
5use crate::protocol::mcp_types::{InitializeResult, Tool};
6use std::time::Duration;
7
8/// Trait that all audit checks implement.
9#[async_trait::async_trait]
10pub trait Check: Send + Sync {
11    /// Unique check ID (e.g., "CONF-001").
12    fn id(&self) -> &str;
13
14    /// Human-readable name.
15    fn name(&self) -> &str;
16
17    /// Category of this check.
18    fn category(&self) -> CheckCategory;
19
20    /// Run the check. May return multiple findings.
21    async fn run(&self, ctx: &mut CheckContext) -> Vec<CheckResult>;
22}
23
24/// Context provided to checks during execution.
25pub struct CheckContext {
26    /// The active MCP session.
27    pub session: McpSession,
28    /// Original connection target used to create reconnectable disposable sessions.
29    pub target: ConnectionTarget,
30    /// Parsed initialize result.
31    pub init_result: Option<InitializeResult>,
32    /// Raw initialize response string.
33    pub raw_init_response: Option<String>,
34    /// Cached tools list.
35    pub tools: Option<Vec<Tool>>,
36    /// Per-request transport timeout.
37    pub request_timeout: Duration,
38}
39
40#[derive(Debug, Clone)]
41pub enum ConnectionTarget {
42    Stdio { command: String },
43    HttpSse { endpoint: String },
44}
45
46impl CheckContext {
47    /// Create a new check context.
48    pub fn new(session: McpSession, target: ConnectionTarget) -> Self {
49        Self {
50            session,
51            target,
52            init_result: None,
53            raw_init_response: None,
54            tools: None,
55            request_timeout: Duration::from_secs(30),
56        }
57    }
58
59    /// Spawn a fresh session without initializing it.
60    pub async fn fresh_session(&self) -> Result<McpSession, crate::error::SessionError> {
61        match &self.target {
62            ConnectionTarget::Stdio { command } => {
63                let mut transport = StdioTransport::spawn_quiet(command)
64                    .await
65                    .map_err(crate::error::SessionError::Transport)?;
66                transport.set_read_timeout(self.request_timeout);
67                Ok(McpSession::new(Box::new(transport)))
68            }
69            ConnectionTarget::HttpSse { endpoint } => {
70                let transport = HttpSseTransport::new(endpoint.clone(), self.request_timeout);
71                Ok(McpSession::new(Box::new(transport)))
72            }
73        }
74    }
75
76    /// Spawn a fresh disposable session for destructive tests.
77    /// The caller is responsible for shutting it down.
78    /// Uses `spawn_quiet` to suppress server stderr noise.
79    pub async fn disposable_session(&self) -> Result<McpSession, crate::error::SessionError> {
80        let mut session = self.fresh_session().await?;
81        session.initialize().await?;
82        Ok(session)
83    }
84}
85
86/// Runs a set of checks against a session.
87pub struct CheckRunner {
88    checks: Vec<Box<dyn Check>>,
89}
90
91impl CheckRunner {
92    pub fn new() -> Self {
93        Self { checks: Vec::new() }
94    }
95
96    /// Register a single check.
97    pub fn register(&mut self, check: Box<dyn Check>) {
98        self.checks.push(check);
99    }
100
101    /// Register all default conformance checks.
102    pub fn register_conformance_checks(&mut self) {
103        use crate::checks::conformance::*;
104        self.register(Box::new(InitializeCapabilities));
105        self.register(Box::new(JsonRpcStructure));
106        self.register(Box::new(ToolsListValid));
107        self.register(Box::new(ToolInputSchemaValid));
108        self.register(Box::new(MalformedRequestHandling));
109        self.register(Box::new(UnknownMethodHandling));
110        self.register(Box::new(ErrorCodeCorrectness));
111        self.register(Box::new(CapabilityNegotiation));
112        self.register(Box::new(ProtocolVersionValidation));
113        self.register(Box::new(InitializedNotificationOrder));
114    }
115
116    /// Register all default security checks.
117    pub fn register_security_checks(&mut self) {
118        use crate::checks::security::*;
119        self.register(Box::new(PromptInjectionEchoSafety));
120        self.register(Box::new(ToolParameterBoundaryValidation));
121        self.register(Box::new(ErrorLeakageDetection));
122        self.register(Box::new(ResourceExhaustionGuardrail));
123    }
124
125    /// Register behavioral checks (reserved for future versions).
126    pub fn register_behavioral_checks(&mut self) {
127        use crate::checks::behavioral::*;
128        self.register(Box::new(IdempotencyBaseline));
129        self.register(Box::new(SchemaOutputAlignment));
130        self.register(Box::new(DeterministicErrorSemantics));
131    }
132
133    /// Register default checks for v0.4 (conformance + security + behavioral).
134    pub fn register_default_v0_4_checks(&mut self) {
135        self.register_conformance_checks();
136        self.register_security_checks();
137        self.register_behavioral_checks();
138    }
139
140    /// Get the total number of registered checks.
141    pub fn check_count(&self) -> usize {
142        self.checks.len()
143    }
144
145    /// Run all registered checks and return all results.
146    pub async fn run_all(&self, ctx: &mut CheckContext) -> Vec<CheckResult> {
147        self.run_all_with_progress(ctx, |_, _, _| {}).await
148    }
149
150    /// Run all registered checks with a progress callback.
151    ///
152    /// The callback is invoked after each check completes with:
153    /// - `check_id`: the ID of the check that just finished (e.g., "CONF-001")
154    /// - `check_name`: the human-readable name
155    /// - `results`: the results produced by this check
156    pub async fn run_all_with_progress<F>(
157        &self,
158        ctx: &mut CheckContext,
159        mut on_check_done: F,
160    ) -> Vec<CheckResult>
161    where
162        F: FnMut(&str, &str, &[CheckResult]),
163    {
164        let mut results = Vec::new();
165        for check in &self.checks {
166            tracing::info!(check_id = %check.id(), name = %check.name(), "running check");
167            let start = std::time::Instant::now();
168            let mut check_results = check.run(ctx).await;
169            let elapsed = start.elapsed().as_millis() as u64;
170            for r in &mut check_results {
171                r.duration_ms = elapsed;
172            }
173            on_check_done(check.id(), check.name(), &check_results);
174            results.extend(check_results);
175        }
176        results
177    }
178}
179
180impl Default for CheckRunner {
181    fn default() -> Self {
182        Self::new()
183    }
184}