use crate::error::Result;
use crate::model::config::ModelConfig;
use crate::nn::VarBuilder;
use crate::ops::traits::architecture::moe::MoEOps;
use crate::ops::traits::position::alibi::AlibiOps;
use crate::ops::traits::{FlashAttentionOps, KvCacheOps, PagedAttentionOps, RoPEOps};
use crate::quant::traits::{DequantOps, QuantMatmulOps};
use numr::autograd::Var;
use numr::ops::{
ActivationOps, BinaryOps, CompareOps, ConditionalOps, IndexingOps, NormalizationOps, ReduceOps,
ScalarOps, ShapeOps, TensorOps, UnaryOps,
};
use numr::runtime::{Runtime, RuntimeClient};
pub trait ModelClient<R: Runtime>:
RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ RoPEOps<R>
+ FlashAttentionOps<R>
+ PagedAttentionOps<R>
+ KvCacheOps<R>
+ QuantMatmulOps<R>
+ NormalizationOps<R>
+ MoEOps<R>
+ AlibiOps<R>
{
}
impl<R, C> ModelClient<R> for C
where
R: Runtime,
C: RuntimeClient<R>
+ TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ RoPEOps<R>
+ FlashAttentionOps<R>
+ PagedAttentionOps<R>
+ KvCacheOps<R>
+ QuantMatmulOps<R>
+ NormalizationOps<R>
+ MoEOps<R>
+ AlibiOps<R>,
{
}
pub trait Model<R: Runtime>: Sized {
fn from_config(config: &ModelConfig, device: &R::Device) -> Result<Self>;
fn from_varbuilder(vb: &mut VarBuilder<R>, config: &ModelConfig) -> Result<Self>
where
R::Client: DequantOps<R> + numr::ops::TypeConversionOps<R>,
{
let _ = (vb, config);
Err(crate::error::Error::ModelError {
reason: "from_varbuilder not implemented for this model".into(),
})
}
fn forward<C>(&self, client: &C, input_ids: &Var<R>) -> Result<Var<R>>
where
C: ModelClient<R>,
R::Client: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ IndexingOps<R>
+ ShapeOps<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>;
fn config(&self) -> &ModelConfig;
}