1use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use tracing::debug;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct GroundednessConfig {
18 #[serde(default = "default_similarity_threshold")]
20 pub similarity_threshold: f64,
21
22 #[serde(default)]
24 pub use_llm_fallback: bool,
25
26 #[serde(default = "default_confidence_threshold")]
28 pub confidence_threshold: f64,
29
30 #[serde(default = "default_true")]
32 pub auto_extract_claims: bool,
33}
34
35fn default_similarity_threshold() -> f64 {
36 0.7
37}
38
39fn default_confidence_threshold() -> f64 {
40 0.8
41}
42
43fn default_true() -> bool {
44 true
45}
46
47impl Default for GroundednessConfig {
48 fn default() -> Self {
49 Self {
50 similarity_threshold: default_similarity_threshold(),
51 use_llm_fallback: false,
52 confidence_threshold: default_confidence_threshold(),
53 auto_extract_claims: true,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct SourceDocument {
61 pub id: String,
63
64 pub content: String,
66
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub title: Option<String>,
70
71 #[serde(skip_serializing_if = "Option::is_none")]
73 pub url: Option<String>,
74
75 #[serde(default = "default_relevance")]
77 pub relevance: f64,
78}
79
80fn default_relevance() -> f64 {
81 1.0
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct Claim {
87 pub text: String,
89
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub start: Option<usize>,
93
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub end: Option<usize>,
97
98 #[serde(default)]
100 pub claim_type: ClaimType,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
105#[serde(rename_all = "snake_case")]
106pub enum ClaimType {
107 #[default]
109 Factual,
110 Opinion,
112 CommonKnowledge,
114 Procedural,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ClaimGroundingResult {
121 pub claim: Claim,
123
124 pub grounded: bool,
126
127 pub confidence: f64,
129
130 #[serde(skip_serializing_if = "Option::is_none")]
132 pub source_id: Option<String>,
133
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub source_excerpt: Option<String>,
137
138 pub similarity: f64,
140
141 pub method: GroundingMethod,
143
144 pub needs_review: bool,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
150#[serde(rename_all = "snake_case")]
151pub enum GroundingMethod {
152 LocalMatch,
154 ExactQuote,
156 SemanticSimilarity,
158 LlmVerification,
160 Skipped,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166pub struct GroundednessResult {
167 pub score: f64,
169
170 pub total_claims: usize,
172
173 pub grounded_claims: usize,
175
176 pub ungrounded_claims: usize,
178
179 pub needs_review_count: usize,
181
182 pub claim_results: Vec<ClaimGroundingResult>,
184
185 pub method: GroundingMethod,
187}
188
189pub struct GroundednessChecker {
191 config: GroundednessConfig,
192}
193
194impl Default for GroundednessChecker {
195 fn default() -> Self {
196 Self::new(GroundednessConfig::default())
197 }
198}
199
200impl GroundednessChecker {
201 pub fn new(config: GroundednessConfig) -> Self {
203 Self { config }
204 }
205
206 pub fn check(
208 &self,
209 response: &str,
210 sources: &[SourceDocument],
211 explicit_claims: Option<Vec<Claim>>,
212 ) -> GroundednessResult {
213 let claims = explicit_claims.unwrap_or_else(|| {
215 if self.config.auto_extract_claims {
216 self.extract_claims(response)
217 } else {
218 vec![Claim {
219 text: response.to_string(),
220 start: None,
221 end: None,
222 claim_type: ClaimType::Factual,
223 }]
224 }
225 });
226
227 if claims.is_empty() {
228 return GroundednessResult {
229 score: 1.0,
230 total_claims: 0,
231 grounded_claims: 0,
232 ungrounded_claims: 0,
233 needs_review_count: 0,
234 claim_results: vec![],
235 method: GroundingMethod::LocalMatch,
236 };
237 }
238
239 let mut claim_results = Vec::with_capacity(claims.len());
241 let mut grounded_count = 0;
242 let mut needs_review_count = 0;
243
244 for claim in claims {
245 let result = self.check_claim(&claim, sources);
246
247 if result.grounded {
248 grounded_count += 1;
249 }
250 if result.needs_review {
251 needs_review_count += 1;
252 }
253
254 claim_results.push(result);
255 }
256
257 let total = claim_results.len();
258 let score = if total > 0 {
259 grounded_count as f64 / total as f64
260 } else {
261 1.0
262 };
263
264 GroundednessResult {
265 score,
266 total_claims: total,
267 grounded_claims: grounded_count,
268 ungrounded_claims: total - grounded_count,
269 needs_review_count,
270 claim_results,
271 method: GroundingMethod::LocalMatch,
272 }
273 }
274
275 fn extract_claims(&self, response: &str) -> Vec<Claim> {
277 let mut claims = Vec::new();
278
279 for sentence in self.split_sentences(response) {
282 let trimmed = sentence.trim();
283 if trimmed.is_empty() {
284 continue;
285 }
286
287 if trimmed.len() < 10 {
289 continue;
290 }
291
292 let claim_type = self.classify_claim(trimmed);
294
295 claims.push(Claim {
296 text: trimmed.to_string(),
297 start: None,
298 end: None,
299 claim_type,
300 });
301 }
302
303 claims
304 }
305
306 fn split_sentences<'a>(&self, text: &'a str) -> Vec<&'a str> {
308 let mut sentences = Vec::new();
311 let mut start = 0;
312
313 for (i, c) in text.char_indices() {
314 if c == '.' || c == '!' || c == '?' {
315 let sentence = &text[start..=i];
316 if !sentence.trim().is_empty() {
317 sentences.push(sentence.trim());
318 }
319 start = i + 1;
320 }
321 }
322
323 if start < text.len() {
325 let remaining = &text[start..];
326 if !remaining.trim().is_empty() {
327 sentences.push(remaining.trim());
328 }
329 }
330
331 sentences
332 }
333
334 fn classify_claim(&self, text: &str) -> ClaimType {
336 let lower = text.to_lowercase();
337
338 let opinion_words = [
340 "i think",
341 "i believe",
342 "in my opinion",
343 "probably",
344 "might",
345 "could be",
346 "seems like",
347 "apparently",
348 ];
349 for word in &opinion_words {
350 if lower.contains(word) {
351 return ClaimType::Opinion;
352 }
353 }
354
355 let procedural_words = [
357 "to do this",
358 "first,",
359 "then,",
360 "finally,",
361 "step ",
362 "you should",
363 "you can",
364 "run the",
365 "execute",
366 ];
367 for word in &procedural_words {
368 if lower.contains(word) {
369 return ClaimType::Procedural;
370 }
371 }
372
373 ClaimType::Factual
374 }
375
376 fn check_claim(&self, claim: &Claim, sources: &[SourceDocument]) -> ClaimGroundingResult {
378 if claim.claim_type != ClaimType::Factual {
380 return ClaimGroundingResult {
381 claim: claim.clone(),
382 grounded: true,
383 confidence: 1.0,
384 source_id: None,
385 source_excerpt: None,
386 similarity: 1.0,
387 method: GroundingMethod::Skipped,
388 needs_review: false,
389 };
390 }
391
392 let claim_text = &claim.text;
393 let claim_lower = claim_text.to_lowercase();
394 let claim_words: HashSet<&str> = claim_lower.split_whitespace().collect();
395
396 let mut best_match: Option<(f64, &SourceDocument, String)> = None;
397
398 for source in sources {
399 let source_lower = source.content.to_lowercase();
400
401 if source_lower.contains(&claim_lower) {
403 return ClaimGroundingResult {
404 claim: claim.clone(),
405 grounded: true,
406 confidence: 1.0,
407 source_id: Some(source.id.clone()),
408 source_excerpt: Some(self.extract_excerpt(&source.content, claim_text)),
409 similarity: 1.0,
410 method: GroundingMethod::ExactQuote,
411 needs_review: false,
412 };
413 }
414
415 let source_words: HashSet<&str> = source_lower.split_whitespace().collect();
417 let intersection = claim_words.intersection(&source_words).count();
418 let union = claim_words.union(&source_words).count();
419
420 let jaccard = if union > 0 {
421 intersection as f64 / union as f64
422 } else {
423 0.0
424 };
425
426 let claim_coverage = if !claim_words.is_empty() {
428 intersection as f64 / claim_words.len() as f64
429 } else {
430 0.0
431 };
432
433 let similarity = (jaccard + claim_coverage) / 2.0;
435
436 if similarity > best_match.as_ref().map(|(s, _, _)| *s).unwrap_or(0.0) {
437 let excerpt = self.find_best_excerpt(&source.content, claim_text);
438 best_match = Some((similarity, source, excerpt));
439 }
440 }
441
442 if let Some((similarity, source, excerpt)) = best_match {
443 let grounded = similarity >= self.config.similarity_threshold;
444 let confidence = similarity;
445 let needs_review = !grounded
446 && similarity >= self.config.similarity_threshold * 0.7
447 && self.config.use_llm_fallback;
448
449 debug!(
450 "Claim grounding: similarity={:.2}, grounded={}, needs_review={}",
451 similarity, grounded, needs_review
452 );
453
454 ClaimGroundingResult {
455 claim: claim.clone(),
456 grounded,
457 confidence,
458 source_id: if grounded || needs_review {
459 Some(source.id.clone())
460 } else {
461 None
462 },
463 source_excerpt: if grounded || needs_review {
464 Some(excerpt)
465 } else {
466 None
467 },
468 similarity,
469 method: GroundingMethod::LocalMatch,
470 needs_review,
471 }
472 } else {
473 ClaimGroundingResult {
474 claim: claim.clone(),
475 grounded: false,
476 confidence: 0.0,
477 source_id: None,
478 source_excerpt: None,
479 similarity: 0.0,
480 method: GroundingMethod::LocalMatch,
481 needs_review: self.config.use_llm_fallback,
482 }
483 }
484 }
485
486 fn extract_excerpt(&self, source: &str, claim: &str) -> String {
488 let source_lower = source.to_lowercase();
489 let claim_lower = claim.to_lowercase();
490
491 if let Some(pos) = source_lower.find(&claim_lower) {
492 let start = pos.saturating_sub(50);
494 let end = (pos + claim.len() + 50).min(source.len());
495
496 let excerpt = &source[start..end];
497 if start > 0 {
498 format!("...{}", excerpt.trim())
499 } else {
500 excerpt.trim().to_string()
501 }
502 } else {
503 source.chars().take(200).collect()
505 }
506 }
507
508 fn find_best_excerpt(&self, source: &str, claim: &str) -> String {
510 let claim_words: Vec<&str> = claim.split_whitespace().take(5).collect();
511
512 let source_lower = source.to_lowercase();
514
515 for word in &claim_words {
516 let word_lower = word.to_lowercase();
517 if let Some(pos) = source_lower.find(&word_lower) {
518 let start = pos.saturating_sub(30);
519 let end = (pos + 150).min(source.len());
520 return format!("...{}...", source[start..end].trim());
521 }
522 }
523
524 source.chars().take(150).collect::<String>() + "..."
526 }
527}
528
529#[derive(Debug, Clone, Serialize, Deserialize)]
533pub struct LlmVerificationRequest {
534 pub claim: String,
536 pub sources: Vec<SourceDocument>,
538 pub system_prompt: String,
540}
541
542#[derive(Debug, Clone, Serialize, Deserialize)]
544pub struct LlmVerificationResponse {
545 pub grounded: bool,
547 pub confidence: f64,
549 pub supporting_source_id: Option<String>,
551 pub explanation: Option<String>,
553}
554
555#[async_trait]
560pub trait LlmGroundednessVerifier: Send + Sync {
561 async fn verify_claim(
563 &self,
564 request: LlmVerificationRequest,
565 ) -> Result<LlmVerificationResponse, String>;
566
567 async fn verify_claims_batch(
569 &self,
570 requests: Vec<LlmVerificationRequest>,
571 ) -> Result<Vec<LlmVerificationResponse>, String> {
572 let mut results = Vec::with_capacity(requests.len());
574 for req in requests {
575 results.push(self.verify_claim(req).await?);
576 }
577 Ok(results)
578 }
579}
580
581pub const DEFAULT_VERIFICATION_PROMPT: &str = r#"You are a groundedness verification assistant. Your task is to determine if a claim is supported by the provided source documents.
583
584Respond with a JSON object containing:
585- "grounded": true/false - whether the claim is supported
586- "confidence": 0.0-1.0 - how confident you are
587- "source_id": the ID of the supporting source, or null
588- "explanation": brief explanation of your reasoning
589
590Be strict: a claim is only grounded if the sources directly support it, not if it's merely plausible."#;
591
592#[derive(Default)]
594pub struct MockLlmVerifier;
595
596#[async_trait]
597impl LlmGroundednessVerifier for MockLlmVerifier {
598 async fn verify_claim(
599 &self,
600 request: LlmVerificationRequest,
601 ) -> Result<LlmVerificationResponse, String> {
602 let claim_lower = request.claim.to_lowercase();
604 let claim_words: std::collections::HashSet<&str> = claim_lower
605 .split_whitespace()
606 .filter(|w| w.len() > 3)
607 .collect();
608
609 for source in &request.sources {
610 let source_lower = source.content.to_lowercase();
611 let matching_words = claim_words
612 .iter()
613 .filter(|w| source_lower.contains(*w))
614 .count();
615 let overlap = matching_words as f64 / claim_words.len().max(1) as f64;
616
617 if overlap > 0.5 {
618 return Ok(LlmVerificationResponse {
619 grounded: true,
620 confidence: overlap,
621 supporting_source_id: Some(source.id.clone()),
622 explanation: Some("Mock: word overlap detected".to_string()),
623 });
624 }
625 }
626
627 Ok(LlmVerificationResponse {
628 grounded: false,
629 confidence: 0.8,
630 supporting_source_id: None,
631 explanation: Some("Mock: no supporting source found".to_string()),
632 })
633 }
634}
635
636pub struct LlmGroundednessChecker<V: LlmGroundednessVerifier> {
638 config: GroundednessConfig,
639 local_checker: GroundednessChecker,
640 llm_verifier: V,
641}
642
643impl<V: LlmGroundednessVerifier> LlmGroundednessChecker<V> {
644 pub fn new(config: GroundednessConfig, llm_verifier: V) -> Self {
646 Self {
647 config: config.clone(),
648 local_checker: GroundednessChecker::new(config),
649 llm_verifier,
650 }
651 }
652
653 pub async fn check_with_llm(
655 &self,
656 response: &str,
657 sources: &[SourceDocument],
658 explicit_claims: Option<Vec<Claim>>,
659 ) -> GroundednessResult {
660 let mut result = self.local_checker.check(response, sources, explicit_claims);
662
663 if !self.config.use_llm_fallback {
665 return result;
666 }
667
668 let uncertain_indices: Vec<usize> = result
670 .claim_results
671 .iter()
672 .enumerate()
673 .filter(|(_, cr)| {
674 cr.method != GroundingMethod::Skipped
676 && cr.confidence < self.config.confidence_threshold
677 })
678 .map(|(i, _)| i)
679 .collect();
680
681 if uncertain_indices.is_empty() {
682 return result;
683 }
684
685 let requests: Vec<LlmVerificationRequest> = uncertain_indices
687 .iter()
688 .map(|&i| LlmVerificationRequest {
689 claim: result.claim_results[i].claim.text.clone(),
690 sources: sources.to_vec(),
691 system_prompt: DEFAULT_VERIFICATION_PROMPT.to_string(),
692 })
693 .collect();
694
695 match self.llm_verifier.verify_claims_batch(requests).await {
697 Ok(responses) => {
698 for (idx_offset, llm_response) in responses.into_iter().enumerate() {
700 let i = uncertain_indices[idx_offset];
701 result.claim_results[i].grounded = llm_response.grounded;
702 result.claim_results[i].confidence = llm_response.confidence;
703 result.claim_results[i].method = GroundingMethod::LlmVerification;
704 result.claim_results[i].source_id = llm_response.supporting_source_id;
705 }
706
707 result.grounded_claims = result
709 .claim_results
710 .iter()
711 .filter(|cr| cr.grounded)
712 .count();
713 result.ungrounded_claims = result
714 .claim_results
715 .iter()
716 .filter(|cr| !cr.grounded && cr.method != GroundingMethod::Skipped)
717 .count();
718 result.needs_review_count = 0;
719 result.score = if result.total_claims > 0 {
720 result.grounded_claims as f64 / result.total_claims as f64
721 } else {
722 1.0
723 };
724 result.method = GroundingMethod::LlmVerification;
726 }
727 Err(e) => {
728 debug!("LLM verification failed: {}", e);
729 }
731 }
732
733 result
734 }
735}
736
737#[cfg(test)]
738mod tests {
739 use super::*;
740
741 #[test]
742 fn test_exact_match() {
743 let checker = GroundednessChecker::default();
744 let sources = vec![SourceDocument {
745 id: "doc1".to_string(),
746 content: "The capital of France is Paris.".to_string(),
747 title: None,
748 url: None,
749 relevance: 1.0,
750 }];
751
752 let result = checker.check(
753 "The capital of France is Paris.",
754 &sources,
755 Some(vec![Claim {
756 text: "The capital of France is Paris.".to_string(),
757 start: None,
758 end: None,
759 claim_type: ClaimType::Factual,
760 }]),
761 );
762
763 assert_eq!(result.score, 1.0);
764 assert_eq!(result.grounded_claims, 1);
765 }
766
767 #[test]
768 fn test_partial_match() {
769 let checker = GroundednessChecker::new(GroundednessConfig {
770 similarity_threshold: 0.5,
771 ..Default::default()
772 });
773
774 let sources = vec![SourceDocument {
775 id: "doc1".to_string(),
776 content: "Paris is the capital city of France and has many monuments.".to_string(),
777 title: None,
778 url: None,
779 relevance: 1.0,
780 }];
781
782 let result = checker.check(
783 "Paris is the capital of France.",
784 &sources,
785 Some(vec![Claim {
786 text: "Paris is the capital of France.".to_string(),
787 start: None,
788 end: None,
789 claim_type: ClaimType::Factual,
790 }]),
791 );
792
793 assert!(result.score > 0.0);
794 }
795
796 #[test]
797 fn test_no_match() {
798 let checker = GroundednessChecker::default();
799 let sources = vec![SourceDocument {
800 id: "doc1".to_string(),
801 content: "The weather today is sunny.".to_string(),
802 title: None,
803 url: None,
804 relevance: 1.0,
805 }];
806
807 let result = checker.check(
808 "The capital of France is Paris.",
809 &sources,
810 Some(vec![Claim {
811 text: "The capital of France is Paris.".to_string(),
812 start: None,
813 end: None,
814 claim_type: ClaimType::Factual,
815 }]),
816 );
817
818 assert_eq!(result.score, 0.0);
819 assert_eq!(result.ungrounded_claims, 1);
820 }
821
822 #[test]
823 fn test_opinion_skipped() {
824 let checker = GroundednessChecker::default();
825 let sources = vec![];
826
827 let result = checker.check(
828 "I think this is a good idea.",
829 &sources,
830 Some(vec![Claim {
831 text: "I think this is a good idea.".to_string(),
832 start: None,
833 end: None,
834 claim_type: ClaimType::Opinion,
835 }]),
836 );
837
838 assert_eq!(result.score, 1.0);
839 assert_eq!(result.claim_results[0].method, GroundingMethod::Skipped);
840 }
841
842 #[test]
843 fn test_auto_extract_claims() {
844 let checker = GroundednessChecker::default();
845 let sources = vec![SourceDocument {
846 id: "doc1".to_string(),
847 content: "Python is a programming language. It was created by Guido van Rossum."
848 .to_string(),
849 title: None,
850 url: None,
851 relevance: 1.0,
852 }];
853
854 let result = checker.check(
855 "Python is a programming language. It is very popular.",
856 &sources,
857 None,
858 );
859
860 assert!(result.total_claims >= 1);
861 }
862}