1use serde::{Deserialize, Serialize};
2
3use schemars::JsonSchema;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ComplexityTier {
10 Simple,
12 Standard,
14 Complex,
16}
17
18const REASONING_KEYWORDS: &[&str] = &[
20 "explain",
21 "why",
22 "analyze",
23 "compare",
24 "design",
25 "implement",
26 "refactor",
27 "debug",
28 "optimize",
29 "architecture",
30 "trade-off",
31 "tradeoff",
32 "reasoning",
33 "step by step",
34 "think through",
35 "evaluate",
36 "critique",
37 "pros and cons",
38];
39
40pub fn estimate_complexity(message: &str) -> ComplexityTier {
48 let lower = message.to_lowercase();
49 let len = message.len();
50
51 let keyword_count = REASONING_KEYWORDS
52 .iter()
53 .filter(|kw| lower.contains(**kw))
54 .count();
55
56 let has_code_fence = message.contains("```");
57
58 if len > 200 || has_code_fence || keyword_count >= 2 {
59 return ComplexityTier::Complex;
60 }
61
62 if len < 50 && keyword_count == 0 {
63 return ComplexityTier::Simple;
64 }
65
66 ComplexityTier::Standard
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
77pub struct AutoClassifyConfig {
78 #[serde(default)]
80 pub simple_hint: Option<String>,
81 #[serde(default)]
83 pub standard_hint: Option<String>,
84 #[serde(default)]
86 pub complex_hint: Option<String>,
87 #[serde(default = "default_cost_optimized_hint")]
89 pub cost_optimized_hint: String,
90}
91
92fn default_cost_optimized_hint() -> String {
93 "cost-optimized".to_string()
94}
95
96impl Default for AutoClassifyConfig {
97 fn default() -> Self {
98 Self {
99 simple_hint: None,
100 standard_hint: None,
101 complex_hint: None,
102 cost_optimized_hint: default_cost_optimized_hint(),
103 }
104 }
105}
106
107impl AutoClassifyConfig {
108 pub fn hint_for(&self, tier: ComplexityTier) -> Option<&str> {
110 match tier {
111 ComplexityTier::Simple => self.simple_hint.as_deref(),
112 ComplexityTier::Standard => self.standard_hint.as_deref(),
113 ComplexityTier::Complex => self.complex_hint.as_deref(),
114 }
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
122pub struct EvalConfig {
123 #[serde(default)]
125 pub enabled: bool,
126 #[serde(default = "default_min_quality_score")]
129 pub min_quality_score: f64,
130 #[serde(default = "default_max_retries")]
132 pub max_retries: u32,
133}
134
135fn default_min_quality_score() -> f64 {
136 0.5
137}
138
139fn default_max_retries() -> u32 {
140 1
141}
142
143impl Default for EvalConfig {
144 fn default() -> Self {
145 Self {
146 enabled: false,
147 min_quality_score: default_min_quality_score(),
148 max_retries: default_max_retries(),
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct EvalResult {
156 pub score: f64,
158 pub checks: Vec<EvalCheck>,
160 pub retry_hint: Option<String>,
162}
163
164#[derive(Debug, Clone)]
165pub struct EvalCheck {
166 pub name: &'static str,
167 pub passed: bool,
168 pub weight: f64,
169}
170
171const CODE_KEYWORDS: &[&str] = &[
173 "code",
174 "function",
175 "implement",
176 "class",
177 "struct",
178 "module",
179 "script",
180 "program",
181 "bug",
182 "error",
183 "compile",
184 "syntax",
185 "refactor",
186];
187
188pub fn evaluate_response(
197 query: &str,
198 response: &str,
199 complexity: ComplexityTier,
200 auto_classify: Option<&AutoClassifyConfig>,
201) -> EvalResult {
202 let mut checks = Vec::new();
203
204 let non_empty = !response.trim().is_empty();
206 checks.push(EvalCheck {
207 name: "non_empty",
208 passed: non_empty,
209 weight: 0.3,
210 });
211
212 let lower_resp = response.to_lowercase();
214 let cop_out_phrases = [
215 "i don't know",
216 "i'm not sure",
217 "i cannot",
218 "i can't help",
219 "as an ai",
220 ];
221 let is_cop_out = cop_out_phrases
222 .iter()
223 .any(|phrase| lower_resp.starts_with(phrase));
224 let not_cop_out = !is_cop_out || response.len() > 200; checks.push(EvalCheck {
226 name: "not_cop_out",
227 passed: not_cop_out,
228 weight: 0.25,
229 });
230
231 let min_len = match complexity {
233 ComplexityTier::Simple => 5,
234 ComplexityTier::Standard => 20,
235 ComplexityTier::Complex => 50,
236 };
237 let sufficient_length = response.len() >= min_len;
238 checks.push(EvalCheck {
239 name: "sufficient_length",
240 passed: sufficient_length,
241 weight: 0.2,
242 });
243
244 let query_lower = query.to_lowercase();
246 let expects_code = CODE_KEYWORDS.iter().any(|kw| query_lower.contains(kw));
247 let has_code = response.contains("```") || response.contains(" "); let code_check_passed = !expects_code || has_code;
249 checks.push(EvalCheck {
250 name: "code_presence",
251 passed: code_check_passed,
252 weight: 0.25,
253 });
254
255 let total_weight: f64 = checks.iter().map(|c| c.weight).sum();
257 let earned: f64 = checks.iter().filter(|c| c.passed).map(|c| c.weight).sum();
258 let score = if total_weight > 0.0 {
259 earned / total_weight
260 } else {
261 1.0
262 };
263
264 let retry_hint = if score <= default_min_quality_score() {
266 let next_tier = match complexity {
268 ComplexityTier::Simple => Some(ComplexityTier::Standard),
269 ComplexityTier::Standard => Some(ComplexityTier::Complex),
270 ComplexityTier::Complex => None, };
272 next_tier.and_then(|tier| {
273 auto_classify
274 .and_then(|ac| ac.hint_for(tier))
275 .map(String::from)
276 })
277 } else {
278 None
279 };
280
281 EvalResult {
282 score,
283 checks,
284 retry_hint,
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
295 fn simple_short_message() {
296 assert_eq!(estimate_complexity("hi"), ComplexityTier::Simple);
297 assert_eq!(estimate_complexity("hello"), ComplexityTier::Simple);
298 assert_eq!(estimate_complexity("yes"), ComplexityTier::Simple);
299 }
300
301 #[test]
302 fn complex_long_message() {
303 let long = "a".repeat(201);
304 assert_eq!(estimate_complexity(&long), ComplexityTier::Complex);
305 }
306
307 #[test]
308 fn complex_code_fence() {
309 let msg = "Here is some code:\n```rust\nfn main() {}\n```";
310 assert_eq!(estimate_complexity(msg), ComplexityTier::Complex);
311 }
312
313 #[test]
314 fn complex_multiple_reasoning_keywords() {
315 let msg = "Please explain why this design is better and analyze the trade-off";
316 assert_eq!(estimate_complexity(msg), ComplexityTier::Complex);
317 }
318
319 #[test]
320 fn standard_medium_message() {
321 let msg = "Can you help me find a good restaurant in this area please?";
323 assert_eq!(estimate_complexity(msg), ComplexityTier::Standard);
324 }
325
326 #[test]
327 fn standard_short_with_one_keyword() {
328 let msg = "explain this";
330 assert_eq!(estimate_complexity(msg), ComplexityTier::Standard);
331 }
332
333 #[test]
336 fn auto_classify_maps_tiers_to_hints() {
337 let ac = AutoClassifyConfig {
338 simple_hint: Some("fast".into()),
339 standard_hint: None,
340 complex_hint: Some("reasoning".into()),
341 ..Default::default()
342 };
343 assert_eq!(ac.hint_for(ComplexityTier::Simple), Some("fast"));
344 assert_eq!(ac.hint_for(ComplexityTier::Standard), None);
345 assert_eq!(ac.hint_for(ComplexityTier::Complex), Some("reasoning"));
346 }
347
348 #[test]
351 fn empty_response_scores_low() {
352 let result = evaluate_response("hello", "", ComplexityTier::Simple, None);
353 assert!(result.score <= 0.5, "empty response should score low");
354 }
355
356 #[test]
357 fn good_response_scores_high() {
358 let result = evaluate_response(
359 "what is 2+2?",
360 "The answer is 4.",
361 ComplexityTier::Simple,
362 None,
363 );
364 assert!(
365 result.score >= 0.9,
366 "good simple response should score high, got {}",
367 result.score
368 );
369 }
370
371 #[test]
372 fn cop_out_response_penalized() {
373 let result = evaluate_response(
374 "explain quantum computing",
375 "I don't know much about that.",
376 ComplexityTier::Standard,
377 None,
378 );
379 assert!(
380 result.score < 1.0,
381 "cop-out should be penalized, got {}",
382 result.score
383 );
384 }
385
386 #[test]
387 fn code_query_without_code_response_penalized() {
388 let result = evaluate_response(
389 "write a function to sort an array",
390 "You should use a sorting algorithm.",
391 ComplexityTier::Standard,
392 None,
393 );
394 let code_check = result.checks.iter().find(|c| c.name == "code_presence");
396 assert!(
397 code_check.is_some() && !code_check.unwrap().passed,
398 "code check should fail"
399 );
400 }
401
402 #[test]
403 fn retry_hint_escalation() {
404 let ac = AutoClassifyConfig {
405 simple_hint: Some("fast".into()),
406 standard_hint: Some("default".into()),
407 complex_hint: Some("reasoning".into()),
408 ..Default::default()
409 };
410 let result = evaluate_response("hello", "", ComplexityTier::Simple, Some(&ac));
412 assert_eq!(result.retry_hint, Some("default".into()));
413 }
414
415 #[test]
416 fn no_retry_when_already_complex() {
417 let ac = AutoClassifyConfig {
418 simple_hint: Some("fast".into()),
419 standard_hint: Some("default".into()),
420 complex_hint: Some("reasoning".into()),
421 ..Default::default()
422 };
423 let result =
425 evaluate_response("explain everything", "", ComplexityTier::Complex, Some(&ac));
426 assert_eq!(result.retry_hint, None);
427 }
428
429 #[test]
430 fn max_retries_defaults() {
431 let config = EvalConfig::default();
432 assert!(!config.enabled);
433 assert_eq!(config.max_retries, 1);
434 assert!((config.min_quality_score - 0.5).abs() < f64::EPSILON);
435 }
436
437 #[test]
438 fn cost_optimized_hint_default() {
439 let config = AutoClassifyConfig::default();
440 assert_eq!(config.cost_optimized_hint, "cost-optimized");
441 }
442}