use crate::compute::{BackendCapabilities, ComputeBackend, ComputeModel, ModelMetadata};
use crate::edge::Edge;
use crate::errors::{InferenceError, NodeError};
use crate::memory::PlacementAcceptance;
use crate::message::{payload::Payload, Message};
use crate::node::{Node, NodeCapabilities, NodeKind, ProcessResult, StepContext, StepResult};
use crate::policy::NodePolicy;
use crate::prelude::{MemoryManager, PlatformClock, Telemetry};
#[inline]
fn map_inference_err(e: InferenceError) -> NodeError {
NodeError::execution_failed().with_code(*e.code())
}
pub struct InferenceModel<B, InP, OutP, const MAX_BATCH: usize>
where
B: ComputeBackend<InP, OutP>,
InP: Payload,
OutP: Payload + Default + Copy,
{
#[allow(dead_code)]
backend: B,
model: B::Model,
backend_caps: BackendCapabilities,
model_meta: ModelMetadata,
node_caps: NodeCapabilities,
node_policy: NodePolicy,
input_acceptance: [PlacementAcceptance; 1],
output_acceptance: [PlacementAcceptance; 1],
scratch_out: OutP,
_pd: core::marker::PhantomData<InP>,
}
impl<B, InP, OutP, const MAX_BATCH: usize> InferenceModel<B, InP, OutP, MAX_BATCH>
where
B: ComputeBackend<InP, OutP>,
InP: Payload,
OutP: Payload + Default + Copy,
{
pub fn new<'desc>(
backend: B,
desc: B::ModelDescriptor<'desc>,
node_policy: NodePolicy,
node_caps: NodeCapabilities,
input_acceptance: [PlacementAcceptance; 1],
output_acceptance: [PlacementAcceptance; 1],
) -> Result<Self, B::Error> {
let backend_caps = backend.capabilities();
let model = backend.load_model(desc)?;
let model_meta = model.metadata();
Ok(Self {
backend,
model,
backend_caps,
model_meta,
node_caps,
node_policy,
input_acceptance,
output_acceptance,
scratch_out: OutP::default(),
_pd: core::marker::PhantomData,
})
}
#[inline]
pub fn backend_capabilities(&self) -> BackendCapabilities {
self.backend_caps
}
#[inline]
pub fn model_metadata(&self) -> ModelMetadata {
self.model_meta
}
}
impl<B, InP, OutP, const MAX_BATCH: usize> Node<1, 1, InP, OutP>
for InferenceModel<B, InP, OutP, MAX_BATCH>
where
B: ComputeBackend<InP, OutP>,
InP: Payload + Default + Copy,
OutP: Payload + Default + Copy,
{
#[inline]
fn describe_capabilities(&self) -> NodeCapabilities {
self.node_caps
}
#[inline]
fn input_acceptance(&self) -> [PlacementAcceptance; 1] {
self.input_acceptance
}
#[inline]
fn output_acceptance(&self) -> [PlacementAcceptance; 1] {
self.output_acceptance
}
#[inline]
fn policy(&self) -> NodePolicy {
self.node_policy
}
#[cfg(any(test, feature = "bench"))]
fn set_policy(&mut self, policy: NodePolicy) {
self.node_policy = policy;
}
#[inline]
fn node_kind(&self) -> NodeKind {
NodeKind::Model
}
#[inline]
fn initialize<C, T>(&mut self, _clock: &C, _telemetry: &mut T) -> Result<(), NodeError>
where
T: Telemetry,
{
Ok(())
}
#[inline]
fn start<C, T>(&mut self, _clock: &C, _telemetry: &mut T) -> Result<(), NodeError>
where
T: Telemetry,
{
self.model.init().map_err(map_inference_err)
}
#[inline]
fn process_message<C>(
&mut self,
msg: &Message<InP>,
_sys_clock: &C,
) -> Result<ProcessResult<OutP>, NodeError>
where
C: PlatformClock + Sized,
{
let inp: &InP = msg.payload();
self.model
.infer_one(inp, &mut self.scratch_out)
.map_err(map_inference_err)?;
let hdr = *msg.header();
let out_msg = Message::new(hdr, core::mem::take(&mut self.scratch_out));
Ok(ProcessResult::Output(out_msg))
}
#[inline]
fn step<'g, 't, 'c, InQ, OutQ, InM, OutM, C, Tel>(
&mut self,
ctx: &mut StepContext<'g, 't, 'c, 1, 1, InP, OutP, InQ, OutQ, InM, OutM, C, Tel>,
) -> Result<StepResult, NodeError>
where
InQ: Edge,
OutQ: Edge,
InM: MemoryManager<InP>,
OutM: MemoryManager<OutP>,
C: PlatformClock + Sized,
Tel: Telemetry + Sized,
{
ctx.pop_and_process(0, |msg| self.process_message(msg, ctx.clock))
}
#[inline]
fn step_batch<'g, 't, 'c, InQ, OutQ, InM, OutM, C, Tel>(
&mut self,
ctx: &mut StepContext<'g, 't, 'c, 1, 1, InP, OutP, InQ, OutQ, InM, OutM, C, Tel>,
) -> Result<StepResult, NodeError>
where
InQ: Edge,
OutQ: Edge,
InM: MemoryManager<InP>,
OutM: MemoryManager<OutP>,
C: PlatformClock + Sized,
Tel: Telemetry + Sized,
{
let want = self.node_policy.batching().fixed_n().unwrap_or(1);
let backend_cap = self.backend_caps.max_batch().unwrap_or(usize::MAX);
let nmax = core::cmp::min(core::cmp::min(want, backend_cap), MAX_BATCH);
if nmax <= 1 {
return self.step(ctx);
}
let node_policy = self.node_policy;
let clock = ctx.clock;
ctx.pop_batch_and_process(0, nmax, &node_policy, |msg| {
self.process_message(msg, clock)
})
}
#[inline]
fn on_watchdog_timeout<C, Tel>(
&mut self,
clock: &C,
_telemetry: &mut Tel,
) -> Result<StepResult, NodeError>
where
C: PlatformClock + Sized,
Tel: Telemetry,
{
if let Some(backoff) = self.node_policy.budget().watchdog_ticks() {
let until = clock.now_ticks().saturating_add(*backoff);
Ok(StepResult::YieldUntil(until))
} else {
Ok(StepResult::YieldUntil(clock.now_ticks()))
}
}
#[inline]
fn stop<C, Tel>(&mut self, _clock: &C, _telemetry: &mut Tel) -> Result<(), NodeError>
where
Tel: Telemetry,
{
self.model.drain().map_err(map_inference_err)?;
self.model.reset().map_err(map_inference_err)
}
}