1use std::sync::Arc;
7use tracing::warn;
8
9use brainwires_core::message::Message;
10use brainwires_core::provider::{ChatOptions, Provider};
11
12use crate::InferenceTimer;
13
14#[derive(Clone, Debug)]
16pub enum ValidationResult {
17 Valid {
19 confidence: f32,
21 },
22 Invalid {
24 reason: String,
26 severity: f32,
28 confidence: f32,
30 },
31 Skipped,
33}
34
35impl ValidationResult {
36 pub fn is_valid(&self) -> bool {
38 matches!(self, ValidationResult::Valid { .. })
39 }
40
41 pub fn is_invalid(&self) -> bool {
43 matches!(self, ValidationResult::Invalid { .. })
44 }
45}
46
47pub struct LocalValidator {
49 provider: Arc<dyn Provider>,
50 model_id: String,
51}
52
53impl LocalValidator {
54 pub fn new(provider: Arc<dyn Provider>, model_id: impl Into<String>) -> Self {
56 Self {
57 provider,
58 model_id: model_id.into(),
59 }
60 }
61
62 pub async fn validate(&self, task: &str, response: &str) -> ValidationResult {
66 let timer = InferenceTimer::new("validate_response", &self.model_id);
67
68 if response.trim().len() < 10 {
70 return ValidationResult::Skipped;
71 }
72
73 let system_prompt = self.build_validation_prompt();
74 let user_prompt = format!(
75 "Validate if this response is appropriate for the task.\n\nTask: {}\n\nResponse: {}\n\nOutput ONLY: VALID or INVALID:<reason>",
76 task,
77 if response.len() > 500 {
79 &response[..500]
80 } else {
81 response
82 }
83 );
84
85 let messages = vec![Message::user(&user_prompt)];
86 let options = ChatOptions::deterministic(50).system(system_prompt);
87
88 match self.provider.chat(&messages, None, &options).await {
89 Ok(chat_response) => {
90 let text = chat_response.message.text_or_summary();
91 let result = self.parse_validation(&text);
92 timer.finish(true);
93 result
94 }
95 Err(e) => {
96 warn!(target: "local_llm", "Response validation failed: {}", e);
97 timer.finish(false);
98 ValidationResult::Skipped
99 }
100 }
101 }
102
103 pub fn validate_heuristic(&self, task: &str, response: &str) -> ValidationResult {
107 let response_lower = response.to_lowercase();
108 let task_lower = task.to_lowercase();
109
110 let task_words: std::collections::HashSet<&str> = task_lower
114 .split_whitespace()
115 .filter(|w| w.len() > 3)
116 .collect();
117 let response_words: std::collections::HashSet<&str> = response_lower
118 .split_whitespace()
119 .filter(|w| w.len() > 3)
120 .collect();
121
122 let overlap = task_words.intersection(&response_words).count();
123 if overlap == 0 && task_words.len() > 3 {
124 return ValidationResult::Invalid {
125 reason: "Response appears unrelated to task".to_string(),
126 severity: 0.6,
127 confidence: 0.4,
128 };
129 }
130
131 let refusal_patterns = [
133 "i cannot",
134 "i can't",
135 "i'm unable",
136 "i am unable",
137 "sorry, i",
138 "i don't have",
139 "i do not have",
140 "as an ai",
141 ];
142
143 for pattern in refusal_patterns {
144 if response_lower.contains(pattern) {
145 return ValidationResult::Invalid {
146 reason: format!("Response contains refusal pattern: {}", pattern),
147 severity: 0.7,
148 confidence: 0.6,
149 };
150 }
151 }
152
153 let task_trimmed = task_lower.trim();
155 let response_trimmed = response_lower.trim();
156 if response_trimmed.starts_with(task_trimmed) && response.len() < task.len() * 2 {
157 return ValidationResult::Invalid {
158 reason: "Response appears to just repeat the task".to_string(),
159 severity: 0.5,
160 confidence: 0.5,
161 };
162 }
163
164 if task.len() > 100 && response.len() < 20 {
166 return ValidationResult::Invalid {
167 reason: "Response too short for complex task".to_string(),
168 severity: 0.4,
169 confidence: 0.4,
170 };
171 }
172
173 ValidationResult::Valid { confidence: 0.5 }
174 }
175
176 fn build_validation_prompt(&self) -> String {
178 r#"You are a response validator. Given a task and response, determine if the response is appropriate.
179
180Check for:
1811. Response addresses the task (not off-topic)
1822. Response doesn't contain confusion or self-correction
1833. Response isn't a refusal or "I can't do that"
1844. Response isn't just repeating the task
1855. Response has substance (not empty platitudes)
186
187Output format:
188- If valid: VALID
189- If invalid: INVALID:<brief reason>
190
191Be strict but fair. Only flag clear issues."#.to_string()
192 }
193
194 fn parse_validation(&self, output: &str) -> ValidationResult {
196 let trimmed = output.trim().to_uppercase();
197
198 if trimmed.starts_with("VALID") && !trimmed.contains("INVALID") {
199 return ValidationResult::Valid { confidence: 0.8 };
200 }
201
202 if trimmed.starts_with("INVALID") {
203 let reason = if let Some(idx) = trimmed.find(':') {
204 trimmed[idx + 1..].trim().to_string()
205 } else {
206 "Unspecified validation failure".to_string()
207 };
208
209 return ValidationResult::Invalid {
210 reason,
211 severity: 0.6,
212 confidence: 0.75,
213 };
214 }
215
216 ValidationResult::Skipped
218 }
219}
220
221pub struct LocalValidatorBuilder {
223 provider: Option<Arc<dyn Provider>>,
224 model_id: String,
225}
226
227impl Default for LocalValidatorBuilder {
228 fn default() -> Self {
229 Self {
230 provider: None,
231 model_id: "lfm2-350m".to_string(),
232 }
233 }
234}
235
236impl LocalValidatorBuilder {
237 pub fn new() -> Self {
239 Self::default()
240 }
241
242 pub fn provider(mut self, provider: Arc<dyn Provider>) -> Self {
244 self.provider = Some(provider);
245 self
246 }
247
248 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
250 self.model_id = model_id.into();
251 self
252 }
253
254 pub fn build(self) -> Option<LocalValidator> {
256 self.provider.map(|p| LocalValidator::new(p, self.model_id))
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_validation_result_checks() {
266 let valid = ValidationResult::Valid { confidence: 0.9 };
267 assert!(valid.is_valid());
268 assert!(!valid.is_invalid());
269
270 let invalid = ValidationResult::Invalid {
271 reason: "test".to_string(),
272 severity: 0.5,
273 confidence: 0.8,
274 };
275 assert!(!invalid.is_valid());
276 assert!(invalid.is_invalid());
277 }
278
279 #[test]
280 fn test_heuristic_validation_refusal() {
281 let _validator = LocalValidatorBuilder::default();
282
283 let result = validate_heuristic_direct(
285 "Write a poem",
286 "I'm sorry, I cannot write poems as an AI assistant.",
287 );
288
289 assert!(matches!(result, ValidationResult::Invalid { .. }));
290 }
291
292 #[test]
293 fn test_heuristic_validation_valid() {
294 let result = validate_heuristic_direct("Calculate 2+2", "The result of 2+2 is 4.");
295
296 assert!(matches!(result, ValidationResult::Valid { .. }));
297 }
298
299 fn validate_heuristic_direct(_task: &str, response: &str) -> ValidationResult {
300 let response_lower = response.to_lowercase();
301
302 let refusal_patterns = ["i cannot", "i can't", "i'm unable", "sorry, i", "as an ai"];
303
304 for pattern in refusal_patterns {
305 if response_lower.contains(pattern) {
306 return ValidationResult::Invalid {
307 reason: format!("Refusal pattern: {}", pattern),
308 severity: 0.7,
309 confidence: 0.6,
310 };
311 }
312 }
313
314 ValidationResult::Valid { confidence: 0.5 }
315 }
316
317 #[test]
318 fn test_parse_validation() {
319 assert!(matches!(
321 parse_validation_direct("VALID"),
322 ValidationResult::Valid { .. }
323 ));
324
325 assert!(matches!(
326 parse_validation_direct("INVALID: Response is off-topic"),
327 ValidationResult::Invalid { .. }
328 ));
329
330 assert!(matches!(
331 parse_validation_direct("Maybe?"),
332 ValidationResult::Skipped
333 ));
334 }
335
336 fn parse_validation_direct(output: &str) -> ValidationResult {
337 let trimmed = output.trim().to_uppercase();
338
339 if trimmed.starts_with("VALID") && !trimmed.contains("INVALID") {
340 return ValidationResult::Valid { confidence: 0.8 };
341 }
342
343 if trimmed.starts_with("INVALID") {
344 let reason = if let Some(idx) = trimmed.find(':') {
345 trimmed[idx + 1..].trim().to_string()
346 } else {
347 "Unspecified".to_string()
348 };
349
350 return ValidationResult::Invalid {
351 reason,
352 severity: 0.6,
353 confidence: 0.75,
354 };
355 }
356
357 ValidationResult::Skipped
358 }
359}