use axonml_autograd::Variable;
use super::preprocess::preprocess_frame;
use super::{CaptureBackend, CaptureConfig, CaptureError};
use std::time::Instant;
pub trait DetectionModel {
type Output;
fn detect(&mut self, input: &Variable) -> Self::Output;
fn input_size(&self) -> (u32, u32);
}
#[derive(Debug, Clone)]
pub struct PipelineStats {
pub frames_processed: u64,
pub total_preprocess_time: f64,
pub total_inference_time: f64,
start_time: Instant,
}
impl PipelineStats {
fn new() -> Self {
Self {
frames_processed: 0,
total_preprocess_time: 0.0,
total_inference_time: 0.0,
start_time: Instant::now(),
}
}
pub fn fps(&self) -> f64 {
let elapsed = self.start_time.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.frames_processed as f64 / elapsed
} else {
0.0
}
}
pub fn avg_preprocess_ms(&self) -> f64 {
if self.frames_processed > 0 {
(self.total_preprocess_time / self.frames_processed as f64) * 1000.0
} else {
0.0
}
}
pub fn avg_inference_ms(&self) -> f64 {
if self.frames_processed > 0 {
(self.total_inference_time / self.frames_processed as f64) * 1000.0
} else {
0.0
}
}
pub fn elapsed_secs(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
}
}
pub struct InferencePipeline<M, B> {
model: M,
backend: B,
stats: PipelineStats,
}
impl<M, B> InferencePipeline<M, B>
where
M: DetectionModel,
B: CaptureBackend,
{
pub fn new(model: M, backend: B) -> Self {
Self {
model,
backend,
stats: PipelineStats::new(),
}
}
pub fn start(&mut self, config: &CaptureConfig) -> Result<(), CaptureError> {
self.backend.open(config)?;
self.stats = PipelineStats::new();
Ok(())
}
pub fn step(&mut self) -> Result<M::Output, CaptureError> {
let frame = self.backend.grab_frame()?;
let (target_w, target_h) = self.model.input_size();
let t0 = Instant::now();
let input = preprocess_frame(&frame, target_w, target_h);
let preprocess_time = t0.elapsed().as_secs_f64();
let t1 = Instant::now();
let output = self.model.detect(&input);
let inference_time = t1.elapsed().as_secs_f64();
self.stats.frames_processed += 1;
self.stats.total_preprocess_time += preprocess_time;
self.stats.total_inference_time += inference_time;
Ok(output)
}
pub fn run_n<F>(&mut self, n: usize, mut callback: F) -> Result<(), CaptureError>
where
F: FnMut(M::Output, &PipelineStats),
{
for _ in 0..n {
let output = self.step()?;
callback(output, &self.stats);
}
Ok(())
}
pub fn stop(&mut self) {
self.backend.close();
}
pub fn stats(&self) -> &PipelineStats {
&self.stats
}
pub fn model_mut(&mut self) -> &mut M {
&mut self.model
}
pub fn model(&self) -> &M {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::camera::file::FileBackend;
struct DummyDetector {
call_count: usize,
}
impl DummyDetector {
fn new() -> Self {
Self { call_count: 0 }
}
}
impl DetectionModel for DummyDetector {
type Output = usize;
fn detect(&mut self, input: &Variable) -> usize {
assert_eq!(input.shape()[0], 1);
assert_eq!(input.shape()[1], 3);
self.call_count += 1;
self.call_count
}
fn input_size(&self) -> (u32, u32) {
(32, 32)
}
}
#[test]
fn test_pipeline_basic() {
let backend = FileBackend::synthetic(64, 48, 128);
let model = DummyDetector::new();
let mut pipeline = InferencePipeline::new(model, backend);
pipeline.start(&CaptureConfig::default()).unwrap();
let result = pipeline.step().unwrap();
assert_eq!(result, 1);
assert_eq!(pipeline.stats().frames_processed, 1);
assert!(pipeline.stats().avg_preprocess_ms() > 0.0);
pipeline.stop();
}
#[test]
fn test_pipeline_run_n() {
let backend = FileBackend::synthetic_sequence(32, 32, 10);
let model = DummyDetector::new();
let mut pipeline = InferencePipeline::new(model, backend);
pipeline.start(&CaptureConfig::default()).unwrap();
let mut results = Vec::new();
pipeline
.run_n(5, |output, _stats| {
results.push(output);
})
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results, vec![1, 2, 3, 4, 5]);
assert_eq!(pipeline.stats().frames_processed, 5);
pipeline.stop();
}
#[test]
fn test_pipeline_fps_tracking() {
let backend = FileBackend::synthetic_sequence(16, 16, 10);
let model = DummyDetector::new();
let mut pipeline = InferencePipeline::new(model, backend);
pipeline.start(&CaptureConfig::default()).unwrap();
for _ in 0..10 {
pipeline.step().unwrap();
}
assert_eq!(pipeline.stats().frames_processed, 10);
assert!(pipeline.stats().fps() > 0.0);
assert!(pipeline.stats().elapsed_secs() > 0.0);
pipeline.stop();
}
#[test]
fn test_pipeline_stats_fresh() {
let stats = PipelineStats::new();
assert_eq!(stats.frames_processed, 0);
assert_eq!(stats.fps(), 0.0);
assert_eq!(stats.avg_preprocess_ms(), 0.0);
assert_eq!(stats.avg_inference_ms(), 0.0);
}
}