Skip to main content

construct/agent/
eval.rs

1use serde::{Deserialize, Serialize};
2
3use schemars::JsonSchema;
4
5// ── Complexity estimation ───────────────────────────────────────
6
7/// Coarse complexity tier for a user message.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComplexityTier {
10    /// Short, simple query (greetings, yes/no, lookups).
11    Simple,
12    /// Typical request — not trivially simple, not deeply complex.
13    Standard,
14    /// Long or reasoning-heavy request (code, multi-step, analysis).
15    Complex,
16}
17
18/// Heuristic keywords that signal reasoning complexity.
19const REASONING_KEYWORDS: &[&str] = &[
20    "explain",
21    "why",
22    "analyze",
23    "compare",
24    "design",
25    "implement",
26    "refactor",
27    "debug",
28    "optimize",
29    "architecture",
30    "trade-off",
31    "tradeoff",
32    "reasoning",
33    "step by step",
34    "think through",
35    "evaluate",
36    "critique",
37    "pros and cons",
38];
39
40/// Estimate the complexity of a user message without an LLM call.
41///
42/// Rules (applied in order):
43/// - **Complex**: message > 200 chars, OR contains a code fence, OR ≥ 2
44///   reasoning keywords.
45/// - **Simple**: message < 50 chars AND no reasoning keywords.
46/// - **Standard**: everything else.
47pub fn estimate_complexity(message: &str) -> ComplexityTier {
48    let lower = message.to_lowercase();
49    let len = message.len();
50
51    let keyword_count = REASONING_KEYWORDS
52        .iter()
53        .filter(|kw| lower.contains(**kw))
54        .count();
55
56    let has_code_fence = message.contains("```");
57
58    if len > 200 || has_code_fence || keyword_count >= 2 {
59        return ComplexityTier::Complex;
60    }
61
62    if len < 50 && keyword_count == 0 {
63        return ComplexityTier::Simple;
64    }
65
66    ComplexityTier::Standard
67}
68
69// ── Auto-classify config ────────────────────────────────────────
70
71/// Configuration for automatic complexity-based classification.
72///
73/// When the rule-based classifier in `QueryClassificationConfig` produces no
74/// match, the eval layer can fall back to `estimate_complexity` and map the
75/// resulting tier to a routing hint.
76#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
77pub struct AutoClassifyConfig {
78    /// Hint to use for `Simple` complexity tier (e.g. `"fast"`).
79    #[serde(default)]
80    pub simple_hint: Option<String>,
81    /// Hint to use for `Standard` complexity tier.
82    #[serde(default)]
83    pub standard_hint: Option<String>,
84    /// Hint to use for `Complex` complexity tier (e.g. `"reasoning"`).
85    #[serde(default)]
86    pub complex_hint: Option<String>,
87    /// Hint prefix for cost-optimized routing (default: `"cost-optimized"`).
88    #[serde(default = "default_cost_optimized_hint")]
89    pub cost_optimized_hint: String,
90}
91
92fn default_cost_optimized_hint() -> String {
93    "cost-optimized".to_string()
94}
95
96impl Default for AutoClassifyConfig {
97    fn default() -> Self {
98        Self {
99            simple_hint: None,
100            standard_hint: None,
101            complex_hint: None,
102            cost_optimized_hint: default_cost_optimized_hint(),
103        }
104    }
105}
106
107impl AutoClassifyConfig {
108    /// Map a complexity tier to the configured hint, if any.
109    pub fn hint_for(&self, tier: ComplexityTier) -> Option<&str> {
110        match tier {
111            ComplexityTier::Simple => self.simple_hint.as_deref(),
112            ComplexityTier::Standard => self.standard_hint.as_deref(),
113            ComplexityTier::Complex => self.complex_hint.as_deref(),
114        }
115    }
116}
117
118// ── Post-response eval ──────────────────────────────────────────
119
120/// Configuration for the post-response quality evaluator.
121#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
122pub struct EvalConfig {
123    /// Enable the eval quality gate.
124    #[serde(default)]
125    pub enabled: bool,
126    /// Minimum quality score (0.0–1.0) to accept a response.
127    /// Below this threshold, a retry with a higher-tier model is suggested.
128    #[serde(default = "default_min_quality_score")]
129    pub min_quality_score: f64,
130    /// Maximum retries with escalated models before accepting whatever we get.
131    #[serde(default = "default_max_retries")]
132    pub max_retries: u32,
133}
134
135fn default_min_quality_score() -> f64 {
136    0.5
137}
138
139fn default_max_retries() -> u32 {
140    1
141}
142
143impl Default for EvalConfig {
144    fn default() -> Self {
145        Self {
146            enabled: false,
147            min_quality_score: default_min_quality_score(),
148            max_retries: default_max_retries(),
149        }
150    }
151}
152
153/// Result of evaluating a response against quality heuristics.
154#[derive(Debug, Clone)]
155pub struct EvalResult {
156    /// Aggregate quality score from 0.0 (terrible) to 1.0 (excellent).
157    pub score: f64,
158    /// Individual check outcomes (for observability).
159    pub checks: Vec<EvalCheck>,
160    /// If score < threshold, the suggested higher-tier hint for retry.
161    pub retry_hint: Option<String>,
162}
163
164#[derive(Debug, Clone)]
165pub struct EvalCheck {
166    pub name: &'static str,
167    pub passed: bool,
168    pub weight: f64,
169}
170
171/// Code-related keywords in user queries.
172const CODE_KEYWORDS: &[&str] = &[
173    "code",
174    "function",
175    "implement",
176    "class",
177    "struct",
178    "module",
179    "script",
180    "program",
181    "bug",
182    "error",
183    "compile",
184    "syntax",
185    "refactor",
186];
187
188/// Evaluate a response against heuristic quality checks. No LLM call.
189///
190/// Checks:
191/// 1. **Non-empty**: response must not be empty.
192/// 2. **Not a cop-out**: response must not be just "I don't know" or similar.
193/// 3. **Sufficient length**: response length should be proportional to query complexity.
194/// 4. **Code presence**: if the query mentions code keywords, the response should
195///    contain a code block.
196pub fn evaluate_response(
197    query: &str,
198    response: &str,
199    complexity: ComplexityTier,
200    auto_classify: Option<&AutoClassifyConfig>,
201) -> EvalResult {
202    let mut checks = Vec::new();
203
204    // Check 1: Non-empty
205    let non_empty = !response.trim().is_empty();
206    checks.push(EvalCheck {
207        name: "non_empty",
208        passed: non_empty,
209        weight: 0.3,
210    });
211
212    // Check 2: Not a cop-out
213    let lower_resp = response.to_lowercase();
214    let cop_out_phrases = [
215        "i don't know",
216        "i'm not sure",
217        "i cannot",
218        "i can't help",
219        "as an ai",
220    ];
221    let is_cop_out = cop_out_phrases
222        .iter()
223        .any(|phrase| lower_resp.starts_with(phrase));
224    let not_cop_out = !is_cop_out || response.len() > 200; // long responses with caveats are fine
225    checks.push(EvalCheck {
226        name: "not_cop_out",
227        passed: not_cop_out,
228        weight: 0.25,
229    });
230
231    // Check 3: Sufficient length for complexity
232    let min_len = match complexity {
233        ComplexityTier::Simple => 5,
234        ComplexityTier::Standard => 20,
235        ComplexityTier::Complex => 50,
236    };
237    let sufficient_length = response.len() >= min_len;
238    checks.push(EvalCheck {
239        name: "sufficient_length",
240        passed: sufficient_length,
241        weight: 0.2,
242    });
243
244    // Check 4: Code presence when expected
245    let query_lower = query.to_lowercase();
246    let expects_code = CODE_KEYWORDS.iter().any(|kw| query_lower.contains(kw));
247    let has_code = response.contains("```") || response.contains("    "); // code block or indented
248    let code_check_passed = !expects_code || has_code;
249    checks.push(EvalCheck {
250        name: "code_presence",
251        passed: code_check_passed,
252        weight: 0.25,
253    });
254
255    // Compute weighted score
256    let total_weight: f64 = checks.iter().map(|c| c.weight).sum();
257    let earned: f64 = checks.iter().filter(|c| c.passed).map(|c| c.weight).sum();
258    let score = if total_weight > 0.0 {
259        earned / total_weight
260    } else {
261        1.0
262    };
263
264    // Determine retry hint: if score is low, suggest escalating
265    let retry_hint = if score <= default_min_quality_score() {
266        // Try to escalate: Simple→Standard→Complex
267        let next_tier = match complexity {
268            ComplexityTier::Simple => Some(ComplexityTier::Standard),
269            ComplexityTier::Standard => Some(ComplexityTier::Complex),
270            ComplexityTier::Complex => None, // already at max
271        };
272        next_tier.and_then(|tier| {
273            auto_classify
274                .and_then(|ac| ac.hint_for(tier))
275                .map(String::from)
276        })
277    } else {
278        None
279    };
280
281    EvalResult {
282        score,
283        checks,
284        retry_hint,
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    // ── estimate_complexity ─────────────────────────────────────
293
294    #[test]
295    fn simple_short_message() {
296        assert_eq!(estimate_complexity("hi"), ComplexityTier::Simple);
297        assert_eq!(estimate_complexity("hello"), ComplexityTier::Simple);
298        assert_eq!(estimate_complexity("yes"), ComplexityTier::Simple);
299    }
300
301    #[test]
302    fn complex_long_message() {
303        let long = "a".repeat(201);
304        assert_eq!(estimate_complexity(&long), ComplexityTier::Complex);
305    }
306
307    #[test]
308    fn complex_code_fence() {
309        let msg = "Here is some code:\n```rust\nfn main() {}\n```";
310        assert_eq!(estimate_complexity(msg), ComplexityTier::Complex);
311    }
312
313    #[test]
314    fn complex_multiple_reasoning_keywords() {
315        let msg = "Please explain why this design is better and analyze the trade-off";
316        assert_eq!(estimate_complexity(msg), ComplexityTier::Complex);
317    }
318
319    #[test]
320    fn standard_medium_message() {
321        // 50+ chars but no code fence, < 2 reasoning keywords
322        let msg = "Can you help me find a good restaurant in this area please?";
323        assert_eq!(estimate_complexity(msg), ComplexityTier::Standard);
324    }
325
326    #[test]
327    fn standard_short_with_one_keyword() {
328        // < 50 chars but has 1 reasoning keyword → still not Simple
329        let msg = "explain this";
330        assert_eq!(estimate_complexity(msg), ComplexityTier::Standard);
331    }
332
333    // ── auto_classify ───────────────────────────────────────────
334
335    #[test]
336    fn auto_classify_maps_tiers_to_hints() {
337        let ac = AutoClassifyConfig {
338            simple_hint: Some("fast".into()),
339            standard_hint: None,
340            complex_hint: Some("reasoning".into()),
341            ..Default::default()
342        };
343        assert_eq!(ac.hint_for(ComplexityTier::Simple), Some("fast"));
344        assert_eq!(ac.hint_for(ComplexityTier::Standard), None);
345        assert_eq!(ac.hint_for(ComplexityTier::Complex), Some("reasoning"));
346    }
347
348    // ── evaluate_response ───────────────────────────────────────
349
350    #[test]
351    fn empty_response_scores_low() {
352        let result = evaluate_response("hello", "", ComplexityTier::Simple, None);
353        assert!(result.score <= 0.5, "empty response should score low");
354    }
355
356    #[test]
357    fn good_response_scores_high() {
358        let result = evaluate_response(
359            "what is 2+2?",
360            "The answer is 4.",
361            ComplexityTier::Simple,
362            None,
363        );
364        assert!(
365            result.score >= 0.9,
366            "good simple response should score high, got {}",
367            result.score
368        );
369    }
370
371    #[test]
372    fn cop_out_response_penalized() {
373        let result = evaluate_response(
374            "explain quantum computing",
375            "I don't know much about that.",
376            ComplexityTier::Standard,
377            None,
378        );
379        assert!(
380            result.score < 1.0,
381            "cop-out should be penalized, got {}",
382            result.score
383        );
384    }
385
386    #[test]
387    fn code_query_without_code_response_penalized() {
388        let result = evaluate_response(
389            "write a function to sort an array",
390            "You should use a sorting algorithm.",
391            ComplexityTier::Standard,
392            None,
393        );
394        // "code_presence" check should fail
395        let code_check = result.checks.iter().find(|c| c.name == "code_presence");
396        assert!(
397            code_check.is_some() && !code_check.unwrap().passed,
398            "code check should fail"
399        );
400    }
401
402    #[test]
403    fn retry_hint_escalation() {
404        let ac = AutoClassifyConfig {
405            simple_hint: Some("fast".into()),
406            standard_hint: Some("default".into()),
407            complex_hint: Some("reasoning".into()),
408            ..Default::default()
409        };
410        // Empty response for a Simple query → should suggest Standard hint
411        let result = evaluate_response("hello", "", ComplexityTier::Simple, Some(&ac));
412        assert_eq!(result.retry_hint, Some("default".into()));
413    }
414
415    #[test]
416    fn no_retry_when_already_complex() {
417        let ac = AutoClassifyConfig {
418            simple_hint: Some("fast".into()),
419            standard_hint: Some("default".into()),
420            complex_hint: Some("reasoning".into()),
421            ..Default::default()
422        };
423        // Empty response for Complex → no escalation possible
424        let result =
425            evaluate_response("explain everything", "", ComplexityTier::Complex, Some(&ac));
426        assert_eq!(result.retry_hint, None);
427    }
428
429    #[test]
430    fn max_retries_defaults() {
431        let config = EvalConfig::default();
432        assert!(!config.enabled);
433        assert_eq!(config.max_retries, 1);
434        assert!((config.min_quality_score - 0.5).abs() < f64::EPSILON);
435    }
436
437    #[test]
438    fn cost_optimized_hint_default() {
439        let config = AutoClassifyConfig::default();
440        assert_eq!(config.cost_optimized_hint, "cost-optimized");
441    }
442}