use std::{
error::Error,
fmt::Debug,
io::{BufRead, Write},
};
use thiserror::Error;
use crate::{
loader::TensorLoader, vocabulary::TokenId, InferenceParameters, InferenceSession,
InferenceSessionConfig, LoadError, Vocabulary,
};
pub mod common;
pub trait KnownModel: Send + Sync {
type Hyperparameters: Hyperparameters;
fn new<E: Error>(
hyperparameters: Self::Hyperparameters,
params: ModelParameters,
vocabulary: Vocabulary,
tensor_loader: impl TensorLoader<E>,
) -> Result<Self, E>
where
Self: Sized;
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession;
fn evaluate(
&self,
session: &mut InferenceSession,
params: &InferenceParameters,
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
);
fn vocabulary(&self) -> &Vocabulary;
fn n_context_tokens(&self) -> usize;
fn bot_token_id(&self) -> Option<TokenId>;
fn eot_token_id(&self) -> TokenId;
fn inference_parameters(&self) -> &InferenceParameters;
}
pub trait Model: Send + Sync {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession;
fn evaluate(
&self,
session: &mut InferenceSession,
params: &InferenceParameters,
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
);
fn vocabulary(&self) -> &Vocabulary;
fn n_context_tokens(&self) -> usize;
fn bot_token_id(&self) -> Option<TokenId>;
fn eot_token_id(&self) -> TokenId;
fn inference_parameters(&self) -> &InferenceParameters;
}
impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
KnownModel::start_session(self, config)
}
fn evaluate(
&self,
session: &mut InferenceSession,
params: &InferenceParameters,
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
KnownModel::evaluate(self, session, params, input_tokens, output_request)
}
fn vocabulary(&self) -> &Vocabulary {
KnownModel::vocabulary(self)
}
fn n_context_tokens(&self) -> usize {
KnownModel::n_context_tokens(self)
}
fn bot_token_id(&self) -> Option<TokenId> {
KnownModel::bot_token_id(self)
}
fn eot_token_id(&self) -> TokenId {
KnownModel::eot_token_id(self)
}
fn inference_parameters(&self) -> &InferenceParameters {
KnownModel::inference_parameters(self)
}
}
pub trait Hyperparameters: Sized + Default + Debug {
fn read_ggml(reader: &mut dyn BufRead) -> Result<Self, LoadError>;
fn write_ggml(&self, writer: &mut dyn Write) -> Result<(), HyperparametersWriteError>;
fn n_vocabulary(&self) -> usize;
}
#[derive(Error, Debug)]
pub enum HyperparametersWriteError {
#[error("non-specific I/O error")]
Io(#[from] std::io::Error),
#[error("invalid integer conversion")]
InvalidIntegerConversion(#[from] std::num::TryFromIntError),
}
pub struct ModelParameters {
pub prefer_mmap: bool,
pub n_context_tokens: usize,
pub inference_parameters: InferenceParameters,
}
impl Default for ModelParameters {
fn default() -> Self {
Self {
prefer_mmap: true,
n_context_tokens: 2048,
inference_parameters: Default::default(),
}
}
}
#[derive(Default, Debug, PartialEq, Clone)]
pub struct OutputRequest {
pub all_logits: Option<Vec<f32>>,
pub embeddings: Option<Vec<f32>>,
}