oxirs_graphrag/query/
planner.rs

1//! Query execution planning
2
3use crate::{config::GraphRAGConfig, GraphRAGResult};
4use serde::{Deserialize, Serialize};
5
6use super::parser::ParsedQuery;
7
8/// Query execution plan
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct QueryPlan {
11    /// Stages in execution order
12    pub stages: Vec<PlanStage>,
13    /// Estimated cost (arbitrary units)
14    pub estimated_cost: f64,
15    /// Whether to use parallel execution
16    pub parallel: bool,
17}
18
19/// A stage in the query plan
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct PlanStage {
22    /// Stage name
23    pub name: String,
24    /// Stage type
25    pub stage_type: StageType,
26    /// Parameters for this stage
27    pub params: StageParams,
28    /// Dependencies on other stages
29    pub depends_on: Vec<usize>,
30}
31
32/// Type of execution stage
33#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
34pub enum StageType {
35    /// Query embedding
36    Embed,
37    /// Vector similarity search
38    VectorSearch,
39    /// Keyword/BM25 search
40    KeywordSearch,
41    /// Result fusion
42    Fusion,
43    /// Graph expansion via SPARQL
44    GraphExpansion,
45    /// Community detection
46    CommunityDetection,
47    /// Context building
48    ContextBuild,
49    /// LLM generation
50    Generation,
51}
52
53/// Parameters for a stage
54#[derive(Debug, Clone, Serialize, Deserialize, Default)]
55pub struct StageParams {
56    /// Top-K for search stages
57    pub top_k: Option<usize>,
58    /// Threshold for filtering
59    pub threshold: Option<f32>,
60    /// SPARQL query template
61    pub sparql_template: Option<String>,
62    /// Maximum results
63    pub max_results: Option<usize>,
64    /// Other parameters
65    #[serde(default)]
66    pub extra: std::collections::HashMap<String, String>,
67}
68
69/// Query planner
70pub struct QueryPlanner {
71    config: GraphRAGConfig,
72}
73
74impl QueryPlanner {
75    pub fn new(config: GraphRAGConfig) -> Self {
76        Self { config }
77    }
78
79    /// Create execution plan for a parsed query
80    pub fn plan(&self, parsed: &ParsedQuery) -> GraphRAGResult<QueryPlan> {
81        let mut stages = Vec::new();
82        let mut stage_idx = 0;
83
84        // Stage 0: Embed query
85        stages.push(PlanStage {
86            name: "embed_query".to_string(),
87            stage_type: StageType::Embed,
88            params: StageParams::default(),
89            depends_on: vec![],
90        });
91        let embed_stage = stage_idx;
92        stage_idx += 1;
93
94        // Stage 1: Vector search (parallel with keyword)
95        stages.push(PlanStage {
96            name: "vector_search".to_string(),
97            stage_type: StageType::VectorSearch,
98            params: StageParams {
99                top_k: Some(self.config.top_k),
100                threshold: Some(self.config.similarity_threshold),
101                ..Default::default()
102            },
103            depends_on: vec![embed_stage],
104        });
105        let vector_stage = stage_idx;
106        stage_idx += 1;
107
108        // Stage 2: Keyword search (parallel with vector)
109        stages.push(PlanStage {
110            name: "keyword_search".to_string(),
111            stage_type: StageType::KeywordSearch,
112            params: StageParams {
113                top_k: Some(self.config.top_k),
114                extra: parsed
115                    .keywords
116                    .iter()
117                    .enumerate()
118                    .map(|(i, k)| (format!("keyword_{}", i), k.clone()))
119                    .collect(),
120                ..Default::default()
121            },
122            depends_on: vec![], // Can run in parallel with vector search
123        });
124        let keyword_stage = stage_idx;
125        stage_idx += 1;
126
127        // Stage 3: Fusion
128        stages.push(PlanStage {
129            name: "fusion".to_string(),
130            stage_type: StageType::Fusion,
131            params: StageParams {
132                max_results: Some(self.config.max_seeds),
133                ..Default::default()
134            },
135            depends_on: vec![vector_stage, keyword_stage],
136        });
137        let fusion_stage = stage_idx;
138        stage_idx += 1;
139
140        // Stage 4: Graph expansion
141        stages.push(PlanStage {
142            name: "graph_expansion".to_string(),
143            stage_type: StageType::GraphExpansion,
144            params: StageParams {
145                max_results: Some(self.config.max_subgraph_size),
146                extra: [("hops".to_string(), self.config.expansion_hops.to_string())]
147                    .into_iter()
148                    .collect(),
149                ..Default::default()
150            },
151            depends_on: vec![fusion_stage],
152        });
153        let expansion_stage = stage_idx;
154        stage_idx += 1;
155
156        // Stage 5: Community detection (optional)
157        let community_stage = if self.config.enable_communities {
158            stages.push(PlanStage {
159                name: "community_detection".to_string(),
160                stage_type: StageType::CommunityDetection,
161                params: StageParams {
162                    extra: [(
163                        "algorithm".to_string(),
164                        format!("{:?}", self.config.community_algorithm),
165                    )]
166                    .into_iter()
167                    .collect(),
168                    ..Default::default()
169                },
170                depends_on: vec![expansion_stage],
171            });
172            let idx = stage_idx;
173            stage_idx += 1;
174            Some(idx)
175        } else {
176            None
177        };
178
179        // Stage 6: Context building
180        let context_deps = if let Some(comm_stage) = community_stage {
181            vec![expansion_stage, comm_stage]
182        } else {
183            vec![expansion_stage]
184        };
185        stages.push(PlanStage {
186            name: "context_build".to_string(),
187            stage_type: StageType::ContextBuild,
188            params: StageParams {
189                max_results: Some(self.config.max_context_triples),
190                ..Default::default()
191            },
192            depends_on: context_deps,
193        });
194        let context_stage = stage_idx;
195        stage_idx += 1;
196
197        // Stage 7: LLM generation
198        stages.push(PlanStage {
199            name: "generation".to_string(),
200            stage_type: StageType::Generation,
201            params: StageParams {
202                extra: [
203                    (
204                        "temperature".to_string(),
205                        self.config.temperature.to_string(),
206                    ),
207                    ("max_tokens".to_string(), self.config.max_tokens.to_string()),
208                ]
209                .into_iter()
210                .collect(),
211                ..Default::default()
212            },
213            depends_on: vec![context_stage],
214        });
215        let _generation_stage = stage_idx;
216
217        // Calculate estimated cost
218        let estimated_cost = self.estimate_cost(&stages);
219
220        Ok(QueryPlan {
221            stages,
222            estimated_cost,
223            parallel: true, // Vector and keyword search can run in parallel
224        })
225    }
226
227    /// Estimate execution cost
228    fn estimate_cost(&self, stages: &[PlanStage]) -> f64 {
229        let mut cost = 0.0;
230
231        for stage in stages {
232            cost += match stage.stage_type {
233                StageType::Embed => 10.0,
234                StageType::VectorSearch => 50.0 * (stage.params.top_k.unwrap_or(20) as f64 / 20.0),
235                StageType::KeywordSearch => 30.0,
236                StageType::Fusion => 5.0,
237                StageType::GraphExpansion => 100.0 * (self.config.expansion_hops as f64),
238                StageType::CommunityDetection => 200.0,
239                StageType::ContextBuild => 10.0,
240                StageType::Generation => 500.0,
241            };
242        }
243
244        cost
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::query::parser::QueryParser;
252
253    #[test]
254    fn test_plan_creation() {
255        let config = GraphRAGConfig::default();
256        let planner = QueryPlanner::new(config);
257        let parser = QueryParser::new();
258
259        let parsed = parser.parse("What are battery safety issues?").unwrap();
260        let plan = planner.plan(&parsed).unwrap();
261
262        assert!(!plan.stages.is_empty());
263        assert!(plan.stages.iter().any(|s| s.stage_type == StageType::Embed));
264        assert!(plan
265            .stages
266            .iter()
267            .any(|s| s.stage_type == StageType::VectorSearch));
268        assert!(plan
269            .stages
270            .iter()
271            .any(|s| s.stage_type == StageType::Generation));
272    }
273}