Skip to main content

oxirs_graphrag/
config.rs

1//! GraphRAG configuration
2
3use serde::{Deserialize, Serialize};
4
5/// GraphRAG configuration
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphRAGConfig {
8    /// Number of seed nodes from vector search
9    #[serde(default = "default_top_k")]
10    pub top_k: usize,
11
12    /// Maximum number of seeds after fusion
13    #[serde(default = "default_max_seeds")]
14    pub max_seeds: usize,
15
16    /// Graph expansion hops
17    #[serde(default = "default_expansion_hops")]
18    pub expansion_hops: usize,
19
20    /// Maximum subgraph size (triples)
21    #[serde(default = "default_max_subgraph_size")]
22    pub max_subgraph_size: usize,
23
24    /// Maximum triples to include in LLM context
25    #[serde(default = "default_max_context_triples")]
26    pub max_context_triples: usize,
27
28    /// Enable community detection
29    #[serde(default = "default_enable_communities")]
30    pub enable_communities: bool,
31
32    /// Community detection algorithm
33    #[serde(default)]
34    pub community_algorithm: CommunityAlgorithm,
35
36    /// Fusion strategy
37    #[serde(default)]
38    pub fusion_strategy: FusionStrategy,
39
40    /// Weight for vector similarity scores (0.0 - 1.0)
41    #[serde(default = "default_vector_weight")]
42    pub vector_weight: f32,
43
44    /// Weight for keyword/BM25 scores (0.0 - 1.0)
45    #[serde(default = "default_keyword_weight")]
46    pub keyword_weight: f32,
47
48    /// Path patterns for graph expansion (SPARQL property paths)
49    #[serde(default)]
50    pub path_patterns: Vec<String>,
51
52    /// Similarity threshold for vector search
53    #[serde(default = "default_similarity_threshold")]
54    pub similarity_threshold: f32,
55
56    /// Cache size for query results
57    #[serde(default)]
58    pub cache_size: Option<usize>,
59
60    /// Cache configuration
61    #[serde(default)]
62    pub cache_config: CacheConfiguration,
63
64    /// Enable query expansion
65    #[serde(default)]
66    pub enable_query_expansion: bool,
67
68    /// Enable hierarchical summarization
69    #[serde(default)]
70    pub enable_hierarchical_summary: bool,
71
72    /// Maximum community levels for hierarchical summarization
73    #[serde(default = "default_max_community_levels")]
74    pub max_community_levels: usize,
75
76    /// LLM model to use for generation
77    #[serde(default)]
78    pub llm_model: Option<String>,
79
80    /// Temperature for LLM generation
81    #[serde(default = "default_temperature")]
82    pub temperature: f32,
83
84    /// Maximum tokens for LLM response
85    #[serde(default = "default_max_tokens")]
86    pub max_tokens: usize,
87}
88
89impl Default for GraphRAGConfig {
90    fn default() -> Self {
91        Self {
92            top_k: default_top_k(),
93            max_seeds: default_max_seeds(),
94            expansion_hops: default_expansion_hops(),
95            max_subgraph_size: default_max_subgraph_size(),
96            max_context_triples: default_max_context_triples(),
97            enable_communities: default_enable_communities(),
98            community_algorithm: CommunityAlgorithm::default(),
99            fusion_strategy: FusionStrategy::default(),
100            vector_weight: default_vector_weight(),
101            keyword_weight: default_keyword_weight(),
102            path_patterns: vec![],
103            similarity_threshold: default_similarity_threshold(),
104            cache_size: Some(1000),
105            cache_config: CacheConfiguration::default(),
106            enable_query_expansion: false,
107            enable_hierarchical_summary: false,
108            max_community_levels: default_max_community_levels(),
109            llm_model: None,
110            temperature: default_temperature(),
111            max_tokens: default_max_tokens(),
112        }
113    }
114}
115
116/// Community detection algorithm
117#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
118pub enum CommunityAlgorithm {
119    /// Louvain algorithm (fast, good quality)
120    #[default]
121    Louvain,
122    /// Leiden algorithm (improved Louvain)
123    Leiden,
124    /// Label propagation (very fast, lower quality)
125    LabelPropagation,
126    /// Connected components (simplest)
127    ConnectedComponents,
128}
129
130/// Fusion strategy for combining retrieval results
131#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
132pub enum FusionStrategy {
133    /// Reciprocal Rank Fusion (default, robust)
134    #[default]
135    ReciprocalRankFusion,
136    /// Linear combination of scores
137    LinearCombination,
138    /// Take highest score per entity
139    HighestScore,
140    /// Learning-to-rank (requires model)
141    LearningToRank,
142}
143
144/// Cache configuration
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct CacheConfiguration {
147    /// Base TTL in seconds (default: 3600 = 1 hour)
148    #[serde(default = "default_base_ttl_seconds")]
149    pub base_ttl_seconds: u64,
150    /// Minimum TTL in seconds (default: 300 = 5 minutes)
151    #[serde(default = "default_min_ttl_seconds")]
152    pub min_ttl_seconds: u64,
153    /// Maximum TTL in seconds (default: 86400 = 24 hours)
154    #[serde(default = "default_max_ttl_seconds")]
155    pub max_ttl_seconds: u64,
156    /// Enable adaptive TTL based on update frequency
157    #[serde(default = "default_adaptive_ttl")]
158    pub adaptive: bool,
159}
160
161impl Default for CacheConfiguration {
162    fn default() -> Self {
163        Self {
164            base_ttl_seconds: default_base_ttl_seconds(),
165            min_ttl_seconds: default_min_ttl_seconds(),
166            max_ttl_seconds: default_max_ttl_seconds(),
167            adaptive: default_adaptive_ttl(),
168        }
169    }
170}
171
172// Default value functions
173fn default_top_k() -> usize {
174    20
175}
176fn default_max_seeds() -> usize {
177    10
178}
179fn default_expansion_hops() -> usize {
180    2
181}
182fn default_max_subgraph_size() -> usize {
183    500
184}
185fn default_max_context_triples() -> usize {
186    100
187}
188fn default_enable_communities() -> bool {
189    true
190}
191fn default_vector_weight() -> f32 {
192    0.7
193}
194fn default_keyword_weight() -> f32 {
195    0.3
196}
197fn default_similarity_threshold() -> f32 {
198    0.7
199}
200fn default_max_community_levels() -> usize {
201    3
202}
203fn default_temperature() -> f32 {
204    0.7
205}
206fn default_max_tokens() -> usize {
207    2048
208}
209fn default_base_ttl_seconds() -> u64 {
210    3600
211}
212fn default_min_ttl_seconds() -> u64 {
213    300
214}
215fn default_max_ttl_seconds() -> u64 {
216    86400
217}
218fn default_adaptive_ttl() -> bool {
219    true
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_default_config() {
228        let config = GraphRAGConfig::default();
229        assert_eq!(config.top_k, 20);
230        assert_eq!(config.expansion_hops, 2);
231        assert!(config.enable_communities);
232        assert_eq!(config.fusion_strategy, FusionStrategy::ReciprocalRankFusion);
233    }
234
235    #[test]
236    fn test_config_serialization() {
237        let config = GraphRAGConfig::default();
238        let json = serde_json::to_string(&config).unwrap();
239        let parsed: GraphRAGConfig = serde_json::from_str(&json).unwrap();
240        assert_eq!(parsed.top_k, config.top_k);
241    }
242}