llm_base/model/mod.rs
1//! Large language model traits and types
2
3use std::{
4 error::Error,
5 fmt::Debug,
6 io::{BufRead, Write},
7};
8
9use thiserror::Error;
10
11use crate::{
12 loader::TensorLoader, vocabulary::TokenId, InferenceParameters, InferenceSession,
13 InferenceSessionConfig, LoadError, Vocabulary,
14};
15
16/// Common functions for model evaluation
17pub mod common;
18
19/// Interfaces for creating and interacting with a large language model with a known type
20/// of [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)).
21pub trait KnownModel: Send + Sync {
22 /// Hyperparameters for the model
23 type Hyperparameters: Hyperparameters;
24
25 /// Creates a new model from the provided [ModelParameters] hyperparameters.
26 /// This function is called by the [load](crate::loader::load) function.
27 fn new<E: Error>(
28 hyperparameters: Self::Hyperparameters,
29 params: ModelParameters,
30 vocabulary: Vocabulary,
31 tensor_loader: impl TensorLoader<E>,
32 ) -> Result<Self, E>
33 where
34 Self: Sized;
35
36 /// Starts a new `InferenceSession` for this model.
37 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession;
38
39 /// This function is called by the provided [InferenceSession]; it will use this model
40 /// and the [InferenceParameters] to generate output by evaluating the `input_tokens`.
41 /// The [OutputRequest] is used to specify additional data to fetch from the
42 /// model.
43 fn evaluate(
44 &self,
45 session: &mut InferenceSession,
46 params: &InferenceParameters,
47 input_tokens: &[TokenId],
48 output_request: &mut OutputRequest,
49 );
50
51 /// Get the vocabulary (loaded from the GGML file) for this model.
52 fn vocabulary(&self) -> &Vocabulary;
53
54 /// Get the context size (configured with [ModelParameters::n_context_tokens]) used by
55 /// this model.
56 fn n_context_tokens(&self) -> usize;
57
58 /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers.
59 fn bot_token_id(&self) -> Option<TokenId>;
60
61 /// Get the end of text/end of string token ID. This value is defined by model implementers.
62 fn eot_token_id(&self) -> TokenId;
63
64 /// Get the default [InferenceParameters] for this model (used by
65 /// [InferenceSession::infer]). This value is configured through
66 /// [ModelParameters::inference_parameters].
67 fn inference_parameters(&self) -> &InferenceParameters;
68}
69
70/// A type-erased model to allow for interacting with a model without knowing
71/// its hyperparameters.
72pub trait Model: Send + Sync {
73 /// Starts a new `InferenceSession` for this model.
74 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession;
75
76 /// This function is called by the provided [InferenceSession]; it will use this model
77 /// and the [InferenceParameters] to generate output by evaluating the `input_tokens`.
78 /// The [OutputRequest] is used to specify additional data to fetch from the
79 /// model.
80 fn evaluate(
81 &self,
82 session: &mut InferenceSession,
83 params: &InferenceParameters,
84 input_tokens: &[TokenId],
85 output_request: &mut OutputRequest,
86 );
87
88 /// Get the vocabulary (loaded from the GGML file) for this model.
89 fn vocabulary(&self) -> &Vocabulary;
90
91 /// Get the context size (configured with [ModelParameters::n_context_tokens]) used by
92 /// this model.
93 fn n_context_tokens(&self) -> usize;
94
95 /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers.
96 fn bot_token_id(&self) -> Option<TokenId>;
97
98 /// Get the end of text/end of string token ID. This value is defined by model implementers.
99 fn eot_token_id(&self) -> TokenId;
100
101 /// Get the default [InferenceParameters] for this model (used by
102 /// [InferenceSession::infer]). This value is configured through
103 /// [ModelParameters::inference_parameters].
104 fn inference_parameters(&self) -> &InferenceParameters;
105}
106impl<H: Hyperparameters, M: KnownModel<Hyperparameters = H>> Model for M {
107 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
108 KnownModel::start_session(self, config)
109 }
110
111 fn evaluate(
112 &self,
113 session: &mut InferenceSession,
114 params: &InferenceParameters,
115 input_tokens: &[TokenId],
116 output_request: &mut OutputRequest,
117 ) {
118 KnownModel::evaluate(self, session, params, input_tokens, output_request)
119 }
120
121 fn vocabulary(&self) -> &Vocabulary {
122 KnownModel::vocabulary(self)
123 }
124
125 fn n_context_tokens(&self) -> usize {
126 KnownModel::n_context_tokens(self)
127 }
128
129 fn bot_token_id(&self) -> Option<TokenId> {
130 KnownModel::bot_token_id(self)
131 }
132
133 fn eot_token_id(&self) -> TokenId {
134 KnownModel::eot_token_id(self)
135 }
136
137 fn inference_parameters(&self) -> &InferenceParameters {
138 KnownModel::inference_parameters(self)
139 }
140}
141
142/// Implemented by model hyperparameters for interacting with hyperparameters
143/// without knowing what they are, as well as writing/reading them as required.
144pub trait Hyperparameters: Sized + Default + Debug {
145 /// Read the parameters in GGML format from a reader.
146 fn read_ggml(reader: &mut dyn BufRead) -> Result<Self, LoadError>;
147
148 /// Write the parameters in GGML format to a writer.
149 fn write_ggml(&self, writer: &mut dyn Write) -> Result<(), HyperparametersWriteError>;
150
151 /// Get the number of tokens in the vocabulary.
152 fn n_vocabulary(&self) -> usize;
153}
154#[derive(Error, Debug)]
155/// Reported from functions that write
156pub enum HyperparametersWriteError {
157 #[error("non-specific I/O error")]
158 /// A non-specific IO error.
159 Io(#[from] std::io::Error),
160 #[error("invalid integer conversion")]
161 /// One of the integers encountered could not be converted to a more appropriate type.
162 InvalidIntegerConversion(#[from] std::num::TryFromIntError),
163}
164
165/// Parameters for tuning model instances
166pub struct ModelParameters {
167 /// For [GGML formats](ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap)
168 /// is the default. Although mmap typically improves performance, setting this value to `false` may
169 /// be preferred in resource-constrained environments.
170 pub prefer_mmap: bool,
171 /// The context size ("memory") the model should use when evaluating a prompt. A larger context
172 /// consumes more resources, but produces more consistent and coherent responses.
173 pub n_context_tokens: usize,
174 /// Default InferenceParameters to use when [evaluating](Model::evaluate) a prompt with this model.
175 pub inference_parameters: InferenceParameters,
176}
177
178impl Default for ModelParameters {
179 fn default() -> Self {
180 Self {
181 prefer_mmap: true,
182 n_context_tokens: 2048,
183 inference_parameters: Default::default(),
184 }
185 }
186}
187
188/// Used in a call to [Model::evaluate] or [InferenceSession::infer] to request
189/// information from the model. If a value is set to `Some`, the `Vec` will be
190/// cleared, resized, and filled with the related data.
191#[derive(Default, Debug, PartialEq, Clone)]
192pub struct OutputRequest {
193 /// Returns all the logits for evaluation. A logit represents the likelihood
194 /// that a given token will be generated based on the tokens that have been
195 /// evaluated or generated so far. Output shape is `n_batch * n_vocab`.
196 pub all_logits: Option<Vec<f32>>,
197 /// Returns all the embeddings for an evaluation. An embedding is a vector
198 /// that measures the relatedness of text strings. Output shape is
199 /// `n_batch * n_embd`.
200 pub embeddings: Option<Vec<f32>>,
201}