Skip to main content

brainwires_reasoning/
validator.rs

1//! Validator - Semantic Response Validation
2//!
3//! Uses a provider to perform semantic validation of responses,
4//! enhancing the pattern-based red-flagging system.
5
6use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14/// Result of local validation
15#[derive(Clone, Debug)]
16pub enum ValidationResult {
17    /// Response is valid.
18    Valid {
19        /// Confidence in the validity assessment (0.0-1.0).
20        confidence: f32,
21    },
22    /// Response has issues.
23    Invalid {
24        /// Description of the validation issue.
25        reason: String,
26        /// Severity of the issue (0.0-1.0).
27        severity: f32,
28        /// Confidence in the invalidity assessment (0.0-1.0).
29        confidence: f32,
30    },
31    /// Validation was skipped (fallback to pattern-based)
32    Skipped,
33}
34
35impl ValidationResult {
36    /// Returns `true` if the response passed validation.
37    pub fn is_valid(&self) -> bool {
38        matches!(self, ValidationResult::Valid { .. })
39    }
40
41    /// Returns `true` if the response failed validation.
42    pub fn is_invalid(&self) -> bool {
43        matches!(self, ValidationResult::Invalid { .. })
44    }
45}
46
47/// Validator for semantic response validation
48pub struct LocalValidator {
49    provider: Arc<dyn Provider>,
50    model_id: String,
51}
52
53impl LocalValidator {
54    /// Create a new validator
55    pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
56        Self {
57            provider,
58            model_id: model_id.into(),
59        }
60    }
61
62    /// Validate a response for the given task
63    ///
64    /// Performs semantic validation to catch issues that pattern matching might miss.
65    pub async fn validate(&self, task: &str, response: &str) -> ValidationResult {
66        let timer = InferenceTimer::new("validate_response", &self.model_id);
67
68        // Skip very short responses (likely already handled by pattern matching)
69        if response.trim().len() < 10 {
70            return ValidationResult::Skipped;
71        }
72
73        let system_prompt = self.build_validation_prompt();
74        let user_prompt = format!(
75            "Validate if this response is appropriate for the task.\n\nTask: {}\n\nResponse: {}\n\nOutput ONLY: VALID or INVALID:<reason>",
76            task,
77            // Truncate response for efficiency
78            if response.len() > 500 {
79                &response[..500]
80            } else {
81                response
82            }
83        );
84
85        let messages = vec![Message::user(&user_prompt)];
86        let options = ChatOptions::deterministic(50).system(system_prompt);
87
88        match self.provider.chat(&messages, None, &options).await {
89            Ok(chat_response) => {
90                let text = chat_response.message.text_or_summary();
91                let result = self.parse_validation(&text);
92                timer.finish(true);
93                result
94            }
95            Err(e) => {
96                warn!(target: "local_llm", "Response validation failed: {}", e);
97                timer.finish(false);
98                ValidationResult::Skipped
99            }
100        }
101    }
102
103    /// Quick heuristic validation (no LLM call)
104    ///
105    /// Use for fast pre-filtering before LLM validation.
106    pub fn validate_heuristic(&self, task: &str, response: &str) -> ValidationResult {
107        let response_lower = response.to_lowercase();
108        let task_lower = task.to_lowercase();
109
110        // Check for obvious issues
111
112        // 1. Response is completely off-topic (no shared words with task)
113        let task_words: std::collections::HashSet<&str> = task_lower
114            .split_whitespace()
115            .filter(|w| w.len() > 3)
116            .collect();
117        let response_words: std::collections::HashSet<&str> = response_lower
118            .split_whitespace()
119            .filter(|w| w.len() > 3)
120            .collect();
121
122        let overlap = task_words.intersection(&response_words).count();
123        if overlap == 0 && task_words.len() > 3 {
124            return ValidationResult::Invalid {
125                reason: "Response appears unrelated to task".to_string(),
126                severity: 0.6,
127                confidence: 0.4,
128            };
129        }
130
131        // 2. Response contains refusal patterns
132        let refusal_patterns = [
133            "i cannot",
134            "i can't",
135            "i'm unable",
136            "i am unable",
137            "sorry, i",
138            "i don't have",
139            "i do not have",
140            "as an ai",
141        ];
142
143        for pattern in refusal_patterns {
144            if response_lower.contains(pattern) {
145                return ValidationResult::Invalid {
146                    reason: format!("Response contains refusal pattern: {}", pattern),
147                    severity: 0.7,
148                    confidence: 0.6,
149                };
150            }
151        }
152
153        // 3. Response is just repeating the task
154        let task_trimmed = task_lower.trim();
155        let response_trimmed = response_lower.trim();
156        if response_trimmed.starts_with(task_trimmed) && response.len() < task.len() * 2 {
157            return ValidationResult::Invalid {
158                reason: "Response appears to just repeat the task".to_string(),
159                severity: 0.5,
160                confidence: 0.5,
161            };
162        }
163
164        // 4. Response is suspiciously short for a complex task
165        if task.len() > 100 && response.len() < 20 {
166            return ValidationResult::Invalid {
167                reason: "Response too short for complex task".to_string(),
168                severity: 0.4,
169                confidence: 0.4,
170            };
171        }
172
173        ValidationResult::Valid { confidence: 0.5 }
174    }
175
176    /// Build the system prompt for validation
177    fn build_validation_prompt(&self) -> String {
178        r#"You are a response validator. Given a task and response, determine if the response is appropriate.
179
180Check for:
1811. Response addresses the task (not off-topic)
1822. Response doesn't contain confusion or self-correction
1833. Response isn't a refusal or "I can't do that"
1844. Response isn't just repeating the task
1855. Response has substance (not empty platitudes)
186
187Output format:
188- If valid: VALID
189- If invalid: INVALID:<brief reason>
190
191Be strict but fair. Only flag clear issues."#.to_string()
192    }
193
194    /// Parse the LLM output to determine validity
195    fn parse_validation(&self, output: &str) -> ValidationResult {
196        let trimmed = output.trim().to_uppercase();
197
198        if trimmed.starts_with("VALID") && !trimmed.contains("INVALID") {
199            return ValidationResult::Valid { confidence: 0.8 };
200        }
201
202        if trimmed.starts_with("INVALID") {
203            let reason = if let Some(idx) = trimmed.find(':') {
204                trimmed[idx + 1..].trim().to_string()
205            } else {
206                "Unspecified validation failure".to_string()
207            };
208
209            return ValidationResult::Invalid {
210                reason,
211                severity: 0.6,
212                confidence: 0.75,
213            };
214        }
215
216        // Ambiguous output - treat as skipped
217        ValidationResult::Skipped
218    }
219}
220
221/// Builder for LocalValidator
222pub struct LocalValidatorBuilder {
223    provider: Option<Arc<dyn Provider>>,
224    model_id: String,
225}
226
227impl Default for LocalValidatorBuilder {
228    fn default() -> Self {
229        Self {
230            provider: None,
231            model_id: "lfm2-350m".to_string(),
232        }
233    }
234}
235
236impl LocalValidatorBuilder {
237    /// Create a new builder with default settings.
238    pub fn new() -> Self {
239        Self::default()
240    }
241
242    /// Set the provider to use for validation.
243    pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
244        self.provider = Some(provider);
245        self
246    }
247
248    /// Set the model ID to use for inference.
249    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
250        self.model_id = model_id.into();
251        self
252    }
253
254    /// Build the validator, returning `None` if no provider was set.
255    pub fn build(self) -> Option<LocalValidator> {
256        self.provider.map(|p| LocalValidator::new(p, self.model_id))
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_validation_result_checks() {
266        let valid = ValidationResult::Valid { confidence: 0.9 };
267        assert!(valid.is_valid());
268        assert!(!valid.is_invalid());
269
270        let invalid = ValidationResult::Invalid {
271            reason: "test".to_string(),
272            severity: 0.5,
273            confidence: 0.8,
274        };
275        assert!(!invalid.is_valid());
276        assert!(invalid.is_invalid());
277    }
278
279    #[test]
280    fn test_heuristic_validation_refusal() {
281        let _validator = LocalValidatorBuilder::default();
282
283        // Test refusal detection
284        let result = validate_heuristic_direct(
285            "Write a poem",
286            "I'm sorry, I cannot write poems as an AI assistant.",
287        );
288
289        assert!(matches!(result, ValidationResult::Invalid { .. }));
290    }
291
292    #[test]
293    fn test_heuristic_validation_valid() {
294        let result = validate_heuristic_direct("Calculate 2+2", "The result of 2+2 is 4.");
295
296        assert!(matches!(result, ValidationResult::Valid { .. }));
297    }
298
299    fn validate_heuristic_direct(_task: &str, response: &str) -> ValidationResult {
300        let response_lower = response.to_lowercase();
301
302        let refusal_patterns = ["i cannot", "i can't", "i'm unable", "sorry, i", "as an ai"];
303
304        for pattern in refusal_patterns {
305            if response_lower.contains(pattern) {
306                return ValidationResult::Invalid {
307                    reason: format!("Refusal pattern: {}", pattern),
308                    severity: 0.7,
309                    confidence: 0.6,
310                };
311            }
312        }
313
314        ValidationResult::Valid { confidence: 0.5 }
315    }
316
317    #[test]
318    fn test_parse_validation() {
319        // Test parsing logic
320        assert!(matches!(
321            parse_validation_direct("VALID"),
322            ValidationResult::Valid { .. }
323        ));
324
325        assert!(matches!(
326            parse_validation_direct("INVALID: Response is off-topic"),
327            ValidationResult::Invalid { .. }
328        ));
329
330        assert!(matches!(
331            parse_validation_direct("Maybe?"),
332            ValidationResult::Skipped
333        ));
334    }
335
336    fn parse_validation_direct(output: &str) -> ValidationResult {
337        let trimmed = output.trim().to_uppercase();
338
339        if trimmed.starts_with("VALID") && !trimmed.contains("INVALID") {
340            return ValidationResult::Valid { confidence: 0.8 };
341        }
342
343        if trimmed.starts_with("INVALID") {
344            let reason = if let Some(idx) = trimmed.find(':') {
345                trimmed[idx + 1..].trim().to_string()
346            } else {
347                "Unspecified".to_string()
348            };
349
350            return ValidationResult::Invalid {
351                reason,
352                severity: 0.6,
353                confidence: 0.75,
354            };
355        }
356
357        ValidationResult::Skipped
358    }
359}