use std::any::Any;
use std::fmt;
use std::sync::Arc;
use arrow_array::{Array, Float32Array, RecordBatch, StringArray, UInt32Array};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion_common::Result;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
#[derive(Debug, Clone)]
pub struct CausalDiscoveryConfig {
pub min_evidence: u32,
pub min_confidence: f32,
pub max_time_gap_secs: u64,
}
impl Default for CausalDiscoveryConfig {
fn default() -> Self {
Self {
min_evidence: 3,
min_confidence: 0.4,
max_time_gap_secs: 3600,
}
}
}
#[derive(Debug)]
pub struct CausalDiscoveryExec {
input: Arc<dyn ExecutionPlan>,
schema: SchemaRef,
properties: PlanProperties,
config: CausalDiscoveryConfig,
namespace: String,
}
impl CausalDiscoveryExec {
pub fn new(
input: Arc<dyn ExecutionPlan>,
config: CausalDiscoveryConfig,
namespace: String,
) -> Self {
let schema = Self::output_schema();
let properties = PlanProperties::new(
datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
EmissionType::Final,
Boundedness::Bounded,
);
Self {
input,
schema,
properties,
config,
namespace,
}
}
pub fn output_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("cause_id", DataType::Utf8, false),
Field::new("effect_id", DataType::Utf8, false),
Field::new("strength", DataType::Float32, false),
Field::new("confidence", DataType::Float32, false),
Field::new("evidence_count", DataType::UInt32, false),
Field::new("mechanism", DataType::Utf8, true),
]))
}
}
impl DisplayAs for CausalDiscoveryExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CausalDiscoveryExec: ns={}, min_ev={}, min_conf={}",
self.namespace, self.config.min_evidence, self.config.min_confidence
)
}
}
impl ExecutionPlan for CausalDiscoveryExec {
fn name(&self) -> &str {
"CausalDiscoveryExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
fn properties(&self) -> &PlanProperties {
&self.properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(Self::new(
children[0].clone(),
self.config.clone(),
self.namespace.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let input = self.input.execute(partition, context)?;
let schema = self.schema.clone();
let stream_schema = schema.clone();
let config = self.config.clone();
let fut = async move {
use futures::StreamExt;
use std::collections::HashMap;
let mut pair_counts: HashMap<(String, String), Vec<(String, String)>> = HashMap::new();
let mut stream = input;
let mut prev_records: Vec<(String, String, u64)> = Vec::new();
while let Some(batch) = stream.next().await {
let batch = batch?;
let id_col = batch.column_by_name("id");
let content_col = batch.column_by_name("content");
let ts_col = batch.column_by_name("created_at");
if let (Some(ids), Some(contents)) = (id_col, content_col) {
if let (Some(id_arr), Some(content_arr)) = (
ids.as_any().downcast_ref::<StringArray>(),
contents.as_any().downcast_ref::<StringArray>(),
) {
let timestamps: Vec<u64> = ts_col
.and_then(|c| {
c.as_any()
.downcast_ref::<arrow_array::UInt64Array>()
.map(|a| (0..a.len()).map(|i| a.value(i)).collect())
})
.unwrap_or_else(|| vec![0u64; id_arr.len()]);
for i in 0..id_arr.len() {
if id_arr.is_null(i) || content_arr.is_null(i) {
continue;
}
let id = id_arr.value(i).to_string();
let content = content_arr.value(i).to_string();
let ts = timestamps.get(i).copied().unwrap_or(0);
for (prev_id, prev_content, prev_ts) in &prev_records {
if ts > *prev_ts
&& (ts - prev_ts) <= config.max_time_gap_secs * 1000
{
let key_a = truncate_key(prev_content);
let key_b = truncate_key(&content);
if key_a != key_b {
pair_counts
.entry((key_a, key_b))
.or_default()
.push((prev_id.clone(), id.clone()));
}
}
}
prev_records.push((id, content, ts));
}
}
}
}
let mut cause_ids = Vec::new();
let mut effect_ids = Vec::new();
let mut strengths = Vec::new();
let mut confidences = Vec::new();
let mut evidence_counts = Vec::new();
let mut mechanisms: Vec<Option<String>> = Vec::new();
for ((_key_a, _key_b), pairs) in &pair_counts {
let count = pairs.len() as u32;
if count < config.min_evidence {
continue;
}
let strength = (count as f32 / 10.0).min(1.0);
let confidence = (0.3 + 0.7 * (1.0 - 1.0 / (1.0 + count as f32))).min(1.0);
if confidence < config.min_confidence {
continue;
}
if let Some((cause, effect)) = pairs.last() {
cause_ids.push(cause.clone());
effect_ids.push(effect.clone());
strengths.push(strength);
confidences.push(confidence);
evidence_counts.push(count);
mechanisms.push(Some("temporal_granger".to_string()));
}
}
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(cause_ids)),
Arc::new(StringArray::from(effect_ids)),
Arc::new(Float32Array::from(strengths)),
Arc::new(Float32Array::from(confidences)),
Arc::new(UInt32Array::from(evidence_counts)),
Arc::new(StringArray::from(mechanisms)),
],
)?;
Ok(batch)
};
let stream = futures::stream::once(fut);
Ok(Box::pin(RecordBatchStreamAdapter::new(
stream_schema,
stream,
)))
}
}
fn truncate_key(s: &str) -> String {
let chars: Vec<char> = s.chars().take(50).collect();
chars.into_iter().collect::<String>().to_lowercase()
}