use crate::error::MlResult;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct PipelineInfo {
pub id: &'static str,
pub name: &'static str,
pub task: PipelineTask,
pub input_size: Option<(u32, u32)>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "kebab-case"))]
pub enum PipelineTask {
SceneClassification,
ShotBoundary,
Detection,
Segmentation,
AestheticScoring,
FaceEmbedding,
Custom,
}
pub trait TypedPipeline {
type Input;
type Output;
fn run(&self, input: Self::Input) -> MlResult<Self::Output>;
fn info(&self) -> PipelineInfo;
}
#[cfg(test)]
mod tests {
use super::*;
struct DoublePipeline;
impl TypedPipeline for DoublePipeline {
type Input = i32;
type Output = i32;
fn run(&self, input: Self::Input) -> MlResult<Self::Output> {
Ok(input * 2)
}
fn info(&self) -> PipelineInfo {
PipelineInfo {
id: "test/double",
name: "Double",
task: PipelineTask::Custom,
input_size: None,
}
}
}
#[test]
fn trait_object_works() {
let p: Box<dyn TypedPipeline<Input = i32, Output = i32>> = Box::new(DoublePipeline);
let info = p.info();
assert_eq!(info.id, "test/double");
assert_eq!(p.run(21).expect("ok"), 42);
}
}