llm_base/model/
mod.rs

1//! Large language model traits and types
2
3use std::{
4    error::Error,
5    fmt::Debug,
6    io::{BufRead, Write},
7};
8
9use thiserror::Error;
10
11use crate::{
12    loader::TensorLoader, vocabulary::TokenId, InferenceParameters, InferenceSession,
13    InferenceSessionConfig, LoadError, Vocabulary,
14};
15
16/// Common functions for model evaluation
17pub mod common;
18
19/// Interfaces for creating and interacting with a large language model with a known type
20/// of [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)).
21pub trait KnownModel: Send + Sync {
22    /// Hyperparameters for the model
23    type Hyperparameters: Hyperparameters;
24
25    /// Creates a new model from the provided [ModelParameters] hyperparameters.
26    /// This function is called by the [load](crate::loader::load) function.
27    fn new<E: Error>(
28        hyperparameters: Self::Hyperparameters,
29        params: ModelParameters,
30        vocabulary: Vocabulary,
31        tensor_loader: impl TensorLoader<E>,
32    ) -> Result<Self, E>
33    where
34        Self: Sized;
35
36    /// Starts a new `InferenceSession` for this model.
37    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession;
38
39    /// This function is called by the provided [InferenceSession]; it will use this model
40    /// and the [InferenceParameters] to generate output by evaluating the `input_tokens`.
41    /// The [OutputRequest] is used to specify additional data to fetch from the
42    /// model.
43    fn evaluate(
44        &self,
45        session: &mut InferenceSession,
46        params: &InferenceParameters,
47        input_tokens: &[TokenId],
48        output_request: &mut OutputRequest,
49    );
50
51    /// Get the vocabulary (loaded from the GGML file) for this model.
52    fn vocabulary(&self) -> &Vocabulary;
53
54    /// Get the context size (configured with [ModelParameters::n_context_tokens]) used by
55    /// this model.
56    fn n_context_tokens(&self) -> usize;
57
58    /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers.
59    fn bot_token_id(&self) -> Option<TokenId>;
60
61    /// Get the end of text/end of string token ID. This value is defined by model implementers.
62    fn eot_token_id(&self) -> TokenId;
63
64    /// Get the default [InferenceParameters] for this model (used by
65    /// [InferenceSession::infer]). This value is configured through
66    /// [ModelParameters::inference_parameters].
67    fn inference_parameters(&self) -> &InferenceParameters;
68}
69
70/// A type-erased model to allow for interacting with a model without knowing
71/// its hyperparameters.
72pub trait Model: Send + Sync {
73    /// Starts a new `InferenceSession` for this model.
74    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession;
75
76    /// This function is called by the provided [InferenceSession]; it will use this model
77    /// and the [InferenceParameters] to generate output by evaluating the `input_tokens`.
78    /// The [OutputRequest] is used to specify additional data to fetch from the
79    /// model.
80    fn evaluate(
81        &self,
82        session: &mut InferenceSession,
83        params: &InferenceParameters,
84        input_tokens: &[TokenId],
85        output_request: &mut OutputRequest,
86    );
87
88    /// Get the vocabulary (loaded from the GGML file) for this model.
89    fn vocabulary(&self) -> &Vocabulary;
90
91    /// Get the context size (configured with [ModelParameters::n_context_tokens]) used by
92    /// this model.
93    fn n_context_tokens(&self) -> usize;
94
95    /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers.
96    fn bot_token_id(&self) -> Option<TokenId>;
97
98    /// Get the end of text/end of string token ID. This value is defined by model implementers.
99    fn eot_token_id(&self) -> TokenId;
100
101    /// Get the default [InferenceParameters] for this model (used by
102    /// [InferenceSession::infer]). This value is configured through
103    /// [ModelParameters::inference_parameters].
104    fn inference_parameters(&self) -> &InferenceParameters;
105}
106impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
107    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
108        KnownModel::start_session(self, config)
109    }
110
111    fn evaluate(
112        &self,
113        session: &mut InferenceSession,
114        params: &InferenceParameters,
115        input_tokens: &[TokenId],
116        output_request: &mut OutputRequest,
117    ) {
118        KnownModel::evaluate(self, session, params, input_tokens, output_request)
119    }
120
121    fn vocabulary(&self) -> &Vocabulary {
122        KnownModel::vocabulary(self)
123    }
124
125    fn n_context_tokens(&self) -> usize {
126        KnownModel::n_context_tokens(self)
127    }
128
129    fn bot_token_id(&self) -> Option<TokenId> {
130        KnownModel::bot_token_id(self)
131    }
132
133    fn eot_token_id(&self) -> TokenId {
134        KnownModel::eot_token_id(self)
135    }
136
137    fn inference_parameters(&self) -> &InferenceParameters {
138        KnownModel::inference_parameters(self)
139    }
140}
141
142/// Implemented by model hyperparameters for interacting with hyperparameters
143/// without knowing what they are, as well as writing/reading them as required.
144pub trait Hyperparameters: Sized + Default + Debug {
145    /// Read the parameters in GGML format from a reader.
146    fn read_ggml(reader: &mut dyn BufRead) -> Result<Self, LoadError>;
147
148    /// Write the parameters in GGML format to a writer.
149    fn write_ggml(&self, writer: &mut dyn Write) -> Result<(), HyperparametersWriteError>;
150
151    /// Get the number of tokens in the vocabulary.
152    fn n_vocabulary(&self) -> usize;
153}
154#[derive(Error, Debug)]
155/// Reported from functions that write
156pub enum HyperparametersWriteError {
157    #[error("non-specific I/O error")]
158    /// A non-specific IO error.
159    Io(#[from] std::io::Error),
160    #[error("invalid integer conversion")]
161    /// One of the integers encountered could not be converted to a more appropriate type.
162    InvalidIntegerConversion(#[from] std::num::TryFromIntError),
163}
164
165/// Parameters for tuning model instances
166pub struct ModelParameters {
167    /// For [GGML formats](ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap)
168    /// is the default. Although mmap typically improves performance, setting this value to `false` may
169    /// be preferred in resource-constrained environments.
170    pub prefer_mmap: bool,
171    /// The context size ("memory") the model should use when evaluating a prompt. A larger context
172    /// consumes more resources, but produces more consistent and coherent responses.
173    pub n_context_tokens: usize,
174    /// Default InferenceParameters to use when [evaluating](Model::evaluate) a prompt with this model.
175    pub inference_parameters: InferenceParameters,
176}
177
178impl Default for ModelParameters {
179    fn default() -> Self {
180        Self {
181            prefer_mmap: true,
182            n_context_tokens: 2048,
183            inference_parameters: Default::default(),
184        }
185    }
186}
187
188/// Used in a call to [Model::evaluate] or [InferenceSession::infer] to request
189/// information from the model. If a value is set to `Some`, the `Vec` will be
190/// cleared, resized, and filled with the related data.
191#[derive(Default, Debug, PartialEq, Clone)]
192pub struct OutputRequest {
193    /// Returns all the logits for evaluation. A logit represents the likelihood
194    /// that a given token will be generated based on the tokens that have been
195    /// evaluated or generated so far. Output shape is `n_batch * n_vocab`.
196    pub all_logits: Option<Vec<f32>>,
197    /// Returns all the embeddings for an evaluation. An embedding is a vector
198    /// that measures the relatedness of text strings. Output shape is
199    /// `n_batch * n_embd`.
200    pub embeddings: Option<Vec<f32>>,
201}