use tensorlogic_ir::EinsumGraph;
#[derive(Debug, Clone)]
pub struct BatchResult<T> {
pub outputs: Vec<T>,
pub batch_size: usize,
}
impl<T> BatchResult<T> {
pub fn new(outputs: Vec<T>) -> Self {
let batch_size = outputs.len();
BatchResult {
outputs,
batch_size,
}
}
pub fn len(&self) -> usize {
self.outputs.len()
}
pub fn is_empty(&self) -> bool {
self.outputs.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
self.outputs.iter()
}
}
pub trait TlBatchExecutor {
type Tensor;
type Error;
fn execute_batch(
&mut self,
graph: &EinsumGraph,
batch_inputs: Vec<Vec<Self::Tensor>>,
) -> Result<BatchResult<Self::Tensor>, Self::Error>;
fn execute_batch_parallel(
&mut self,
graph: &EinsumGraph,
batch_inputs: Vec<Vec<Self::Tensor>>,
num_threads: Option<usize>,
) -> Result<BatchResult<Self::Tensor>, Self::Error>;
fn max_batch_size(&self) -> Option<usize> {
None }
fn optimal_batch_size(&self) -> usize {
32 }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_result() {
let outputs = vec![1, 2, 3, 4];
let result = BatchResult::new(outputs.clone());
assert_eq!(result.len(), 4);
assert_eq!(result.batch_size, 4);
assert!(!result.is_empty());
let collected: Vec<&i32> = result.iter().collect();
assert_eq!(collected.len(), 4);
}
#[test]
fn test_empty_batch_result() {
let result: BatchResult<i32> = BatchResult::new(vec![]);
assert_eq!(result.len(), 0);
assert!(result.is_empty());
}
}