use std::{any::Any, collections::HashMap};
#[cfg(not(target_arch = "wasm32"))]
use futures::future::BoxFuture;
#[cfg(target_arch = "wasm32")]
use futures::future::LocalBoxFuture;
use half::f16;
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::wasm_bindgen;
use super::loader::{Lora, Reader, PAD_MAT};
use crate::{
context::{Context, ContextBuilder},
impl_deserialize_seed,
num::Scalar,
tensor::{kind::ReadWrite, shape::Shape, TensorCpu, TensorError, TensorGpu, TensorGpuView},
};
#[wasm_bindgen]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelVersion {
V4,
V5,
V6,
V7,
}
#[wasm_bindgen]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ModelInfo {
pub version: ModelVersion,
pub num_layer: usize,
pub num_emb: usize,
pub num_hidden: usize,
pub num_vocab: usize,
pub num_head: usize,
#[wasm_bindgen(skip)]
pub custom: ModelCustomInfo,
}
impl ModelInfo {
pub const BUFFER_SIZE: usize = 256 << 20;
pub const STORAGE_BUFFER_BINDING_SIZE: usize = 128 << 20;
}
impl_deserialize_seed!(ModelInfo);
#[wasm_bindgen]
impl ModelInfo {
pub fn max_non_head_buffer_size(&self) -> usize {
self.num_emb * self.num_hidden * f16::size()
}
pub fn head_buffer_size(&self) -> usize {
self.num_emb * self.num_vocab_padded() * f16::size()
}
pub fn num_vocab_padded(&self) -> usize {
self.num_vocab.next_multiple_of(PAD_MAT[1])
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelCustomInfo {
#[default]
None,
V6(super::v6::CustomInfo),
V7(super::v7::CustomInfo),
}
pub trait AsAny {
fn as_any(&self) -> &dyn Any;
}
pub trait State {
fn num_batch(&self) -> usize;
fn init_shape(&self) -> Shape;
fn init(&self) -> TensorCpu<f32>;
fn att(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
fn ffn(&self, layer: usize) -> Result<TensorGpuView<'_, f32>, TensorError>;
fn load(&self, tensor: TensorCpu<f32>, batch: usize) -> Result<(), TensorError>;
#[cfg(not(target_arch = "wasm32"))]
fn back(&self, batch: usize) -> BoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
#[cfg(target_arch = "wasm32")]
fn back(&self, batch: usize) -> LocalBoxFuture<'_, Result<TensorCpu<f32>, TensorError>>;
fn write(&self, tensor: TensorGpu<f32, ReadWrite>, batch: usize) -> Result<(), TensorError>;
fn read(&self, batch: usize) -> Result<TensorGpu<f32, ReadWrite>, TensorError>;
fn embed(&self, layer: usize, backed: TensorCpu<f32>) -> Result<TensorCpu<f32>, TensorError>;
}
pub trait Bundle {
fn info(&self) -> ModelInfo;
#[cfg(not(target_arch = "wasm32"))]
fn state(&self) -> impl State + AsAny + Send + Sync + 'static;
#[cfg(target_arch = "wasm32")]
fn state(&self) -> impl State + AsAny + 'static;
#[cfg(not(target_arch = "wasm32"))]
fn model(&self) -> impl Serialize + Send + Sync + 'static;
#[cfg(target_arch = "wasm32")]
fn model(&self) -> impl Serialize + 'static;
}
#[wasm_bindgen]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Quant {
#[default]
None,
Int8,
NF4,
SF4,
}
pub struct ModelBuilder<R: Reader> {
pub context: Context,
pub model: R,
pub rescale: Option<usize>,
pub sep: Option<usize>,
pub lora: Vec<Lora<R>>,
pub quant: HashMap<usize, Quant>,
}
impl<R: Reader> ModelBuilder<R> {
pub fn new(context: &Context, model: R) -> Self {
Self {
context: context.clone(),
model,
rescale: None,
sep: None,
lora: vec![],
quant: Default::default(),
}
}
pub fn rescale(mut self, value: usize) -> Self {
self.rescale = match value {
0 => Some(usize::MAX),
x => Some(x),
};
self
}
pub fn sep(mut self, value: usize) -> Self {
self.sep = match value {
0 => Some(usize::MAX),
x => Some(x),
};
self
}
pub fn lora(mut self, value: Lora<R>) -> Self {
self.lora.push(value);
self
}
pub fn quant(mut self, value: HashMap<usize, Quant>) -> Self {
self.quant = value;
self
}
}
pub trait ContextAutoLimits {
fn auto_limits(self, info: &ModelInfo) -> Self;
}
impl ContextAutoLimits for ContextBuilder {
fn auto_limits(mut self, info: &ModelInfo) -> Self {
self.limits.max_buffer_size = ModelInfo::BUFFER_SIZE
.max(info.max_non_head_buffer_size())
.max(info.head_buffer_size()) as u64;
self.limits.max_storage_buffer_binding_size = ModelInfo::STORAGE_BUFFER_BINDING_SIZE
.max(info.max_non_head_buffer_size())
.max(info.head_buffer_size())
as u64;
self
}
}