llm_chain/
options.rs

1use lazy_static::lazy_static;
2use paste::paste;
3use std::{collections::HashMap, env::VarError, ffi::OsStr};
4
5use serde::{Deserialize, Serialize};
6use strum_macros::EnumDiscriminants;
7
8use crate::tokens::Token;
9
10/// A collection of options that can be used to configure a model.
11#[derive(Default, Debug, Clone, Serialize, Deserialize)]
12/// `Options` is the struct that represents a set of options for a large language model.
13/// It provides methods for creating, adding, and retrieving options.
14///
15/// The 'Options' struct is mainly created using the `OptionsBuilder` to allow for
16/// flexibility in setting options.
17pub struct Options {
18    /// The actual options, stored as a vector.
19    opts: Vec<Opt>,
20}
21
22#[derive(thiserror::Error, Debug)]
23/// An error indicating that a required option is not set.
24#[error("Option not set")]
25struct OptionNotSetError;
26
27lazy_static! {
28    /// An empty set of options, useful as a default.
29    static ref EMPTY_OPTIONS: Options = Options::builder().build();
30}
31
32impl Options {
33    /// Constructs a new `OptionsBuilder` for creating an `Options` instance.
34    ///
35    /// This function serves as an entry point for using the builder pattern to create `Options`.
36    ///
37    /// # Returns
38    ///
39    /// An `OptionsBuilder` instance.
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// # use llm_chain::options::*;
45    /// let builder = Options::builder();
46    /// ```
47    pub fn builder() -> OptionsBuilder {
48        OptionsBuilder::new()
49    }
50
51    /// Returns a reference to an empty set of options.
52    ///
53    /// This function provides a static reference to an empty `Options` instance,
54    /// which can be useful as a default value.
55    ///
56    /// # Returns
57    ///
58    /// A reference to an empty `Options`.
59    ///
60    /// # Example
61    ///
62    /// ```rust
63    /// # use llm_chain::options::*;
64    /// let empty_options = Options::empty();
65    /// ```
66    pub fn empty() -> &'static Self {
67        &EMPTY_OPTIONS
68    }
69    /// Gets the value of an option from this set of options.
70    ///
71    /// This function finds the first option in `opts` that matches the provided `OptDiscriminants`.
72    ///
73    /// # Arguments
74    ///
75    /// * `opt_discriminant` - An `OptDiscriminants` value representing the discriminant of the desired `Opt`.
76    ///
77    /// # Returns
78    ///
79    /// An `Option` that contains a reference to the `Opt` if found, or `None` if not found.
80    ///
81    /// # Example
82    ///
83    /// ```rust
84    /// # use llm_chain::options::*;
85    /// let mut builder = Options::builder();
86    /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
87    /// let options = builder.build();
88    /// let model_option = options.get(OptDiscriminants::Model);
89    /// ```
90    pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
91        self.opts
92            .iter()
93            .find(|opt| OptDiscriminants::from(*opt) == opt_discriminant)
94    }
95}
96
97/// `options!` is a declarative macro that facilitates the creation of an `Options` instance.
98///
99/// # Usage
100///
101/// This macro can be used to construct an instance of `Options` using a more readable and
102/// ergonomic syntax. The syntax of the macro is:
103///
104/// ```ignore
105/// options!{
106///     OptionName1: value1,
107///     OptionName2: value2,
108///     ...
109/// }
110/// ```
111///
112/// Here, `OptionNameN` is the identifier of the option you want to set, and `valueN` is the value
113/// you want to assign to that option.
114///
115/// # Example
116///
117/// ```ignore
118/// let options = options!{
119///     FooBar: "lol",
120///     SomeReadyMadeOption: "another_value"
121/// };
122/// ```
123///
124/// In this example, an instance of `Options` is being created with two options: `FooBar` and
125/// `SomeReadyMadeOption`, which are set to `"lol"` and `"another_value"`, respectively.
126///
127/// # Notes
128///
129/// - The option identifier (`OptionNameN`) must match an enum variant in `Opt`. If the identifier
130///   does not match any of the `Opt` variants, a compilation error will occur.
131///
132/// - The value (`valueN`) should be of a type that is acceptable for the corresponding option.
133///   If the value type does not match the expected type for the option, a compilation error will occur.
134///
135#[macro_export]
136macro_rules! options {
137    ( $( $opt_name:ident : $opt_value:expr ),* ) => {
138        {
139            let mut _opts = $crate::options::Options::builder();
140            $(
141                _opts.add_option($crate::options::Opt::$opt_name($opt_value.into()));
142            )*
143            _opts.build()
144        }
145    };
146}
147
148/// `OptionsBuilder` is a helper structure used to construct `Options` in a flexible way.
149///
150/// `OptionsBuilder` follows the builder pattern, providing a fluent interface to add options
151/// and finally, build an `Options` instance. This pattern is used to handle cases where the `Options`
152/// instance may require complex configuration or optional fields.
153///
154///
155/// # Example
156///
157/// ```rust
158/// # use llm_chain::options::*;
159/// let mut builder = OptionsBuilder::new();
160/// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
161/// let options = builder.build();
162/// ```
163#[derive(Default, Debug, Clone, Serialize, Deserialize)]
164pub struct OptionsBuilder {
165    /// A Vec<Opt> field that holds the options to be added to the `Options` instance.
166    opts: Vec<Opt>,
167}
168
169impl OptionsBuilder {
170    /// Constructs a new, empty `OptionsBuilder`.
171    ///
172    /// Returns an `OptionsBuilder` instance with an empty `opts` field.
173    ///
174    /// # Example
175    ///
176    /// ```rust
177    /// # use llm_chain::options::*;
178    /// let builder = OptionsBuilder::new();
179    /// ```
180    pub fn new() -> Self {
181        OptionsBuilder { opts: Vec::new() }
182    }
183
184    /// Adds an option to the `OptionsBuilder`.
185    ///
186    /// This function takes an `Opt` instance and pushes it to the `opts` field.
187    ///
188    /// # Arguments
189    ///
190    /// * `opt` - An `Opt` instance to be added to the `OptionsBuilder`.
191    ///
192    /// # Example
193    ///
194    /// ```rust
195    /// # use llm_chain::options::*;
196    /// let mut builder = OptionsBuilder::new();
197    /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
198    /// ```
199    pub fn add_option(&mut self, opt: Opt) {
200        self.opts.push(opt);
201    }
202
203    /// Consumes the `OptionsBuilder`, returning an `Options` instance.
204    ///
205    /// This function consumes the `OptionsBuilder`, moving its `opts` field to a new `Options` instance.
206    ///
207    /// # Returns
208    ///
209    /// An `Options` instance with the options added through the builder.
210    ///
211    /// # Example
212    ///
213    /// ```rust
214    /// # use llm_chain::options::*;
215    /// let mut builder = OptionsBuilder::new();
216    /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
217    /// let options = builder.build();
218    /// ```
219    pub fn build(self) -> Options {
220        Options { opts: self.opts }
221    }
222}
223
224/// A cascade of option sets.
225///
226/// Options added later in the cascade override earlier options.
227pub struct OptionsCascade<'a> {
228    /// The sets of options, in the order they were added.
229    cascades: Vec<&'a Options>,
230}
231
232impl<'a> OptionsCascade<'a> {
233    /// Creates a new, empty cascade of options.
234    pub fn new() -> Self {
235        OptionsCascade::from_vec(Vec::new())
236    }
237
238    /// Setups a typical options cascade, with model_defaults, environment defaults a model config and possibly a specific config.
239    pub fn new_typical(
240        model_default: &'a Options,
241        env_defaults: &'a Options,
242        model_config: &'a Options,
243        specific_config: Option<&'a Options>,
244    ) -> Self {
245        let mut v = vec![model_default, env_defaults, model_config];
246        if let Some(specific_config) = specific_config {
247            v.push(specific_config);
248        }
249        Self::from_vec(v)
250    }
251
252    /// Creates a new cascade of options from a vector of option sets.
253    pub fn from_vec(cascades: Vec<&'a Options>) -> Self {
254        OptionsCascade { cascades }
255    }
256
257    /// Returns a new cascade of options with the given set of options added.
258    pub fn with_options(mut self, options: &'a Options) -> Self {
259        self.cascades.push(options);
260        self
261    }
262
263    /// Gets the value of an option from this cascade.
264    ///
265    /// Returns `None` if the option is not present in any set in this cascade.
266    /// If the option is present in multiple sets, the value from the most
267    /// recently added set is returned.
268    pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
269        for options in self.cascades.iter().rev() {
270            if let Some(opt) = options.get(opt_discriminant) {
271                return Some(opt);
272            }
273        }
274        None
275    }
276
277    /// Returns a boolean indicating if options indicate that requests should be streamed or not.
278    pub fn is_streaming(&self) -> bool {
279        let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
280            return false;
281        };
282        *val
283    }
284}
285
286impl<'a> Default for OptionsCascade<'a> {
287    /// Returns a new, empty cascade of options.
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293#[derive(Clone, Debug, Serialize, Deserialize)]
294/// A reference to a model name or path
295/// Useful for
296pub struct ModelRef(String);
297
298impl ModelRef {
299    pub fn from_path<S: Into<String>>(p: S) -> Self {
300        Self(p.into())
301    }
302    pub fn from_model_name<S: Into<String>>(model_name: S) -> Self {
303        Self(model_name.into())
304    }
305    /// Returns the path for this model reference
306    pub fn to_path(&self) -> String {
307        self.0.clone()
308    }
309    /// Returns the name of the model
310    pub fn to_name(&self) -> String {
311        self.0.clone()
312    }
313}
314
315/// A list of tokens to bias during the process of inferencing.
316#[derive(Serialize, Deserialize, Debug, Clone)]
317pub struct TokenBias(Vec<(Token, f32)>); // TODO: Serialize to a JSON object of str(F32) =>
318
319impl TokenBias {
320    /// Returns the token bias as a hashmap where the keys are i32 and the value f32. If the type doesn't match returns None
321    pub fn as_i32_f32_hashmap(&self) -> Option<HashMap<i32, f32>> {
322        let mut map = HashMap::new();
323        for (token, value) in &self.0 {
324            map.insert(token.to_i32()?, *value);
325        }
326        Some(map)
327    }
328}
329
330#[derive(EnumDiscriminants, Clone, Debug, Serialize, Deserialize)]
331pub enum Opt {
332    /// The name or path of the model used.
333    Model(ModelRef),
334    /// The API key for the model service.
335    ApiKey(String),
336    /// The number of threads to use for parallel processing.
337    /// This is common to all models.
338    NThreads(usize),
339    /// The maximum number of tokens that the model will generate.
340    /// This is common to all models.
341    MaxTokens(usize),
342    /// The maximum context size of the model.
343    MaxContextSize(usize),
344    /// The sequences that, when encountered, will cause the model to stop generating further tokens.
345    /// OpenAI models allow up to four stop sequences.
346    StopSequence(Vec<String>),
347    /// Whether or not to use streaming mode.
348    /// This is common to all models.
349    Stream(bool),
350
351    /// The penalty to apply for using frequent tokens.
352    /// This is used by OpenAI and llama models.
353    FrequencyPenalty(f32),
354    /// The penalty to apply for using novel tokens.
355    /// This is used by OpenAI and llama models.
356    PresencePenalty(f32),
357
358    /// A bias to apply to certain tokens during the inference process.
359    /// This is known as logit bias in OpenAI and is also used in llm-chain-local.
360    TokenBias(TokenBias),
361
362    /// The maximum number of tokens to consider for each step of generation.
363    /// This is common to all models, but is not used by OpenAI.
364    TopK(i32),
365    /// The cumulative probability threshold for token selection.
366    /// This is common to all models.
367    TopP(f32),
368    /// The temperature to use for token selection. Higher values result in more random output.
369    /// This is common to all models.
370    Temperature(f32),
371    /// The penalty to apply for repeated tokens.
372    /// This is common to all models.
373    RepeatPenalty(f32),
374    /// The number of most recent tokens to consider when applying the repeat penalty.
375    /// This is common to all models.
376    RepeatPenaltyLastN(usize),
377
378    /// The TfsZ parameter for llm-chain-llama.
379    TfsZ(f32),
380    /// The TypicalP parameter for llm-chain-llama.
381    TypicalP(f32),
382    /// The Mirostat parameter for llm-chain-llama.
383    Mirostat(i32),
384    /// The MirostatTau parameter for llm-chain-llama.
385    MirostatTau(f32),
386    /// The MirostatEta parameter for llm-chain-llama.
387    MirostatEta(f32),
388    /// Whether or not to penalize newline characters for llm-chain-llama.
389    PenalizeNl(bool),
390
391    /// The batch size for llm-chain-local.
392    NBatch(usize),
393    /// The username for llm-chain-openai.
394    User(String),
395    /// The type of the model.
396    ModelType(String),
397}
398
399// Helper function to extract environment variables
400fn option_from_env<K, F>(opts: &mut OptionsBuilder, key: K, f: F) -> Result<(), VarError>
401where
402    K: AsRef<OsStr>,
403    F: FnOnce(String) -> Option<Opt>,
404{
405    match std::env::var(key) {
406        Ok(v) => {
407            if let Some(x) = f(v) {
408                opts.add_option(x);
409            }
410            Ok(())
411        }
412        Err(VarError::NotPresent) => Ok(()),
413        Err(e) => Err(e),
414    }
415}
416
417// Conversion functions for each Opt variant
418fn model_from_string(s: String) -> Option<Opt> {
419    Some(Opt::Model(ModelRef::from_path(s)))
420}
421
422fn api_key_from_string(s: String) -> Option<Opt> {
423    Some(Opt::ApiKey(s))
424}
425
426macro_rules! opt_parse_str {
427    ($v:ident) => {
428        paste! {
429            fn [< $v:snake:lower _from_string >] (s: String) -> Option<Opt> {
430                        Some(Opt::$v(s.parse().ok()?))
431            }
432        }
433    };
434}
435
436opt_parse_str!(NThreads);
437opt_parse_str!(MaxTokens);
438opt_parse_str!(MaxContextSize);
439// Skip stop sequence?
440// Skip stream?
441
442opt_parse_str!(FrequencyPenalty);
443opt_parse_str!(PresencePenalty);
444// Skip TokenBias for now
445opt_parse_str!(TopK);
446opt_parse_str!(TopP);
447opt_parse_str!(Temperature);
448opt_parse_str!(RepeatPenalty);
449opt_parse_str!(RepeatPenaltyLastN);
450opt_parse_str!(TfsZ);
451opt_parse_str!(PenalizeNl);
452opt_parse_str!(NBatch);
453
454macro_rules! opt_from_env {
455    ($opt:ident, $v:ident) => {
456        paste! {
457            option_from_env(&mut $opt, stringify!([<
458                LLM_CHAIN_ $v:snake:upper
459                >]), [< $v:snake:lower _from_string >])?;
460        }
461    };
462}
463
464macro_rules! opts_from_env {
465    ($opt:ident, $($v:ident),*) => {
466        $(
467            opt_from_env!($opt, $v);
468        )*
469    };
470}
471
472/// Loads options from environment variables.
473/// Every option that can be easily understood from a string is avaliable the name
474/// of the option will be in upper snake case, that means that the option `Opt::ApiKey` has the environment variable
475/// `LLM_CHAIN_API_KEY`
476pub fn options_from_env() -> Result<Options, VarError> {
477    let mut opts = OptionsBuilder::new();
478    opts_from_env!(
479        opts,
480        Model,
481        ApiKey,
482        NThreads,
483        MaxTokens,
484        MaxContextSize,
485        FrequencyPenalty,
486        PresencePenalty,
487        TopK,
488        TopP,
489        Temperature,
490        RepeatPenalty,
491        RepeatPenaltyLastN,
492        TfsZ,
493        PenalizeNl,
494        NBatch
495    );
496    Ok(opts.build())
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502    // Tests for FromStr
503    #[test]
504    fn test_options_from_env() {
505        use std::env;
506        let orig_model = "/123/123.bin";
507        let orig_nbatch = 1_usize;
508        let orig_api_key = "!asd";
509        env::set_var("LLM_CHAIN_MODEL", orig_model);
510        env::set_var("LLM_CHAIN_N_BATCH", orig_nbatch.to_string());
511        env::set_var("LLM_CHAIN_API_KEY", orig_api_key);
512        let opts = options_from_env().unwrap();
513        let model_path = opts
514            .get(OptDiscriminants::Model)
515            .and_then(|x| match x {
516                Opt::Model(m) => Some(m),
517                _ => None,
518            })
519            .unwrap();
520        let nbatch = opts
521            .get(OptDiscriminants::NBatch)
522            .and_then(|x| match x {
523                Opt::NBatch(m) => Some(m),
524                _ => None,
525            })
526            .unwrap();
527        let api_key = opts
528            .get(OptDiscriminants::ApiKey)
529            .and_then(|x| match x {
530                Opt::ApiKey(m) => Some(m),
531                _ => None,
532            })
533            .unwrap();
534        assert_eq!(model_path.to_path(), orig_model);
535        assert_eq!(nbatch.clone(), orig_nbatch);
536        assert_eq!(api_key, orig_api_key);
537    }
538}