ai_chain/options.rs
1use lazy_static::lazy_static;
2use paste::paste;
3use std::{collections::HashMap, env::VarError, ffi::OsStr};
4
5use serde::{Deserialize, Serialize};
6use strum_macros::EnumDiscriminants;
7
8use crate::tokens::Token;
9
10/// A collection of options that can be used to configure a model.
11#[derive(Default, Debug, Clone, Serialize, Deserialize)]
12/// `Options` is the struct that represents a set of options for a large language model.
13/// It provides methods for creating, adding, and retrieving options.
14///
15/// The 'Options' struct is mainly created using the `OptionsBuilder` to allow for
16/// flexibility in setting options.
17pub struct Options {
18 /// The actual options, stored as a vector.
19 opts: Vec<Opt>,
20}
21
22#[derive(thiserror::Error, Debug)]
23/// An error indicating that a required option is not set.
24#[error("Option not set")]
25struct OptionNotSetError;
26
27lazy_static! {
28 /// An empty set of options, useful as a default.
29 static ref EMPTY_OPTIONS: Options = Options::builder().build();
30}
31
32impl Options {
33 /// Constructs a new `OptionsBuilder` for creating an `Options` instance.
34 ///
35 /// This function serves as an entry point for using the builder pattern to create `Options`.
36 ///
37 /// # Returns
38 ///
39 /// An `OptionsBuilder` instance.
40 ///
41 /// # Example
42 ///
43 /// ```rust
44 /// # use ai_chain::options::*;
45 /// let builder = Options::builder();
46 /// ```
47 pub fn builder() -> OptionsBuilder {
48 OptionsBuilder::new()
49 }
50
51 /// Returns a reference to an empty set of options.
52 ///
53 /// This function provides a static reference to an empty `Options` instance,
54 /// which can be useful as a default value.
55 ///
56 /// # Returns
57 ///
58 /// A reference to an empty `Options`.
59 ///
60 /// # Example
61 ///
62 /// ```rust
63 /// # use ai_chain::options::*;
64 /// let empty_options = Options::empty();
65 /// ```
66 pub fn empty() -> &'static Self {
67 &EMPTY_OPTIONS
68 }
69 /// Gets the value of an option from this set of options.
70 ///
71 /// This function finds the first option in `opts` that matches the provided `OptDiscriminants`.
72 ///
73 /// # Arguments
74 ///
75 /// * `opt_discriminant` - An `OptDiscriminants` value representing the discriminant of the desired `Opt`.
76 ///
77 /// # Returns
78 ///
79 /// An `Option` that contains a reference to the `Opt` if found, or `None` if not found.
80 ///
81 /// # Example
82 ///
83 /// ```rust
84 /// # use ai_chain::options::*;
85 /// let mut builder = Options::builder();
86 /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
87 /// let options = builder.build();
88 /// let model_option = options.get(OptDiscriminants::Model);
89 /// ```
90 pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
91 self.opts
92 .iter()
93 .find(|opt| OptDiscriminants::from(*opt) == opt_discriminant)
94 }
95}
96
97/// `options!` is a declarative macro that facilitates the creation of an `Options` instance.
98///
99/// # Usage
100///
101/// This macro can be used to construct an instance of `Options` using a more readable and
102/// ergonomic syntax. The syntax of the macro is:
103///
104/// ```ignore
105/// options!{
106/// OptionName1: value1,
107/// OptionName2: value2,
108/// ...
109/// }
110/// ```
111///
112/// Here, `OptionNameN` is the identifier of the option you want to set, and `valueN` is the value
113/// you want to assign to that option.
114///
115/// # Example
116///
117/// ```ignore
118/// let options = options!{
119/// FooBar: "lol",
120/// SomeReadyMadeOption: "another_value"
121/// };
122/// ```
123///
124/// In this example, an instance of `Options` is being created with two options: `FooBar` and
125/// `SomeReadyMadeOption`, which are set to `"lol"` and `"another_value"`, respectively.
126///
127/// # Notes
128///
129/// - The option identifier (`OptionNameN`) must match an enum variant in `Opt`. If the identifier
130/// does not match any of the `Opt` variants, a compilation error will occur.
131///
132/// - The value (`valueN`) should be of a type that is acceptable for the corresponding option.
133/// If the value type does not match the expected type for the option, a compilation error will occur.
134///
135#[macro_export]
136macro_rules! options {
137 ( $( $opt_name:ident : $opt_value:expr ),* ) => {
138 {
139 let mut _opts = $crate::options::Options::builder();
140 $(
141 _opts.add_option($crate::options::Opt::$opt_name($opt_value.into()));
142 )*
143 _opts.build()
144 }
145 };
146}
147
148/// `OptionsBuilder` is a helper structure used to construct `Options` in a flexible way.
149///
150/// `OptionsBuilder` follows the builder pattern, providing a fluent interface to add options
151/// and finally, build an `Options` instance. This pattern is used to handle cases where the `Options`
152/// instance may require complex configuration or optional fields.
153///
154///
155/// # Example
156///
157/// ```rust
158/// # use ai_chain::options::*;
159/// let mut builder = OptionsBuilder::new();
160/// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
161/// let options = builder.build();
162/// ```
163#[derive(Default, Debug, Clone, Serialize, Deserialize)]
164pub struct OptionsBuilder {
165 /// A Vec<Opt> field that holds the options to be added to the `Options` instance.
166 opts: Vec<Opt>,
167}
168
169impl OptionsBuilder {
170 /// Constructs a new, empty `OptionsBuilder`.
171 ///
172 /// Returns an `OptionsBuilder` instance with an empty `opts` field.
173 ///
174 /// # Example
175 ///
176 /// ```rust
177 /// # use ai_chain::options::*;
178 /// let builder = OptionsBuilder::new();
179 /// ```
180 pub fn new() -> Self {
181 OptionsBuilder { opts: Vec::new() }
182 }
183
184 /// Adds an option to the `OptionsBuilder`.
185 ///
186 /// This function takes an `Opt` instance and pushes it to the `opts` field.
187 ///
188 /// # Arguments
189 ///
190 /// * `opt` - An `Opt` instance to be added to the `OptionsBuilder`.
191 ///
192 /// # Example
193 ///
194 /// ```rust
195 /// # use ai_chain::options::*;
196 /// let mut builder = OptionsBuilder::new();
197 /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
198 /// ```
199 pub fn add_option(&mut self, opt: Opt) {
200 self.opts.push(opt);
201 }
202
203 /// Consumes the `OptionsBuilder`, returning an `Options` instance.
204 ///
205 /// This function consumes the `OptionsBuilder`, moving its `opts` field to a new `Options` instance.
206 ///
207 /// # Returns
208 ///
209 /// An `Options` instance with the options added through the builder.
210 ///
211 /// # Example
212 ///
213 /// ```rust
214 /// # use ai_chain::options::*;
215 /// let mut builder = OptionsBuilder::new();
216 /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
217 /// let options = builder.build();
218 /// ```
219 pub fn build(self) -> Options {
220 Options { opts: self.opts }
221 }
222}
223
224/// A cascade of option sets.
225///
226/// Options added later in the cascade override earlier options.
227pub struct OptionsCascade<'a> {
228 /// The sets of options, in the order they were added.
229 cascades: Vec<&'a Options>,
230}
231
232impl<'a> OptionsCascade<'a> {
233 /// Creates a new, empty cascade of options.
234 pub fn new() -> Self {
235 OptionsCascade::from_vec(Vec::new())
236 }
237
238 /// Setups a typical options cascade, with model_defaults, environment defaults a model config and possibly a specific config.
239 pub fn new_typical(
240 model_default: &'a Options,
241 env_defaults: &'a Options,
242 model_config: &'a Options,
243 specific_config: Option<&'a Options>,
244 ) -> Self {
245 let mut v = vec![model_default, env_defaults, model_config];
246 if let Some(specific_config) = specific_config {
247 v.push(specific_config);
248 }
249 Self::from_vec(v)
250 }
251
252 /// Creates a new cascade of options from a vector of option sets.
253 pub fn from_vec(cascades: Vec<&'a Options>) -> Self {
254 OptionsCascade { cascades }
255 }
256
257 /// Returns a new cascade of options with the given set of options added.
258 pub fn with_options(mut self, options: &'a Options) -> Self {
259 self.cascades.push(options);
260 self
261 }
262
263 /// Gets the value of an option from this cascade.
264 ///
265 /// Returns `None` if the option is not present in any set in this cascade.
266 /// If the option is present in multiple sets, the value from the most
267 /// recently added set is returned.
268 pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
269 for options in self.cascades.iter().rev() {
270 if let Some(opt) = options.get(opt_discriminant) {
271 return Some(opt);
272 }
273 }
274 None
275 }
276
277 /// Returns a boolean indicating if options indicate that requests should be streamed or not.
278 pub fn is_streaming(&self) -> bool {
279 let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
280 return false;
281 };
282 *val
283 }
284}
285
286impl<'a> Default for OptionsCascade<'a> {
287 /// Returns a new, empty cascade of options.
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293#[derive(Clone, Debug, Serialize, Deserialize)]
294/// A reference to a model name or path
295/// Useful for
296pub struct ModelRef(String);
297
298impl ModelRef {
299 pub fn from_path<S: Into<String>>(p: S) -> Self {
300 Self(p.into())
301 }
302 pub fn from_model_name<S: Into<String>>(model_name: S) -> Self {
303 Self(model_name.into())
304 }
305 /// Returns the path for this model reference
306 pub fn to_path(&self) -> String {
307 self.0.clone()
308 }
309 /// Returns the name of the model
310 pub fn to_name(&self) -> String {
311 self.0.clone()
312 }
313}
314
315/// A list of tokens to bias during the process of inferencing.
316#[derive(Serialize, Deserialize, Debug, Clone)]
317pub struct TokenBias(Vec<(Token, f32)>); // TODO: Serialize to a JSON object of str(F32) =>
318
319impl TokenBias {
320 /// Returns the token bias as a hashmap where the keys are i32 and the value f32. If the type doesn't match returns None
321 pub fn as_i32_f32_hashmap(&self) -> Option<HashMap<i32, f32>> {
322 let mut map = HashMap::new();
323 for (token, value) in &self.0 {
324 map.insert(token.to_i32()?, *value);
325 }
326 Some(map)
327 }
328}
329
330#[derive(EnumDiscriminants, Clone, Debug, Serialize, Deserialize)]
331pub enum Opt {
332 /// The name or path of the model used.
333 Model(ModelRef),
334 /// The API key for the model service.
335 ApiKey(String),
336 /// The number of threads to use for parallel processing.
337 /// This is common to all models.
338 NThreads(usize),
339 /// The maximum number of tokens that the model will generate.
340 /// This is common to all models.
341 MaxTokens(usize),
342 /// The maximum context size of the model.
343 MaxContextSize(usize),
344 /// The maximum batch size of the model.
345 /// This is used by llama models.
346 MaxBatchSize(usize),
347 /// The sequences that, when encountered, will cause the model to stop generating further tokens.
348 /// OpenAI models allow up to four stop sequences.
349 StopSequence(Vec<String>),
350 /// Whether or not to use streaming mode.
351 /// This is common to all models.
352 Stream(bool),
353
354 /// The penalty to apply for using frequent tokens.
355 /// This is used by OpenAI and llama models.
356 FrequencyPenalty(f32),
357 /// The penalty to apply for using novel tokens.
358 /// This is used by OpenAI and llama models.
359 PresencePenalty(f32),
360
361 /// A bias to apply to certain tokens during the inference process.
362 /// This is known as logit bias in OpenAI and is also used in ai-chain-local.
363 TokenBias(TokenBias),
364
365 /// The maximum number of tokens to consider for each step of generation.
366 /// This is common to all models, but is not used by OpenAI.
367 TopK(i32),
368 /// The cumulative probability threshold for token selection.
369 /// This is common to all models.
370 TopP(f32),
371 /// The temperature to use for token selection. Higher values result in more random output.
372 /// This is common to all models.
373 Temperature(f32),
374 /// The penalty to apply for repeated tokens.
375 /// This is common to all models.
376 RepeatPenalty(f32),
377 /// The number of most recent tokens to consider when applying the repeat penalty.
378 /// This is common to all models.
379 RepeatPenaltyLastN(usize),
380
381 /// The TfsZ parameter for ai-chain-llama.
382 TfsZ(f32),
383 /// The TypicalP parameter for ai-chain-llama.
384 TypicalP(f32),
385 /// The Mirostat parameter for ai-chain-llama.
386 Mirostat(i32),
387 /// The MirostatTau parameter for ai-chain-llama.
388 MirostatTau(f32),
389 /// The MirostatEta parameter for ai-chain-llama.
390 MirostatEta(f32),
391 /// Whether or not to penalize newline characters for ai-chain-llama.
392 PenalizeNl(bool),
393
394 /// The batch size for ai-chain-local.
395 NBatch(usize),
396 /// The username for ai-chain-openai.
397 User(String),
398 /// The type of the model.
399 ModelType(String),
400
401 // The number of layers to be stored in GPU VRAM for ai-chain-llama.
402 NGpuLayers(i32),
403 // The GPU that should be used for scratch and small tensors for ai-chain-llama.
404 MainGpu(i32),
405 // How the layers should be split accross the available GPUs for ai-chain-llama.
406 TensorSplit(Option<Vec<f32>>),
407 // Only load the vocabulary for ai-chain-llama, no weights will be loaded.
408 VocabOnly(bool),
409 // Use memory mapped files for ai-chain-llama where possible.
410 UseMmap(bool),
411 // Force the system to keep the model in memory for ai-chain-llama.
412 UseMlock(bool),
413}
414
415// Helper function to extract environment variables
416fn option_from_env<K, F>(opts: &mut OptionsBuilder, key: K, f: F) -> Result<(), VarError>
417where
418 K: AsRef<OsStr>,
419 F: FnOnce(String) -> Option<Opt>,
420{
421 match std::env::var(key) {
422 Ok(v) => {
423 if let Some(x) = f(v) {
424 opts.add_option(x);
425 }
426 Ok(())
427 }
428 Err(VarError::NotPresent) => Ok(()),
429 Err(e) => Err(e),
430 }
431}
432
433// Conversion functions for each Opt variant
434fn model_from_string(s: String) -> Option<Opt> {
435 Some(Opt::Model(ModelRef::from_path(s)))
436}
437
438fn api_key_from_string(s: String) -> Option<Opt> {
439 Some(Opt::ApiKey(s))
440}
441
442macro_rules! opt_parse_str {
443 ($v:ident) => {
444 paste! {
445 fn [< $v:snake:lower _from_string >] (s: String) -> Option<Opt> {
446 Some(Opt::$v(s.parse().ok()?))
447 }
448 }
449 };
450}
451
452opt_parse_str!(NThreads);
453opt_parse_str!(MaxTokens);
454opt_parse_str!(MaxContextSize);
455// Skip stop sequence?
456// Skip stream?
457
458opt_parse_str!(FrequencyPenalty);
459opt_parse_str!(PresencePenalty);
460// Skip TokenBias for now
461opt_parse_str!(TopK);
462opt_parse_str!(TopP);
463opt_parse_str!(Temperature);
464opt_parse_str!(RepeatPenalty);
465opt_parse_str!(RepeatPenaltyLastN);
466opt_parse_str!(TfsZ);
467opt_parse_str!(PenalizeNl);
468opt_parse_str!(NBatch);
469
470macro_rules! opt_from_env {
471 ($opt:ident, $v:ident) => {
472 paste! {
473 option_from_env(&mut $opt, stringify!([<
474 ai_chain_ $v:snake:upper
475 >]), [< $v:snake:lower _from_string >])?;
476 }
477 };
478}
479
480macro_rules! opts_from_env {
481 ($opt:ident, $($v:ident),*) => {
482 $(
483 opt_from_env!($opt, $v);
484 )*
485 };
486}
487
488/// Loads options from environment variables.
489/// Every option that can be easily understood from a string is avaliable the name
490/// of the option will be in upper snake case, that means that the option `Opt::ApiKey` has the environment variable
491/// `ai_chain_API_KEY`
492pub fn options_from_env() -> Result<Options, VarError> {
493 let mut opts = OptionsBuilder::new();
494 opts_from_env!(
495 opts,
496 Model,
497 ApiKey,
498 NThreads,
499 MaxTokens,
500 MaxContextSize,
501 FrequencyPenalty,
502 PresencePenalty,
503 TopK,
504 TopP,
505 Temperature,
506 RepeatPenalty,
507 RepeatPenaltyLastN,
508 TfsZ,
509 PenalizeNl,
510 NBatch
511 );
512 Ok(opts.build())
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 // Tests for FromStr
519 #[test]
520 fn test_options_from_env() {
521 use std::env;
522 let orig_model = "/123/123.bin";
523 let orig_nbatch = 1_usize;
524 let orig_api_key = "!asd";
525 env::set_var("ai_chain_MODEL", orig_model);
526 env::set_var("ai_chain_N_BATCH", orig_nbatch.to_string());
527 env::set_var("ai_chain_API_KEY", orig_api_key);
528 let opts = options_from_env().unwrap();
529 let model_path = opts
530 .get(OptDiscriminants::Model)
531 .and_then(|x| match x {
532 Opt::Model(m) => Some(m),
533 _ => None,
534 })
535 .unwrap();
536 let nbatch = opts
537 .get(OptDiscriminants::NBatch)
538 .and_then(|x| match x {
539 Opt::NBatch(m) => Some(m),
540 _ => None,
541 })
542 .unwrap();
543 let api_key = opts
544 .get(OptDiscriminants::ApiKey)
545 .and_then(|x| match x {
546 Opt::ApiKey(m) => Some(m),
547 _ => None,
548 })
549 .unwrap();
550 assert_eq!(model_path.to_path(), orig_model);
551 assert_eq!(nbatch.clone(), orig_nbatch);
552 assert_eq!(api_key, orig_api_key);
553 }
554}