1use std::sync::Arc;
10
11use async_trait::async_trait;
12use datafusion::execution::SessionState;
13use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
14use datafusion_common::Result;
15use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
16use datafusion_physical_plan::ExecutionPlan;
17
18use hirn_query::compiler::plan_compiler::{ActivationRepr, HirnOp, HirnPlanNode};
19
20use crate::extensions::HirnSessionExt;
21use crate::operators::{
22 AbaReconsolidationExec, ActivationMode, CausalChainExec, CausalDiscoveryConfig,
23 CausalDiscoveryExec, CausalQueryReadExec, CausalReadKind, ContextAssemblyExec,
24 ContextBudgetExec, GlobalSearchExec, GlobalSearchParams, GraphActivationExec,
25 GraphTraverseExec, HebbianBufferExec, HybridSearchParams, InterferenceConfig,
26 InterferenceDetectorExec, IterativeConfig, IterativeRetrievalExec, LanceHybridSearchExec,
27 McfaConfig, McfaDefenseExec, NliConfig, NliContradictionExec, PolicyQueryReadExec,
28 PolicyReadKind, ProspectiveConfig, ProspectiveIndexingExec, QualityGateConfig, QualityGateExec,
29 RaptorSearchExec, RaptorSearchParams, RecallMergeExec, RpeConfig, RpeScoreExec,
30 SemanticHistoryScanExec, SvoConfig, SvoEventScanExec, SvoExtractionExec, TargetedQueryReadExec,
31 TargetedReadKind,
32};
33use crate::rules::{DEFAULT_PROSPECTIVE_THRESHOLD, ProspectiveShortCircuitExec};
34
35pub struct HirnExtensionPlanner;
41
42#[async_trait]
43impl ExtensionPlanner for HirnExtensionPlanner {
44 async fn plan_extension(
45 &self,
46 _planner: &dyn PhysicalPlanner,
47 node: &dyn UserDefinedLogicalNode,
48 _logical_inputs: &[&LogicalPlan],
49 physical_inputs: &[Arc<dyn ExecutionPlan>],
50 session_state: &SessionState,
51 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
52 let Some(hirn_node) = node.as_any().downcast_ref::<HirnPlanNode>() else {
53 return Ok(None);
55 };
56
57 let hirnconfig = session_state
60 .config()
61 .options()
62 .extensions
63 .get::<HirnSessionExt>()
64 .map(|ext| Arc::clone(&ext.config));
65
66 let plan: Arc<dyn ExecutionPlan> = match &hirn_node.op {
67 HirnOp::HybridSearch {
73 query,
74 layers,
75 limit,
76 hybrid_mode,
77 namespace_filter,
78 ..
79 } => {
80 let schema = hirn_node.schema.as_ref().inner().clone();
81 let datasets = layers
82 .iter()
83 .map(|layer| match layer {
84 hirn_core::types::Layer::Episodic => "episodic",
85 hirn_core::types::Layer::Semantic => "semantic",
86 hirn_core::types::Layer::Procedural => "procedural",
87 hirn_core::types::Layer::Working => "working",
88 })
89 .map(ToString::to_string)
90 .collect::<Vec<_>>();
91
92 let ns_filter = if namespace_filter.is_empty() {
93 None
94 } else {
95 tracing::debug!(namespace_filter = %namespace_filter, "HybridSearch: namespace pushdown applied");
96 Some(namespace_filter.clone())
97 };
98 let metric = session_distance_metric(session_state)?;
99
100 Arc::new(LanceHybridSearchExec::new(
101 schema,
102 HybridSearchParams {
103 datasets,
104 vector_column: "embedding".to_string(),
105 query_vector: Vec::new(),
106 hybrid_mode: *hybrid_mode,
107 fts_columns: vec!["content".to_string()],
108 fts_query: query.clone(),
109 limit: *limit,
110 metric,
111 filter: ns_filter,
112 numeric_filters: Vec::new(),
113 temporal_start_ms: None,
114 temporal_end_ms: None,
115 temporal_expansion: false,
116 temporal_boost: 1.25,
117 },
118 ))
119 }
120
121 HirnOp::GlobalSearch {
122 query,
123 namespace_filter,
124 max_communities,
125 community_threshold,
126 max_members_per_community,
127 } => {
128 let global_ns_filter = if namespace_filter.is_empty() {
129 None
130 } else {
131 tracing::debug!(namespace_filter = %namespace_filter, "GlobalSearch: namespace pushdown applied");
132 Some(namespace_filter.clone())
133 };
134 let schema = hirn_node.schema.as_ref().inner().clone();
135 Arc::new(GlobalSearchExec::new(
136 schema,
137 GlobalSearchParams {
138 query: query.clone(),
139 query_vector: Vec::new(),
140 filter: global_ns_filter,
141 limit: max_communities
142 .saturating_mul(max_members_per_community.saturating_add(1)),
143 max_communities: *max_communities,
144 community_threshold: *community_threshold as f32 / 1000.0,
145 max_members_per_community: *max_members_per_community,
146 },
147 ))
148 }
149
150 HirnOp::RaptorSearch {
151 query,
152 namespace_filter,
153 max_per_level,
154 similarity_threshold,
155 max_depth,
156 } => {
157 let raptor_ns_filter = if namespace_filter.is_empty() {
158 None
159 } else {
160 tracing::debug!(namespace_filter = %namespace_filter, "RaptorSearch: namespace pushdown applied");
161 Some(namespace_filter.clone())
162 };
163 let schema = hirn_node.schema.as_ref().inner().clone();
164 Arc::new(RaptorSearchExec::new(
165 schema,
166 RaptorSearchParams {
167 query: query.clone(),
168 query_vector: Vec::new(),
169 filter: raptor_ns_filter,
170 limit: max_per_level.saturating_mul(max_depth.saturating_add(1)),
171 max_per_level: *max_per_level,
172 similarity_threshold: *similarity_threshold as f32 / 1000.0,
173 max_depth: *max_depth,
174 },
175 ))
176 }
177
178 HirnOp::RecallMerge => {
179 if physical_inputs.len() < 2 {
180 return Err(datafusion_common::DataFusionError::Plan(
181 "HirnRecallMerge requires at least two inputs".to_string(),
182 ));
183 }
184 Arc::new(RecallMergeExec::new(
185 hirn_node.schema.as_ref().inner().clone(),
186 physical_inputs.to_vec(),
187 ))
188 }
189
190 HirnOp::QueryComplexity { .. } => {
194 if let Some(child) = physical_inputs.first() {
195 Arc::clone(child)
196 } else {
197 let schema = hirn_node.schema.as_ref().inner().clone();
198 Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
199 }
200 }
201
202 HirnOp::GraphActivation {
204 seed_limit,
205 depth,
206 min_weight: _,
207 activation,
208 } => {
209 let input = require_single_input(physical_inputs, "GraphActivation")?;
210 let mode = match activation {
211 ActivationRepr::Static => ActivationMode::Static,
212 ActivationRepr::Spreading => ActivationMode::Spreading,
213 ActivationRepr::Ppr => ActivationMode::Ppr,
214 ActivationRepr::None => {
215 return Ok(Some(input));
217 }
218 };
219 let epsilon = hirnconfig
220 .as_ref()
221 .map(|c| c.graph_activation_epsilon)
222 .unwrap_or(0.001_f32);
223 let inhibition_mu = hirnconfig
224 .as_ref()
225 .map(|c| c.graph_activation_inhibition_mu)
226 .unwrap_or(0.5_f32);
227 Arc::new(GraphActivationExec::new(
228 input,
229 *seed_limit,
230 mode,
231 *depth,
232 epsilon,
233 inhibition_mu,
234 )?)
235 }
236
237 HirnOp::CausalChain { depth } => {
238 let input = require_single_input(physical_inputs, "CausalChain")?;
239 let min_confidence = hirnconfig
240 .as_ref()
241 .map(|c| c.causal_min_confidence)
242 .unwrap_or(0.3_f32);
243 Arc::new(CausalChainExec::new(input, *depth, min_confidence))
244 }
245
246 HirnOp::HebbianBuffer => {
247 let input = require_single_input(physical_inputs, "HebbianBuffer")?;
248 let queue = Arc::new(crossbeam_queue::SegQueue::new());
251 Arc::new(HebbianBufferExec::new(input, queue))
252 }
253
254 HirnOp::ContextBudget { budget } => {
255 let input = require_single_input(physical_inputs, "ContextBudget")?;
256 Arc::new(ContextBudgetExec::new(input, *budget as u32))
257 }
258
259 HirnOp::QualityGate { threshold } => {
260 let input = require_single_input(physical_inputs, "QualityGate")?;
261 let config = QualityGateConfig {
262 threshold: *threshold as f32 / 1000.0,
263 ..QualityGateConfig::default()
264 };
265 let token_budget = hirnconfig
266 .as_ref()
267 .map(|c| c.default_token_budget)
268 .unwrap_or(4096_usize);
269 Arc::new(QualityGateExec::new(input, config, token_budget))
270 }
271
272 HirnOp::IterativeRetrieval { max_hops } => {
273 let input = require_single_input(physical_inputs, "IterativeRetrieval")?;
274 let config = IterativeConfig {
275 max_rounds: *max_hops as u32,
276 ..IterativeConfig::default()
277 };
278 Arc::new(IterativeRetrievalExec::new(input, config))
279 }
280
281 HirnOp::RpeScore => {
283 let input = require_single_input(physical_inputs, "RpeScore")?;
284 Arc::new(RpeScoreExec::new(input, RpeConfig::default()))
285 }
286
287 HirnOp::ProspectiveIndexing => {
288 let input = require_single_input(physical_inputs, "ProspectiveIndexing")?;
289 Arc::new(ProspectiveIndexingExec::new(
290 input,
291 ProspectiveConfig::default(),
292 ))
293 }
294
295 HirnOp::SvoExtraction => {
296 let input = require_single_input(physical_inputs, "SvoExtraction")?;
297 Arc::new(SvoExtractionExec::new(input, SvoConfig::default()))
298 }
299
300 HirnOp::InterferenceDetector => {
301 let input = require_single_input(physical_inputs, "InterferenceDetector")?;
302 let classifier = session_state
305 .config()
306 .options()
307 .extensions
308 .get::<HirnSessionExt>()
309 .and_then(|ext| ext.nli_classifier());
310 match classifier {
311 Some(clf) => Arc::new(InterferenceDetectorExec::with_nli_classifier(
312 input,
313 InterferenceConfig::default(),
314 clf,
315 )),
316 None => Arc::new(InterferenceDetectorExec::new(
317 input,
318 InterferenceConfig::default(),
319 )),
320 }
321 }
322
323 HirnOp::McfaDefense => {
324 let input = require_single_input(physical_inputs, "McfaDefense")?;
325 Arc::new(McfaDefenseExec::new(input, McfaConfig::default(), None))
326 }
327
328 HirnOp::ImperativeBoundary { .. } => {
332 if let Some(child) = physical_inputs.first() {
333 Arc::clone(child)
334 } else {
335 let schema = hirn_node.schema.as_ref().inner().clone();
336 Arc::new(datafusion_physical_plan::empty::EmptyExec::new(schema))
337 }
338 }
339
340 HirnOp::ProspectiveSearch { .. } => {
342 let input = require_single_input(physical_inputs, "ProspectiveSearch")?;
343 Arc::new(ProspectiveShortCircuitExec::new(
344 input,
345 DEFAULT_PROSPECTIVE_THRESHOLD,
346 )?)
347 }
348
349 HirnOp::SvoEventScan {
351 namespace,
352 filter,
353 limit,
354 } => {
355 let schema = hirn_node.schema.as_ref().inner().clone();
356 Arc::new(SvoEventScanExec::new(
357 schema,
358 namespace.clone(),
359 filter.clone(),
360 *limit,
361 ))
362 }
363
364 HirnOp::SemanticHistoryScan {
365 target,
366 target_kind,
367 namespace,
368 } => {
369 let schema = hirn_node.schema.as_ref().inner().clone();
370 Arc::new(SemanticHistoryScanExec::new(
371 schema,
372 target.clone(),
373 *target_kind,
374 namespace.clone(),
375 ))
376 }
377
378 HirnOp::InspectScan {
379 target,
380 target_kind,
381 } => {
382 let schema = hirn_node.schema.as_ref().inner().clone();
383 Arc::new(TargetedQueryReadExec::new(
384 schema,
385 TargetedReadKind::Inspect,
386 target.clone(),
387 *target_kind,
388 ))
389 }
390
391 HirnOp::TraceScan {
392 target,
393 target_kind,
394 } => {
395 let schema = hirn_node.schema.as_ref().inner().clone();
396 Arc::new(TargetedQueryReadExec::new(
397 schema,
398 TargetedReadKind::Trace,
399 target.clone(),
400 *target_kind,
401 ))
402 }
403
404 HirnOp::ExplainCausesScan {
405 query,
406 depth,
407 namespace,
408 } => {
409 let schema = hirn_node.schema.as_ref().inner().clone();
410 Arc::new(CausalQueryReadExec::new(
411 schema,
412 CausalReadKind::ExplainCauses,
413 query.clone(),
414 None,
415 *depth,
416 namespace.clone(),
417 ))
418 }
419
420 HirnOp::WhatIfScan {
421 intervention,
422 outcome,
423 namespace,
424 } => {
425 let schema = hirn_node.schema.as_ref().inner().clone();
426 Arc::new(CausalQueryReadExec::new(
427 schema,
428 CausalReadKind::WhatIf,
429 intervention.clone(),
430 Some(outcome.clone()),
431 0,
432 namespace.clone(),
433 ))
434 }
435
436 HirnOp::CounterfactualScan {
437 antecedent,
438 consequent,
439 namespace,
440 } => {
441 let schema = hirn_node.schema.as_ref().inner().clone();
442 Arc::new(CausalQueryReadExec::new(
443 schema,
444 CausalReadKind::Counterfactual,
445 antecedent.clone(),
446 Some(consequent.clone()),
447 0,
448 namespace.clone(),
449 ))
450 }
451
452 HirnOp::ShowPoliciesScan {
453 principal_kind,
454 principal_name,
455 } => {
456 let schema = hirn_node.schema.as_ref().inner().clone();
457 Arc::new(PolicyQueryReadExec::new(
458 schema,
459 PolicyReadKind::ShowPolicies,
460 principal_kind.clone(),
461 principal_name.clone(),
462 None,
463 None,
464 None,
465 ))
466 }
467
468 HirnOp::ExplainPolicyScan {
469 principal_kind,
470 principal_name,
471 resource_type,
472 resource_name,
473 action,
474 } => {
475 let schema = hirn_node.schema.as_ref().inner().clone();
476 Arc::new(PolicyQueryReadExec::new(
477 schema,
478 PolicyReadKind::ExplainPolicy,
479 Some(principal_kind.clone()),
480 Some(principal_name.clone()),
481 Some(resource_type.clone()),
482 Some(resource_name.clone()),
483 Some(action.clone()),
484 ))
485 }
486
487 HirnOp::TraverseGraph {
488 start_id,
489 relation_filter,
490 depth,
491 namespace,
492 } => {
493 let schema = hirn_node.schema.as_ref().inner().clone();
494 Arc::new(GraphTraverseExec::new(
495 schema,
496 start_id.clone(),
497 relation_filter.clone(),
498 *depth,
499 namespace.clone(),
500 ))
501 }
502
503 HirnOp::NliContradiction => {
505 let input = require_single_input(physical_inputs, "NliContradiction")?;
506 Arc::new(NliContradictionExec::new(input, NliConfig::default()))
507 }
508
509 HirnOp::AbaReconsolidation { namespace } => {
510 let input = require_single_input(physical_inputs, "AbaReconsolidation")?;
511 Arc::new(AbaReconsolidationExec::new(input, namespace.clone()))
512 }
513
514 HirnOp::CausalDiscovery { namespace } => {
515 let input = require_single_input(physical_inputs, "CausalDiscovery")?;
516 Arc::new(CausalDiscoveryExec::new(
517 input,
518 CausalDiscoveryConfig::default(),
519 namespace.clone(),
520 ))
521 }
522
523 HirnOp::ContextAssembly => {
525 let input = require_single_input(physical_inputs, "ContextAssembly")?;
526 Arc::new(ContextAssemblyExec::new(input))
527 }
528 };
529
530 Ok(Some(plan))
531 }
532}
533
534#[derive(Debug)]
538pub struct HirnQueryPlanner;
539
540#[async_trait]
541impl datafusion::execution::context::QueryPlanner for HirnQueryPlanner {
542 async fn create_physical_plan(
543 &self,
544 logical_plan: &LogicalPlan,
545 session_state: &SessionState,
546 ) -> Result<Arc<dyn ExecutionPlan>> {
547 use datafusion::physical_planner::DefaultPhysicalPlanner;
548 let planner =
549 DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(HirnExtensionPlanner)]);
550 planner
551 .create_physical_plan(logical_plan, session_state)
552 .await
553 }
554}
555
556fn require_single_input(
558 inputs: &[Arc<dyn ExecutionPlan>],
559 op_name: &str,
560) -> Result<Arc<dyn ExecutionPlan>> {
561 inputs.first().cloned().ok_or_else(|| {
562 datafusion_common::DataFusionError::Plan(format!(
563 "Hirn{op_name} requires exactly one input, got 0"
564 ))
565 })
566}
567
568fn session_distance_metric(
569 session_state: &SessionState,
570) -> Result<hirn_storage::store::DistanceMetric> {
571 let ext = session_state
572 .config()
573 .options()
574 .extensions
575 .get::<crate::extensions::HirnSessionExt>()
576 .ok_or_else(|| {
577 datafusion_common::DataFusionError::Configuration(
578 "HirnSessionExt must be registered before planning compiled search operators"
579 .to_string(),
580 )
581 })?;
582
583 Ok(ext.config.metric)
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 use datafusion::execution::SessionStateBuilder;
591 use datafusion::prelude::SessionContext;
592 use hirn_core::HirnConfig;
593
594 #[test]
595 fn session_distance_metric_uses_registered_config() {
596 let state = SessionStateBuilder::new_with_default_features().build();
597 let session = SessionContext::new_with_state(state);
598 crate::extensions::HirnSessionExt::new(
599 Arc::new(0_u8),
600 Arc::new(
601 HirnConfig::builder()
602 .distance_metric(hirn_core::DistanceMetric::Cosine)
603 .build()
604 .expect("test config should build"),
605 ),
606 None,
607 )
608 .register(&session)
609 .unwrap();
610
611 let state = session.state();
612 assert_eq!(
613 session_distance_metric(&state).unwrap(),
614 hirn_storage::store::DistanceMetric::Cosine
615 );
616 }
617
618 #[test]
619 fn session_distance_metric_requires_session_extension() {
620 let state = SessionStateBuilder::new_with_default_features().build();
621 let session = SessionContext::new_with_state(state);
622 let state = session.state();
623
624 let error = session_distance_metric(&state).unwrap_err().to_string();
625 assert!(error.contains("HirnSessionExt must be registered"));
626 }
627}