Skip to main content

Model

Trait Model 

Source
pub trait Model: Send + Sync {
    type Tensor: ?Sized + Storage;
    type Error: Error + Display + Send + Sync + 'static;

    // Required methods
    fn forward(
        &mut self,
        ctx: &mut RuntimeResourceRef<'_>,
        input: &Self::Tensor,
        completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
    ) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
    fn load_parameters(
        &mut self,
        ctx: &mut RuntimeResourceRef<'_>,
        params: &Self::Tensor,
        completion: CompletionHandle<(), Self::Error>,
    ) -> ContractResponse<(), Self::Error>;
    fn backward(
        &mut self,
        ctx: &mut RuntimeResourceRef<'_>,
        grad: &Self::Tensor,
        completion: CompletionHandle<(), Self::Error>,
    ) -> ContractResponse<(), Self::Error>;
    fn apply_delta(
        &mut self,
        ctx: &mut RuntimeResourceRef<'_>,
        delta: &Self::Tensor,
        completion: CompletionHandle<(), Self::Error>,
    ) -> ContractResponse<(), Self::Error>;
    fn compute_loss(
        &mut self,
        ctx: &mut RuntimeResourceRef<'_>,
        input: &Self::Tensor,
        target: &Self::Tensor,
        completion: CompletionHandle<f32, Self::Error>,
    ) -> ContractResponse<f32, Self::Error>;
    fn params(
        &self,
        ctx: &mut RuntimeResourceRef<'_>,
        completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
    ) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
}
Expand description

User-facing Contract trait for an ML model.

Required Associated Types§

Source

type Tensor: ?Sized + Storage

Tensor storage type. One associated type covers input/output/params/grad/delta. Implement as [f32] for flat f32 tensors.

Source

type Error: Error + Display + Send + Sync + 'static

Library-maker-defined error type.

Required Methods§

Source

fn forward( &mut self, ctx: &mut RuntimeResourceRef<'_>, input: &Self::Tensor, completion: CompletionHandle<Box<Self::Tensor>, Self::Error>, ) -> ContractResponse<Box<Self::Tensor>, Self::Error>

Forward pass: input → output. ctx is the per-dispatch runtime surface; impls reach their declared #[depends(...)] siblings through RuntimeResourceRef::dependency.

Source

fn load_parameters( &mut self, ctx: &mut RuntimeResourceRef<'_>, params: &Self::Tensor, completion: CompletionHandle<(), Self::Error>, ) -> ContractResponse<(), Self::Error>

Load parameters wholesale.

Source

fn backward( &mut self, ctx: &mut RuntimeResourceRef<'_>, grad: &Self::Tensor, completion: CompletionHandle<(), Self::Error>, ) -> ContractResponse<(), Self::Error>

Backward pass: accumulate gradients given upstream gradient.

Source

fn apply_delta( &mut self, ctx: &mut RuntimeResourceRef<'_>, delta: &Self::Tensor, completion: CompletionHandle<(), Self::Error>, ) -> ContractResponse<(), Self::Error>

Apply a parameter delta in-place.

Source

fn compute_loss( &mut self, ctx: &mut RuntimeResourceRef<'_>, input: &Self::Tensor, target: &Self::Tensor, completion: CompletionHandle<f32, Self::Error>, ) -> ContractResponse<f32, Self::Error>

Compute loss: (input, target) → scalar score. Returns f32 regardless of the tensor element type — loss is always a framework-fixed scalar.

Source

fn params( &self, ctx: &mut RuntimeResourceRef<'_>, completion: CompletionHandle<Box<Self::Tensor>, Self::Error>, ) -> ContractResponse<Box<Self::Tensor>, Self::Error>

Snapshot the current parameter tensor (owned — async serialization needs owned values).

Dyn Compatibility§

This trait is dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§