llm_chain/options.rs
1use lazy_static::lazy_static;
2use paste::paste;
3use std::{collections::HashMap, env::VarError, ffi::OsStr};
4
5use serde::{Deserialize, Serialize};
6use strum_macros::EnumDiscriminants;
7
8use crate::tokens::Token;
9
10/// A collection of options that can be used to configure a model.
11#[derive(Default, Debug, Clone, Serialize, Deserialize)]
12/// `Options` is the struct that represents a set of options for a large language model.
13/// It provides methods for creating, adding, and retrieving options.
14///
15/// The 'Options' struct is mainly created using the `OptionsBuilder` to allow for
16/// flexibility in setting options.
17pub struct Options {
18 /// The actual options, stored as a vector.
19 opts: Vec<Opt>,
20}
21
22#[derive(thiserror::Error, Debug)]
23/// An error indicating that a required option is not set.
24#[error("Option not set")]
25struct OptionNotSetError;
26
27lazy_static! {
28 /// An empty set of options, useful as a default.
29 static ref EMPTY_OPTIONS: Options = Options::builder().build();
30}
31
32impl Options {
33 /// Constructs a new `OptionsBuilder` for creating an `Options` instance.
34 ///
35 /// This function serves as an entry point for using the builder pattern to create `Options`.
36 ///
37 /// # Returns
38 ///
39 /// An `OptionsBuilder` instance.
40 ///
41 /// # Example
42 ///
43 /// ```rust
44 /// # use llm_chain::options::*;
45 /// let builder = Options::builder();
46 /// ```
47 pub fn builder() -> OptionsBuilder {
48 OptionsBuilder::new()
49 }
50
51 /// Returns a reference to an empty set of options.
52 ///
53 /// This function provides a static reference to an empty `Options` instance,
54 /// which can be useful as a default value.
55 ///
56 /// # Returns
57 ///
58 /// A reference to an empty `Options`.
59 ///
60 /// # Example
61 ///
62 /// ```rust
63 /// # use llm_chain::options::*;
64 /// let empty_options = Options::empty();
65 /// ```
66 pub fn empty() -> &'static Self {
67 &EMPTY_OPTIONS
68 }
69 /// Gets the value of an option from this set of options.
70 ///
71 /// This function finds the first option in `opts` that matches the provided `OptDiscriminants`.
72 ///
73 /// # Arguments
74 ///
75 /// * `opt_discriminant` - An `OptDiscriminants` value representing the discriminant of the desired `Opt`.
76 ///
77 /// # Returns
78 ///
79 /// An `Option` that contains a reference to the `Opt` if found, or `None` if not found.
80 ///
81 /// # Example
82 ///
83 /// ```rust
84 /// # use llm_chain::options::*;
85 /// let mut builder = Options::builder();
86 /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
87 /// let options = builder.build();
88 /// let model_option = options.get(OptDiscriminants::Model);
89 /// ```
90 pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
91 self.opts
92 .iter()
93 .find(|opt| OptDiscriminants::from(*opt) == opt_discriminant)
94 }
95}
96
97/// `options!` is a declarative macro that facilitates the creation of an `Options` instance.
98///
99/// # Usage
100///
101/// This macro can be used to construct an instance of `Options` using a more readable and
102/// ergonomic syntax. The syntax of the macro is:
103///
104/// ```ignore
105/// options!{
106/// OptionName1: value1,
107/// OptionName2: value2,
108/// ...
109/// }
110/// ```
111///
112/// Here, `OptionNameN` is the identifier of the option you want to set, and `valueN` is the value
113/// you want to assign to that option.
114///
115/// # Example
116///
117/// ```ignore
118/// let options = options!{
119/// FooBar: "lol",
120/// SomeReadyMadeOption: "another_value"
121/// };
122/// ```
123///
124/// In this example, an instance of `Options` is being created with two options: `FooBar` and
125/// `SomeReadyMadeOption`, which are set to `"lol"` and `"another_value"`, respectively.
126///
127/// # Notes
128///
129/// - The option identifier (`OptionNameN`) must match an enum variant in `Opt`. If the identifier
130/// does not match any of the `Opt` variants, a compilation error will occur.
131///
132/// - The value (`valueN`) should be of a type that is acceptable for the corresponding option.
133/// If the value type does not match the expected type for the option, a compilation error will occur.
134///
135#[macro_export]
136macro_rules! options {
137 ( $( $opt_name:ident : $opt_value:expr ),* ) => {
138 {
139 let mut _opts = $crate::options::Options::builder();
140 $(
141 _opts.add_option($crate::options::Opt::$opt_name($opt_value.into()));
142 )*
143 _opts.build()
144 }
145 };
146}
147
148/// `OptionsBuilder` is a helper structure used to construct `Options` in a flexible way.
149///
150/// `OptionsBuilder` follows the builder pattern, providing a fluent interface to add options
151/// and finally, build an `Options` instance. This pattern is used to handle cases where the `Options`
152/// instance may require complex configuration or optional fields.
153///
154///
155/// # Example
156///
157/// ```rust
158/// # use llm_chain::options::*;
159/// let mut builder = OptionsBuilder::new();
160/// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
161/// let options = builder.build();
162/// ```
163#[derive(Default, Debug, Clone, Serialize, Deserialize)]
164pub struct OptionsBuilder {
165 /// A Vec<Opt> field that holds the options to be added to the `Options` instance.
166 opts: Vec<Opt>,
167}
168
169impl OptionsBuilder {
170 /// Constructs a new, empty `OptionsBuilder`.
171 ///
172 /// Returns an `OptionsBuilder` instance with an empty `opts` field.
173 ///
174 /// # Example
175 ///
176 /// ```rust
177 /// # use llm_chain::options::*;
178 /// let builder = OptionsBuilder::new();
179 /// ```
180 pub fn new() -> Self {
181 OptionsBuilder { opts: Vec::new() }
182 }
183
184 /// Adds an option to the `OptionsBuilder`.
185 ///
186 /// This function takes an `Opt` instance and pushes it to the `opts` field.
187 ///
188 /// # Arguments
189 ///
190 /// * `opt` - An `Opt` instance to be added to the `OptionsBuilder`.
191 ///
192 /// # Example
193 ///
194 /// ```rust
195 /// # use llm_chain::options::*;
196 /// let mut builder = OptionsBuilder::new();
197 /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
198 /// ```
199 pub fn add_option(&mut self, opt: Opt) {
200 self.opts.push(opt);
201 }
202
203 /// Consumes the `OptionsBuilder`, returning an `Options` instance.
204 ///
205 /// This function consumes the `OptionsBuilder`, moving its `opts` field to a new `Options` instance.
206 ///
207 /// # Returns
208 ///
209 /// An `Options` instance with the options added through the builder.
210 ///
211 /// # Example
212 ///
213 /// ```rust
214 /// # use llm_chain::options::*;
215 /// let mut builder = OptionsBuilder::new();
216 /// builder.add_option(Opt::Model(ModelRef::from_path("path_to_model")));
217 /// let options = builder.build();
218 /// ```
219 pub fn build(self) -> Options {
220 Options { opts: self.opts }
221 }
222}
223
224/// A cascade of option sets.
225///
226/// Options added later in the cascade override earlier options.
227pub struct OptionsCascade<'a> {
228 /// The sets of options, in the order they were added.
229 cascades: Vec<&'a Options>,
230}
231
232impl<'a> OptionsCascade<'a> {
233 /// Creates a new, empty cascade of options.
234 pub fn new() -> Self {
235 OptionsCascade::from_vec(Vec::new())
236 }
237
238 /// Setups a typical options cascade, with model_defaults, environment defaults a model config and possibly a specific config.
239 pub fn new_typical(
240 model_default: &'a Options,
241 env_defaults: &'a Options,
242 model_config: &'a Options,
243 specific_config: Option<&'a Options>,
244 ) -> Self {
245 let mut v = vec![model_default, env_defaults, model_config];
246 if let Some(specific_config) = specific_config {
247 v.push(specific_config);
248 }
249 Self::from_vec(v)
250 }
251
252 /// Creates a new cascade of options from a vector of option sets.
253 pub fn from_vec(cascades: Vec<&'a Options>) -> Self {
254 OptionsCascade { cascades }
255 }
256
257 /// Returns a new cascade of options with the given set of options added.
258 pub fn with_options(mut self, options: &'a Options) -> Self {
259 self.cascades.push(options);
260 self
261 }
262
263 /// Gets the value of an option from this cascade.
264 ///
265 /// Returns `None` if the option is not present in any set in this cascade.
266 /// If the option is present in multiple sets, the value from the most
267 /// recently added set is returned.
268 pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
269 for options in self.cascades.iter().rev() {
270 if let Some(opt) = options.get(opt_discriminant) {
271 return Some(opt);
272 }
273 }
274 None
275 }
276
277 /// Returns a boolean indicating if options indicate that requests should be streamed or not.
278 pub fn is_streaming(&self) -> bool {
279 let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
280 return false;
281 };
282 *val
283 }
284}
285
286impl<'a> Default for OptionsCascade<'a> {
287 /// Returns a new, empty cascade of options.
288 fn default() -> Self {
289 Self::new()
290 }
291}
292
293#[derive(Clone, Debug, Serialize, Deserialize)]
294/// A reference to a model name or path
295/// Useful for
296pub struct ModelRef(String);
297
298impl ModelRef {
299 pub fn from_path<S: Into<String>>(p: S) -> Self {
300 Self(p.into())
301 }
302 pub fn from_model_name<S: Into<String>>(model_name: S) -> Self {
303 Self(model_name.into())
304 }
305 /// Returns the path for this model reference
306 pub fn to_path(&self) -> String {
307 self.0.clone()
308 }
309 /// Returns the name of the model
310 pub fn to_name(&self) -> String {
311 self.0.clone()
312 }
313}
314
315/// A list of tokens to bias during the process of inferencing.
316#[derive(Serialize, Deserialize, Debug, Clone)]
317pub struct TokenBias(Vec<(Token, f32)>); // TODO: Serialize to a JSON object of str(F32) =>
318
319impl TokenBias {
320 /// Returns the token bias as a hashmap where the keys are i32 and the value f32. If the type doesn't match returns None
321 pub fn as_i32_f32_hashmap(&self) -> Option<HashMap<i32, f32>> {
322 let mut map = HashMap::new();
323 for (token, value) in &self.0 {
324 map.insert(token.to_i32()?, *value);
325 }
326 Some(map)
327 }
328}
329
330#[derive(EnumDiscriminants, Clone, Debug, Serialize, Deserialize)]
331pub enum Opt {
332 /// The name or path of the model used.
333 Model(ModelRef),
334 /// The API key for the model service.
335 ApiKey(String),
336 /// The number of threads to use for parallel processing.
337 /// This is common to all models.
338 NThreads(usize),
339 /// The maximum number of tokens that the model will generate.
340 /// This is common to all models.
341 MaxTokens(usize),
342 /// The maximum context size of the model.
343 MaxContextSize(usize),
344 /// The sequences that, when encountered, will cause the model to stop generating further tokens.
345 /// OpenAI models allow up to four stop sequences.
346 StopSequence(Vec<String>),
347 /// Whether or not to use streaming mode.
348 /// This is common to all models.
349 Stream(bool),
350
351 /// The penalty to apply for using frequent tokens.
352 /// This is used by OpenAI and llama models.
353 FrequencyPenalty(f32),
354 /// The penalty to apply for using novel tokens.
355 /// This is used by OpenAI and llama models.
356 PresencePenalty(f32),
357
358 /// A bias to apply to certain tokens during the inference process.
359 /// This is known as logit bias in OpenAI and is also used in llm-chain-local.
360 TokenBias(TokenBias),
361
362 /// The maximum number of tokens to consider for each step of generation.
363 /// This is common to all models, but is not used by OpenAI.
364 TopK(i32),
365 /// The cumulative probability threshold for token selection.
366 /// This is common to all models.
367 TopP(f32),
368 /// The temperature to use for token selection. Higher values result in more random output.
369 /// This is common to all models.
370 Temperature(f32),
371 /// The penalty to apply for repeated tokens.
372 /// This is common to all models.
373 RepeatPenalty(f32),
374 /// The number of most recent tokens to consider when applying the repeat penalty.
375 /// This is common to all models.
376 RepeatPenaltyLastN(usize),
377
378 /// The TfsZ parameter for llm-chain-llama.
379 TfsZ(f32),
380 /// The TypicalP parameter for llm-chain-llama.
381 TypicalP(f32),
382 /// The Mirostat parameter for llm-chain-llama.
383 Mirostat(i32),
384 /// The MirostatTau parameter for llm-chain-llama.
385 MirostatTau(f32),
386 /// The MirostatEta parameter for llm-chain-llama.
387 MirostatEta(f32),
388 /// Whether or not to penalize newline characters for llm-chain-llama.
389 PenalizeNl(bool),
390
391 /// The batch size for llm-chain-local.
392 NBatch(usize),
393 /// The username for llm-chain-openai.
394 User(String),
395 /// The type of the model.
396 ModelType(String),
397}
398
399// Helper function to extract environment variables
400fn option_from_env<K, F>(opts: &mut OptionsBuilder, key: K, f: F) -> Result<(), VarError>
401where
402 K: AsRef<OsStr>,
403 F: FnOnce(String) -> Option<Opt>,
404{
405 match std::env::var(key) {
406 Ok(v) => {
407 if let Some(x) = f(v) {
408 opts.add_option(x);
409 }
410 Ok(())
411 }
412 Err(VarError::NotPresent) => Ok(()),
413 Err(e) => Err(e),
414 }
415}
416
417// Conversion functions for each Opt variant
418fn model_from_string(s: String) -> Option<Opt> {
419 Some(Opt::Model(ModelRef::from_path(s)))
420}
421
422fn api_key_from_string(s: String) -> Option<Opt> {
423 Some(Opt::ApiKey(s))
424}
425
426macro_rules! opt_parse_str {
427 ($v:ident) => {
428 paste! {
429 fn [< $v:snake:lower _from_string >] (s: String) -> Option<Opt> {
430 Some(Opt::$v(s.parse().ok()?))
431 }
432 }
433 };
434}
435
436opt_parse_str!(NThreads);
437opt_parse_str!(MaxTokens);
438opt_parse_str!(MaxContextSize);
439// Skip stop sequence?
440// Skip stream?
441
442opt_parse_str!(FrequencyPenalty);
443opt_parse_str!(PresencePenalty);
444// Skip TokenBias for now
445opt_parse_str!(TopK);
446opt_parse_str!(TopP);
447opt_parse_str!(Temperature);
448opt_parse_str!(RepeatPenalty);
449opt_parse_str!(RepeatPenaltyLastN);
450opt_parse_str!(TfsZ);
451opt_parse_str!(PenalizeNl);
452opt_parse_str!(NBatch);
453
454macro_rules! opt_from_env {
455 ($opt:ident, $v:ident) => {
456 paste! {
457 option_from_env(&mut $opt, stringify!([<
458 LLM_CHAIN_ $v:snake:upper
459 >]), [< $v:snake:lower _from_string >])?;
460 }
461 };
462}
463
464macro_rules! opts_from_env {
465 ($opt:ident, $($v:ident),*) => {
466 $(
467 opt_from_env!($opt, $v);
468 )*
469 };
470}
471
472/// Loads options from environment variables.
473/// Every option that can be easily understood from a string is avaliable the name
474/// of the option will be in upper snake case, that means that the option `Opt::ApiKey` has the environment variable
475/// `LLM_CHAIN_API_KEY`
476pub fn options_from_env() -> Result<Options, VarError> {
477 let mut opts = OptionsBuilder::new();
478 opts_from_env!(
479 opts,
480 Model,
481 ApiKey,
482 NThreads,
483 MaxTokens,
484 MaxContextSize,
485 FrequencyPenalty,
486 PresencePenalty,
487 TopK,
488 TopP,
489 Temperature,
490 RepeatPenalty,
491 RepeatPenaltyLastN,
492 TfsZ,
493 PenalizeNl,
494 NBatch
495 );
496 Ok(opts.build())
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502 // Tests for FromStr
503 #[test]
504 fn test_options_from_env() {
505 use std::env;
506 let orig_model = "/123/123.bin";
507 let orig_nbatch = 1_usize;
508 let orig_api_key = "!asd";
509 env::set_var("LLM_CHAIN_MODEL", orig_model);
510 env::set_var("LLM_CHAIN_N_BATCH", orig_nbatch.to_string());
511 env::set_var("LLM_CHAIN_API_KEY", orig_api_key);
512 let opts = options_from_env().unwrap();
513 let model_path = opts
514 .get(OptDiscriminants::Model)
515 .and_then(|x| match x {
516 Opt::Model(m) => Some(m),
517 _ => None,
518 })
519 .unwrap();
520 let nbatch = opts
521 .get(OptDiscriminants::NBatch)
522 .and_then(|x| match x {
523 Opt::NBatch(m) => Some(m),
524 _ => None,
525 })
526 .unwrap();
527 let api_key = opts
528 .get(OptDiscriminants::ApiKey)
529 .and_then(|x| match x {
530 Opt::ApiKey(m) => Some(m),
531 _ => None,
532 })
533 .unwrap();
534 assert_eq!(model_path.to_path(), orig_model);
535 assert_eq!(nbatch.clone(), orig_nbatch);
536 assert_eq!(api_key, orig_api_key);
537 }
538}