assay_core/
coverage.rs

1//! Coverage metrics for Assay policies
2//!
3//! Analyzes traces to determine:
4//! - Tool coverage: which tools from policy were exercised
5//! - Rule coverage: which rules were triggered
6//! - Gap detection: high-risk tools never seen in traces
7
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ToolCoverage {
13    /// Total unique tools referenced in policy
14    pub total_tools_in_policy: usize,
15
16    /// Tools that appeared in at least one trace
17    pub tools_seen_in_traces: usize,
18
19    /// Coverage percentage
20    pub coverage_pct: f64,
21
22    /// Tools in policy but never seen
23    pub unseen_tools: Vec<String>,
24
25    /// Tools seen in traces but not in policy (potential gaps)
26    pub unexpected_tools: Vec<String>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct PolicyViolation {
31    pub trace_id: String,
32    pub tool: String,
33    pub error_code: String,
34    pub reason: String,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PolicyWarning {
39    pub trace_id: String,
40    pub tool: String,
41    pub warning_code: String,
42    pub reason: String,
43}
44
45/// Coverage analysis result
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CoverageReport {
48    /// Tool coverage metrics
49    pub tool_coverage: ToolCoverage,
50
51    /// Rule coverage metrics
52    pub rule_coverage: RuleCoverage,
53
54    /// High-risk gaps (blocklisted tools never seen)
55    pub high_risk_gaps: Vec<HighRiskGap>,
56
57    /// Policy violations found during analysis
58    #[serde(default)]
59    pub policy_violations: Vec<PolicyViolation>,
60
61    /// Policy warnings (e.g. unconstrained tools)
62    #[serde(default)]
63    pub policy_warnings: Vec<PolicyWarning>,
64
65    /// Overall coverage percentage
66    pub overall_coverage_pct: f64,
67
68    /// Whether coverage meets threshold
69    pub meets_threshold: bool,
70
71    /// Threshold that was checked
72    pub threshold: f64,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RuleCoverage {
77    /// Total rules in policy
78    pub total_rules: usize,
79
80    /// Rules that were triggered (evaluated to allow or deny)
81    pub rules_triggered: usize,
82
83    /// Coverage percentage
84    pub coverage_pct: f64,
85
86    /// Rules that were never triggered
87    pub untriggered_rules: Vec<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct HighRiskGap {
92    /// Tool name
93    pub tool: String,
94
95    /// Why it's high risk
96    pub reason: String,
97
98    /// Severity: "critical", "high", "medium"
99    pub severity: String,
100}
101
102/// Trace data for coverage analysis
103#[derive(Debug, Clone)]
104pub struct TraceRecord {
105    pub trace_id: String,
106    pub tools_called: Vec<String>,
107    pub rules_triggered: HashSet<String>,
108}
109
110/// Coverage analyzer
111pub struct CoverageAnalyzer {
112    /// Tools referenced in policy (from allow, deny, sequences)
113    policy_tools: HashSet<String>,
114
115    /// High-risk tools (from deny list, blocklist patterns)
116    high_risk_tools: HashSet<String>,
117
118    /// Rule IDs in policy
119    rule_ids: Vec<String>,
120
121    /// Resolved aliases (alias -> members)
122    aliases: HashMap<String, Vec<String>>,
123}
124
125impl CoverageAnalyzer {
126    /// Create analyzer from a v1.1 policy
127    pub fn from_policy(policy: &crate::model::Policy) -> Self {
128        let mut policy_tools = HashSet::new();
129        let mut high_risk_tools = HashSet::new();
130        let mut rule_ids = Vec::new();
131
132        // Extract tools from policy.tools
133        if let Some(allow) = &policy.tools.allow {
134            for tool in allow {
135                policy_tools.insert(tool.clone());
136            }
137        }
138
139        if let Some(deny) = &policy.tools.deny {
140            for tool in deny {
141                policy_tools.insert(tool.clone());
142                high_risk_tools.insert(tool.clone()); // Denied = high risk
143            }
144        }
145
146        if let Some(require_args) = &policy.tools.require_args {
147            for tool in require_args.keys() {
148                policy_tools.insert(tool.clone());
149            }
150        }
151
152        // Extract tools from sequences
153        for (idx, rule) in policy.sequences.iter().enumerate() {
154            let rule_id = Self::rule_id(rule, idx);
155            rule_ids.push(rule_id);
156
157            match rule {
158                crate::model::SequenceRule::Require { tool } => {
159                    policy_tools.insert(tool.clone());
160                }
161                crate::model::SequenceRule::Eventually { tool, .. } => {
162                    policy_tools.insert(tool.clone());
163                }
164                crate::model::SequenceRule::MaxCalls { tool, .. } => {
165                    policy_tools.insert(tool.clone());
166                }
167                crate::model::SequenceRule::Before { first, then } => {
168                    policy_tools.insert(first.clone());
169                    policy_tools.insert(then.clone());
170                }
171                crate::model::SequenceRule::After { trigger, then, .. } => {
172                    policy_tools.insert(trigger.clone());
173                    policy_tools.insert(then.clone());
174                }
175                crate::model::SequenceRule::NeverAfter { trigger, forbidden } => {
176                    policy_tools.insert(trigger.clone());
177                    policy_tools.insert(forbidden.clone());
178                    high_risk_tools.insert(forbidden.clone()); // Forbidden = high risk
179                }
180                crate::model::SequenceRule::Sequence { tools, .. } => {
181                    for tool in tools {
182                        policy_tools.insert(tool.clone());
183                    }
184                }
185                crate::model::SequenceRule::Blocklist { pattern } => {
186                    // Pattern-based, mark as high risk indicator
187                    high_risk_tools.insert(format!("*{}*", pattern));
188                }
189            }
190        }
191
192        // Resolve aliases - add alias members to policy_tools
193        for (alias, members) in &policy.aliases {
194            policy_tools.insert(alias.clone());
195            for member in members {
196                policy_tools.insert(member.clone());
197            }
198        }
199
200        Self {
201            policy_tools,
202            high_risk_tools,
203            rule_ids,
204            aliases: policy.aliases.clone(),
205        }
206    }
207
208    /// Generate a rule ID from rule type and index
209    fn rule_id(rule: &crate::model::SequenceRule, _idx: usize) -> String {
210        match rule {
211            crate::model::SequenceRule::Require { tool } => {
212                format!("require_{}", tool.to_lowercase())
213            }
214            crate::model::SequenceRule::Eventually { tool, within } => {
215                format!("eventually_{}_{}", tool.to_lowercase(), within)
216            }
217            crate::model::SequenceRule::MaxCalls { tool, max } => {
218                format!("max_calls_{}_{}", tool.to_lowercase(), max)
219            }
220            crate::model::SequenceRule::Before { first, then } => {
221                format!(
222                    "before_{}_then_{}",
223                    first.to_lowercase(),
224                    then.to_lowercase()
225                )
226            }
227            crate::model::SequenceRule::After { trigger, then, .. } => {
228                format!(
229                    "after_{}_then_{}",
230                    trigger.to_lowercase(),
231                    then.to_lowercase()
232                )
233            }
234            crate::model::SequenceRule::NeverAfter { trigger, forbidden } => {
235                format!(
236                    "never_after_{}_forbidden_{}",
237                    trigger.to_lowercase(),
238                    forbidden.to_lowercase()
239                )
240            }
241            crate::model::SequenceRule::Sequence { tools, strict } => {
242                let mode = if *strict { "strict" } else { "seq" };
243                format!("{}_{}", mode, tools.join("_").to_lowercase())
244            }
245            crate::model::SequenceRule::Blocklist { pattern } => {
246                format!("blocklist_{}", pattern.to_lowercase())
247            }
248        }
249    }
250
251    /// Analyze coverage from a set of traces
252    pub fn analyze(&self, traces: &[TraceRecord], threshold: f64) -> CoverageReport {
253        let mut tools_seen: HashSet<String> = HashSet::new();
254        let mut rules_triggered: HashSet<String> = HashSet::new();
255        let mut unexpected_tools: HashSet<String> = HashSet::new();
256
257        // Collect all tools and triggered rules from traces
258        for trace in traces {
259            for tool in &trace.tools_called {
260                tools_seen.insert(tool.clone());
261
262                // Check if tool is in policy (including alias resolution)
263                if !self.is_policy_tool(tool) {
264                    unexpected_tools.insert(tool.clone());
265                }
266            }
267
268            for rule_id in &trace.rules_triggered {
269                rules_triggered.insert(rule_id.clone());
270            }
271        }
272
273        // Calculate tool coverage
274        let policy_tool_count = self.policy_tools.len();
275        let seen_policy_tools: HashSet<_> = tools_seen
276            .iter()
277            .filter(|t| self.is_policy_tool(t))
278            .cloned()
279            .collect();
280        let tools_seen_count = seen_policy_tools.len();
281
282        let unseen_tools: Vec<String> = self
283            .policy_tools
284            .iter()
285            .filter(|t| !self.is_tool_seen(t, &tools_seen))
286            .cloned()
287            .collect();
288
289        let tool_coverage_pct = if policy_tool_count > 0 {
290            (tools_seen_count as f64 / policy_tool_count as f64) * 100.0
291        } else {
292            100.0
293        };
294
295        // Calculate rule coverage
296        let total_rules = self.rule_ids.len();
297        let triggered_count = rules_triggered.len();
298
299        let untriggered_rules: Vec<String> = self
300            .rule_ids
301            .iter()
302            .filter(|r| !rules_triggered.contains(*r))
303            .cloned()
304            .collect();
305
306        let rule_coverage_pct = if total_rules > 0 {
307            (triggered_count as f64 / total_rules as f64) * 100.0
308        } else {
309            100.0
310        };
311
312        // Identify high-risk gaps
313        let high_risk_gaps: Vec<HighRiskGap> = self
314            .high_risk_tools
315            .iter()
316            .filter(|t| !t.starts_with('*')) // Skip patterns
317            .filter(|t| !self.is_tool_seen(t, &tools_seen))
318            .map(|t| HighRiskGap {
319                tool: t.clone(),
320                reason: "Tool is in deny list but never appeared in test traces".to_string(),
321                severity: "high".to_string(),
322            })
323            .collect();
324
325        // Overall coverage (average of tool and rule coverage)
326        let overall_coverage_pct = (tool_coverage_pct + rule_coverage_pct) / 2.0;
327        let meets_threshold = overall_coverage_pct >= threshold;
328
329        CoverageReport {
330            tool_coverage: ToolCoverage {
331                total_tools_in_policy: policy_tool_count,
332                tools_seen_in_traces: tools_seen_count,
333                coverage_pct: tool_coverage_pct,
334                unseen_tools,
335                unexpected_tools: unexpected_tools.into_iter().collect(),
336            },
337            rule_coverage: RuleCoverage {
338                total_rules,
339                rules_triggered: triggered_count,
340                coverage_pct: rule_coverage_pct,
341                untriggered_rules,
342            },
343            high_risk_gaps,
344            policy_violations: Vec::new(),
345            policy_warnings: Vec::new(),
346            overall_coverage_pct,
347            meets_threshold,
348            threshold,
349        }
350    }
351
352    /// Check if a tool is in the policy (including alias resolution)
353    fn is_policy_tool(&self, tool: &str) -> bool {
354        if self.policy_tools.contains(tool) {
355            return true;
356        }
357
358        // Check if tool is a member of any alias
359        for members in self.aliases.values() {
360            if members.contains(&tool.to_string()) {
361                return true;
362            }
363        }
364
365        false
366    }
367
368    /// Check if a tool (or any of its alias members) was seen
369    fn is_tool_seen(&self, tool: &str, seen: &HashSet<String>) -> bool {
370        if seen.contains(tool) {
371            return true;
372        }
373
374        // Check if this tool is an alias and any member was seen
375        if let Some(members) = self.aliases.get(tool) {
376            return members.iter().any(|m| seen.contains(m));
377        }
378
379        // Check if tool is a member of an alias that was seen
380        for (alias, members) in &self.aliases {
381            if members.contains(&tool.to_string()) && seen.contains(alias) {
382                return true;
383            }
384        }
385
386        false
387    }
388}
389
390impl CoverageReport {
391    /// Format as GitHub Actions annotation
392    pub fn to_github_annotation(&self) -> String {
393        let mut lines = Vec::new();
394
395        if !self.meets_threshold {
396            lines.push(format!(
397                "::error::Coverage {:.1}% is below threshold {:.1}%",
398                self.overall_coverage_pct, self.threshold
399            ));
400        }
401
402        for gap in &self.high_risk_gaps {
403            lines.push(format!(
404                "::warning::High-risk tool '{}' never tested: {}",
405                gap.tool, gap.reason
406            ));
407        }
408
409        for tool in &self.tool_coverage.unseen_tools {
410            lines.push(format!(
411                "::notice::Tool '{}' in policy but not covered by tests",
412                tool
413            ));
414        }
415
416        lines.join("\n")
417    }
418
419    /// Format as markdown summary
420    pub fn to_markdown(&self) -> String {
421        let status = if self.meets_threshold { "✅" } else { "❌" };
422
423        let mut md = format!(
424            "## Coverage Report {}\n\n\
425            | Metric | Value |\n\
426            |--------|-------|\n\
427            | Overall Coverage | {:.1}% |\n\
428            | Tool Coverage | {:.1}% ({}/{}) |\n\
429            | Rule Coverage | {:.1}% ({}/{}) |\n\
430            | Threshold | {:.1}% |\n\n",
431            status,
432            self.overall_coverage_pct,
433            self.tool_coverage.coverage_pct,
434            self.tool_coverage.tools_seen_in_traces,
435            self.tool_coverage.total_tools_in_policy,
436            self.rule_coverage.coverage_pct,
437            self.rule_coverage.rules_triggered,
438            self.rule_coverage.total_rules,
439            self.threshold,
440        );
441
442        if !self.high_risk_gaps.is_empty() {
443            md.push_str("### ⚠️ High-Risk Gaps\n\n");
444            for gap in &self.high_risk_gaps {
445                md.push_str(&format!("- **{}**: {}\n", gap.tool, gap.reason));
446            }
447            md.push('\n');
448        }
449
450        if !self.tool_coverage.unseen_tools.is_empty() {
451            md.push_str("### Uncovered Tools\n\n");
452            for tool in &self.tool_coverage.unseen_tools {
453                md.push_str(&format!("- `{}`\n", tool));
454            }
455            md.push('\n');
456        }
457
458        if !self.rule_coverage.untriggered_rules.is_empty() {
459            md.push_str("### Untriggered Rules\n\n");
460            for rule in &self.rule_coverage.untriggered_rules {
461                md.push_str(&format!("- `{}`\n", rule));
462            }
463            md.push('\n');
464        }
465
466        md
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::model::{Policy, SequenceRule, ToolsPolicy};
474    use crate::on_error::ErrorPolicy;
475
476    fn make_policy() -> Policy {
477        Policy {
478            version: "1.1".to_string(),
479            name: "test".to_string(),
480            metadata: None,
481            tools: ToolsPolicy {
482                allow: Some(vec![
483                    "SearchKnowledgeBase".to_string(),
484                    "GetCustomerInfo".to_string(),
485                    "CreateTicket".to_string(),
486                ]),
487                deny: Some(vec!["DeleteAccount".to_string()]),
488                require_args: None,
489                arg_constraints: None,
490            },
491            sequences: vec![
492                SequenceRule::Before {
493                    first: "SearchKnowledgeBase".to_string(),
494                    then: "CreateTicket".to_string(),
495                },
496                SequenceRule::MaxCalls {
497                    tool: "GetCustomerInfo".to_string(),
498                    max: 3,
499                },
500            ],
501            aliases: HashMap::new(),
502            on_error: ErrorPolicy::default(),
503        }
504    }
505
506    #[test]
507    fn test_full_coverage() {
508        let policy = make_policy();
509        let analyzer = CoverageAnalyzer::from_policy(&policy);
510
511        let traces = vec![TraceRecord {
512            trace_id: "t1".to_string(),
513            tools_called: vec![
514                "SearchKnowledgeBase".to_string(),
515                "GetCustomerInfo".to_string(),
516                "CreateTicket".to_string(),
517                "DeleteAccount".to_string(), // High-risk, but tested
518            ],
519            rules_triggered: HashSet::from([
520                "before_searchknowledgebase_then_createticket".to_string(),
521                "max_calls_getcustomerinfo_3".to_string(),
522            ]),
523        }];
524
525        let report = analyzer.analyze(&traces, 80.0);
526
527        assert_eq!(report.tool_coverage.tools_seen_in_traces, 4);
528        assert!(report.tool_coverage.unseen_tools.is_empty());
529        assert!(report.high_risk_gaps.is_empty()); // DeleteAccount was seen
530        assert!(report.meets_threshold);
531    }
532
533    #[test]
534    fn test_partial_coverage() {
535        let policy = make_policy();
536        let analyzer = CoverageAnalyzer::from_policy(&policy);
537
538        let traces = vec![TraceRecord {
539            trace_id: "t1".to_string(),
540            tools_called: vec!["SearchKnowledgeBase".to_string()],
541            rules_triggered: HashSet::new(),
542        }];
543
544        let report = analyzer.analyze(&traces, 80.0);
545
546        assert_eq!(report.tool_coverage.tools_seen_in_traces, 1);
547        assert!(report
548            .tool_coverage
549            .unseen_tools
550            .contains(&"CreateTicket".to_string()));
551        assert!(report
552            .tool_coverage
553            .unseen_tools
554            .contains(&"GetCustomerInfo".to_string()));
555        assert!(!report.high_risk_gaps.is_empty()); // DeleteAccount not seen
556        assert!(!report.meets_threshold);
557    }
558
559    #[test]
560    fn test_unexpected_tools() {
561        let policy = make_policy();
562        let analyzer = CoverageAnalyzer::from_policy(&policy);
563
564        let traces = vec![TraceRecord {
565            trace_id: "t1".to_string(),
566            tools_called: vec![
567                "SearchKnowledgeBase".to_string(),
568                "UnknownTool".to_string(), // Not in policy
569            ],
570            rules_triggered: HashSet::new(),
571        }];
572
573        let report = analyzer.analyze(&traces, 50.0);
574
575        assert!(report
576            .tool_coverage
577            .unexpected_tools
578            .contains(&"UnknownTool".to_string()));
579    }
580
581    #[test]
582    fn test_github_annotation_format() {
583        let report = CoverageReport {
584            tool_coverage: ToolCoverage {
585                total_tools_in_policy: 4,
586                tools_seen_in_traces: 2,
587                coverage_pct: 50.0,
588                unseen_tools: vec!["CreateTicket".to_string()],
589                unexpected_tools: vec![],
590            },
591            rule_coverage: RuleCoverage {
592                total_rules: 2,
593                rules_triggered: 1,
594                coverage_pct: 50.0,
595                untriggered_rules: vec!["max_calls_api_3".to_string()],
596            },
597            high_risk_gaps: vec![HighRiskGap {
598                tool: "DeleteAccount".to_string(),
599                reason: "Never tested".to_string(),
600                severity: "high".to_string(),
601            }],
602            policy_violations: vec![],
603            policy_warnings: vec![],
604            overall_coverage_pct: 50.0,
605            meets_threshold: false,
606            threshold: 80.0,
607        };
608
609        let annotation = report.to_github_annotation();
610
611        assert!(annotation.contains("::error::Coverage 50.0% is below threshold 80.0%"));
612        assert!(annotation.contains("::warning::High-risk tool 'DeleteAccount'"));
613        assert!(annotation.contains("::notice::Tool 'CreateTicket'"));
614    }
615}