use ndarray::{Array3, ArrayView2, ArrayView3, Axis};
use crate::embedding::EmbeddingError;
#[derive(Debug, Clone, Copy)]
pub struct InferenceManagerConfig {
pub max_chunk_size: usize,
pub min_batch_size: usize,
pub safety_factor: f32,
pub verbose: bool,
}
impl Default for InferenceManagerConfig {
fn default() -> Self {
Self {
max_chunk_size: 16,
min_batch_size: 1,
safety_factor: 0.8,
verbose: false,
}
}
}
#[derive(Debug, Clone)]
pub struct InferenceManager {
pub config: InferenceManagerConfig,
}
impl InferenceManager {
pub fn new(config: InferenceManagerConfig) -> Self {
Self { config }
}
pub fn run_chunked<F>(
&self,
x: ArrayView3<f32>,
mut forward_fn: F,
) -> Result<Array3<f32>, EmbeddingError>
where
F: FnMut(ArrayView3<f32>) -> Result<Array3<f32>, EmbeddingError>,
{
let (b, _t, _h) = (x.shape()[0], x.shape()[1], x.shape()[2]);
let chunk_size = self.effective_chunk_size(b);
if chunk_size >= b {
return forward_fn(x);
}
let mut outputs: Vec<Array3<f32>> = Vec::new();
let mut start = 0;
while start < b {
let end = (start + chunk_size).min(b);
let chunk = x.slice(ndarray::s![start..end, .., ..]);
if self.config.verbose {
eprintln!("[InferenceManager] processing chunk [{start}..{end}) of {b}");
}
let out = forward_fn(chunk)?;
outputs.push(out);
start = end;
}
let views: Vec<_> = outputs.iter().map(|a| a.view()).collect();
Ok(
ndarray::concatenate(Axis(0), &views)
.expect("chunk outputs must agree on axes 1 and 2"),
)
}
pub fn run_predict_chunked<F>(
&self,
x_train: ArrayView2<f32>,
x_test: ArrayView2<f32>,
mut forward_fn: F,
) -> Result<Array3<f32>, EmbeddingError>
where
F: FnMut(ArrayView2<f32>) -> Result<Array3<f32>, EmbeddingError>,
{
let n_test = x_test.shape()[0];
let chunk_size = self.effective_chunk_size(n_test);
if chunk_size >= n_test {
return forward_fn(x_test);
}
let _ = x_train; let mut outputs: Vec<Array3<f32>> = Vec::new();
let mut start = 0;
while start < n_test {
let end = (start + chunk_size).min(n_test);
let chunk = x_test.slice(ndarray::s![start..end, ..]);
let out = forward_fn(chunk)?;
outputs.push(out);
start = end;
}
let views: Vec<_> = outputs.iter().map(|a| a.view()).collect();
Ok(ndarray::concatenate(Axis(1), &views)
.expect("predict chunk outputs must agree on shapes"))
}
fn effective_chunk_size(&self, n: usize) -> usize {
let s = (self.config.max_chunk_size as f32 * self.config.safety_factor) as usize;
s.max(self.config.min_batch_size).min(n)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
#[test]
fn defaults_match_python_signature() {
let c = InferenceManagerConfig::default();
assert_eq!(c.min_batch_size, 1);
assert!((c.safety_factor - 0.8).abs() < 1e-6);
}
#[test]
fn single_chunk_when_batch_fits() {
let mgr = InferenceManager::new(InferenceManagerConfig {
max_chunk_size: 100,
..Default::default()
});
let x = Array::<f32, _>::zeros((3, 4, 2));
let calls = std::cell::Cell::new(0);
let _ = mgr
.run_chunked(x.view(), |chunk| {
calls.set(calls.get() + 1);
let mut out = Array3::<f32>::zeros((chunk.shape()[0], 1, 2));
for bi in 0..chunk.shape()[0] {
out[(bi, 0, 0)] = bi as f32;
}
Ok(out)
})
.unwrap();
assert_eq!(calls.get(), 1);
}
#[test]
fn multi_chunk_when_batch_exceeds_limit() {
let mgr = InferenceManager::new(InferenceManagerConfig {
max_chunk_size: 3,
safety_factor: 1.0,
min_batch_size: 1,
verbose: false,
});
let x = Array::<f32, _>::from_shape_fn((7, 4, 2), |(b, _, _)| b as f32);
let calls = std::cell::Cell::new(0);
let chunks_seen = std::cell::RefCell::new(Vec::new());
let out = mgr
.run_chunked(x.view(), |chunk| {
calls.set(calls.get() + 1);
chunks_seen.borrow_mut().push(chunk.shape()[0]);
let mut o = Array3::<f32>::zeros((chunk.shape()[0], 1, 2));
for bi in 0..chunk.shape()[0] {
o[(bi, 0, 0)] = chunk[(bi, 0, 0)];
}
Ok(o)
})
.unwrap();
assert_eq!(*chunks_seen.borrow(), vec![3, 3, 1]);
assert_eq!(calls.get(), 3);
assert_eq!(out.shape(), &[7, 1, 2]);
for bi in 0..7 {
assert_eq!(out[(bi, 0, 0)], bi as f32);
}
}
#[test]
fn safety_factor_shrinks_chunk_size() {
let mgr = InferenceManager::new(InferenceManagerConfig {
max_chunk_size: 10,
safety_factor: 0.5,
min_batch_size: 1,
verbose: false,
});
assert_eq!(mgr.effective_chunk_size(20), 5);
}
#[test]
fn min_batch_size_floor_honored() {
let mgr = InferenceManager::new(InferenceManagerConfig {
max_chunk_size: 1,
safety_factor: 0.1,
min_batch_size: 4,
verbose: false,
});
assert_eq!(mgr.effective_chunk_size(100), 4);
}
}