use crate::completion::{CompletionHandle, ContractResponse};
use crate::runtime::RuntimeResourceRef;
pub trait Model: Send + Sync {
type Tensor: ?Sized + bb_ir::types::Storage;
type Error: std::error::Error + std::fmt::Display + Send + Sync + 'static;
fn forward(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
input: &Self::Tensor,
completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
fn load_parameters(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
params: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>;
fn backward(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
grad: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>;
fn apply_delta(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
delta: &Self::Tensor,
completion: CompletionHandle<(), Self::Error>,
) -> ContractResponse<(), Self::Error>;
fn compute_loss(
&mut self,
ctx: &mut RuntimeResourceRef<'_>,
input: &Self::Tensor,
target: &Self::Tensor,
completion: CompletionHandle<f32, Self::Error>,
) -> ContractResponse<f32, Self::Error>;
fn params(
&self,
ctx: &mut RuntimeResourceRef<'_>,
completion: CompletionHandle<Box<Self::Tensor>, Self::Error>,
) -> ContractResponse<Box<Self::Tensor>, Self::Error>;
}