use anyhow::{Context, Result};
use indicatif::{ProgressBar, ProgressStyle};
use rayon::prelude::*;
use std::fs;
use std::path::Path;
use std::sync::{Arc, Mutex};
use tensorlogic_compiler::{compile_to_einsum_with_context, CompilerContext};
use tensorlogic_ir::validate_graph;
use crate::output::{print_error, print_success};
use crate::parser::parse_expression;
pub struct BatchProcessor {
context: CompilerContext,
validate: bool,
parallel: bool,
num_threads: Option<usize>,
}
impl BatchProcessor {
pub fn new(context: CompilerContext, validate: bool) -> Self {
Self {
context,
validate,
parallel: false,
num_threads: None,
}
}
#[allow(dead_code)]
pub fn with_parallelism(
context: CompilerContext,
validate: bool,
num_threads: Option<usize>,
) -> Self {
Self {
context,
validate,
parallel: true,
num_threads,
}
}
#[allow(dead_code)]
pub fn enable_parallel(&mut self, num_threads: Option<usize>) {
self.parallel = true;
self.num_threads = num_threads;
}
#[allow(dead_code)]
pub fn disable_parallel(&mut self) {
self.parallel = false;
self.num_threads = None;
}
pub fn process_file(&mut self, input_path: &Path) -> Result<BatchResult> {
let content = fs::read_to_string(input_path)
.with_context(|| format!("Failed to read file: {}", input_path.display()))?;
let expressions: Vec<String> = content
.lines()
.filter(|line| !line.trim().is_empty() && !line.trim().starts_with('#'))
.map(|s| s.to_string())
.collect();
self.process_expressions(&expressions)
}
pub fn process_expressions(&mut self, expressions: &[String]) -> Result<BatchResult> {
if self.parallel {
self.process_expressions_parallel(expressions)
} else {
self.process_expressions_sequential(expressions)
}
}
fn process_expressions_sequential(&mut self, expressions: &[String]) -> Result<BatchResult> {
let total = expressions.len();
let mut successes = 0;
let mut failures = Vec::new();
let pb = ProgressBar::new(total as u64);
pb.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
.unwrap()
.progress_chars("##-"),
);
for (i, expr_str) in expressions.iter().enumerate() {
pb.set_message(format!("Processing expression {}", i + 1));
match self.process_one(expr_str) {
Ok(_) => successes += 1,
Err(e) => failures.push((i + 1, expr_str.clone(), e.to_string())),
}
pb.inc(1);
}
pb.finish_with_message("Done");
Ok(BatchResult {
total,
successes,
failures,
})
}
fn process_expressions_parallel(&self, expressions: &[String]) -> Result<BatchResult> {
let total = expressions.len();
if let Some(num_threads) = self.num_threads {
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build_global()
.ok(); }
let pb = Arc::new(ProgressBar::new(total as u64));
pb.set_style(
ProgressStyle::default_bar()
.template(
"[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg} (parallel)",
)
.unwrap()
.progress_chars("##-"),
);
let successes = Arc::new(Mutex::new(0usize));
let failures = Arc::new(Mutex::new(Vec::new()));
expressions
.par_iter()
.enumerate()
.for_each(|(i, expr_str)| {
let pb_clone = Arc::clone(&pb);
pb_clone.set_message(format!("Processing expression {}", i + 1));
let mut local_context = self.context.clone();
match self.process_one_with_context(expr_str, &mut local_context) {
Ok(_) => {
let mut succ = successes.lock().unwrap();
*succ += 1;
}
Err(e) => {
let mut fails = failures.lock().unwrap();
fails.push((i + 1, expr_str.clone(), e.to_string()));
}
}
pb_clone.inc(1);
});
pb.finish_with_message("Done");
let final_successes = *successes.lock().unwrap();
let final_failures = failures.lock().unwrap().clone();
Ok(BatchResult {
total,
successes: final_successes,
failures: final_failures,
})
}
fn process_one(&mut self, expr_str: &str) -> Result<()> {
let expr = parse_expression(expr_str)?;
let graph = compile_to_einsum_with_context(&expr, &mut self.context)?;
if self.validate {
let report = validate_graph(&graph);
if !report.is_valid() {
anyhow::bail!("Validation failed: {:?}", report.errors);
}
}
Ok(())
}
fn process_one_with_context(
&self,
expr_str: &str,
context: &mut CompilerContext,
) -> Result<()> {
let expr = parse_expression(expr_str)?;
let graph = compile_to_einsum_with_context(&expr, context)?;
if self.validate {
let report = validate_graph(&graph);
if !report.is_valid() {
anyhow::bail!("Validation failed: {:?}", report.errors);
}
}
Ok(())
}
}
pub struct BatchResult {
pub total: usize,
pub successes: usize,
pub failures: Vec<(usize, String, String)>,
}
impl BatchResult {
pub fn print_summary(&self) {
println!("\nBatch Processing Summary:");
println!(" Total: {}", self.total);
print_success(&format!(" Successes: {}", self.successes));
if !self.failures.is_empty() {
print_error(&format!(" Failures: {}", self.failures.len()));
println!("\nFailed expressions:");
for (line_num, expr, error) in &self.failures {
println!(" Line {}: {}", line_num, expr);
println!(" Error: {}", error);
}
}
}
#[allow(dead_code)]
pub fn all_succeeded(&self) -> bool {
self.failures.is_empty()
}
#[allow(dead_code)]
pub fn success_rate(&self) -> f64 {
if self.total == 0 {
0.0
} else {
(self.successes as f64 / self.total as f64) * 100.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensorlogic_compiler::CompilationConfig;
#[test]
fn test_batch_processor_sequential() {
let context = CompilerContext::with_config(CompilationConfig::soft_differentiable());
let mut processor = BatchProcessor::new(context, false);
let expressions = vec!["pred(x, y)".to_string(), "AND(a, b)".to_string()];
let result = processor.process_expressions(&expressions).unwrap();
assert_eq!(result.total, 2);
assert!(result.successes > 0);
}
#[test]
fn test_batch_processor_parallel() {
let context = CompilerContext::with_config(CompilationConfig::soft_differentiable());
let mut processor = BatchProcessor::with_parallelism(context, false, Some(2));
let expressions = vec![
"pred(x, y)".to_string(),
"AND(a, b)".to_string(),
"OR(p, q)".to_string(),
];
let result = processor.process_expressions(&expressions).unwrap();
assert_eq!(result.total, 3);
assert!(result.successes > 0);
}
#[test]
fn test_batch_result_metrics() {
let result = BatchResult {
total: 10,
successes: 8,
failures: vec![(1, "bad".to_string(), "error".to_string())],
};
assert!(!result.all_succeeded());
assert_eq!(result.success_rate(), 80.0);
}
}