use lazy_static::lazy_static;
use paste::paste;
use std::{collections::HashMap, env::VarError, ffi::OsStr};
use serde::{Deserialize, Serialize};
use strum_macros::EnumDiscriminants;
use crate::tokens::Token;
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct Options {
opts: Vec<Opt>,
}
#[derive(thiserror::Error, Debug)]
#[error("Option not set")]
struct OptionNotSetError;
lazy_static! {
static ref EMPTY_OPTIONS: Options = Options::builder().build();
}
impl Options {
pub fn builder() -> OptionsBuilder {
OptionsBuilder::new()
}
pub fn empty() -> &'static Self {
&EMPTY_OPTIONS
}
pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
self.opts
.iter()
.find(|opt| OptDiscriminants::from(*opt) == opt_discriminant)
}
}
#[macro_export]
macro_rules! options {
( $( $opt_name:ident : $opt_value:expr ),* ) => {
{
let mut _opts = $crate::options::Options::builder();
$(
_opts.add_option($crate::options::Opt::$opt_name($opt_value.into()));
)*
_opts.build()
}
};
}
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct OptionsBuilder {
opts: Vec<Opt>,
}
impl OptionsBuilder {
pub fn new() -> Self {
OptionsBuilder { opts: Vec::new() }
}
pub fn add_option(&mut self, opt: Opt) {
self.opts.push(opt);
}
pub fn build(self) -> Options {
Options { opts: self.opts }
}
}
pub struct OptionsCascade<'a> {
cascades: Vec<&'a Options>,
}
impl<'a> OptionsCascade<'a> {
pub fn new() -> Self {
OptionsCascade::from_vec(Vec::new())
}
pub fn new_typical(
model_default: &'a Options,
env_defaults: &'a Options,
model_config: &'a Options,
specific_config: Option<&'a Options>,
) -> Self {
let mut v = vec![model_default, env_defaults, model_config];
if let Some(specific_config) = specific_config {
v.push(specific_config);
}
Self::from_vec(v)
}
pub fn from_vec(cascades: Vec<&'a Options>) -> Self {
OptionsCascade { cascades }
}
pub fn with_options(mut self, options: &'a Options) -> Self {
self.cascades.push(options);
self
}
pub fn get(&self, opt_discriminant: OptDiscriminants) -> Option<&Opt> {
for options in self.cascades.iter().rev() {
if let Some(opt) = options.get(opt_discriminant) {
return Some(opt);
}
}
None
}
pub fn is_streaming(&self) -> bool {
let Some(Opt::Stream(val)) = self.get(OptDiscriminants::Stream) else {
return false;
};
*val
}
}
impl<'a> Default for OptionsCascade<'a> {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelRef(String);
impl ModelRef {
pub fn from_path<S: Into<String>>(p: S) -> Self {
Self(p.into())
}
pub fn from_model_name<S: Into<String>>(model_name: S) -> Self {
Self(model_name.into())
}
pub fn to_path(&self) -> String {
self.0.clone()
}
pub fn to_name(&self) -> String {
self.0.clone()
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TokenBias(Vec<(Token, f32)>);
impl TokenBias {
pub fn as_i32_f32_hashmap(&self) -> Option<HashMap<i32, f32>> {
let mut map = HashMap::new();
for (token, value) in &self.0 {
map.insert(token.to_i32()?, *value);
}
Some(map)
}
}
#[derive(EnumDiscriminants, Clone, Debug, Serialize, Deserialize)]
pub enum Opt {
Model(ModelRef),
ApiKey(String),
NThreads(usize),
MaxTokens(usize),
MaxContextSize(usize),
StopSequence(Vec<String>),
Stream(bool),
FrequencyPenalty(f32),
PresencePenalty(f32),
TokenBias(TokenBias),
TopK(i32),
TopP(f32),
Temperature(f32),
RepeatPenalty(f32),
RepeatPenaltyLastN(usize),
TfsZ(f32),
TypicalP(f32),
Mirostat(i32),
MirostatTau(f32),
MirostatEta(f32),
PenalizeNl(bool),
NBatch(usize),
User(String),
ModelType(String),
}
fn option_from_env<K, F>(opts: &mut OptionsBuilder, key: K, f: F) -> Result<(), VarError>
where
K: AsRef<OsStr>,
F: FnOnce(String) -> Option<Opt>,
{
match std::env::var(key) {
Ok(v) => {
if let Some(x) = f(v) {
opts.add_option(x);
}
Ok(())
}
Err(VarError::NotPresent) => Ok(()),
Err(e) => Err(e),
}
}
fn model_from_string(s: String) -> Option<Opt> {
Some(Opt::Model(ModelRef::from_path(s)))
}
fn api_key_from_string(s: String) -> Option<Opt> {
Some(Opt::ApiKey(s))
}
macro_rules! opt_parse_str {
($v:ident) => {
paste! {
fn [< $v:snake:lower _from_string >] (s: String) -> Option<Opt> {
Some(Opt::$v(s.parse().ok()?))
}
}
};
}
opt_parse_str!(NThreads);
opt_parse_str!(MaxTokens);
opt_parse_str!(MaxContextSize);
opt_parse_str!(FrequencyPenalty);
opt_parse_str!(PresencePenalty);
opt_parse_str!(TopK);
opt_parse_str!(TopP);
opt_parse_str!(Temperature);
opt_parse_str!(RepeatPenalty);
opt_parse_str!(RepeatPenaltyLastN);
opt_parse_str!(TfsZ);
opt_parse_str!(PenalizeNl);
opt_parse_str!(NBatch);
macro_rules! opt_from_env {
($opt:ident, $v:ident) => {
paste! {
option_from_env(&mut $opt, stringify!([<
LLM_CHAIN_ $v:snake:upper
>]), [< $v:snake:lower _from_string >])?;
}
};
}
macro_rules! opts_from_env {
($opt:ident, $($v:ident),*) => {
$(
opt_from_env!($opt, $v);
)*
};
}
pub fn options_from_env() -> Result<Options, VarError> {
let mut opts = OptionsBuilder::new();
opts_from_env!(
opts,
Model,
ApiKey,
NThreads,
MaxTokens,
MaxContextSize,
FrequencyPenalty,
PresencePenalty,
TopK,
TopP,
Temperature,
RepeatPenalty,
RepeatPenaltyLastN,
TfsZ,
PenalizeNl,
NBatch
);
Ok(opts.build())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_options_from_env() {
use std::env;
let orig_model = "/123/123.bin";
let orig_nbatch = 1_usize;
let orig_api_key = "!asd";
env::set_var("LLM_CHAIN_MODEL", orig_model);
env::set_var("LLM_CHAIN_N_BATCH", orig_nbatch.to_string());
env::set_var("LLM_CHAIN_API_KEY", orig_api_key);
let opts = options_from_env().unwrap();
let model_path = opts
.get(OptDiscriminants::Model)
.and_then(|x| match x {
Opt::Model(m) => Some(m),
_ => None,
})
.unwrap();
let nbatch = opts
.get(OptDiscriminants::NBatch)
.and_then(|x| match x {
Opt::NBatch(m) => Some(m),
_ => None,
})
.unwrap();
let api_key = opts
.get(OptDiscriminants::ApiKey)
.and_then(|x| match x {
Opt::ApiKey(m) => Some(m),
_ => None,
})
.unwrap();
assert_eq!(model_path.to_path(), orig_model);
assert_eq!(nbatch.clone(), orig_nbatch);
assert_eq!(api_key, orig_api_key);
}
}