use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion::error::Result as DFResult;
use datafusion::execution::TaskContext;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
};
use futures::TryStreamExt;
use parking_lot::RwLock;
use uni_algo::algo::AlgorithmRegistry;
use uni_locy::{ClassifierRegistry, ModelInvocation, ModelInvocationCache};
use uni_store::runtime::L0Manager;
use uni_store::runtime::property_manager::PropertyManager;
use uni_store::storage::manager::StorageManager;
use uni_xervo::runtime::ModelRuntime;
use super::locy_fixpoint::apply_model_invocations;
#[derive(Clone, Default)]
pub struct GraphAlgoHandle {
pub(crate) registry: Option<Arc<AlgorithmRegistry>>,
pub(crate) storage: Option<Arc<StorageManager>>,
pub(crate) l0_manager: Option<Arc<L0Manager>>,
pub(crate) property_manager: Option<Arc<PropertyManager>>,
pub(crate) l0_buffers: Option<L0Buffers>,
}
#[derive(Clone)]
pub(crate) struct L0Buffers {
pub(crate) current: Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>,
pub(crate) transaction: Option<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
pub(crate) pending_flush: Vec<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
}
impl std::fmt::Debug for GraphAlgoHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (&self.registry, &self.storage) {
(Some(_), Some(_)) => write!(f, "GraphAlgoHandle(<configured>)"),
_ => write!(f, "GraphAlgoHandle(<none>)"),
}
}
}
impl GraphAlgoHandle {
pub fn is_configured(&self) -> bool {
self.registry.is_some() && self.storage.is_some()
}
}
#[derive(Clone, Default)]
pub struct XervoRuntimeHandle(pub Option<Arc<ModelRuntime>>);
impl std::fmt::Debug for XervoRuntimeHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(_) => write!(f, "XervoRuntimeHandle(<configured>)"),
None => write!(f, "XervoRuntimeHandle(<none>)"),
}
}
}
impl XervoRuntimeHandle {
pub fn as_ref(&self) -> Option<&Arc<ModelRuntime>> {
self.0.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct PathContextHandle {
pub source_rule: String,
pub data: Arc<RwLock<Vec<RecordBatch>>>,
pub schema: SchemaRef,
}
#[derive(Debug)]
pub struct LocyModelInvokeExec {
input: Arc<dyn ExecutionPlan>,
invocations: Vec<ModelInvocation>,
registry: Arc<ClassifierRegistry>,
cache: Option<Arc<ModelInvocationCache>>,
path_context_handles: HashMap<String, PathContextHandle>,
xervo_runtime: XervoRuntimeHandle,
graph_algo: GraphAlgoHandle,
provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
schema: SchemaRef,
plan_properties: Arc<PlanProperties>,
}
impl LocyModelInvokeExec {
#[allow(clippy::too_many_arguments)]
pub fn new(
input: Arc<dyn ExecutionPlan>,
invocations: Vec<ModelInvocation>,
registry: Arc<ClassifierRegistry>,
cache: Option<Arc<ModelInvocationCache>>,
provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
path_context_handles: HashMap<String, PathContextHandle>,
xervo_runtime: XervoRuntimeHandle,
graph_algo: GraphAlgoHandle,
) -> Self {
let schema = compute_output_schema(input.schema(), &invocations);
let plan_properties = compute_plan_properties(&input, schema.clone());
Self {
input,
invocations,
registry,
cache,
provenance_store,
path_context_handles,
xervo_runtime,
graph_algo,
schema,
plan_properties,
}
}
}
fn compute_output_schema(input_schema: SchemaRef, invocations: &[ModelInvocation]) -> SchemaRef {
use arrow_schema::{DataType, Field, Schema};
if invocations.is_empty() {
return input_schema;
}
let mut fields: Vec<Arc<Field>> = input_schema.fields().iter().cloned().collect();
for invocation in invocations {
if let Some((idx, _)) = input_schema
.fields()
.iter()
.enumerate()
.find(|(_, f)| f.name() == &invocation.output_column)
{
fields[idx] = Arc::new(Field::new(
&invocation.output_column,
DataType::Float64,
true,
));
} else {
fields.push(Arc::new(Field::new(
&invocation.output_column,
DataType::Float64,
true,
)));
}
}
Arc::new(Schema::new(fields))
}
fn compute_plan_properties(
input: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
) -> Arc<PlanProperties> {
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
let eq = EquivalenceProperties::new(schema);
Arc::new(PlanProperties::new(
eq,
input.properties().output_partitioning().clone(),
EmissionType::Final,
Boundedness::Bounded,
))
}
impl DisplayAs for LocyModelInvokeExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"LocyModelInvokeExec: invocations=[{}]",
self.invocations
.iter()
.map(|inv| format!("{}→{}", inv.model_name, inv.output_column))
.collect::<Vec<_>>()
.join(", ")
)
}
}
impl ExecutionPlan for LocyModelInvokeExec {
fn name(&self) -> &str {
"LocyModelInvokeExec"
}
fn as_any(&self) -> &dyn Any {
self
}
fn properties(&self) -> &Arc<PlanProperties> {
&self.plan_properties
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
if children.len() != 1 {
return Err(datafusion::error::DataFusionError::Internal(format!(
"LocyModelInvokeExec expects exactly 1 child, got {}",
children.len()
)));
}
Ok(Arc::new(Self::new(
children.into_iter().next().unwrap(),
self.invocations.clone(),
Arc::clone(&self.registry),
self.cache.as_ref().map(Arc::clone),
self.provenance_store.as_ref().map(Arc::clone),
self.path_context_handles.clone(),
self.xervo_runtime.clone(),
self.graph_algo.clone(),
)))
}
fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> DFResult<SendableRecordBatchStream> {
let input_stream = self.input.execute(partition, context)?;
let invocations = self.invocations.clone();
let registry = Arc::clone(&self.registry);
let cache = self.cache.as_ref().map(Arc::clone);
let provenance_store = self.provenance_store.as_ref().map(Arc::clone);
let path_context_handles = self.path_context_handles.clone();
let xervo_runtime = self.xervo_runtime.clone();
let graph_algo = self.graph_algo.clone();
let schema = self.schema.clone();
let fut = async move {
let batches: Vec<RecordBatch> = input_stream.try_collect::<Vec<_>>().await?;
let out = apply_model_invocations(
batches,
&invocations,
®istry,
cache.as_ref(),
provenance_store.as_ref(),
&path_context_handles,
&xervo_runtime,
&graph_algo,
)
.await?;
Ok::<_, datafusion::error::DataFusionError>(futures::stream::iter(
out.into_iter().map(Ok),
))
};
let stream = futures::stream::once(fut).try_flatten();
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
}
}