assay_core/validate/
mod.rs

1use crate::config::path_resolver::PathResolver;
2use crate::errors::diagnostic::{codes, Diagnostic};
3use crate::model::EvalConfig;
4use crate::model::Expected;
5use crate::providers::llm::LlmClient; // Import trait for .complete()
6use crate::providers::trace::TraceClient;
7use std::path::{Path, PathBuf};
8
9#[derive(Debug, Clone)]
10pub struct ValidateOptions {
11    pub trace_file: Option<PathBuf>,
12    pub baseline_file: Option<PathBuf>,
13    pub replay_strict: bool,
14}
15
16#[derive(Debug, Clone, Default)]
17pub struct ValidateReport {
18    pub diagnostics: Vec<Diagnostic>,
19}
20
21pub async fn validate(
22    cfg: &EvalConfig,
23    opts: &ValidateOptions,
24    resolver: &PathResolver,
25) -> anyhow::Result<ValidateReport> {
26    let mut diags = Vec::new();
27
28    // 1. Path Resolution Checks (E_PATH_NOT_FOUND)
29    // Actually the CLI loader does this, but we can double check config assets if any.
30    // For now, let's assume config is loaded correctly if we are here,
31    // but check the explicitly provided trace/baseline files if they exist.
32
33    if let Some(path) = &opts.trace_file {
34        if !path.exists() {
35            diags.push(
36                Diagnostic::new(
37                    codes::E_PATH_NOT_FOUND,
38                    format!("Trace file not found: {}", path.display()),
39                )
40                .with_context(serde_json::json!({ "path": path }))
41                .with_source("validate")
42                .with_fix_step("Ensure the --trace-file path is correct and accessible"),
43            );
44        }
45    }
46
47    if let Some(path) = &opts.baseline_file {
48        if !path.exists() {
49            diags.push(
50                Diagnostic::new(
51                    codes::E_PATH_NOT_FOUND,
52                    format!("Baseline file not found: {}", path.display()),
53                )
54                .with_context(serde_json::json!({ "path": path }))
55                .with_source("validate")
56                .with_fix_step("Ensure the --baseline path is correct and accessible"),
57            );
58        }
59    }
60
61    // Return early if basic files missing to avoid noise
62    if !diags.is_empty() {
63        return Ok(ValidateReport { diagnostics: diags });
64    }
65
66    // 2. Load Trace & Baseline for deeper checks
67    let trace_client = if let Some(path) = &opts.trace_file {
68        match TraceClient::from_path(path) {
69            Ok(client) => Some(client),
70            Err(e) => {
71                diags.push(
72                    Diagnostic::new(
73                        codes::E_TRACE_INVALID,
74                        format!("Failed to parse trace file: {}", e),
75                    )
76                    .with_source("trace")
77                    .with_context(serde_json::json!({ "path": path, "error": e.to_string() })),
78                );
79                return Ok(ValidateReport { diagnostics: diags });
80            }
81        }
82    } else {
83        None
84    };
85
86    let baseline = if let Some(path) = &opts.baseline_file {
87        match crate::baseline::Baseline::load(path) {
88            Ok(b) => Some(b),
89            Err(e) => {
90                diags.push(
91                    Diagnostic::new(
92                        codes::E_BASE_MISMATCH,
93                        format!("Failed to parse baseline: {}", e),
94                    )
95                    .with_source("baseline")
96                    .with_context(serde_json::json!({ "path": path, "error": e.to_string() })),
97                );
98                return Ok(ValidateReport { diagnostics: diags });
99            }
100        }
101    } else {
102        None
103    };
104
105    // 3. Trace Coverage (E_TRACE_MISS)
106    if let Some(client) = &trace_client {
107        for tc in &cfg.tests {
108            // We use the same lookup logic as TraceClient::complete
109            // But here we want to collect ALL misses, not just fail on first.
110            // Since `complete` is not exposed as "check only", we iterate.
111            // Actually TraceClient doesn't expose keys publicly yet.
112            // We might need to call complete and catch error?
113            // OR better: call complete() on client. Since it returns LlmResponse or Err(Diagnostic)
114
115            let res = client
116                .complete(&tc.input.prompt, tc.input.context.as_deref())
117                .await;
118            if let Err(e) = res {
119                // If it's a diagnostic, push it.
120                // We use try_map_error from errors module
121                if let Some(diag) = crate::errors::try_map_error(&e) {
122                    // Enrich with test_id
123                    let mut d = diag.clone();
124                    if let serde_json::Value::Object(ref mut map) = d.context {
125                        map.insert("test_id".into(), serde_json::json!(tc.id));
126                        map.insert("trace_file".into(), serde_json::json!(opts.trace_file));
127                    }
128                    d.source = "trace".to_string();
129                    diags.push(d);
130                } else {
131                    // Unexpected error?
132                    diags.push(
133                        Diagnostic::new("E_UNKNOWN", format!("Unexpected trace error: {}", e))
134                            .with_source("trace"),
135                    );
136                }
137            } else if let Ok(resp) = res {
138                // Check Strict Replay (Requirement 4)
139                if opts.replay_strict {
140                    validate_strict_requirements(tc, &resp, &mut diags, opts.trace_file.as_deref());
141                }
142
143                // Check Embedding Dims (Requirement 5)
144                // This is checking per-test, potentially spammy.
145                // Better to check once per trace? But we don't have access to all embeddings.
146                // We'll check via response meta if available.
147                check_embedding_dims(&resp, &mut diags, opts.trace_file.as_deref());
148
149                // Check Policy (Requirement 2: ArgsValid)
150                if let Expected::ArgsValid {
151                    policy: Some(policy_path),
152                    ..
153                } = &tc.expected
154                {
155                    // 1. Load Policy
156                    // For now, load fully. In future, cache via resolver.
157                    // We need to resolve relative to config?
158                    // resolver.resolve_path(policy_path)?
159                    let mut p_str = policy_path.clone();
160                    resolver.resolve_str(&mut p_str);
161                    let policy_file = std::path::PathBuf::from(p_str);
162                    if !policy_file.exists() {
163                        diags.push(
164                            Diagnostic::new(
165                                codes::E_PATH_NOT_FOUND,
166                                format!("Policy file not found: {}", policy_file.display()),
167                            )
168                            .with_source("validate")
169                            .with_context(serde_json::json!({ "path": policy_file })),
170                        );
171                    } else {
172                        match crate::model::Policy::load(&policy_file) {
173                            Ok(pol) => {
174                                // 2. Get Tool Calls from Trace
175                                let tool_calls =
176                                    resp.meta.get("tool_calls").and_then(|v| v.as_array());
177
178                                if let Some(calls) = tool_calls {
179                                    // Convert to policy value for engine
180                                    let policy_val = serde_json::to_value(
181                                        pol.tools.arg_constraints.unwrap_or_default(),
182                                    )
183                                    .unwrap_or(serde_json::Value::Null);
184
185                                    // Check for Allowed/Denied lists first?
186                                    // Let's use simple policy_engine:evaluate_tool_args which expects JSON schema map.
187                                    // Wait, Policy struct has complex structure.
188                                    // policy.tools.arg_constraints is Map<Tool, Schema>.
189                                    // policy.tools.allow/deny are lists.
190
191                                    // Simplified validation for v1.2.1: Just check args against schema if present.
192                                    // Detailed enforcement requires full policy engine context (TODO for v1.3)
193
194                                    for call in calls {
195                                        let tool_name = call
196                                            .get("tool_name")
197                                            .and_then(|s| s.as_str())
198                                            .unwrap_or("unknown");
199                                        let args =
200                                            call.get("args").unwrap_or(&serde_json::Value::Null);
201
202                                        // Need to construct the "policy" value expected by evaluate_tool_args
203                                        // It expects { "ToolName": Schema, ... }
204                                        // This is exactly `arg_constraints`.
205
206                                        let verdict = crate::policy_engine::evaluate_tool_args(
207                                            &policy_val,
208                                            tool_name,
209                                            args,
210                                        );
211
212                                        if let crate::policy_engine::VerdictStatus::Blocked =
213                                            verdict.status
214                                        {
215                                            let mut d = Diagnostic::new(
216                                                verdict.reason_code,
217                                                "Policy violation in tool call",
218                                            )
219                                            .with_source("policy")
220                                            .with_context(verdict.details);
221
222                                            // Add trace context
223                                            if let serde_json::Value::Object(ref mut map) =
224                                                d.context
225                                            {
226                                                map.insert("tool".into(), tool_name.into());
227                                                map.insert("test_id".into(), tc.id.clone().into());
228                                            }
229                                            diags.push(d);
230                                        }
231                                    }
232                                } else {
233                                    // No tool calls found in trace?
234                                    // If policy expects validation, maybe warn?
235                                }
236                            }
237                            Err(e) => {
238                                diags.push(
239                                    Diagnostic::new(
240                                        codes::E_CFG_PARSE,
241                                        format!("Failed to parse policy: {}", e),
242                                    )
243                                    .with_source("policy"),
244                                );
245                            }
246                        }
247                    }
248                }
249            }
250        }
251    }
252
253    // Baseline Compat (Requirement 3)
254    if let Some(base) = &baseline {
255        if base.suite != cfg.suite {
256            diags.push(
257                Diagnostic::new(codes::E_BASE_MISMATCH, "Baseline suite mismatch")
258                    .with_source("baseline")
259                    .with_context(serde_json::json!({
260                        "expected_suite": cfg.suite,
261                        "baseline_suite": base.suite,
262                        "baseline_file": opts.baseline_file
263                    }))
264                    .with_fix_step("Use the baseline file created for this suite")
265                    .with_fix_step("Or export a new baseline: assay ci ... --export-baseline ..."),
266            );
267        }
268    }
269
270    // Deduplicate diagnostics?
271    // E_EMB_DIMS might be spammy if every test fails.
272    // Simple dedup by code + message signature could be added later.
273
274    Ok(ValidateReport { diagnostics: diags })
275}
276
277fn validate_strict_requirements(
278    tc: &crate::model::TestCase,
279    resp: &crate::model::LlmResponse,
280    diags: &mut Vec<Diagnostic>,
281    trace_path: Option<&Path>,
282) {
283    let mut missing = Vec::new();
284
285    // Check Semantic Metrics -> Need Embeddings
286    if let Expected::SemanticSimilarityTo { .. } = &tc.expected {
287        if resp.meta.pointer("/assay/embeddings/response").is_none() {
288            missing.push(serde_json::json!({
289                "requirement": "embeddings",
290                "needed_by": ["semantic_similarity_to"],
291                "meta_path": "meta.assay.embeddings"
292            }));
293        }
294    }
295
296    // Check Judge -> Need Judge Results
297    // Only if expected is Faithfulness or Relevance
298    match &tc.expected {
299        Expected::Faithfulness { .. } => {
300            if resp.meta.pointer("/assay/judge/faithfulness").is_none() {
301                missing.push(serde_json::json!({
302                    "requirement": "judge_faithfulness",
303                    "needed_by": ["faithfulness"],
304                    "meta_path": "meta.assay.judge.faithfulness"
305                }));
306            }
307        }
308        Expected::Relevance { .. } => {
309            if resp.meta.pointer("/assay/judge/relevance").is_none() {
310                missing.push(serde_json::json!({
311                    "requirement": "judge_relevance",
312                    "needed_by": ["relevance"],
313                    "meta_path": "meta.assay.judge.relevance"
314                }));
315            }
316        }
317        _ => {}
318    }
319
320    if !missing.is_empty() {
321        diags.push(
322            Diagnostic::new(
323                codes::E_REPLAY_STRICT_MISSING,
324                "Strict replay requires precomputed data that is missing from trace",
325            )
326            .with_source("replay")
327            .with_context(serde_json::json!({
328                "replay_strict": true,
329                "trace_file": trace_path,
330                "missing": missing,
331                "test_id": tc.id
332            }))
333            .with_fix_step("Run `assay trace precompute-embeddings ...`")
334            .with_fix_step("Run `assay trace precompute-judge ...`"),
335        );
336    }
337}
338
339fn check_embedding_dims(
340    resp: &crate::model::LlmResponse,
341    diags: &mut Vec<Diagnostic>,
342    trace_path: Option<&Path>,
343) {
344    // Basic heuristic: if we have embeddings, check simple consistency?
345    // Or if we know expected model?
346    // For now, looking for obvious bad data (empty vectors)
347    // Or strict mismatch if we ever passed an embedder config (not available here yet).
348
349    if let Some(embeddings) = resp
350        .meta
351        .pointer("/assay/embeddings")
352        .and_then(|v| v.as_object())
353    {
354        if let Some(response_vec) = embeddings.get("response").and_then(|v| v.as_array()) {
355            if response_vec.is_empty() {
356                diags.push(
357                    Diagnostic::new(codes::E_EMB_DIMS, "Empty embedding vector found in trace")
358                        .with_source("trace")
359                        .with_context(serde_json::json!({ "trace_file": trace_path }))
360                        .with_fix_step("Regenerate embeddings with precompute-embeddings"),
361                );
362            }
363        }
364    }
365}