Skip to main content

graphrag_core/evaluation/
pipeline_validation.rs

1//! Pipeline validation framework
2//!
3//! This module provides tools to validate each phase of the GraphRAG pipeline,
4//! ensuring that every step produces expected outputs before proceeding.
5
6use crate::{Document, TextChunk, Entity, Relationship};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Validation result for a pipeline phase
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PhaseValidation {
13    /// Phase name
14    pub phase_name: String,
15    /// Whether the phase passed validation
16    pub passed: bool,
17    /// Validation checks performed
18    pub checks: Vec<ValidationCheck>,
19    /// Warnings (non-fatal issues)
20    pub warnings: Vec<String>,
21    /// Metrics collected during validation
22    pub metrics: HashMap<String, f64>,
23}
24
25/// A single validation check
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ValidationCheck {
28    /// Name of the check
29    pub name: String,
30    /// Whether the check passed
31    pub passed: bool,
32    /// Expected value or condition
33    pub expected: String,
34    /// Actual value observed
35    pub actual: String,
36    /// Detailed message
37    pub message: String,
38}
39
40/// Validator for document processing phase
41pub struct DocumentProcessingValidator;
42
43impl DocumentProcessingValidator {
44    /// Validate document processing results
45    pub fn validate(document: &Document, chunks: &[TextChunk]) -> PhaseValidation {
46        let mut checks = Vec::new();
47        let mut warnings = Vec::new();
48        let mut metrics = HashMap::new();
49
50        // Check 1: Document is not empty
51        checks.push(ValidationCheck {
52            name: "document_not_empty".to_string(),
53            passed: !document.content.is_empty(),
54            expected: "Non-empty content".to_string(),
55            actual: format!("{} characters", document.content.len()),
56            message: if document.content.is_empty() {
57                "Document content is empty".to_string()
58            } else {
59                "Document contains content".to_string()
60            },
61        });
62
63        // Check 2: Chunks were created
64        checks.push(ValidationCheck {
65            name: "chunks_created".to_string(),
66            passed: !chunks.is_empty(),
67            expected: "At least 1 chunk".to_string(),
68            actual: format!("{} chunks", chunks.len()),
69            message: if chunks.is_empty() {
70                "No chunks were created from document".to_string()
71            } else {
72                format!("Successfully created {} chunks", chunks.len())
73            },
74        });
75
76        // Check 3: Chunks cover document content
77        if !chunks.is_empty() {
78            let total_chunk_chars: usize = chunks.iter().map(|c| c.content.len()).sum();
79            let coverage_ratio = total_chunk_chars as f64 / document.content.len() as f64;
80
81            checks.push(ValidationCheck {
82                name: "content_coverage".to_string(),
83                passed: coverage_ratio >= 0.9, // At least 90% coverage
84                expected: "Coverage ratio >= 0.9".to_string(),
85                actual: format!("{:.2}", coverage_ratio),
86                message: format!("Chunks cover {:.1}% of original content", coverage_ratio * 100.0),
87            });
88
89            metrics.insert("coverage_ratio".to_string(), coverage_ratio);
90        }
91
92        // Check 4: No chunk is empty
93        let empty_chunks = chunks.iter().filter(|c| c.content.trim().is_empty()).count();
94        checks.push(ValidationCheck {
95            name: "no_empty_chunks".to_string(),
96            passed: empty_chunks == 0,
97            expected: "0 empty chunks".to_string(),
98            actual: format!("{} empty chunks", empty_chunks),
99            message: if empty_chunks > 0 {
100                format!("Found {} empty chunks", empty_chunks)
101            } else {
102                "All chunks have content".to_string()
103            },
104        });
105
106        // Check 5: Chunk metadata is populated
107        let chunks_with_metadata = chunks
108            .iter()
109            .filter(|c|
110                c.metadata.chapter.is_some() ||
111                !c.metadata.keywords.is_empty() ||
112                c.metadata.summary.is_some()
113            )
114            .count();
115
116        let metadata_ratio = if chunks.is_empty() {
117            0.0
118        } else {
119            chunks_with_metadata as f64 / chunks.len() as f64
120        };
121
122        if metadata_ratio < 0.5 {
123            warnings.push(format!(
124                "Only {}/{} chunks have enriched metadata ({}%)",
125                chunks_with_metadata, chunks.len(), (metadata_ratio * 100.0) as u32
126            ));
127        }
128
129        checks.push(ValidationCheck {
130            name: "metadata_enrichment".to_string(),
131            passed: true, // Metadata enrichment is optional - always pass but collect metrics
132            expected: "Metadata enrichment (optional)".to_string(),
133            actual: format!("{}/{} chunks", chunks_with_metadata, chunks.len()),
134            message: format!("{:.1}% of chunks have metadata", metadata_ratio * 100.0),
135        });
136
137        metrics.insert("metadata_ratio".to_string(), metadata_ratio);
138        metrics.insert("chunks_count".to_string(), chunks.len() as f64);
139        metrics.insert("avg_chunk_size".to_string(),
140            chunks.iter().map(|c| c.content.len()).sum::<usize>() as f64 / chunks.len().max(1) as f64
141        );
142
143        let passed = checks.iter().all(|c| c.passed);
144
145        PhaseValidation {
146            phase_name: "Document Processing".to_string(),
147            passed,
148            checks,
149            warnings,
150            metrics,
151        }
152    }
153}
154
155/// Validator for entity extraction phase
156pub struct EntityExtractionValidator;
157
158impl EntityExtractionValidator {
159    /// Validate entity extraction results
160    pub fn validate(chunks: &[TextChunk], entities: &[Entity]) -> PhaseValidation {
161        let mut checks = Vec::new();
162        let mut warnings = Vec::new();
163        let mut metrics = HashMap::new();
164
165        // Check 1: Entities were extracted
166        checks.push(ValidationCheck {
167            name: "entities_extracted".to_string(),
168            passed: !entities.is_empty(),
169            expected: "At least 1 entity".to_string(),
170            actual: format!("{} entities", entities.len()),
171            message: if entities.is_empty() {
172                "No entities were extracted".to_string()
173            } else {
174                format!("Successfully extracted {} entities", entities.len())
175            },
176        });
177
178        // Check 2: Entity confidence scores are valid
179        let invalid_confidence = entities
180            .iter()
181            .filter(|e| e.confidence < 0.0 || e.confidence > 1.0)
182            .count();
183
184        checks.push(ValidationCheck {
185            name: "confidence_scores_valid".to_string(),
186            passed: invalid_confidence == 0,
187            expected: "All confidences in [0.0, 1.0]".to_string(),
188            actual: format!("{} invalid scores", invalid_confidence),
189            message: if invalid_confidence > 0 {
190                format!("{} entities have invalid confidence scores", invalid_confidence)
191            } else {
192                "All confidence scores are valid".to_string()
193            },
194        });
195
196        // Check 3: Entity types are populated
197        let missing_types = entities.iter().filter(|e| e.entity_type.is_empty()).count();
198        checks.push(ValidationCheck {
199            name: "entity_types_populated".to_string(),
200            passed: missing_types == 0,
201            expected: "All entities have types".to_string(),
202            actual: format!("{} without types", missing_types),
203            message: if missing_types > 0 {
204                format!("{} entities missing entity_type", missing_types)
205            } else {
206                "All entities have types assigned".to_string()
207            },
208        });
209
210        // Check 4: Entity names are not empty
211        let empty_names = entities.iter().filter(|e| e.name.trim().is_empty()).count();
212        checks.push(ValidationCheck {
213            name: "entity_names_valid".to_string(),
214            passed: empty_names == 0,
215            expected: "All entities have names".to_string(),
216            actual: format!("{} empty names", empty_names),
217            message: if empty_names > 0 {
218                format!("{} entities have empty names", empty_names)
219            } else {
220                "All entities have valid names".to_string()
221            },
222        });
223
224        // Check 5: Entity mentions reference valid chunks
225        if !entities.is_empty() {
226            let chunk_ids: Vec<_> = chunks.iter().map(|c| &c.id).collect();
227            let invalid_mentions = entities
228                .iter()
229                .flat_map(|e| &e.mentions)
230                .filter(|m| !chunk_ids.contains(&&m.chunk_id))
231                .count();
232
233            checks.push(ValidationCheck {
234                name: "entity_mentions_valid".to_string(),
235                passed: invalid_mentions == 0,
236                expected: "All mentions reference valid chunks".to_string(),
237                actual: format!("{} invalid references", invalid_mentions),
238                message: if invalid_mentions > 0 {
239                    format!("{} entity mentions reference non-existent chunks", invalid_mentions)
240                } else {
241                    "All entity mentions are valid".to_string()
242                },
243            });
244
245            if invalid_mentions > 0 {
246                warnings.push("Some entity mentions reference non-existent chunks".to_string());
247            }
248        }
249
250        // Metrics
251        metrics.insert("entities_count".to_string(), entities.len() as f64);
252        if !entities.is_empty() {
253            metrics.insert("avg_confidence".to_string(),
254                entities.iter().map(|e| e.confidence as f64).sum::<f64>() / entities.len() as f64
255            );
256            metrics.insert("avg_mentions_per_entity".to_string(),
257                entities.iter().map(|e| e.mentions.len()).sum::<usize>() as f64 / entities.len() as f64
258            );
259        }
260
261        // Warning: Low average confidence
262        if let Some(&avg_conf) = metrics.get("avg_confidence") {
263            if avg_conf < 0.5 {
264                warnings.push(format!("Low average entity confidence: {:.2}", avg_conf));
265            }
266        }
267
268        let passed = checks.iter().all(|c| c.passed);
269
270        PhaseValidation {
271            phase_name: "Entity Extraction".to_string(),
272            passed,
273            checks,
274            warnings,
275            metrics,
276        }
277    }
278}
279
280/// Validator for relationship extraction phase
281pub struct RelationshipExtractionValidator;
282
283impl RelationshipExtractionValidator {
284    /// Validate relationship extraction results
285    pub fn validate(entities: &[Entity], relationships: &[Relationship]) -> PhaseValidation {
286        let mut checks = Vec::new();
287        let mut warnings = Vec::new();
288        let mut metrics = HashMap::new();
289
290        // Check 1: Relationships were extracted (if entities exist)
291        if !entities.is_empty() {
292            let has_relationships = !relationships.is_empty();
293            checks.push(ValidationCheck {
294                name: "relationships_extracted".to_string(),
295                passed: has_relationships,
296                expected: "At least 1 relationship".to_string(),
297                actual: format!("{} relationships", relationships.len()),
298                message: if !has_relationships {
299                    "No relationships extracted despite entities present".to_string()
300                } else {
301                    format!("Extracted {} relationships", relationships.len())
302                },
303            });
304
305            if !has_relationships {
306                warnings.push("No relationships found between entities".to_string());
307            }
308        }
309
310        // Check 2: Relationship confidence scores are valid
311        let invalid_confidence = relationships
312            .iter()
313            .filter(|r| r.confidence < 0.0 || r.confidence > 1.0)
314            .count();
315
316        checks.push(ValidationCheck {
317            name: "relationship_confidence_valid".to_string(),
318            passed: invalid_confidence == 0,
319            expected: "All confidences in [0.0, 1.0]".to_string(),
320            actual: format!("{} invalid", invalid_confidence),
321            message: if invalid_confidence > 0 {
322                format!("{} relationships have invalid confidence", invalid_confidence)
323            } else {
324                "All relationship confidences valid".to_string()
325            },
326        });
327
328        // Check 3: Relationship types are populated
329        let missing_types = relationships.iter().filter(|r| r.relation_type.is_empty()).count();
330        checks.push(ValidationCheck {
331            name: "relationship_types_populated".to_string(),
332            passed: missing_types == 0,
333            expected: "All relationships typed".to_string(),
334            actual: format!("{} untyped", missing_types),
335            message: if missing_types > 0 {
336                format!("{} relationships missing type", missing_types)
337            } else {
338                "All relationships have types".to_string()
339            },
340        });
341
342        // Check 4: Source and target entities exist
343        let entity_ids: Vec<_> = entities.iter().map(|e| &e.id).collect();
344        let orphan_relationships = relationships
345            .iter()
346            .filter(|r| !entity_ids.contains(&&r.source) || !entity_ids.contains(&&r.target))
347            .count();
348
349        checks.push(ValidationCheck {
350            name: "relationship_entities_exist".to_string(),
351            passed: orphan_relationships == 0,
352            expected: "All relationships reference valid entities".to_string(),
353            actual: format!("{} orphaned", orphan_relationships),
354            message: if orphan_relationships > 0 {
355                format!("{} relationships reference non-existent entities", orphan_relationships)
356            } else {
357                "All relationships have valid entity references".to_string()
358            },
359        });
360
361        if orphan_relationships > 0 {
362            warnings.push("Some relationships reference entities that don't exist in the graph".to_string());
363        }
364
365        // Metrics
366        metrics.insert("relationships_count".to_string(), relationships.len() as f64);
367        if !entities.is_empty() {
368            metrics.insert("relationships_per_entity".to_string(),
369                relationships.len() as f64 / entities.len() as f64
370            );
371        }
372        if !relationships.is_empty() {
373            metrics.insert("avg_relationship_confidence".to_string(),
374                relationships.iter().map(|r| r.confidence as f64).sum::<f64>() / relationships.len() as f64
375            );
376        }
377
378        let passed = checks.iter().all(|c| c.passed);
379
380        PhaseValidation {
381            phase_name: "Relationship Extraction".to_string(),
382            passed,
383            checks,
384            warnings,
385            metrics,
386        }
387    }
388}
389
390/// Validator for graph construction phase
391pub struct GraphConstructionValidator;
392
393impl GraphConstructionValidator {
394    /// Validate constructed knowledge graph
395    pub fn validate(
396        documents: usize,
397        chunks: usize,
398        entities: usize,
399        relationships: usize,
400    ) -> PhaseValidation {
401        let mut checks = Vec::new();
402        let mut warnings = Vec::new();
403        let mut metrics = HashMap::new();
404
405        // Check 1: Graph has content
406        checks.push(ValidationCheck {
407            name: "graph_not_empty".to_string(),
408            passed: entities > 0 || documents > 0,
409            expected: "At least some nodes".to_string(),
410            actual: format!("{} entities, {} docs", entities, documents),
411            message: if entities == 0 && documents == 0 {
412                "Graph is completely empty".to_string()
413            } else {
414                "Graph contains content".to_string()
415            },
416        });
417
418        // Check 2: Reasonable entity-to-chunk ratio
419        if chunks > 0 {
420            let entities_per_chunk = entities as f64 / chunks as f64;
421            let reasonable = entities_per_chunk >= 0.1 && entities_per_chunk <= 10.0;
422
423            checks.push(ValidationCheck {
424                name: "entity_chunk_ratio_reasonable".to_string(),
425                passed: reasonable,
426                expected: "0.1 to 10 entities per chunk".to_string(),
427                actual: format!("{:.2} entities/chunk", entities_per_chunk),
428                message: if !reasonable {
429                    format!("Unusual entity-to-chunk ratio: {:.2}", entities_per_chunk)
430                } else {
431                    "Entity density looks reasonable".to_string()
432                },
433            });
434
435            metrics.insert("entities_per_chunk".to_string(), entities_per_chunk);
436
437            if entities_per_chunk < 0.5 {
438                warnings.push("Low entity density - may need better entity extraction".to_string());
439            }
440            if entities_per_chunk > 5.0 {
441                warnings.push("High entity density - may have duplicate extractions".to_string());
442            }
443        }
444
445        // Check 3: Graph connectivity
446        if entities > 1 {
447            let connectivity = relationships as f64 / entities as f64;
448            let is_connected = connectivity > 0.1; // At least 10% connectivity
449
450            checks.push(ValidationCheck {
451                name: "graph_connectivity".to_string(),
452                passed: is_connected,
453                expected: ">0.1 relationships per entity".to_string(),
454                actual: format!("{:.2} rels/entity", connectivity),
455                message: if !is_connected {
456                    "Graph is sparsely connected".to_string()
457                } else {
458                    "Graph has reasonable connectivity".to_string()
459                },
460            });
461
462            metrics.insert("connectivity".to_string(), connectivity);
463
464            if connectivity < 0.5 {
465                warnings.push("Graph is sparsely connected - entities may be isolated".to_string());
466            }
467        }
468
469        // Metrics
470        metrics.insert("documents".to_string(), documents as f64);
471        metrics.insert("chunks".to_string(), chunks as f64);
472        metrics.insert("entities".to_string(), entities as f64);
473        metrics.insert("relationships".to_string(), relationships as f64);
474
475        let passed = checks.iter().all(|c| c.passed);
476
477        PhaseValidation {
478            phase_name: "Graph Construction".to_string(),
479            passed,
480            checks,
481            warnings,
482            metrics,
483        }
484    }
485}
486
487/// Complete pipeline validation report
488#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct PipelineValidationReport {
490    /// Validation results for each phase
491    pub phases: Vec<PhaseValidation>,
492    /// Overall validation status
493    pub overall_passed: bool,
494    /// Total checks performed
495    pub total_checks: usize,
496    /// Number of passed checks
497    pub passed_checks: usize,
498    /// Summary message
499    pub summary: String,
500}
501
502impl PipelineValidationReport {
503    /// Create a report from phase validations
504    pub fn from_phases(phases: Vec<PhaseValidation>) -> Self {
505        let overall_passed = phases.iter().all(|p| p.passed);
506        let total_checks = phases.iter().map(|p| p.checks.len()).sum();
507        let passed_checks = phases
508            .iter()
509            .flat_map(|p| &p.checks)
510            .filter(|c| c.passed)
511            .count();
512
513        let summary = if overall_passed {
514            format!("✅ All pipeline phases validated successfully ({}/{} checks passed)",
515                passed_checks, total_checks)
516        } else {
517            let failed_phases: Vec<_> = phases
518                .iter()
519                .filter(|p| !p.passed)
520                .map(|p| p.phase_name.as_str())
521                .collect();
522            format!("❌ Pipeline validation failed in: {} ({}/{} checks passed)",
523                failed_phases.join(", "), passed_checks, total_checks)
524        };
525
526        Self {
527            phases,
528            overall_passed,
529            total_checks,
530            passed_checks,
531            summary,
532        }
533    }
534
535    /// Generate a detailed report string
536    pub fn detailed_report(&self) -> String {
537        let mut report = String::new();
538        report.push_str(&format!("# Pipeline Validation Report\n\n"));
539        report.push_str(&format!("{}\n\n", self.summary));
540        report.push_str(&format!("**Total Checks**: {}/{} passed\n\n",
541            self.passed_checks, self.total_checks));
542
543        for phase in &self.phases {
544            report.push_str(&format!("## Phase: {}\n", phase.phase_name));
545            report.push_str(&format!("**Status**: {}\n\n",
546                if phase.passed { "✅ PASSED" } else { "❌ FAILED" }));
547
548            // Checks
549            report.push_str("### Checks\n");
550            for check in &phase.checks {
551                let icon = if check.passed { "✅" } else { "❌" };
552                report.push_str(&format!("{} **{}**: {}\n", icon, check.name, check.message));
553                report.push_str(&format!("   - Expected: {}\n", check.expected));
554                report.push_str(&format!("   - Actual: {}\n\n", check.actual));
555            }
556
557            // Warnings
558            if !phase.warnings.is_empty() {
559                report.push_str("### Warnings\n");
560                for warning in &phase.warnings {
561                    report.push_str(&format!("⚠️  {}\n", warning));
562                }
563                report.push_str("\n");
564            }
565
566            // Metrics
567            if !phase.metrics.is_empty() {
568                report.push_str("### Metrics\n");
569                for (key, value) in &phase.metrics {
570                    report.push_str(&format!("- {}: {:.2}\n", key, value));
571                }
572                report.push_str("\n");
573            }
574
575            report.push_str("---\n\n");
576        }
577
578        report
579    }
580
581    /// Get all warnings across all phases
582    pub fn all_warnings(&self) -> Vec<String> {
583        self.phases
584            .iter()
585            .flat_map(|p| p.warnings.clone())
586            .collect()
587    }
588
589    /// Get failed phases
590    pub fn failed_phases(&self) -> Vec<&PhaseValidation> {
591        self.phases.iter().filter(|p| !p.passed).collect()
592    }
593}
594
595#[cfg(test)]
596mod tests {
597    use super::*;
598    use crate::{DocumentId, ChunkId, EntityId};
599
600    #[test]
601    fn test_document_processing_validation() {
602        let doc = Document::new(
603            DocumentId::new("test".to_string()),
604            "Test".to_string(),
605            "This is test content with multiple words.".to_string(),
606        );
607
608        let chunks = vec![
609            TextChunk::new(
610                ChunkId::new("c1".to_string()),
611                doc.id.clone(),
612                "This is test".to_string(),
613                0,
614                12,
615            ),
616            TextChunk::new(
617                ChunkId::new("c2".to_string()),
618                doc.id.clone(),
619                "content with multiple words.".to_string(),
620                13,
621                41,
622            ),
623        ];
624
625        let validation = DocumentProcessingValidator::validate(&doc, &chunks);
626        assert!(validation.passed);
627        assert!(validation.checks.iter().all(|c| c.passed));
628    }
629
630    #[test]
631    fn test_entity_extraction_validation() {
632        let chunks = vec![
633            TextChunk::new(
634                ChunkId::new("c1".to_string()),
635                DocumentId::new("test".to_string()),
636                "Alice works at Stanford".to_string(),
637                0,
638                23,
639            ),
640        ];
641
642        let entities = vec![
643            Entity {
644                id: EntityId::new("e1".to_string()),
645                name: "Alice".to_string(),
646                entity_type: "person".to_string(),
647                confidence: 0.9,
648                mentions: vec![],
649                embedding: None,
650            },
651        ];
652
653        let validation = EntityExtractionValidator::validate(&chunks, &entities);
654        assert!(validation.passed);
655    }
656
657    #[test]
658    fn test_pipeline_report() {
659        let doc_validation = PhaseValidation {
660            phase_name: "Test Phase".to_string(),
661            passed: true,
662            checks: vec![
663                ValidationCheck {
664                    name: "test_check".to_string(),
665                    passed: true,
666                    expected: "pass".to_string(),
667                    actual: "pass".to_string(),
668                    message: "OK".to_string(),
669                },
670            ],
671            warnings: vec![],
672            metrics: HashMap::new(),
673        };
674
675        let report = PipelineValidationReport::from_phases(vec![doc_validation]);
676        assert!(report.overall_passed);
677        assert_eq!(report.total_checks, 1);
678        assert_eq!(report.passed_checks, 1);
679    }
680}