use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphRAGConfig {
#[serde(default = "default_top_k")]
pub top_k: usize,
#[serde(default = "default_max_seeds")]
pub max_seeds: usize,
#[serde(default = "default_expansion_hops")]
pub expansion_hops: usize,
#[serde(default = "default_max_subgraph_size")]
pub max_subgraph_size: usize,
#[serde(default = "default_max_context_triples")]
pub max_context_triples: usize,
#[serde(default = "default_enable_communities")]
pub enable_communities: bool,
#[serde(default)]
pub community_algorithm: CommunityAlgorithm,
#[serde(default)]
pub fusion_strategy: FusionStrategy,
#[serde(default = "default_vector_weight")]
pub vector_weight: f32,
#[serde(default = "default_keyword_weight")]
pub keyword_weight: f32,
#[serde(default)]
pub path_patterns: Vec<String>,
#[serde(default = "default_similarity_threshold")]
pub similarity_threshold: f32,
#[serde(default)]
pub cache_size: Option<usize>,
#[serde(default)]
pub cache_config: CacheConfiguration,
#[serde(default)]
pub enable_query_expansion: bool,
#[serde(default)]
pub enable_hierarchical_summary: bool,
#[serde(default = "default_max_community_levels")]
pub max_community_levels: usize,
#[serde(default)]
pub llm_model: Option<String>,
#[serde(default = "default_temperature")]
pub temperature: f32,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
}
impl Default for GraphRAGConfig {
fn default() -> Self {
Self {
top_k: default_top_k(),
max_seeds: default_max_seeds(),
expansion_hops: default_expansion_hops(),
max_subgraph_size: default_max_subgraph_size(),
max_context_triples: default_max_context_triples(),
enable_communities: default_enable_communities(),
community_algorithm: CommunityAlgorithm::default(),
fusion_strategy: FusionStrategy::default(),
vector_weight: default_vector_weight(),
keyword_weight: default_keyword_weight(),
path_patterns: vec![],
similarity_threshold: default_similarity_threshold(),
cache_size: Some(1000),
cache_config: CacheConfiguration::default(),
enable_query_expansion: false,
enable_hierarchical_summary: false,
max_community_levels: default_max_community_levels(),
llm_model: None,
temperature: default_temperature(),
max_tokens: default_max_tokens(),
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum CommunityAlgorithm {
#[default]
Louvain,
Leiden,
LabelPropagation,
ConnectedComponents,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
pub enum FusionStrategy {
#[default]
ReciprocalRankFusion,
LinearCombination,
HighestScore,
LearningToRank,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfiguration {
#[serde(default = "default_base_ttl_seconds")]
pub base_ttl_seconds: u64,
#[serde(default = "default_min_ttl_seconds")]
pub min_ttl_seconds: u64,
#[serde(default = "default_max_ttl_seconds")]
pub max_ttl_seconds: u64,
#[serde(default = "default_adaptive_ttl")]
pub adaptive: bool,
}
impl Default for CacheConfiguration {
fn default() -> Self {
Self {
base_ttl_seconds: default_base_ttl_seconds(),
min_ttl_seconds: default_min_ttl_seconds(),
max_ttl_seconds: default_max_ttl_seconds(),
adaptive: default_adaptive_ttl(),
}
}
}
fn default_top_k() -> usize {
20
}
fn default_max_seeds() -> usize {
10
}
fn default_expansion_hops() -> usize {
2
}
fn default_max_subgraph_size() -> usize {
500
}
fn default_max_context_triples() -> usize {
100
}
fn default_enable_communities() -> bool {
true
}
fn default_vector_weight() -> f32 {
0.7
}
fn default_keyword_weight() -> f32 {
0.3
}
fn default_similarity_threshold() -> f32 {
0.7
}
fn default_max_community_levels() -> usize {
3
}
fn default_temperature() -> f32 {
0.7
}
fn default_max_tokens() -> usize {
2048
}
fn default_base_ttl_seconds() -> u64 {
3600
}
fn default_min_ttl_seconds() -> u64 {
300
}
fn default_max_ttl_seconds() -> u64 {
86400
}
fn default_adaptive_ttl() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = GraphRAGConfig::default();
assert_eq!(config.top_k, 20);
assert_eq!(config.expansion_hops, 2);
assert!(config.enable_communities);
assert_eq!(config.fusion_strategy, FusionStrategy::ReciprocalRankFusion);
}
#[test]
fn test_config_serialization() {
let config = GraphRAGConfig::default();
let json = serde_json::to_string(&config).expect("should succeed");
let parsed: GraphRAGConfig = serde_json::from_str(&json).expect("should succeed");
assert_eq!(parsed.top_k, config.top_k);
}
}