use crate::errors::InferenceError;
use crate::memory::MemoryClass;
use crate::message::payload::Payload;
use crate::prelude::Batch;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BackendCapabilities {
device_streams: bool,
max_batch: Option<usize>,
dtype_mask: u64,
}
impl BackendCapabilities {
#[inline]
pub fn new(device_streams: bool, max_batch: Option<usize>, dtype_mask: u64) -> Self {
Self {
device_streams,
max_batch,
dtype_mask,
}
}
#[inline]
pub fn device_streams(&self) -> &bool {
&self.device_streams
}
#[inline]
pub fn max_batch(&self) -> &Option<usize> {
&self.max_batch
}
#[inline]
pub fn dtype_mask(&self) -> &u64 {
&self.dtype_mask
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ModelMetadata {
preferred_input: MemoryClass,
preferred_output: MemoryClass,
max_input_bytes: Option<usize>,
max_output_bytes: Option<usize>,
}
impl ModelMetadata {
#[inline]
pub fn new(
preferred_input: MemoryClass,
preferred_output: MemoryClass,
max_input_bytes: Option<usize>,
max_output_bytes: Option<usize>,
) -> Self {
Self {
preferred_input,
preferred_output,
max_input_bytes,
max_output_bytes,
}
}
#[inline]
pub fn preferred_input(&self) -> &MemoryClass {
&self.preferred_input
}
#[inline]
pub fn preferred_output(&self) -> &MemoryClass {
&self.preferred_output
}
#[inline]
pub fn max_input_bytes(&self) -> &Option<usize> {
&self.max_input_bytes
}
#[inline]
pub fn max_output_bytes(&self) -> &Option<usize> {
&self.max_output_bytes
}
}
pub trait ComputeModel<InP: Payload, OutP: Payload> {
fn init(&mut self) -> Result<(), InferenceError>;
fn infer_one(&mut self, inp: &InP, out: &mut OutP) -> Result<(), InferenceError>;
#[inline]
fn infer_batch(
&mut self,
inps: Batch<'_, InP>,
outs: &mut [OutP],
) -> Result<(), InferenceError> {
for (m, o) in inps.messages().iter().zip(outs.iter_mut()) {
self.infer_one(m.payload(), o)?;
}
Ok(())
}
fn drain(&mut self) -> Result<(), InferenceError>;
fn reset(&mut self) -> Result<(), InferenceError>;
fn metadata(&self) -> ModelMetadata;
}
pub trait ComputeBackend<InP: Payload, OutP: Payload> {
type Model: ComputeModel<InP, OutP>;
type Error;
type ModelDescriptor<'desc>
where
Self: 'desc;
fn capabilities(&self) -> BackendCapabilities;
fn load_model<'desc>(
&self,
desc: Self::ModelDescriptor<'desc>,
) -> Result<Self::Model, Self::Error>;
}
#[cfg(feature = "std")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ModelArtifact {
bytes: std::sync::Arc<Vec<u8>>,
label: Option<String>,
}
#[cfg(feature = "std")]
impl ModelArtifact {
#[inline]
pub fn new(bytes: std::sync::Arc<Vec<u8>>, label: Option<String>) -> Self {
Self { bytes, label }
}
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self {
bytes: std::sync::Arc::new(bytes),
label: None,
}
}
pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
let bytes = std::fs::read(path)?;
Ok(Self::from_bytes(bytes))
}
#[inline]
pub fn bytes(&self) -> std::sync::Arc<Vec<u8>> {
self.bytes.clone()
}
#[inline]
pub fn label(&self) -> Option<&str> {
self.label.as_deref()
}
}