1use crate::{config::GraphRAGConfig, GraphRAGResult};
4use serde::{Deserialize, Serialize};
5
6use super::parser::ParsedQuery;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct QueryPlan {
11 pub stages: Vec<PlanStage>,
13 pub estimated_cost: f64,
15 pub parallel: bool,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct PlanStage {
22 pub name: String,
24 pub stage_type: StageType,
26 pub params: StageParams,
28 pub depends_on: Vec<usize>,
30}
31
32#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
34pub enum StageType {
35 Embed,
37 VectorSearch,
39 KeywordSearch,
41 Fusion,
43 GraphExpansion,
45 CommunityDetection,
47 ContextBuild,
49 Generation,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
55pub struct StageParams {
56 pub top_k: Option<usize>,
58 pub threshold: Option<f32>,
60 pub sparql_template: Option<String>,
62 pub max_results: Option<usize>,
64 #[serde(default)]
66 pub extra: std::collections::HashMap<String, String>,
67}
68
69pub struct QueryPlanner {
71 config: GraphRAGConfig,
72}
73
74impl QueryPlanner {
75 pub fn new(config: GraphRAGConfig) -> Self {
76 Self { config }
77 }
78
79 pub fn plan(&self, parsed: &ParsedQuery) -> GraphRAGResult<QueryPlan> {
81 let mut stages = Vec::new();
82 let mut stage_idx = 0;
83
84 stages.push(PlanStage {
86 name: "embed_query".to_string(),
87 stage_type: StageType::Embed,
88 params: StageParams::default(),
89 depends_on: vec![],
90 });
91 let embed_stage = stage_idx;
92 stage_idx += 1;
93
94 stages.push(PlanStage {
96 name: "vector_search".to_string(),
97 stage_type: StageType::VectorSearch,
98 params: StageParams {
99 top_k: Some(self.config.top_k),
100 threshold: Some(self.config.similarity_threshold),
101 ..Default::default()
102 },
103 depends_on: vec![embed_stage],
104 });
105 let vector_stage = stage_idx;
106 stage_idx += 1;
107
108 stages.push(PlanStage {
110 name: "keyword_search".to_string(),
111 stage_type: StageType::KeywordSearch,
112 params: StageParams {
113 top_k: Some(self.config.top_k),
114 extra: parsed
115 .keywords
116 .iter()
117 .enumerate()
118 .map(|(i, k)| (format!("keyword_{}", i), k.clone()))
119 .collect(),
120 ..Default::default()
121 },
122 depends_on: vec![], });
124 let keyword_stage = stage_idx;
125 stage_idx += 1;
126
127 stages.push(PlanStage {
129 name: "fusion".to_string(),
130 stage_type: StageType::Fusion,
131 params: StageParams {
132 max_results: Some(self.config.max_seeds),
133 ..Default::default()
134 },
135 depends_on: vec![vector_stage, keyword_stage],
136 });
137 let fusion_stage = stage_idx;
138 stage_idx += 1;
139
140 stages.push(PlanStage {
142 name: "graph_expansion".to_string(),
143 stage_type: StageType::GraphExpansion,
144 params: StageParams {
145 max_results: Some(self.config.max_subgraph_size),
146 extra: [("hops".to_string(), self.config.expansion_hops.to_string())]
147 .into_iter()
148 .collect(),
149 ..Default::default()
150 },
151 depends_on: vec![fusion_stage],
152 });
153 let expansion_stage = stage_idx;
154 stage_idx += 1;
155
156 let community_stage = if self.config.enable_communities {
158 stages.push(PlanStage {
159 name: "community_detection".to_string(),
160 stage_type: StageType::CommunityDetection,
161 params: StageParams {
162 extra: [(
163 "algorithm".to_string(),
164 format!("{:?}", self.config.community_algorithm),
165 )]
166 .into_iter()
167 .collect(),
168 ..Default::default()
169 },
170 depends_on: vec![expansion_stage],
171 });
172 let idx = stage_idx;
173 stage_idx += 1;
174 Some(idx)
175 } else {
176 None
177 };
178
179 let context_deps = if let Some(comm_stage) = community_stage {
181 vec![expansion_stage, comm_stage]
182 } else {
183 vec![expansion_stage]
184 };
185 stages.push(PlanStage {
186 name: "context_build".to_string(),
187 stage_type: StageType::ContextBuild,
188 params: StageParams {
189 max_results: Some(self.config.max_context_triples),
190 ..Default::default()
191 },
192 depends_on: context_deps,
193 });
194 let context_stage = stage_idx;
195 stage_idx += 1;
196
197 stages.push(PlanStage {
199 name: "generation".to_string(),
200 stage_type: StageType::Generation,
201 params: StageParams {
202 extra: [
203 (
204 "temperature".to_string(),
205 self.config.temperature.to_string(),
206 ),
207 ("max_tokens".to_string(), self.config.max_tokens.to_string()),
208 ]
209 .into_iter()
210 .collect(),
211 ..Default::default()
212 },
213 depends_on: vec![context_stage],
214 });
215 let _generation_stage = stage_idx;
216
217 let estimated_cost = self.estimate_cost(&stages);
219
220 Ok(QueryPlan {
221 stages,
222 estimated_cost,
223 parallel: true, })
225 }
226
227 fn estimate_cost(&self, stages: &[PlanStage]) -> f64 {
229 let mut cost = 0.0;
230
231 for stage in stages {
232 cost += match stage.stage_type {
233 StageType::Embed => 10.0,
234 StageType::VectorSearch => 50.0 * (stage.params.top_k.unwrap_or(20) as f64 / 20.0),
235 StageType::KeywordSearch => 30.0,
236 StageType::Fusion => 5.0,
237 StageType::GraphExpansion => 100.0 * (self.config.expansion_hops as f64),
238 StageType::CommunityDetection => 200.0,
239 StageType::ContextBuild => 10.0,
240 StageType::Generation => 500.0,
241 };
242 }
243
244 cost
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::query::parser::QueryParser;
252
253 #[test]
254 fn test_plan_creation() {
255 let config = GraphRAGConfig::default();
256 let planner = QueryPlanner::new(config);
257 let parser = QueryParser::new();
258
259 let parsed = parser.parse("What are battery safety issues?").unwrap();
260 let plan = planner.plan(&parsed).unwrap();
261
262 assert!(!plan.stages.is_empty());
263 assert!(plan.stages.iter().any(|s| s.stage_type == StageType::Embed));
264 assert!(plan
265 .stages
266 .iter()
267 .any(|s| s.stage_type == StageType::VectorSearch));
268 assert!(plan
269 .stages
270 .iter()
271 .any(|s| s.stage_type == StageType::Generation));
272 }
273}