pub trait Model: Send + Sync {
type Tensor: ?Sized + Storage;
type Error: Error + Display + Send + Sync + 'static;
// Required methods
fn forward(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
input: &Self::Tensor,
completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
fn load_parameters(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
params: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>;
fn backward(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
grad: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>;
fn apply_delta(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
delta: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>;
fn compute_loss(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
input: &Self::Tensor,
target: &Self::Tensor,
completion: CompletionHandle<f32, Self::Error>,
) -> ContractResponse<f32, Self::Error>;
fn params(
&self,
ctx: &mut RuntimeResourceRef<'_>,
completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
}Expand description
User-facing Contract trait for an ML model.
Required Associated Types§
Required Methods§
Sourcefn forward(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
input: &Self::Tensor,
completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
) -> ContractResponse<Box<Self::Tensor>, Self::Error>
fn forward( &mut self, ctx: &mut RuntimeResourceRef<'_>, input: &Self::Tensor, completion: CompletionHandle<Box<Self::Tensor>, Self::Error>, ) -> ContractResponse<Box<Self::Tensor>, Self::Error>
Forward pass: input → output. ctx is the per-dispatch
runtime surface; impls reach their declared #[depends(...)]
siblings through RuntimeResourceRef::dependency.
Sourcefn load_parameters(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
params: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>
fn load_parameters( &mut self, ctx: &mut RuntimeResourceRef<'_>, params: &Self::Tensor, completion: CompletionHandle<(), Self::Error>, ) -> ContractResponse<(), Self::Error>
Load parameters wholesale.
Sourcefn backward(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
grad: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>
fn backward( &mut self, ctx: &mut RuntimeResourceRef<'_>, grad: &Self::Tensor, completion: CompletionHandle<(), Self::Error>, ) -> ContractResponse<(), Self::Error>
Backward pass: accumulate gradients given upstream gradient.
Sourcefn apply_delta(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
delta: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>
fn apply_delta( &mut self, ctx: &mut RuntimeResourceRef<'_>, delta: &Self::Tensor, completion: CompletionHandle<(), Self::Error>, ) -> ContractResponse<(), Self::Error>
Apply a parameter delta in-place.
Sourcefn compute_loss(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
input: &Self::Tensor,
target: &Self::Tensor,
completion: CompletionHandle<f32, Self::Error>,
) -> ContractResponse<f32, Self::Error>
fn compute_loss( &mut self, ctx: &mut RuntimeResourceRef<'_>, input: &Self::Tensor, target: &Self::Tensor, completion: CompletionHandle<f32, Self::Error>, ) -> ContractResponse<f32, Self::Error>
Compute loss: (input, target) → scalar score. Returns f32
regardless of the tensor element type — loss is always a
framework-fixed scalar.
Sourcefn params(
&self,
ctx: &mut RuntimeResourceRef<'_>,
completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
) -> ContractResponse<Box<Self::Tensor>, Self::Error>
fn params( &self, ctx: &mut RuntimeResourceRef<'_>, completion: CompletionHandle<Box<Self::Tensor>, Self::Error>, ) -> ContractResponse<Box<Self::Tensor>, Self::Error>
Snapshot the current parameter tensor (owned — async serialization needs owned values).
Dyn Compatibility§
This trait is dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".