ai_chain/
options.rs

1use lazy_static::lazy_static;
2use paste::paste;
3use std::{collections::HashMap, env::VarError, ffi::OsStr};
4
5use serde::{Deserialize, Serialize};
6use strum_macros::EnumDiscriminants;
7
8use crate::tokens::Token;
9
10/// A collection of options that can be used to configure a model.
11#[derive(Default, Debug, Clone, Serialize, Deserialize)]
12/// `Options` is the struct that represents a set of options for a large language model.
13/// It provides methods for creating, adding, and retrieving options.
14///
15/// The 'Options' struct is mainly created using the `OptionsBuilder` to allow for
16/// flexibility in setting options.
17pub struct Options {
18    /// The actual options, stored as a vector.
19    opts: Vec<Opt>,
20}
21
22#[derive(thiserror::Error, Debug)]
23/// An error indicating that a required option is not set.
24#[error("Option not set")]
25struct OptionNotSetError;
26
27lazy_static! {
28    /// An empty set of options, useful as a default.
29    static ref EMPTY_OPTIONS: Options = Options::builder().build();
30}
31
32impl Options {
33    /// Constructs a new `OptionsBuilder` for creating an `Options` instance.
34    ///
35    /// This function serves as an entry point for using the builder pattern to create `Options`.
36    ///
37    /// # Returns
38    ///
39    /// An `OptionsBuilder` instance.
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// # use ai_chain::options::*;
45    /// let builder = Options::builder();
46    /// ```
47    pub fn builder() -> OptionsBuilder {
48        OptionsBuilder::new()
49    }
50
51    /// Returns a reference to an empty set of options.
52    ///
53    /// This function provides a static reference to an empty `Options` instance,
54    /// which can be useful as a default value.
55    ///
56    /// # Returns
57    ///
58    /// A reference to an empty `Options`.
59    ///
60    /// # Example
61    ///
62    /// ```rust
63    /// # use ai_chain::options::*;
64    /// let empty_options = Options::empty();
65    /// ```
66    pub fn empty() -> &'static Self {
67        &EMPTY_OPTIONS
68    }
69    /// Gets the value of an option from this set of options.
70    ///
71    /// This function finds the first option in `opts` that matches the provided `OptDiscriminants`.
72    ///
73    /// # Arguments
74    ///
75    /// * `opt_discriminant` - An `OptDiscriminants` value representing the discriminant of the desired `Opt`.
76    ///
77    /// # Returns
78    ///
79    /// An `Option` that contains a reference to the `Opt` if found, or `None` if not found.
80    ///
81    /// # Example
82    ///
83    /// ```rust
84    /// # use ai_chain::options::*;
85    /// let mut builder = Options::builder();
86    /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
87    /// let options = builder.build();
88    /// let model_option = options.get(OptDiscriminants::Model);
89    /// ```
90    pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
91        self.opts
92            .iter()
93            .find(|opt| OptDiscriminants::from(*opt) == opt_discriminant)
94    }
95}
96
97/// `options!` is a declarative macro that facilitates the creation of an `Options` instance.
98///
99/// # Usage
100///
101/// This macro can be used to construct an instance of `Options` using a more readable and
102/// ergonomic syntax. The syntax of the macro is:
103///
104/// ```ignore
105/// options!{
106///     OptionName1: value1,
107///     OptionName2: value2,
108///     ...
109/// }
110/// ```
111///
112/// Here, `OptionNameN` is the identifier of the option you want to set, and `valueN` is the value
113/// you want to assign to that option.
114///
115/// # Example
116///
117/// ```ignore
118/// let options = options!{
119///     FooBar: "lol",
120///     SomeReadyMadeOption: "another_value"
121/// };
122/// ```
123///
124/// In this example, an instance of `Options` is being created with two options: `FooBar` and
125/// `SomeReadyMadeOption`, which are set to `"lol"` and `"another_value"`, respectively.
126///
127/// # Notes
128///
129/// - The option identifier (`OptionNameN`) must match an enum variant in `Opt`. If the identifier
130///   does not match any of the `Opt` variants, a compilation error will occur.
131///
132/// - The value (`valueN`) should be of a type that is acceptable for the corresponding option.
133///   If the value type does not match the expected type for the option, a compilation error will occur.
134///
135#[macro_export]
136macro_rules! options {
137    ( $( $opt_name:ident : $opt_value:expr ),* ) => {
138        {
139            let mut _opts = $crate::options::Options::builder();
140            $(
141                _opts.add_option($crate::options::Opt::$opt_name($opt_value.into()));
142            )*
143            _opts.build()
144        }
145    };
146}
147
148/// `OptionsBuilder` is a helper structure used to construct `Options` in a flexible way.
149///
150/// `OptionsBuilder` follows the builder pattern, providing a fluent interface to add options
151/// and finally, build an `Options` instance. This pattern is used to handle cases where the `Options`
152/// instance may require complex configuration or optional fields.
153///
154///
155/// # Example
156///
157/// ```rust
158/// # use ai_chain::options::*;
159/// let mut builder = OptionsBuilder::new();
160/// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
161/// let options = builder.build();
162/// ```
163#[derive(Default, Debug, Clone, Serialize, Deserialize)]
164pub struct OptionsBuilder {
165    /// A Vec<Opt> field that holds the options to be added to the `Options` instance.
166    opts: Vec<Opt>,
167}
168
169impl OptionsBuilder {
170    /// Constructs a new, empty `OptionsBuilder`.
171    ///
172    /// Returns an `OptionsBuilder` instance with an empty `opts` field.
173    ///
174    /// # Example
175    ///
176    /// ```rust
177    /// # use ai_chain::options::*;
178    /// let builder = OptionsBuilder::new();
179    /// ```
180    pub fn new() -> Self {
181        OptionsBuilder { opts: Vec::new() }
182    }
183
184    /// Adds an option to the `OptionsBuilder`.
185    ///
186    /// This function takes an `Opt` instance and pushes it to the `opts` field.
187    ///
188    /// # Arguments
189    ///
190    /// * `opt` - An `Opt` instance to be added to the `OptionsBuilder`.
191    ///
192    /// # Example
193    ///
194    /// ```rust
195    /// # use ai_chain::options::*;
196    /// let mut builder = OptionsBuilder::new();
197    /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
198    /// ```
199    pub fn add_option(&mut self, opt: Opt) {
200        self.opts.push(opt);
201    }
202
203    /// Consumes the `OptionsBuilder`, returning an `Options` instance.
204    ///
205    /// This function consumes the `OptionsBuilder`, moving its `opts` field to a new `Options` instance.
206    ///
207    /// # Returns
208    ///
209    /// An `Options` instance with the options added through the builder.
210    ///
211    /// # Example
212    ///
213    /// ```rust
214    /// # use ai_chain::options::*;
215    /// let mut builder = OptionsBuilder::new();
216    /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
217    /// let options = builder.build();
218    /// ```
219    pub fn build(self) -> Options {
220        Options { opts: self.opts }
221    }
222}
223
224/// A cascade of option sets.
225///
226/// Options added later in the cascade override earlier options.
227pub struct OptionsCascade<'a> {
228    /// The sets of options, in the order they were added.
229    cascades: Vec<&'a Options>,
230}
231
232impl<'a> OptionsCascade<'a> {
233    /// Creates a new, empty cascade of options.
234    pub fn new() -> Self {
235        OptionsCascade::from_vec(Vec::new())
236    }
237
238    /// Setups a typical options cascade, with model_defaults, environment defaults a model config and possibly a specific config.
239    pub fn new_typical(
240        model_default: &'a Options,
241        env_defaults: &'a Options,
242        model_config: &'a Options,
243        specific_config: Option<&'a Options>,
244    ) -> Self {
245        let mut v = vec![model_default, env_defaults, model_config];
246        if let Some(specific_config) = specific_config {
247            v.push(specific_config);
248        }
249        Self::from_vec(v)
250    }
251
252    /// Creates a new cascade of options from a vector of option sets.
253    pub fn from_vec(cascades: Vec<&'a Options>) -> Self {
254        OptionsCascade { cascades }
255    }
256
257    /// Returns a new cascade of options with the given set of options added.
258    pub fn with_options(mut self, options: &'a Options) -> Self {
259        self.cascades.push(options);
260        self
261    }
262
263    /// Gets the value of an option from this cascade.
264    ///
265    /// Returns `None` if the option is not present in any set in this cascade.
266    /// If the option is present in multiple sets, the value from the most
267    /// recently added set is returned.
268    pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
269        for options in self.cascades.iter().rev() {
270            if let Some(opt) = options.get(opt_discriminant) {
271                return Some(opt);
272            }
273        }
274        None
275    }
276
277    /// Returns a boolean indicating if options indicate that requests should be streamed or not.
278    pub fn is_streaming(&self) -> bool {
279        let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
280            return false;
281        };
282        *val
283    }
284}
285
286impl<'a> Default for OptionsCascade<'a> {
287    /// Returns a new, empty cascade of options.
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293#[derive(Clone, Debug, Serialize, Deserialize)]
294/// A reference to a model name or path
295/// Useful for
296pub struct ModelRef(String);
297
298impl ModelRef {
299    pub fn from_path<S: Into<String>>(p: S) -> Self {
300        Self(p.into())
301    }
302    pub fn from_model_name<S: Into<String>>(model_name: S) -> Self {
303        Self(model_name.into())
304    }
305    /// Returns the path for this model reference
306    pub fn to_path(&self) -> String {
307        self.0.clone()
308    }
309    /// Returns the name of the model
310    pub fn to_name(&self) -> String {
311        self.0.clone()
312    }
313}
314
315/// A list of tokens to bias during the process of inferencing.
316#[derive(Serialize, Deserialize, Debug, Clone)]
317pub struct TokenBias(Vec<(Token, f32)>); // TODO: Serialize to a JSON object of str(F32) =>
318
319impl TokenBias {
320    /// Returns the token bias as a hashmap where the keys are i32 and the value f32. If the type doesn't match returns None
321    pub fn as_i32_f32_hashmap(&self) -> Option<HashMap<i32, f32>> {
322        let mut map = HashMap::new();
323        for (token, value) in &self.0 {
324            map.insert(token.to_i32()?, *value);
325        }
326        Some(map)
327    }
328}
329
330#[derive(EnumDiscriminants, Clone, Debug, Serialize, Deserialize)]
331pub enum Opt {
332    /// The name or path of the model used.
333    Model(ModelRef),
334    /// The API key for the model service.
335    ApiKey(String),
336    /// The number of threads to use for parallel processing.
337    /// This is common to all models.
338    NThreads(usize),
339    /// The maximum number of tokens that the model will generate.
340    /// This is common to all models.
341    MaxTokens(usize),
342    /// The maximum context size of the model.
343    MaxContextSize(usize),
344    /// The maximum batch size of the model.
345    /// This is used by llama models.
346    MaxBatchSize(usize),
347    /// The sequences that, when encountered, will cause the model to stop generating further tokens.
348    /// OpenAI models allow up to four stop sequences.
349    StopSequence(Vec<String>),
350    /// Whether or not to use streaming mode.
351    /// This is common to all models.
352    Stream(bool),
353
354    /// The penalty to apply for using frequent tokens.
355    /// This is used by OpenAI and llama models.
356    FrequencyPenalty(f32),
357    /// The penalty to apply for using novel tokens.
358    /// This is used by OpenAI and llama models.
359    PresencePenalty(f32),
360
361    /// A bias to apply to certain tokens during the inference process.
362    /// This is known as logit bias in OpenAI and is also used in ai-chain-local.
363    TokenBias(TokenBias),
364
365    /// The maximum number of tokens to consider for each step of generation.
366    /// This is common to all models, but is not used by OpenAI.
367    TopK(i32),
368    /// The cumulative probability threshold for token selection.
369    /// This is common to all models.
370    TopP(f32),
371    /// The temperature to use for token selection. Higher values result in more random output.
372    /// This is common to all models.
373    Temperature(f32),
374    /// The penalty to apply for repeated tokens.
375    /// This is common to all models.
376    RepeatPenalty(f32),
377    /// The number of most recent tokens to consider when applying the repeat penalty.
378    /// This is common to all models.
379    RepeatPenaltyLastN(usize),
380
381    /// The TfsZ parameter for ai-chain-llama.
382    TfsZ(f32),
383    /// The TypicalP parameter for ai-chain-llama.
384    TypicalP(f32),
385    /// The Mirostat parameter for ai-chain-llama.
386    Mirostat(i32),
387    /// The MirostatTau parameter for ai-chain-llama.
388    MirostatTau(f32),
389    /// The MirostatEta parameter for ai-chain-llama.
390    MirostatEta(f32),
391    /// Whether or not to penalize newline characters for ai-chain-llama.
392    PenalizeNl(bool),
393
394    /// The batch size for ai-chain-local.
395    NBatch(usize),
396    /// The username for ai-chain-openai.
397    User(String),
398    /// The type of the model.
399    ModelType(String),
400
401    // The number of layers to be stored in GPU VRAM for ai-chain-llama.
402    NGpuLayers(i32),
403    // The GPU that should be used for scratch and small tensors for ai-chain-llama.
404    MainGpu(i32),
405    // How the layers should be split accross the available GPUs for ai-chain-llama.
406    TensorSplit(Option<Vec<f32>>),
407    // Only load the vocabulary for ai-chain-llama, no weights will be loaded.
408    VocabOnly(bool),
409    // Use memory mapped files for ai-chain-llama where possible.
410    UseMmap(bool),
411    // Force the system to keep the model in memory for ai-chain-llama.
412    UseMlock(bool),
413}
414
415// Helper function to extract environment variables
416fn option_from_env<K, F>(opts: &mut OptionsBuilder, key: K, f: F) -> Result<(), VarError>
417where
418    K: AsRef<OsStr>,
419    F: FnOnce(String) -> Option<Opt>,
420{
421    match std::env::var(key) {
422        Ok(v) => {
423            if let Some(x) = f(v) {
424                opts.add_option(x);
425            }
426            Ok(())
427        }
428        Err(VarError::NotPresent) => Ok(()),
429        Err(e) => Err(e),
430    }
431}
432
433// Conversion functions for each Opt variant
434fn model_from_string(s: String) -> Option<Opt> {
435    Some(Opt::Model(ModelRef::from_path(s)))
436}
437
438fn api_key_from_string(s: String) -> Option<Opt> {
439    Some(Opt::ApiKey(s))
440}
441
442macro_rules! opt_parse_str {
443    ($v:ident) => {
444        paste! {
445            fn [< $v:snake:lower _from_string >] (s: String) -> Option<Opt> {
446                        Some(Opt::$v(s.parse().ok()?))
447            }
448        }
449    };
450}
451
452opt_parse_str!(NThreads);
453opt_parse_str!(MaxTokens);
454opt_parse_str!(MaxContextSize);
455// Skip stop sequence?
456// Skip stream?
457
458opt_parse_str!(FrequencyPenalty);
459opt_parse_str!(PresencePenalty);
460// Skip TokenBias for now
461opt_parse_str!(TopK);
462opt_parse_str!(TopP);
463opt_parse_str!(Temperature);
464opt_parse_str!(RepeatPenalty);
465opt_parse_str!(RepeatPenaltyLastN);
466opt_parse_str!(TfsZ);
467opt_parse_str!(PenalizeNl);
468opt_parse_str!(NBatch);
469
470macro_rules! opt_from_env {
471    ($opt:ident, $v:ident) => {
472        paste! {
473            option_from_env(&mut $opt, stringify!([<
474                ai_chain_ $v:snake:upper
475                >]), [< $v:snake:lower _from_string >])?;
476        }
477    };
478}
479
480macro_rules! opts_from_env {
481    ($opt:ident, $($v:ident),*) => {
482        $(
483            opt_from_env!($opt, $v);
484        )*
485    };
486}
487
488/// Loads options from environment variables.
489/// Every option that can be easily understood from a string is avaliable the name
490/// of the option will be in upper snake case, that means that the option `Opt::ApiKey` has the environment variable
491/// `ai_chain_API_KEY`
492pub fn options_from_env() -> Result<Options, VarError> {
493    let mut opts = OptionsBuilder::new();
494    opts_from_env!(
495        opts,
496        Model,
497        ApiKey,
498        NThreads,
499        MaxTokens,
500        MaxContextSize,
501        FrequencyPenalty,
502        PresencePenalty,
503        TopK,
504        TopP,
505        Temperature,
506        RepeatPenalty,
507        RepeatPenaltyLastN,
508        TfsZ,
509        PenalizeNl,
510        NBatch
511    );
512    Ok(opts.build())
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    // Tests for FromStr
519    #[test]
520    fn test_options_from_env() {
521        use std::env;
522        let orig_model = "/123/123.bin";
523        let orig_nbatch = 1_usize;
524        let orig_api_key = "!asd";
525        env::set_var("ai_chain_MODEL", orig_model);
526        env::set_var("ai_chain_N_BATCH", orig_nbatch.to_string());
527        env::set_var("ai_chain_API_KEY", orig_api_key);
528        let opts = options_from_env().unwrap();
529        let model_path = opts
530            .get(OptDiscriminants::Model)
531            .and_then(|x| match x {
532                Opt::Model(m) => Some(m),
533                _ => None,
534            })
535            .unwrap();
536        let nbatch = opts
537            .get(OptDiscriminants::NBatch)
538            .and_then(|x| match x {
539                Opt::NBatch(m) => Some(m),
540                _ => None,
541            })
542            .unwrap();
543        let api_key = opts
544            .get(OptDiscriminants::ApiKey)
545            .and_then(|x| match x {
546                Opt::ApiKey(m) => Some(m),
547                _ => None,
548            })
549            .unwrap();
550        assert_eq!(model_path.to_path(), orig_model);
551        assert_eq!(nbatch.clone(), orig_nbatch);
552        assert_eq!(api_key, orig_api_key);
553    }
554}