Skip to main content

hirn_exec/
planner.rs

1//! `HirnExtensionPlanner` — converts `HirnPlanNode` logical nodes into physical operators.
2//!
3//! Implements DataFusion's `ExtensionPlanner` trait to bridge the gap between
4//! `hirn-query`'s compiled `LogicalPlan` (containing `HirnPlanNode` extension nodes)
5//! and `hirn-exec`'s physical `ExecutionPlan` operators.
6//!
7//! This is Stage 6 of the 7-stage QueryPipeline.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use datafusion::execution::SessionState;
13use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
14use datafusion_common::Result;
15use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
16use datafusion_physical_plan::ExecutionPlan;
17
18use hirn_query::compiler::plan_compiler::{ActivationRepr, HirnOp, HirnPlanNode};
19
20use crate::extensions::HirnSessionExt;
21use crate::operators::{
22    AbaReconsolidationExec, ActivationMode, CausalChainExec, CausalDiscoveryConfig,
23    CausalDiscoveryExec, CausalQueryReadExec, CausalReadKind, ContextAssemblyExec,
24    ContextBudgetExec, GlobalSearchExec, GlobalSearchParams, GraphActivationExec,
25    GraphTraverseExec, HebbianBufferExec, HybridSearchParams, InterferenceConfig,
26    InterferenceDetectorExec, IterativeConfig, IterativeRetrievalExec, LanceHybridSearchExec,
27    McfaConfig, McfaDefenseExec, NliConfig, NliContradictionExec, PolicyQueryReadExec,
28    PolicyReadKind, ProspectiveConfig, ProspectiveIndexingExec, QualityGateConfig, QualityGateExec,
29    RaptorSearchExec, RaptorSearchParams, RecallMergeExec, RpeConfig, RpeScoreExec,
30    SemanticHistoryScanExec, SvoConfig, SvoEventScanExec, SvoExtractionExec, TargetedQueryReadExec,
31    TargetedReadKind,
32};
33use crate::rules::{DEFAULT_PROSPECTIVE_THRESHOLD, ProspectiveShortCircuitExec};
34
35/// DataFusion extension planner that converts `HirnPlanNode` extension nodes
36/// into physical `ExecutionPlan` operators.
37///
38/// Registered with `DefaultPhysicalPlanner::with_extension_planners()` during
39/// `HirnDB::open_with_config()`.
40pub struct HirnExtensionPlanner;
41
42#[async_trait]
43impl ExtensionPlanner for HirnExtensionPlanner {
44    async fn plan_extension(
45        &self,
46        _planner: &dyn PhysicalPlanner,
47        node: &dyn UserDefinedLogicalNode,
48        _logical_inputs: &[&LogicalPlan],
49        physical_inputs: &[Arc<dyn ExecutionPlan>],
50        session_state: &SessionState,
51    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
52        let Some(hirn_node) = node.as_any().downcast_ref::<HirnPlanNode>() else {
53            // Not a hirn node — delegate to other planners.
54            return Ok(None);
55        };
56
57        // N-M13: read tunable params from HirnConfig via HirnSessionExt so that
58        // operators respect the configuration rather than using compile-time literals.
59        let hirnconfig = session_state
60            .config()
61            .options()
62            .extensions
63            .get::<HirnSessionExt>()
64            .map(|ext| Arc::clone(&ext.config));
65
66        let plan: Arc<dyn ExecutionPlan> = match &hirn_node.op {
67            // ── Source operators (leaf nodes — no physical_inputs expected) ──
68
69            // HybridSearch is materialized with empty batches first; the engine
70            // can then replace the placeholder with pre-fetched batches for the
71            // supported DataFusion execution slice.
72            HirnOp::HybridSearch {
73                query,
74                layers,
75                limit,
76                hybrid_mode,
77                namespace_filter,
78                ..
79            } => {
80                let schema = hirn_node.schema.as_ref().inner().clone();
81                let datasets = layers
82                    .iter()
83                    .map(|layer| match layer {
84                        hirn_core::types::Layer::Episodic => "episodic",
85                        hirn_core::types::Layer::Semantic => "semantic",
86                        hirn_core::types::Layer::Procedural => "procedural",
87                        hirn_core::types::Layer::Working => "working",
88                    })
89                    .map(ToString::to_string)
90                    .collect::<Vec<_>>();
91
92                let ns_filter = if namespace_filter.is_empty() {
93                    None
94                } else {
95                    tracing::debug!(namespace_filter = %namespace_filter, "HybridSearch: namespace pushdown applied");
96                    Some(namespace_filter.clone())
97                };
98                let metric = session_distance_metric(session_state)?;
99
100                Arc::new(LanceHybridSearchExec::new(
101                    schema,
102                    HybridSearchParams {
103                        datasets,
104                        vector_column: "embedding".to_string(),
105                        query_vector: Vec::new(),
106                        hybrid_mode: *hybrid_mode,
107                        fts_columns: vec!["content".to_string()],
108                        fts_query: query.clone(),
109                        limit: *limit,
110                        metric,
111                        filter: ns_filter,
112                        numeric_filters: Vec::new(),
113                        temporal_start_ms: None,
114                        temporal_end_ms: None,
115                        temporal_expansion: false,
116                        temporal_boost: 1.25,
117                    },
118                ))
119            }
120
121            HirnOp::GlobalSearch {
122                query,
123                namespace_filter,
124                max_communities,
125                community_threshold,
126                max_members_per_community,
127            } => {
128                let global_ns_filter = if namespace_filter.is_empty() {
129                    None
130                } else {
131                    tracing::debug!(namespace_filter = %namespace_filter, "GlobalSearch: namespace pushdown applied");
132                    Some(namespace_filter.clone())
133                };
134                let schema = hirn_node.schema.as_ref().inner().clone();
135                Arc::new(GlobalSearchExec::new(
136                    schema,
137                    GlobalSearchParams {
138                        query: query.clone(),
139                        query_vector: Vec::new(),
140                        filter: global_ns_filter,
141                        limit: max_communities
142                            .saturating_mul(max_members_per_community.saturating_add(1)),
143                        max_communities: *max_communities,
144                        community_threshold: *community_threshold as f32 / 1000.0,
145                        max_members_per_community: *max_members_per_community,
146                    },
147                ))
148            }
149
150            HirnOp::RaptorSearch {
151                query,
152                namespace_filter,
153                max_per_level,
154                similarity_threshold,
155                max_depth,
156            } => {
157                let raptor_ns_filter = if namespace_filter.is_empty() {
158                    None
159                } else {
160                    tracing::debug!(namespace_filter = %namespace_filter, "RaptorSearch: namespace pushdown applied");
161                    Some(namespace_filter.clone())
162                };
163                let schema = hirn_node.schema.as_ref().inner().clone();
164                Arc::new(RaptorSearchExec::new(
165                    schema,
166                    RaptorSearchParams {
167                        query: query.clone(),
168                        query_vector: Vec::new(),
169                        filter: raptor_ns_filter,
170                        limit: max_per_level.saturating_mul(max_depth.saturating_add(1)),
171                        max_per_level: *max_per_level,
172                        similarity_threshold: *similarity_threshold as f32 / 1000.0,
173                        max_depth: *max_depth,
174                    },
175                ))
176            }
177
178            HirnOp::RecallMerge => {
179                if physical_inputs.len() < 2 {
180                    return Err(datafusion_common::DataFusionError::Plan(
181                        "HirnRecallMerge requires at least two inputs".to_string(),
182                    ));
183                }
184                Arc::new(RecallMergeExec::new(
185                    hirn_node.schema.as_ref().inner().clone(),
186                    physical_inputs.to_vec(),
187                ))
188            }
189
190            // QueryComplexity is a classification node that produces no data.
191            // It signals the depth scheduler; at physical level it's a pass-through
192            // of its child (if any) or an empty exec.
193            HirnOp::QueryComplexity { .. } => {
194                if let Some(child) = physical_inputs.first() {
195                    Arc::clone(child)
196                } else {
197                    let schema = hirn_node.schema.as_ref().inner().clone();
198                    Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
199                }
200            }
201
202            // ── Read-path operators ──
203            HirnOp::GraphActivation {
204                seed_limit,
205                depth,
206                min_weight: _,
207                activation,
208            } => {
209                let input = require_single_input(physical_inputs, "GraphActivation")?;
210                let mode = match activation {
211                    ActivationRepr::Static => ActivationMode::Static,
212                    ActivationRepr::Spreading => ActivationMode::Spreading,
213                    ActivationRepr::Ppr => ActivationMode::Ppr,
214                    ActivationRepr::None => {
215                        // No activation requested — pass through.
216                        return Ok(Some(input));
217                    }
218                };
219                let epsilon = hirnconfig
220                    .as_ref()
221                    .map(|c| c.graph_activation_epsilon)
222                    .unwrap_or(0.001_f32);
223                let inhibition_mu = hirnconfig
224                    .as_ref()
225                    .map(|c| c.graph_activation_inhibition_mu)
226                    .unwrap_or(0.5_f32);
227                Arc::new(GraphActivationExec::new(
228                    input,
229                    *seed_limit,
230                    mode,
231                    *depth,
232                    epsilon,
233                    inhibition_mu,
234                )?)
235            }
236
237            HirnOp::CausalChain { depth } => {
238                let input = require_single_input(physical_inputs, "CausalChain")?;
239                let min_confidence = hirnconfig
240                    .as_ref()
241                    .map(|c| c.causal_min_confidence)
242                    .unwrap_or(0.3_f32);
243                Arc::new(CausalChainExec::new(input, *depth, min_confidence))
244            }
245
246            HirnOp::HebbianBuffer => {
247                let input = require_single_input(physical_inputs, "HebbianBuffer")?;
248                // Create a shared co-retrieval queue. The engine drains this
249                // periodically to update Hebbian weights in the graph.
250                let queue = Arc::new(crossbeam_queue::SegQueue::new());
251                Arc::new(HebbianBufferExec::new(input, queue))
252            }
253
254            HirnOp::ContextBudget { budget } => {
255                let input = require_single_input(physical_inputs, "ContextBudget")?;
256                Arc::new(ContextBudgetExec::new(input, *budget as u32))
257            }
258
259            HirnOp::QualityGate { threshold } => {
260                let input = require_single_input(physical_inputs, "QualityGate")?;
261                let config = QualityGateConfig {
262                    threshold: *threshold as f32 / 1000.0,
263                    ..QualityGateConfig::default()
264                };
265                let token_budget = hirnconfig
266                    .as_ref()
267                    .map(|c| c.default_token_budget)
268                    .unwrap_or(4096_usize);
269                Arc::new(QualityGateExec::new(input, config, token_budget))
270            }
271
272            HirnOp::IterativeRetrieval { max_hops } => {
273                let input = require_single_input(physical_inputs, "IterativeRetrieval")?;
274                let config = IterativeConfig {
275                    max_rounds: *max_hops as u32,
276                    ..IterativeConfig::default()
277                };
278                Arc::new(IterativeRetrievalExec::new(input, config))
279            }
280
281            // ── Write-path operators ──
282            HirnOp::RpeScore => {
283                let input = require_single_input(physical_inputs, "RpeScore")?;
284                Arc::new(RpeScoreExec::new(input, RpeConfig::default()))
285            }
286
287            HirnOp::ProspectiveIndexing => {
288                let input = require_single_input(physical_inputs, "ProspectiveIndexing")?;
289                Arc::new(ProspectiveIndexingExec::new(
290                    input,
291                    ProspectiveConfig::default(),
292                ))
293            }
294
295            HirnOp::SvoExtraction => {
296                let input = require_single_input(physical_inputs, "SvoExtraction")?;
297                Arc::new(SvoExtractionExec::new(input, SvoConfig::default()))
298            }
299
300            HirnOp::InterferenceDetector => {
301                let input = require_single_input(physical_inputs, "InterferenceDetector")?;
302                // Session ext may carry an injected NLI classifier (e.g. DeBERTa-MNLI ONNX).
303                // If not present, `InterferenceDetectorExec::new` picks the heuristic default.
304                let classifier = session_state
305                    .config()
306                    .options()
307                    .extensions
308                    .get::<HirnSessionExt>()
309                    .and_then(|ext| ext.nli_classifier());
310                match classifier {
311                    Some(clf) => Arc::new(InterferenceDetectorExec::with_nli_classifier(
312                        input,
313                        InterferenceConfig::default(),
314                        clf,
315                    )),
316                    None => Arc::new(InterferenceDetectorExec::new(
317                        input,
318                        InterferenceConfig::default(),
319                    )),
320                }
321            }
322
323            HirnOp::McfaDefense => {
324                let input = require_single_input(physical_inputs, "McfaDefense")?;
325                Arc::new(McfaDefenseExec::new(input, McfaConfig::default(), None))
326            }
327
328            // ── Mutation operators (pass-through at physical level) ──
329            // The actual insert/delete/connect logic runs in the engine after
330            // collecting the physical plan's output batches.
331            HirnOp::ImperativeBoundary { .. } => {
332                if let Some(child) = physical_inputs.first() {
333                    Arc::clone(child)
334                } else {
335                    let schema = hirn_node.schema.as_ref().inner().clone();
336                    Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
337                }
338            }
339
340            // ── Prospective search (recall-path) ──
341            HirnOp::ProspectiveSearch { .. } => {
342                let input = require_single_input(physical_inputs, "ProspectiveSearch")?;
343                Arc::new(ProspectiveShortCircuitExec::new(
344                    input,
345                    DEFAULT_PROSPECTIVE_THRESHOLD,
346                )?)
347            }
348
349            // ── SVO event scan ──
350            HirnOp::SvoEventScan {
351                namespace,
352                filter,
353                limit,
354            } => {
355                let schema = hirn_node.schema.as_ref().inner().clone();
356                Arc::new(SvoEventScanExec::new(
357                    schema,
358                    namespace.clone(),
359                    filter.clone(),
360                    *limit,
361                ))
362            }
363
364            HirnOp::SemanticHistoryScan {
365                target,
366                target_kind,
367                namespace,
368            } => {
369                let schema = hirn_node.schema.as_ref().inner().clone();
370                Arc::new(SemanticHistoryScanExec::new(
371                    schema,
372                    target.clone(),
373                    *target_kind,
374                    namespace.clone(),
375                ))
376            }
377
378            HirnOp::InspectScan {
379                target,
380                target_kind,
381            } => {
382                let schema = hirn_node.schema.as_ref().inner().clone();
383                Arc::new(TargetedQueryReadExec::new(
384                    schema,
385                    TargetedReadKind::Inspect,
386                    target.clone(),
387                    *target_kind,
388                ))
389            }
390
391            HirnOp::TraceScan {
392                target,
393                target_kind,
394            } => {
395                let schema = hirn_node.schema.as_ref().inner().clone();
396                Arc::new(TargetedQueryReadExec::new(
397                    schema,
398                    TargetedReadKind::Trace,
399                    target.clone(),
400                    *target_kind,
401                ))
402            }
403
404            HirnOp::ExplainCausesScan {
405                query,
406                depth,
407                namespace,
408            } => {
409                let schema = hirn_node.schema.as_ref().inner().clone();
410                Arc::new(CausalQueryReadExec::new(
411                    schema,
412                    CausalReadKind::ExplainCauses,
413                    query.clone(),
414                    None,
415                    *depth,
416                    namespace.clone(),
417                ))
418            }
419
420            HirnOp::WhatIfScan {
421                intervention,
422                outcome,
423                namespace,
424            } => {
425                let schema = hirn_node.schema.as_ref().inner().clone();
426                Arc::new(CausalQueryReadExec::new(
427                    schema,
428                    CausalReadKind::WhatIf,
429                    intervention.clone(),
430                    Some(outcome.clone()),
431                    0,
432                    namespace.clone(),
433                ))
434            }
435
436            HirnOp::CounterfactualScan {
437                antecedent,
438                consequent,
439                namespace,
440            } => {
441                let schema = hirn_node.schema.as_ref().inner().clone();
442                Arc::new(CausalQueryReadExec::new(
443                    schema,
444                    CausalReadKind::Counterfactual,
445                    antecedent.clone(),
446                    Some(consequent.clone()),
447                    0,
448                    namespace.clone(),
449                ))
450            }
451
452            HirnOp::ShowPoliciesScan {
453                principal_kind,
454                principal_name,
455            } => {
456                let schema = hirn_node.schema.as_ref().inner().clone();
457                Arc::new(PolicyQueryReadExec::new(
458                    schema,
459                    PolicyReadKind::ShowPolicies,
460                    principal_kind.clone(),
461                    principal_name.clone(),
462                    None,
463                    None,
464                    None,
465                ))
466            }
467
468            HirnOp::ExplainPolicyScan {
469                principal_kind,
470                principal_name,
471                resource_type,
472                resource_name,
473                action,
474            } => {
475                let schema = hirn_node.schema.as_ref().inner().clone();
476                Arc::new(PolicyQueryReadExec::new(
477                    schema,
478                    PolicyReadKind::ExplainPolicy,
479                    Some(principal_kind.clone()),
480                    Some(principal_name.clone()),
481                    Some(resource_type.clone()),
482                    Some(resource_name.clone()),
483                    Some(action.clone()),
484                ))
485            }
486
487            HirnOp::TraverseGraph {
488                start_id,
489                relation_filter,
490                depth,
491                namespace,
492            } => {
493                let schema = hirn_node.schema.as_ref().inner().clone();
494                Arc::new(GraphTraverseExec::new(
495                    schema,
496                    start_id.clone(),
497                    relation_filter.clone(),
498                    *depth,
499                    namespace.clone(),
500                ))
501            }
502
503            // ── NLI + ABA + Causal Discovery (consolidation sub-operators) ──
504            HirnOp::NliContradiction => {
505                let input = require_single_input(physical_inputs, "NliContradiction")?;
506                Arc::new(NliContradictionExec::new(input, NliConfig::default()))
507            }
508
509            HirnOp::AbaReconsolidation { namespace } => {
510                let input = require_single_input(physical_inputs, "AbaReconsolidation")?;
511                Arc::new(AbaReconsolidationExec::new(input, namespace.clone()))
512            }
513
514            HirnOp::CausalDiscovery { namespace } => {
515                let input = require_single_input(physical_inputs, "CausalDiscovery")?;
516                Arc::new(CausalDiscoveryExec::new(
517                    input,
518                    CausalDiscoveryConfig::default(),
519                    namespace.clone(),
520                ))
521            }
522
523            // ── Context assembly (THINK terminal operator) ──
524            HirnOp::ContextAssembly => {
525                let input = require_single_input(physical_inputs, "ContextAssembly")?;
526                Arc::new(ContextAssemblyExec::new(input))
527            }
528        };
529
530        Ok(Some(plan))
531    }
532}
533
534/// Custom `QueryPlanner` that wires `HirnExtensionPlanner` into DataFusion's
535/// `DefaultPhysicalPlanner`. Register via
536/// `SessionStateBuilder::with_query_planner(Arc::new(HirnQueryPlanner))`.
537#[derive(Debug)]
538pub struct HirnQueryPlanner;
539
540#[async_trait]
541impl datafusion::execution::context::QueryPlanner for HirnQueryPlanner {
542    async fn create_physical_plan(
543        &self,
544        logical_plan: &LogicalPlan,
545        session_state: &SessionState,
546    ) -> Result<Arc<dyn ExecutionPlan>> {
547        use datafusion::physical_planner::DefaultPhysicalPlanner;
548        let planner =
549            DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(HirnExtensionPlanner)]);
550        planner
551            .create_physical_plan(logical_plan, session_state)
552            .await
553    }
554}
555
556/// Extract the single required child input from physical_inputs.
557fn require_single_input(
558    inputs: &[Arc<dyn ExecutionPlan>],
559    op_name: &str,
560) -> Result<Arc<dyn ExecutionPlan>> {
561    inputs.first().cloned().ok_or_else(|| {
562        datafusion_common::DataFusionError::Plan(format!(
563            "Hirn{op_name} requires exactly one input, got 0"
564        ))
565    })
566}
567
568fn session_distance_metric(
569    session_state: &SessionState,
570) -> Result<hirn_storage::store::DistanceMetric> {
571    let ext = session_state
572        .config()
573        .options()
574        .extensions
575        .get::<crate::extensions::HirnSessionExt>()
576        .ok_or_else(|| {
577            datafusion_common::DataFusionError::Configuration(
578                "HirnSessionExt must be registered before planning compiled search operators"
579                    .to_string(),
580            )
581        })?;
582
583    Ok(ext.config.metric)
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    use datafusion::execution::SessionStateBuilder;
591    use datafusion::prelude::SessionContext;
592    use hirn_core::HirnConfig;
593
594    #[test]
595    fn session_distance_metric_uses_registered_config() {
596        let state = SessionStateBuilder::new_with_default_features().build();
597        let session = SessionContext::new_with_state(state);
598        crate::extensions::HirnSessionExt::new(
599            Arc::new(0_u8),
600            Arc::new(
601                HirnConfig::builder()
602                    .distance_metric(hirn_core::DistanceMetric::Cosine)
603                    .build()
604                    .expect("test config should build"),
605            ),
606            None,
607        )
608        .register(&session)
609        .unwrap();
610
611        let state = session.state();
612        assert_eq!(
613            session_distance_metric(&state).unwrap(),
614            hirn_storage::store::DistanceMetric::Cosine
615        );
616    }
617
618    #[test]
619    fn session_distance_metric_requires_session_extension() {
620        let state = SessionStateBuilder::new_with_default_features().build();
621        let session = SessionContext::new_with_state(state);
622        let state = session.state();
623
624        let error = session_distance_metric(&state).unwrap_err().to_string();
625        assert!(error.contains("HirnSessionExt must be registered"));
626    }
627}