use std::collections::hash_map::DefaultHasher;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::time::Instant;
use tensorlogic_compiler::compile_to_einsum;
use tensorlogic_infer::{ExecutorError, TlAutodiff};
use tensorlogic_ir::TLExpr;
use tensorlogic_scirs_backend::Scirs2Exec;
use crate::shacl::validation::{ValidationReport, ValidationResult, ValidationSeverity};
#[derive(Debug, Clone)]
pub struct ValidationExecutorConfig {
pub max_tensor_size: usize,
pub float_precision: usize,
pub base_iri: String,
}
impl Default for ValidationExecutorConfig {
fn default() -> Self {
Self {
max_tensor_size: 65536,
float_precision: 6,
base_iri: "https://tensorlogic.local/".into(),
}
}
}
#[derive(Debug, Clone)]
pub struct ExecutionStats {
pub compile_time_us: u64,
pub execute_time_us: u64,
pub graph_node_count: usize,
pub output_tensor_count: usize,
pub total_elements: usize,
}
#[derive(Debug, Clone)]
pub struct ExecutionTensor {
pub name: String,
pub shape: Vec<usize>,
pub values: Vec<f64>,
}
impl ExecutionTensor {
pub fn has_nan(&self) -> bool {
self.values.iter().any(|v| v.is_nan())
}
pub fn has_inf(&self) -> bool {
self.values.iter().any(|v| v.is_infinite())
}
pub fn all_finite(&self) -> bool {
self.values.iter().all(|v| v.is_finite())
}
pub fn min_value(&self) -> Option<f64> {
self.values.iter().copied().reduce(f64::min)
}
pub fn max_value(&self) -> Option<f64> {
self.values.iter().copied().reduce(f64::max)
}
pub fn non_finite_count(&self) -> usize {
self.values.iter().filter(|v| !v.is_finite()).count()
}
}
#[derive(Debug, Clone)]
pub struct ExecutionResult {
pub expression_repr: String,
pub graph_node_count: usize,
pub output_tensors: Vec<ExecutionTensor>,
pub stats: ExecutionStats,
}
#[derive(Debug)]
pub enum ValidationExecutorError {
Compile(anyhow::Error),
Execute(ExecutorError),
TensorTooLarge {
name: String,
size: usize,
max: usize,
},
EmptyGraph,
}
impl fmt::Display for ValidationExecutorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Compile(e) => write!(f, "compilation error: {e}"),
Self::Execute(e) => write!(f, "executor error: {e}"),
Self::TensorTooLarge { name, size, max } => write!(
f,
"output tensor '{name}' has {size} elements which exceeds the limit of {max}"
),
Self::EmptyGraph => write!(f, "compiled graph is empty (no tensors or nodes)"),
}
}
}
impl std::error::Error for ValidationExecutorError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Compile(_) => None,
Self::Execute(e) => Some(e),
Self::TensorTooLarge { .. } | Self::EmptyGraph => None,
}
}
}
pub struct ValidationExecutor {
config: ValidationExecutorConfig,
}
impl ValidationExecutor {
pub fn new(config: ValidationExecutorConfig) -> Self {
Self { config }
}
pub fn config(&self) -> &ValidationExecutorConfig {
&self.config
}
pub fn execute_rule(&self, expr: &TLExpr) -> Result<ExecutionResult, ValidationExecutorError> {
let t_compile_start = Instant::now();
let graph = compile_to_einsum(expr).map_err(ValidationExecutorError::Compile)?;
let compile_time_us = t_compile_start.elapsed().as_micros() as u64;
if graph.is_empty() {
return Err(ValidationExecutorError::EmptyGraph);
}
let t_exec_start = Instant::now();
let mut exec = Scirs2Exec::new();
for (i, tensor_name) in graph.tensors.iter().enumerate() {
let base_name = tensor_name
.split('[')
.next()
.unwrap_or(tensor_name.as_str());
if base_name.starts_with("const_") || tensor_name.starts_with("const_") {
continue;
}
let val = 0.1 + 0.1 * (i % 9) as f64;
let placeholder = scirs2_core::ndarray::Array1::from_vec(vec![val]).into_dyn();
exec.add_tensor(tensor_name.clone(), placeholder);
}
let result_tensor = exec
.forward(&graph)
.map_err(ValidationExecutorError::Execute)?;
let execute_time_us = t_exec_start.elapsed().as_micros() as u64;
let shape: Vec<usize> = result_tensor.shape().to_vec();
let values: Vec<f64> = result_tensor.iter().copied().collect();
let total_elements = values.len();
if total_elements > self.config.max_tensor_size {
return Err(ValidationExecutorError::TensorTooLarge {
name: "output".to_string(),
size: total_elements,
max: self.config.max_tensor_size,
});
}
let output_tensor = ExecutionTensor {
name: "output".to_string(),
shape,
values,
};
let graph_node_count = graph.nodes.len();
Ok(ExecutionResult {
expression_repr: format!("{expr:?}"),
graph_node_count,
output_tensors: vec![output_tensor],
stats: ExecutionStats {
compile_time_us,
execute_time_us,
graph_node_count,
output_tensor_count: 1,
total_elements,
},
})
}
pub fn generate_validation_report(&self, result: &ExecutionResult) -> ValidationReport {
let mut report = ValidationReport::new();
for tensor in &result.output_tensors {
if !tensor.all_finite() {
let non_finite = tensor.non_finite_count();
let has_nan = tensor.has_nan();
let has_inf = tensor.has_inf();
let kind_desc = match (has_nan, has_inf) {
(true, true) => "NaN and Inf values",
(true, false) => "NaN values",
(false, true) => "Inf values",
(false, false) => "non-finite values",
};
let message = format!(
"Output tensor '{}' contains {} {} (shape: {:?})",
tensor.name, non_finite, kind_desc, tensor.shape,
);
let focus_node = format!("{}tensor/{}", self.config.base_iri, tensor.name);
let source_shape = format!("{}shape/FiniteValueConstraint", self.config.base_iri);
let constraint_component = format!(
"{}constraint/FiniteValueConstraintComponent",
self.config.base_iri
);
let vr =
ValidationResult::new(focus_node, source_shape, constraint_component, message)
.with_severity(ValidationSeverity::Violation)
.with_value(format!("{non_finite} non-finite elements"));
report.add_result(vr);
}
}
report
}
pub fn export_as_rdf(&self, result: &ExecutionResult) -> String {
let base_iri = &self.config.base_iri;
let prec = self.config.float_precision;
let exec_hash = {
let mut h = DefaultHasher::new();
result.expression_repr.hash(&mut h);
h.finish()
};
let escaped_repr = escape_turtle_literal(&result.expression_repr);
let all_conforms = result.output_tensors.iter().all(|t| t.all_finite());
let mut out = String::with_capacity(512);
out.push_str(&format!("@prefix tl: <{base_iri}> .\n"));
out.push_str("@prefix xsd: <http://www.w3.org/2001/XMLSchema#> .\n");
out.push('\n');
out.push_str(&format!(
"tl:exec_{exec_hash:016x} a tl:ExecutionResult ;\n"
));
out.push_str(&format!(" tl:expressionRepr \"{escaped_repr}\" ;\n"));
out.push_str(&format!(
" tl:graphNodeCount {graph_node_count}^^xsd:integer ;\n",
graph_node_count = result.graph_node_count
));
out.push_str(&format!(
" tl:compileTimeUs {compile_us}^^xsd:integer ;\n",
compile_us = result.stats.compile_time_us
));
out.push_str(&format!(
" tl:executeTimeUs {execute_us}^^xsd:integer ;\n",
execute_us = result.stats.execute_time_us
));
out.push_str(&format!(
" tl:totalElements {total}^^xsd:integer ;\n",
total = result.stats.total_elements
));
out.push_str(&format!(
" tl:conforms {conforms}^^xsd:boolean",
conforms = all_conforms
));
if result.output_tensors.is_empty() {
out.push_str(" .\n");
} else {
out.push_str(" ;\n");
let tensor_count = result.output_tensors.len();
for (idx, tensor) in result.output_tensors.iter().enumerate() {
let is_last = idx == tensor_count - 1;
let tensor_node = format_tensor_blank_node(tensor, prec);
if is_last {
out.push_str(&format!(" tl:outputTensor {tensor_node} .\n"));
} else {
out.push_str(&format!(" tl:outputTensor {tensor_node} ;\n"));
}
}
}
out
}
}
fn escape_turtle_literal(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for ch in s.chars() {
match ch {
'\\' => out.push_str("\\\\"),
'"' => out.push_str("\\\""),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
other => out.push(other),
}
}
out
}
fn format_tensor_blank_node(tensor: &ExecutionTensor, prec: usize) -> String {
let shape_str: Vec<String> = tensor.shape.iter().map(|d| d.to_string()).collect();
let shape_literal = shape_str.join(",");
let all_finite = tensor.all_finite();
let min_str = tensor
.min_value()
.map(|v| format!("{v:.prec$}"))
.unwrap_or_else(|| "null".to_string());
let max_str = tensor
.max_value()
.map(|v| format!("{v:.prec$}"))
.unwrap_or_else(|| "null".to_string());
let mut node = String::new();
node.push_str("[\n");
node.push_str(" a tl:OutputTensor ;\n");
node.push_str(&format!(
" tl:tensorName \"{name}\" ;\n",
name = escape_turtle_literal(&tensor.name)
));
node.push_str(&format!(" tl:shape \"{shape_literal}\" ;\n",));
node.push_str(&format!(
" tl:elementCount {count}^^xsd:integer ;\n",
count = tensor.values.len()
));
node.push_str(&format!(
" tl:allFinite {all_finite}^^xsd:boolean ;\n",
));
node.push_str(&format!(
" tl:minValue \"{min_str}\"^^xsd:decimal ;\n",
));
node.push_str(&format!(" tl:maxValue \"{max_str}\"^^xsd:decimal\n",));
node.push_str(" ]");
node
}
#[cfg(test)]
mod tests {
use tensorlogic_ir::{TLExpr, Term};
use super::*;
fn default_executor() -> ValidationExecutor {
ValidationExecutor::new(ValidationExecutorConfig::default())
}
#[test]
fn test_simple_predicate_compiles_and_runs() {
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let executor = default_executor();
let result = executor
.execute_rule(&expr)
.expect("execute simple predicate");
assert!(
!result.output_tensors.is_empty(),
"expected at least one output tensor"
);
}
#[test]
fn test_finite_output_conforms() {
let expr = TLExpr::pred("p", vec![Term::var("x")]);
let executor = default_executor();
let result = executor.execute_rule(&expr).expect("execute predicate");
let report = executor.generate_validation_report(&result);
assert!(report.conforms, "finite outputs should conform");
}
#[test]
fn test_export_rdf_contains_required_prefixes() {
let expr = TLExpr::pred("q", vec![Term::var("a")]);
let executor = default_executor();
let result = executor.execute_rule(&expr).expect("execute");
let rdf = executor.export_as_rdf(&result);
assert!(rdf.contains("@prefix tl:"), "missing tl: prefix");
assert!(rdf.contains("@prefix xsd:"), "missing xsd: prefix");
assert!(
rdf.contains("tl:ExecutionResult"),
"missing ExecutionResult type"
);
}
#[test]
fn test_export_rdf_conforms_field() {
let expr = TLExpr::pred("r", vec![Term::var("x")]);
let executor = default_executor();
let result = executor.execute_rule(&expr).expect("execute");
let rdf = executor.export_as_rdf(&result);
assert!(
rdf.contains("tl:conforms true") || rdf.contains("tl:conforms false"),
"missing conforms field in RDF: {rdf}"
);
}
#[test]
fn test_execution_stats_recorded() {
let p = TLExpr::pred("s", vec![Term::var("x")]);
let q = TLExpr::pred("t", vec![Term::var("x")]);
let expr = TLExpr::and(p, q);
let executor = default_executor();
let result = executor
.execute_rule(&expr)
.expect("execute AND expression");
assert!(
result.stats.graph_node_count > 0,
"expected at least one graph node"
);
}
#[test]
fn test_max_tensor_size_zero_returns_error_or_empty() {
let config = ValidationExecutorConfig {
max_tensor_size: 0,
..Default::default()
};
let expr = TLExpr::pred("t", vec![Term::var("x")]);
let executor = ValidationExecutor::new(config);
let _ = executor.execute_rule(&expr);
}
#[test]
fn test_execution_tensor_helpers_all_finite() {
let t = ExecutionTensor {
name: "test".to_string(),
shape: vec![3],
values: vec![1.0, 2.0, 3.0],
};
assert!(t.all_finite());
assert!(!t.has_nan());
assert!(!t.has_inf());
assert_eq!(t.min_value(), Some(1.0));
assert_eq!(t.max_value(), Some(3.0));
assert_eq!(t.non_finite_count(), 0);
}
#[test]
fn test_execution_tensor_helpers_with_nan() {
let t = ExecutionTensor {
name: "bad".to_string(),
shape: vec![2],
values: vec![f64::NAN, 1.0],
};
assert!(!t.all_finite());
assert!(t.has_nan());
assert_eq!(t.non_finite_count(), 1);
}
#[test]
fn test_error_display_empty_graph() {
let e = ValidationExecutorError::EmptyGraph;
assert!(e.to_string().contains("empty"), "unexpected: {e}");
}
#[test]
fn test_error_display_tensor_too_large() {
let e = ValidationExecutorError::TensorTooLarge {
name: "out".into(),
size: 100,
max: 50,
};
let s = e.to_string();
assert!(s.contains("out"), "unexpected: {s}");
assert!(s.contains("100"), "unexpected: {s}");
assert!(s.contains("50"), "unexpected: {s}");
}
#[test]
fn test_escape_turtle_literal_special_chars() {
let raw = "Hello\nworld\\foo\"bar";
let escaped = escape_turtle_literal(raw);
assert!(escaped.contains("\\n"), "newline not escaped");
assert!(escaped.contains("\\\\"), "backslash not escaped");
assert!(escaped.contains("\\\""), "quote not escaped");
}
#[test]
fn test_validation_report_for_infinite_tensor() {
let executor = default_executor();
let result = ExecutionResult {
expression_repr: "test".to_string(),
graph_node_count: 1,
output_tensors: vec![ExecutionTensor {
name: "output".to_string(),
shape: vec![1],
values: vec![f64::INFINITY],
}],
stats: ExecutionStats {
compile_time_us: 0,
execute_time_us: 0,
graph_node_count: 1,
output_tensor_count: 1,
total_elements: 1,
},
};
let report = executor.generate_validation_report(&result);
assert!(!report.conforms, "Inf tensor should not conform");
assert!(
!report.results.is_empty(),
"expected at least one violation"
);
}
}