1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
use llm_chain::traits;
use serde::{Deserialize, Serialize};
use strum_macros::EnumString;
/// The `Model` enum represents the available ChatGPT models that you can use through the OpenAI
/// API.
///
/// These models have different capabilities and performance characteristics, allowing you to choose
/// the one that best suits your needs. See <https://platform.openai.com/docs/models> for more
/// information.
///
/// # Example
///
/// ```
/// use llm_chain_openai::chatgpt::Model;
///
/// let turbo_model = Model::Gpt35Turbo;
/// let custom_model = Model::Other("your_custom_model_name".to_string());
/// ```
#[derive(Debug, Default, Clone, Serialize, Deserialize, EnumString, PartialEq, Eq)]
#[non_exhaustive]
pub enum Model {
/// A high-performance and versatile model that offers a great balance of speed, quality, and
/// affordability.
#[default]
#[strum(
serialize = "gpt-3.5-turbo",
serialize = "gpt-35-turbo",
serialize = "gpt3.5",
serialize = "gpt35"
)]
Gpt35Turbo,
/// Snapshot of gpt-3.5-turbo from March 1st 2023. Unlike gpt-3.5-turbo, this model will not
/// receive updates, and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-3.5-turbo-0301")]
Gpt35Turbo0301,
/// A high-performance model that offers the best quality, but is slower and more expensive than
/// the `ChatGPT3_5Turbo` model.
#[strum(serialize = "gpt-4", serialize = "gpt4")]
Gpt4,
/// Snapshot of gpt-4 from March 14th 2023. Unlike gpt-4, this model will not receive updates,
/// and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-4-0314")]
Gpt4_0314,
/// Same capabilities as the base gpt-4 mode but with 4x the context length. Will be updated
/// with our latest model iteration.
#[strum(serialize = "gpt-4-32k")]
Gpt4_32k,
/// Snapshot of gpt-4-32 from March 14th 2023. Unlike gpt-4-32k, this model will not receive
/// updates, and will be deprecated 3 months after a new version is released.
#[strum(serialize = "gpt-4-32k-0314")]
Gpt4_32k0314,
/// A variant that allows you to specify a custom model name as a string, in case new models
/// are introduced or you have access to specialized models.
#[strum(default)]
Other(String),
}
impl Model {
/// included for backwards compatibility
#[deprecated(note = "Use `Model::Gpt35Turbo` instead")]
#[allow(non_upper_case_globals)]
pub const ChatGPT3_5Turbo: Model = Model::Gpt35Turbo;
/// included for backwards compatibility
#[deprecated(note = "Use `Model::Gpt4` instead")]
pub const GPT4: Model = Model::Gpt4;
}
/// The `Model` enum implements the `ToString` trait, allowing you to easily convert it to a string.
impl ToString for Model {
fn to_string(&self) -> String {
match &self {
Model::Gpt35Turbo => "gpt-3.5-turbo".to_string(),
Model::Gpt4 => "gpt-4".to_string(),
Model::Gpt35Turbo0301 => "gpt-3.5-turbo-0301".to_string(),
Model::Gpt4_0314 => "gpt-4-0314".to_string(),
Model::Gpt4_32k => "gpt-4-32k".to_string(),
Model::Gpt4_32k0314 => "gpt-4-32k-0314".to_string(),
Model::Other(model) => model.to_string(),
}
}
}
/// The `PerInvocation` struct contains options that can be specified for each ChatGPT invocation.
/// Currently, it only supports specifying a `Model`.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PerInvocation {
pub(crate) model: Option<Model>,
}
impl PerInvocation {
/// Creates a new `PerInvocation` struct with default values.
pub fn new() -> Self {
Self::default()
}
/// Sets the `Model` for the `PerInvocation` struct.
pub fn for_model(self, model: Model) -> Self {
Self { model: Some(model) }
}
}
impl traits::Options for PerInvocation {}
/// The `PerExecutor` struct contains options that can be specified for the ChatGPT `Executor`.
/// Currently, it only supports specifying an `api_key`.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PerExecutor {
pub api_key: Option<String>,
}
impl traits::Options for PerExecutor {}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
// Tests for FromStr
#[test]
fn test_from_str() -> Result<(), Box<dyn std::error::Error>> {
assert_eq!(Model::from_str("gpt-3.5-turbo")?, Model::Gpt35Turbo);
assert_eq!(
Model::from_str("gpt-3.5-turbo-0301")?,
Model::Gpt35Turbo0301
);
assert_eq!(Model::from_str("gpt-4")?, Model::Gpt4);
assert_eq!(Model::from_str("gpt-4-0314")?, Model::Gpt4_0314);
assert_eq!(Model::from_str("gpt-4-32k")?, Model::Gpt4_32k);
assert_eq!(Model::from_str("gpt-4-32k-0314")?, Model::Gpt4_32k0314);
assert_eq!(
Model::from_str("custom_model")?,
Model::Other("custom_model".to_string())
);
Ok(())
}
// Test ToString
#[test]
fn test_to_string() {
assert_eq!(Model::Gpt35Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Model::Gpt4.to_string(), "gpt-4");
assert_eq!(Model::Gpt35Turbo0301.to_string(), "gpt-3.5-turbo-0301");
assert_eq!(Model::Gpt4_0314.to_string(), "gpt-4-0314");
assert_eq!(Model::Gpt4_32k.to_string(), "gpt-4-32k");
assert_eq!(Model::Gpt4_32k0314.to_string(), "gpt-4-32k-0314");
assert_eq!(
Model::Other("custom_model".to_string()).to_string(),
"custom_model"
);
}
#[test]
fn test_to_string_deprecated() {
assert_eq!(Model::ChatGPT3_5Turbo.to_string(), "gpt-3.5-turbo");
assert_eq!(Model::GPT4.to_string(), "gpt-4");
}
}