1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct GraphRAGConfig {
8 #[serde(default = "default_top_k")]
10 pub top_k: usize,
11
12 #[serde(default = "default_max_seeds")]
14 pub max_seeds: usize,
15
16 #[serde(default = "default_expansion_hops")]
18 pub expansion_hops: usize,
19
20 #[serde(default = "default_max_subgraph_size")]
22 pub max_subgraph_size: usize,
23
24 #[serde(default = "default_max_context_triples")]
26 pub max_context_triples: usize,
27
28 #[serde(default = "default_enable_communities")]
30 pub enable_communities: bool,
31
32 #[serde(default)]
34 pub community_algorithm: CommunityAlgorithm,
35
36 #[serde(default)]
38 pub fusion_strategy: FusionStrategy,
39
40 #[serde(default = "default_vector_weight")]
42 pub vector_weight: f32,
43
44 #[serde(default = "default_keyword_weight")]
46 pub keyword_weight: f32,
47
48 #[serde(default)]
50 pub path_patterns: Vec<String>,
51
52 #[serde(default = "default_similarity_threshold")]
54 pub similarity_threshold: f32,
55
56 #[serde(default)]
58 pub cache_size: Option<usize>,
59
60 #[serde(default)]
62 pub cache_config: CacheConfiguration,
63
64 #[serde(default)]
66 pub enable_query_expansion: bool,
67
68 #[serde(default)]
70 pub enable_hierarchical_summary: bool,
71
72 #[serde(default = "default_max_community_levels")]
74 pub max_community_levels: usize,
75
76 #[serde(default)]
78 pub llm_model: Option<String>,
79
80 #[serde(default = "default_temperature")]
82 pub temperature: f32,
83
84 #[serde(default = "default_max_tokens")]
86 pub max_tokens: usize,
87}
88
89impl Default for GraphRAGConfig {
90 fn default() -> Self {
91 Self {
92 top_k: default_top_k(),
93 max_seeds: default_max_seeds(),
94 expansion_hops: default_expansion_hops(),
95 max_subgraph_size: default_max_subgraph_size(),
96 max_context_triples: default_max_context_triples(),
97 enable_communities: default_enable_communities(),
98 community_algorithm: CommunityAlgorithm::default(),
99 fusion_strategy: FusionStrategy::default(),
100 vector_weight: default_vector_weight(),
101 keyword_weight: default_keyword_weight(),
102 path_patterns: vec![],
103 similarity_threshold: default_similarity_threshold(),
104 cache_size: Some(1000),
105 cache_config: CacheConfiguration::default(),
106 enable_query_expansion: false,
107 enable_hierarchical_summary: false,
108 max_community_levels: default_max_community_levels(),
109 llm_model: None,
110 temperature: default_temperature(),
111 max_tokens: default_max_tokens(),
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
118pub enum CommunityAlgorithm {
119 #[default]
121 Louvain,
122 Leiden,
124 LabelPropagation,
126 ConnectedComponents,
128}
129
130#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
132pub enum FusionStrategy {
133 #[default]
135 ReciprocalRankFusion,
136 LinearCombination,
138 HighestScore,
140 LearningToRank,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct CacheConfiguration {
147 #[serde(default = "default_base_ttl_seconds")]
149 pub base_ttl_seconds: u64,
150 #[serde(default = "default_min_ttl_seconds")]
152 pub min_ttl_seconds: u64,
153 #[serde(default = "default_max_ttl_seconds")]
155 pub max_ttl_seconds: u64,
156 #[serde(default = "default_adaptive_ttl")]
158 pub adaptive: bool,
159}
160
161impl Default for CacheConfiguration {
162 fn default() -> Self {
163 Self {
164 base_ttl_seconds: default_base_ttl_seconds(),
165 min_ttl_seconds: default_min_ttl_seconds(),
166 max_ttl_seconds: default_max_ttl_seconds(),
167 adaptive: default_adaptive_ttl(),
168 }
169 }
170}
171
172fn default_top_k() -> usize {
174 20
175}
176fn default_max_seeds() -> usize {
177 10
178}
179fn default_expansion_hops() -> usize {
180 2
181}
182fn default_max_subgraph_size() -> usize {
183 500
184}
185fn default_max_context_triples() -> usize {
186 100
187}
188fn default_enable_communities() -> bool {
189 true
190}
191fn default_vector_weight() -> f32 {
192 0.7
193}
194fn default_keyword_weight() -> f32 {
195 0.3
196}
197fn default_similarity_threshold() -> f32 {
198 0.7
199}
200fn default_max_community_levels() -> usize {
201 3
202}
203fn default_temperature() -> f32 {
204 0.7
205}
206fn default_max_tokens() -> usize {
207 2048
208}
209fn default_base_ttl_seconds() -> u64 {
210 3600
211}
212fn default_min_ttl_seconds() -> u64 {
213 300
214}
215fn default_max_ttl_seconds() -> u64 {
216 86400
217}
218fn default_adaptive_ttl() -> bool {
219 true
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_default_config() {
228 let config = GraphRAGConfig::default();
229 assert_eq!(config.top_k, 20);
230 assert_eq!(config.expansion_hops, 2);
231 assert!(config.enable_communities);
232 assert_eq!(config.fusion_strategy, FusionStrategy::ReciprocalRankFusion);
233 }
234
235 #[test]
236 fn test_config_serialization() {
237 let config = GraphRAGConfig::default();
238 let json = serde_json::to_string(&config).unwrap();
239 let parsed: GraphRAGConfig = serde_json::from_str(&json).unwrap();
240 assert_eq!(parsed.top_k, config.top_k);
241 }
242}