use anyhow::{anyhow, Result};
use tensorlogic_ir::{PredicateSignature, TLExpr};
use super::diagnostics::{diagnose_expression, DiagnosticLevel};
use super::scope_analysis::{analyze_scopes, suggest_quantifiers};
use super::type_checking::TypeChecker;
pub fn validate_arity(expr: &TLExpr) -> Result<()> {
expr.validate_arity().map_err(|e| anyhow!("{}", e))
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub passed: bool,
pub error_count: usize,
pub warning_count: usize,
pub diagnostics: Vec<String>,
}
impl ValidationResult {
pub fn is_ok(&self) -> bool {
self.passed
}
pub fn has_errors(&self) -> bool {
self.error_count > 0
}
pub fn error_message(&self) -> String {
self.diagnostics.join("\n")
}
}
pub fn validate_expression(expr: &TLExpr) -> ValidationResult {
let mut diagnostics = Vec::new();
let mut error_count = 0;
let mut warning_count = 0;
if let Err(e) = validate_arity(expr) {
diagnostics.push(format!("Arity error: {}", e));
error_count += 1;
}
match analyze_scopes(expr) {
Ok(scope_result) => {
if !scope_result.unbound_variables.is_empty() {
for var in &scope_result.unbound_variables {
diagnostics.push(format!("Unbound variable: '{}'", var));
error_count += 1;
}
if let Ok(suggestions) = suggest_quantifiers(expr) {
if !suggestions.is_empty() {
diagnostics.push(format!("Suggestion: {}", suggestions.join(", ")));
}
}
}
for conflict in &scope_result.type_conflicts {
diagnostics.push(format!(
"Type conflict: variable '{}' has conflicting types '{}' and '{}'",
conflict.variable, conflict.type1, conflict.type2
));
error_count += 1;
}
}
Err(e) => {
diagnostics.push(format!("Scope analysis error: {}", e));
error_count += 1;
}
}
let diag_messages = diagnose_expression(expr);
for diag in diag_messages {
let formatted = diag.format();
match diag.level {
DiagnosticLevel::Error => {
if !diagnostics.iter().any(|d| d.contains(&diag.message)) {
diagnostics.push(formatted);
error_count += 1;
}
}
DiagnosticLevel::Warning => {
diagnostics.push(formatted);
warning_count += 1;
}
DiagnosticLevel::Info | DiagnosticLevel::Hint => {
diagnostics.push(formatted);
}
}
}
ValidationResult {
passed: error_count == 0,
error_count,
warning_count,
diagnostics,
}
}
pub fn validate_expression_with_types(
expr: &TLExpr,
signatures: &[PredicateSignature],
) -> ValidationResult {
let mut result = validate_expression(expr);
use tensorlogic_ir::SignatureRegistry;
let mut registry = SignatureRegistry::new();
for sig in signatures {
registry.register(sig.clone());
}
let checker = TypeChecker::new(registry);
if let Err(e) = checker.check_expr(expr) {
result.diagnostics.push(format!("Type error: {}", e));
result.error_count += 1;
result.passed = false;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_ir::Term;
#[test]
fn test_validate_expression_ok() {
let expr = TLExpr::exists(
"x",
"Person",
TLExpr::exists(
"y",
"Person",
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
),
);
let result = validate_expression(&expr);
if !result.is_ok() {
eprintln!("Validation failed with errors:");
for diag in &result.diagnostics {
eprintln!(" - {}", diag);
}
}
assert!(result.is_ok());
assert_eq!(result.error_count, 0);
}
#[test]
fn test_validate_expression_partial_binding() {
let expr = TLExpr::exists(
"y",
"Person",
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
);
let result = validate_expression(&expr);
eprintln!(
"Error count: {}, diagnostics: {:?}",
result.error_count, result.diagnostics
);
assert!(result.has_errors());
assert!(result.error_count >= 1); }
#[test]
fn test_validate_expression_unbound_vars() {
let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
let result = validate_expression(&expr);
assert!(result.has_errors());
assert!(result.error_count >= 2);
}
#[test]
fn test_validate_expression_arity_mismatch() {
let expr = TLExpr::and(
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
TLExpr::pred("knows", vec![Term::var("z")]),
);
let result = validate_expression(&expr);
assert!(result.has_errors());
assert!(result.diagnostics.iter().any(|d| d.contains("Arity")));
}
#[test]
fn test_validate_expression_with_warnings() {
let expr = TLExpr::exists(
"x",
"Person",
TLExpr::pred("p", vec![Term::var("y")]), );
let result = validate_expression(&expr);
assert!(result.warning_count > 0);
}
#[test]
fn test_validate_with_types() {
use tensorlogic_ir::TypeAnnotation;
let signatures = vec![PredicateSignature {
name: "knows".to_string(),
arity: 2,
arg_types: vec![
TypeAnnotation {
type_name: "Person".to_string(),
},
TypeAnnotation {
type_name: "Person".to_string(),
},
],
parametric_types: None,
}];
let expr = TLExpr::exists(
"x",
"Person",
TLExpr::exists(
"y",
"Person",
TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
),
);
let result = validate_expression_with_types(&expr, &signatures);
assert!(result.is_ok());
}
#[test]
fn test_validation_result_message() {
let expr = TLExpr::pred("knows", vec![Term::var("x")]);
let result = validate_expression(&expr);
let message = result.error_message();
assert!(!message.is_empty());
assert!(message.contains("Unbound"));
}
}