Skip to main content

graphrag_core/config/
validation.rs

1use crate::config::{Config, SetConfig};
2use crate::{GraphRAGError, Result};
3use std::path::Path;
4
5/// Result of configuration validation
6#[derive(Debug, Clone, Default)]
7pub struct ValidationResult {
8    /// Whether the configuration is valid
9    pub is_valid: bool,
10    /// List of validation errors
11    pub errors: Vec<String>,
12    /// List of validation warnings
13    pub warnings: Vec<String>,
14    /// List of optimization suggestions
15    pub suggestions: Vec<String>,
16}
17
18impl ValidationResult {
19    /// Create a new validation result
20    pub fn new() -> Self {
21        Self::default()
22    }
23
24    /// Add an error and mark validation as failed
25    pub fn add_error(&mut self, error: String) {
26        self.errors.push(error);
27        self.is_valid = false;
28    }
29
30    /// Add a warning (doesn't affect validity)
31    pub fn add_warning(&mut self, warning: String) {
32        self.warnings.push(warning);
33    }
34
35    /// Add an optimization suggestion
36    pub fn add_suggestion(&mut self, suggestion: String) {
37        self.suggestions.push(suggestion);
38    }
39}
40
41/// Trait for configuration validation
42pub trait Validatable {
43    /// Validate configuration with standard checks
44    fn validate(&self) -> ValidationResult;
45    /// Validate configuration with strict checks (includes warnings and suggestions)
46    fn validate_strict(&self) -> ValidationResult;
47}
48
49impl Validatable for Config {
50    fn validate(&self) -> ValidationResult {
51        let mut result = ValidationResult::new();
52
53        // Validate output directory
54        if self.output_dir.is_empty() {
55            result.add_error("Output directory cannot be empty".to_string());
56        }
57
58        // Validate chunk size
59        if self.chunk_size == 0 {
60            result.add_error("Chunk size must be greater than 0".to_string());
61        } else if self.chunk_size < 100 {
62            result.add_warning(
63                "Chunk size is very small (<100), this may affect performance".to_string(),
64            );
65        } else if self.chunk_size > 10000 {
66            result.add_warning(
67                "Chunk size is very large (>10000), this may affect quality".to_string(),
68            );
69        } else {
70            // Chunk size is in acceptable range (100-10000)
71        }
72
73        // Validate chunk overlap
74        if self.chunk_overlap >= self.chunk_size {
75            result.add_error("Chunk overlap must be less than chunk size".to_string());
76        } else if self.chunk_overlap > self.chunk_size / 2 {
77            result.add_warning(
78                "Chunk overlap is more than 50% of chunk size, this may be inefficient".to_string(),
79            );
80        } else {
81            // Chunk overlap is in acceptable range
82        }
83
84        // Validate entity extraction settings
85        if let Some(max_entities) = self.max_entities_per_chunk {
86            if max_entities == 0 {
87                result.add_error("Max entities per chunk must be greater than 0".to_string());
88            } else if max_entities > 100 {
89                result.add_warning("Max entities per chunk is very high (>100)".to_string());
90            } else {
91                // Max entities is in acceptable range
92            }
93        }
94
95        // Validate retrieval settings
96        if let Some(top_k) = self.top_k_results {
97            if top_k == 0 {
98                result.add_error("Top-k results must be greater than 0".to_string());
99            } else if top_k > 100 {
100                result.add_warning(
101                    "Top-k results is very high (>100), this may affect performance".to_string(),
102                );
103            } else {
104                // Top-k is in acceptable range
105            }
106        }
107
108        // Validate similarity threshold
109        if let Some(threshold) = self.similarity_threshold {
110            if !(0.0..=1.0).contains(&threshold) {
111                result.add_error("Similarity threshold must be between 0.0 and 1.0".to_string());
112            } else if threshold < 0.1 {
113                result.add_warning(
114                    "Similarity threshold is very low (<0.1), this may return irrelevant results"
115                        .to_string(),
116                );
117            } else if threshold > 0.9 {
118                result.add_warning(
119                    "Similarity threshold is very high (>0.9), this may return too few results"
120                        .to_string(),
121                );
122            } else {
123                // Similarity threshold is in acceptable range (0.1-0.9)
124            }
125        }
126
127        // Add suggestions based on configuration
128        if self.chunk_size > 1000 && self.chunk_overlap < 100 {
129            result.add_suggestion("Consider increasing chunk overlap for better context preservation with large chunks".to_string());
130        }
131
132        result
133    }
134
135    fn validate_strict(&self) -> ValidationResult {
136        let mut result = self.validate();
137
138        // Additional strict validations
139
140        // Ensure all paths exist
141        let output_path = Path::new(&self.output_dir);
142        if !output_path.exists() {
143            result.add_warning(format!(
144                "Output directory does not exist: {}",
145                self.output_dir
146            ));
147            result.add_suggestion("Directory will be created automatically".to_string());
148        }
149
150        // Validate feature compatibility
151        #[cfg(not(feature = "ollama"))]
152        {
153            result.add_warning(
154                "Ollama feature is not enabled, local LLM support unavailable".to_string(),
155            );
156        }
157
158        #[cfg(not(feature = "parallel-processing"))]
159        {
160            result.add_warning(
161                "Parallel processing is not enabled, performance may be reduced".to_string(),
162            );
163        }
164
165        // Check for optimal settings
166        let optimal_chunk_size = 800;
167        let optimal_overlap = 200;
168
169        if (self.chunk_size as i32 - optimal_chunk_size).abs() > 300 {
170            result.add_suggestion(format!(
171                "Consider using chunk size around {} for optimal performance",
172                optimal_chunk_size
173            ));
174        }
175
176        if (self.chunk_overlap as i32 - optimal_overlap).abs() > 100 {
177            result.add_suggestion(format!(
178                "Consider using chunk overlap around {} for optimal context preservation",
179                optimal_overlap
180            ));
181        }
182
183        result
184    }
185}
186
187/// Validate pipeline approach configuration (semantic/algorithmic/hybrid)
188fn validate_pipeline_approach(config: &SetConfig, result: &mut ValidationResult) {
189    let approach = &config.mode.approach;
190
191    // Validate approach value
192    match approach.as_str() {
193        "semantic" | "algorithmic" | "hybrid" => {},
194        invalid => {
195            result.add_error(format!(
196                "Invalid pipeline approach: '{}'. Must be 'semantic', 'algorithmic', or 'hybrid'",
197                invalid
198            ));
199            return;
200        },
201    }
202
203    // Validate semantic pipeline
204    if approach == "semantic" {
205        match &config.semantic {
206            None => {
207                result.add_error(
208                    "Semantic pipeline approach selected but [semantic] configuration is missing"
209                        .to_string(),
210                );
211            },
212            Some(semantic) => {
213                if !semantic.enabled {
214                    result.add_error(
215                        "Semantic pipeline approach selected but semantic.enabled = false"
216                            .to_string(),
217                    );
218                }
219
220                // Validate semantic embeddings
221                let valid_backends = [
222                    "huggingface",
223                    "openai",
224                    "voyage",
225                    "cohere",
226                    "jina",
227                    "mistral",
228                    "together",
229                    "ollama",
230                ];
231                if !valid_backends.contains(&semantic.embeddings.backend.as_str()) {
232                    result.add_error(format!(
233                        "Invalid semantic embedding backend: '{}'. Must be one of: {}",
234                        semantic.embeddings.backend,
235                        valid_backends.join(", ")
236                    ));
237                }
238
239                if semantic.embeddings.dimension == 0 {
240                    result.add_error(
241                        "Semantic embedding dimension must be greater than 0".to_string(),
242                    );
243                }
244
245                // Validate semantic entity extraction
246                if semantic.entity_extraction.confidence_threshold < 0.0
247                    || semantic.entity_extraction.confidence_threshold > 1.0
248                {
249                    result.add_error("Semantic entity extraction confidence threshold must be between 0.0 and 1.0".to_string());
250                }
251
252                if semantic.entity_extraction.temperature < 0.0
253                    || semantic.entity_extraction.temperature > 2.0
254                {
255                    result.add_error(
256                        "Semantic entity extraction temperature must be between 0.0 and 2.0"
257                            .to_string(),
258                    );
259                }
260
261                // Validate semantic retrieval
262                if semantic.retrieval.similarity_threshold < 0.0
263                    || semantic.retrieval.similarity_threshold > 1.0
264                {
265                    result.add_error(
266                        "Semantic retrieval similarity threshold must be between 0.0 and 1.0"
267                            .to_string(),
268                    );
269                }
270
271                if semantic.retrieval.top_k == 0 {
272                    result.add_error("Semantic retrieval top_k must be greater than 0".to_string());
273                }
274            },
275        }
276    }
277
278    // Validate algorithmic pipeline
279    if approach == "algorithmic" {
280        match &config.algorithmic {
281            None => {
282                result.add_error("Algorithmic pipeline approach selected but [algorithmic] configuration is missing".to_string());
283            },
284            Some(algorithmic) => {
285                if !algorithmic.enabled {
286                    result.add_error(
287                        "Algorithmic pipeline approach selected but algorithmic.enabled = false"
288                            .to_string(),
289                    );
290                }
291
292                // Validate algorithmic embeddings
293                if algorithmic.embeddings.backend != "hash" {
294                    result.add_warning(format!(
295                        "Algorithmic pipeline typically uses 'hash' backend, but '{}' is configured",
296                        algorithmic.embeddings.backend
297                    ));
298                }
299
300                if algorithmic.embeddings.dimension == 0 {
301                    result.add_error(
302                        "Algorithmic embedding dimension must be greater than 0".to_string(),
303                    );
304                }
305
306                if algorithmic.embeddings.max_document_frequency < 0.0
307                    || algorithmic.embeddings.max_document_frequency > 1.0
308                {
309                    result.add_error(
310                        "Algorithmic max_document_frequency must be between 0.0 and 1.0"
311                            .to_string(),
312                    );
313                }
314
315                // Validate algorithmic entity extraction
316                if algorithmic.entity_extraction.confidence_threshold < 0.0
317                    || algorithmic.entity_extraction.confidence_threshold > 1.0
318                {
319                    result.add_error("Algorithmic entity extraction confidence threshold must be between 0.0 and 1.0".to_string());
320                }
321
322                // Validate algorithmic retrieval (BM25 parameters)
323                if algorithmic.retrieval.k1 < 0.0 {
324                    result.add_error("BM25 k1 parameter must be non-negative".to_string());
325                }
326
327                if algorithmic.retrieval.b < 0.0 || algorithmic.retrieval.b > 1.0 {
328                    result.add_error("BM25 b parameter must be between 0.0 and 1.0".to_string());
329                }
330
331                if algorithmic.retrieval.top_k == 0 {
332                    result.add_error(
333                        "Algorithmic retrieval top_k must be greater than 0".to_string(),
334                    );
335                }
336            },
337        }
338    }
339
340    // Validate hybrid pipeline
341    if approach == "hybrid" {
342        match &config.hybrid {
343            None => {
344                result.add_error(
345                    "Hybrid pipeline approach selected but [hybrid] configuration is missing"
346                        .to_string(),
347                );
348            },
349            Some(hybrid) => {
350                if !hybrid.enabled {
351                    result.add_error(
352                        "Hybrid pipeline approach selected but hybrid.enabled = false".to_string(),
353                    );
354                }
355
356                // Validate hybrid weights
357                let weight_sum = hybrid.weights.semantic_weight + hybrid.weights.algorithmic_weight;
358                if (weight_sum - 1.0).abs() > 0.01 {
359                    result.add_warning(format!(
360                        "Hybrid weights should sum to 1.0 (currently: {:.2})",
361                        weight_sum
362                    ));
363                }
364
365                if hybrid.weights.semantic_weight < 0.0 || hybrid.weights.semantic_weight > 1.0 {
366                    result.add_error(
367                        "Hybrid semantic_weight must be between 0.0 and 1.0".to_string(),
368                    );
369                }
370
371                if hybrid.weights.algorithmic_weight < 0.0
372                    || hybrid.weights.algorithmic_weight > 1.0
373                {
374                    result.add_error(
375                        "Hybrid algorithmic_weight must be between 0.0 and 1.0".to_string(),
376                    );
377                }
378
379                // Validate hybrid entity extraction weights
380                let entity_weight_sum =
381                    hybrid.entity_extraction.llm_weight + hybrid.entity_extraction.pattern_weight;
382                if (entity_weight_sum - 1.0).abs() > 0.01 {
383                    result.add_warning(format!(
384                        "Hybrid entity extraction weights should sum to 1.0 (currently: {:.2})",
385                        entity_weight_sum
386                    ));
387                }
388
389                // Validate hybrid retrieval weights
390                let retrieval_weight_sum =
391                    hybrid.retrieval.vector_weight + hybrid.retrieval.bm25_weight;
392                if (retrieval_weight_sum - 1.0).abs() > 0.01 {
393                    result.add_warning(format!(
394                        "Hybrid retrieval weights should sum to 1.0 (currently: {:.2})",
395                        retrieval_weight_sum
396                    ));
397                }
398
399                if hybrid.retrieval.rrf_constant == 0 {
400                    result.add_error(
401                        "Hybrid RRF constant must be greater than 0 (typically 60)".to_string(),
402                    );
403                }
404
405                // Validate confidence boost
406                if hybrid.entity_extraction.confidence_boost < 0.0
407                    || hybrid.entity_extraction.confidence_boost > 1.0
408                {
409                    result.add_warning(
410                        "Hybrid confidence_boost should typically be between 0.0 and 1.0"
411                            .to_string(),
412                    );
413                }
414            },
415        }
416    }
417
418    // Add suggestions based on approach
419    match approach.as_str() {
420        "semantic" => {
421            result.add_suggestion("Semantic pipeline uses neural embeddings and LLM-based extraction for high-quality results".to_string());
422            if config.ollama.enabled {
423                result.add_suggestion(
424                    "Consider using 'llama3.1:8b' for entity extraction with gleaning enabled"
425                        .to_string(),
426                );
427            }
428        },
429        "algorithmic" => {
430            result.add_suggestion("Algorithmic pipeline uses pattern matching and TF-IDF for fast, resource-efficient processing".to_string());
431            result.add_suggestion("Algorithmic pipeline works well for structured documents and doesn't require an LLM".to_string());
432        },
433        "hybrid" => {
434            result.add_suggestion("Hybrid pipeline combines semantic and algorithmic approaches for balanced quality and performance".to_string());
435            result.add_suggestion(
436                "Fine-tune hybrid weights based on your specific use case and evaluation metrics"
437                    .to_string(),
438            );
439        },
440        _ => {},
441    }
442}
443
444impl Validatable for SetConfig {
445    fn validate(&self) -> ValidationResult {
446        let mut result = ValidationResult::new();
447
448        // Validate pipeline approach configuration
449        validate_pipeline_approach(self, &mut result);
450
451        // Validate general settings
452        if let Some(input_path) = &self.general.input_document_path {
453            if input_path.is_empty() {
454                result.add_error("Input document path cannot be empty".to_string());
455            } else {
456                let path = Path::new(input_path);
457                if !path.exists() {
458                    result.add_error(format!("Input document not found: {}", input_path));
459                } else if !path.is_file() {
460                    result.add_error(format!("Input path is not a file: {}", input_path));
461                } else {
462                    // Input path exists and is a valid file
463                }
464            }
465        } else {
466            result.add_error("Input document path is required".to_string());
467        }
468
469        if self.general.output_dir.is_empty() {
470            result.add_error("Output directory cannot be empty".to_string());
471        }
472
473        // Validate pipeline settings
474        let pipeline = &self.pipeline;
475        if pipeline.text_extraction.chunk_size == 0 {
476            result.add_error("Chunk size must be greater than 0".to_string());
477        }
478
479        if pipeline.text_extraction.chunk_overlap >= pipeline.text_extraction.chunk_size {
480            result.add_error("Chunk overlap must be less than chunk size".to_string());
481        }
482
483        // Validate Ollama settings if enabled
484        let ollama = &self.ollama;
485        if ollama.enabled {
486            if ollama.host.is_empty() {
487                result.add_error("Ollama host cannot be empty when enabled".to_string());
488            }
489
490            if ollama.port == 0 {
491                result.add_error("Ollama port must be valid".to_string());
492            }
493
494            if ollama.chat_model.is_empty() {
495                result.add_error("Ollama chat model must be specified".to_string());
496            }
497
498            if ollama.embedding_model.is_empty() {
499                result.add_error("Ollama embedding model must be specified".to_string());
500            }
501
502            // Suggest common models if using defaults
503            if ollama.chat_model == "llama2" {
504                result.add_suggestion(
505                    "Consider using 'llama3.1:8b' for better performance".to_string(),
506                );
507            }
508        }
509
510        // Validate storage settings
511        let storage = &self.storage;
512        match storage.database_type.as_str() {
513            "memory" | "file" | "sqlite" | "postgresql" | "neo4j" => {},
514            db_type => {
515                result.add_error(format!("Unknown database type: {}", db_type));
516                result.add_suggestion(
517                    "Supported types: memory, file, sqlite, postgresql, neo4j".to_string(),
518                );
519            },
520        }
521
522        result
523    }
524
525    fn validate_strict(&self) -> ValidationResult {
526        let mut result = self.validate();
527
528        // Additional strict checks
529        if !self.ollama.enabled {
530            result.add_warning("Ollama is not enabled, will use mock LLM".to_string());
531        }
532
533        result
534    }
535}
536
537/// Validate a TOML configuration file
538pub fn validate_config_file(path: &Path, strict: bool) -> Result<ValidationResult> {
539    let config_str = std::fs::read_to_string(path)?;
540    let config: SetConfig = toml::from_str(&config_str).map_err(|e| GraphRAGError::Config {
541        message: format!("Failed to parse TOML config: {}", e),
542    })?;
543
544    let result = if strict {
545        config.validate_strict()
546    } else {
547        config.validate()
548    };
549
550    Ok(result)
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    #[test]
558    fn test_config_validation() {
559        let config = Config {
560            chunk_size: 0,
561            ..Default::default()
562        };
563
564        let result = config.validate();
565        assert!(!result.is_valid);
566        assert!(!result.errors.is_empty());
567    }
568
569    #[test]
570    fn test_chunk_overlap_validation() {
571        let config = Config {
572            chunk_size: 100,
573            chunk_overlap: 150,
574            ..Default::default()
575        };
576
577        let result = config.validate();
578        assert!(!result.is_valid);
579        assert!(result.errors.iter().any(|e| e.contains("overlap")));
580    }
581
582    #[test]
583    fn test_suggestions() {
584        let config = Config {
585            chunk_size: 2000,
586            chunk_overlap: 50,
587            ..Default::default()
588        };
589
590        let result = config.validate();
591        assert!(!result.suggestions.is_empty());
592    }
593}