use crate::{Scirs2Exec, Scirs2Tensor};
use tensorlogic_infer::{BatchResult, ExecutorError, TlAutodiff, TlBatchExecutor};
use tensorlogic_ir::EinsumGraph;
#[cfg(feature = "parallel")]
use scirs2_core::parallel_ops::*;
impl TlBatchExecutor for Scirs2Exec {
type Tensor = Scirs2Tensor;
type Error = ExecutorError;
fn execute_batch(
&mut self,
graph: &EinsumGraph,
batch_inputs: Vec<Vec<Self::Tensor>>,
) -> Result<BatchResult<Self::Tensor>, Self::Error> {
if batch_inputs.is_empty() {
return Err(ExecutorError::InvalidEinsumSpec(
"Empty batch provided".to_string(),
));
}
let mut outputs = Vec::with_capacity(batch_inputs.len());
for input_batch in batch_inputs {
for (idx, tensor) in input_batch.iter().enumerate() {
if idx < graph.tensors.len() {
self.add_tensor(graph.tensors[idx].clone(), tensor.clone());
}
}
let output = self.forward(graph)?;
outputs.push(output);
}
Ok(BatchResult::new(outputs))
}
fn execute_batch_parallel(
&mut self,
graph: &EinsumGraph,
batch_inputs: Vec<Vec<Self::Tensor>>,
num_threads: Option<usize>,
) -> Result<BatchResult<Self::Tensor>, Self::Error> {
#[cfg(feature = "parallel")]
{
if batch_inputs.is_empty() {
return Err(ExecutorError::InvalidEinsumSpec(
"Empty batch provided".to_string(),
));
}
if let Some(threads) = num_threads {
ThreadPoolBuilder::new()
.num_threads(threads)
.build_global()
.ok(); }
let results: Result<Vec<_>, _> = batch_inputs
.par_iter()
.map(|input_batch| {
let mut executor = self.clone();
for (idx, tensor) in input_batch.iter().enumerate() {
if idx < graph.tensors.len() {
executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
}
}
executor.forward(graph)
})
.collect();
let outputs = results?;
Ok(BatchResult::new(outputs))
}
#[cfg(not(feature = "parallel"))]
{
let _ = num_threads; self.execute_batch(graph, batch_inputs)
}
}
fn optimal_batch_size(&self) -> usize {
let num_cpus = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4);
num_cpus * 2
}
}
pub struct ParallelBatchExecutor {
base: Scirs2Exec,
}
impl ParallelBatchExecutor {
pub fn new() -> Self {
ParallelBatchExecutor {
base: Scirs2Exec::new(),
}
}
pub fn with_memory_pool() -> Self {
ParallelBatchExecutor {
base: Scirs2Exec::with_memory_pool(),
}
}
pub fn execute_parallel(
&self,
graph: &EinsumGraph,
batch_inputs: Vec<Vec<Scirs2Tensor>>,
) -> Result<BatchResult<Scirs2Tensor>, ExecutorError> {
if batch_inputs.is_empty() {
return Err(ExecutorError::InvalidEinsumSpec(
"Empty batch provided".to_string(),
));
}
#[cfg(feature = "parallel")]
{
let results: Result<Vec<_>, _> = batch_inputs
.par_iter()
.map(|input_batch| {
let mut executor = self.base.clone();
for (idx, tensor) in input_batch.iter().enumerate() {
if idx < graph.tensors.len() {
executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
}
}
executor.forward(graph)
})
.collect();
let outputs = results?;
Ok(BatchResult::new(outputs))
}
#[cfg(not(feature = "parallel"))]
{
let mut outputs = Vec::with_capacity(batch_inputs.len());
for input_batch in batch_inputs {
let mut executor = self.base.clone();
for (idx, tensor) in input_batch.iter().enumerate() {
if idx < graph.tensors.len() {
executor.add_tensor(graph.tensors[idx].clone(), tensor.clone());
}
}
let output = executor.forward(graph)?;
outputs.push(output);
}
Ok(BatchResult::new(outputs))
}
}
}
impl Default for ParallelBatchExecutor {
fn default() -> Self {
Self::new()
}
}
impl Clone for Scirs2Exec {
fn clone(&self) -> Self {
Scirs2Exec {
tensors: self.tensors.clone(),
tape: self.tape.clone(),
pool: None, }
}
}
#[cfg(all(test, feature = "integration-tests"))]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
use tensorlogic_compiler::compile_to_einsum;
use tensorlogic_ir::{TLExpr, Term};
fn create_test_tensor(shape: &[usize], value: f64) -> ArrayD<f64> {
ArrayD::from_elem(shape.to_vec(), value)
}
#[test]
fn test_batch_executor_basic() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let y = TLExpr::pred("y", vec![Term::var("i")]);
let expr = TLExpr::add(x, y);
let graph = compile_to_einsum(&expr).expect("unwrap");
let mut executor = Scirs2Exec::new();
let batch_inputs = vec![
vec![create_test_tensor(&[5], 1.0), create_test_tensor(&[5], 2.0)],
vec![create_test_tensor(&[5], 3.0), create_test_tensor(&[5], 4.0)],
vec![create_test_tensor(&[5], 5.0), create_test_tensor(&[5], 6.0)],
];
let result = executor
.execute_batch(&graph, batch_inputs)
.expect("unwrap");
assert_eq!(result.len(), 3);
assert!((result.outputs[0][0] - 3.0).abs() < 1e-6); assert!((result.outputs[1][0] - 7.0).abs() < 1e-6); assert!((result.outputs[2][0] - 11.0).abs() < 1e-6); assert_eq!(result.batch_size, 3);
}
#[test]
fn test_optimal_batch_size() {
let executor = Scirs2Exec::new();
let batch_size = executor.optimal_batch_size();
assert!(batch_size > 0);
assert!(batch_size <= 128); }
#[test]
fn test_parallel_batch_executor() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let expr = TLExpr::mul(x.clone(), x);
let graph = compile_to_einsum(&expr).expect("unwrap");
let executor = ParallelBatchExecutor::new();
let batch_inputs = vec![
vec![create_test_tensor(&[3], 2.0)],
vec![create_test_tensor(&[3], 3.0)],
];
let result = executor
.execute_parallel(&graph, batch_inputs)
.expect("unwrap");
assert_eq!(result.len(), 2);
assert!((result.outputs[0][0] - 4.0).abs() < 1e-6); assert!((result.outputs[1][0] - 9.0).abs() < 1e-6); }
#[test]
fn test_empty_batch_error() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let graph = compile_to_einsum(&x).expect("unwrap");
let mut executor = Scirs2Exec::new();
let batch_inputs: Vec<Vec<ArrayD<f64>>> = vec![];
let result = executor.execute_batch(&graph, batch_inputs);
assert!(result.is_err());
}
#[test]
fn test_batch_parallel_same_as_sequential() {
let x = TLExpr::pred("x", vec![Term::var("i")]);
let y = TLExpr::pred("y", vec![Term::var("i")]);
let expr = TLExpr::add(x, y);
let graph = compile_to_einsum(&expr).expect("unwrap");
let batch_inputs = vec![
vec![create_test_tensor(&[3], 1.0), create_test_tensor(&[3], 2.0)],
vec![create_test_tensor(&[3], 3.0), create_test_tensor(&[3], 4.0)],
];
let mut executor = Scirs2Exec::new();
let result_seq = executor
.execute_batch(&graph, batch_inputs.clone())
.expect("unwrap");
let mut executor2 = Scirs2Exec::new();
let result_par = executor2
.execute_batch_parallel(&graph, batch_inputs, None)
.expect("unwrap");
assert_eq!(result_seq.len(), result_par.len());
for (seq, par) in result_seq.outputs.iter().zip(result_par.outputs.iter()) {
assert_eq!(seq.shape(), par.shape());
for (s, p) in seq.iter().zip(par.iter()) {
assert!((s - p).abs() < 1e-10);
}
}
}
}