1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
use std::{fmt::Display, str::FromStr};
use std::time::Duration;
#[cfg(feature = "functions")]
use crate::functions::FunctionValidationStrategy;
use derive_builder::Builder;
use serde::Serialize;
/// The struct containing main configuration for the ChatGPT API
#[derive(Debug, Clone, PartialEq, PartialOrd, Builder)]
#[builder(default, setter(into))]
pub struct ModelConfiguration {
/// The GPT version used.
pub engine: ChatGPTEngine,
/// Controls randomness of the output. Higher values means more random
pub temperature: f32,
/// Controls diversity via nucleus sampling, not recommended to use with temperature
pub top_p: f32,
/// Controls the maximum number of tokens to generate in the completion
pub max_tokens: Option<u32>,
/// Determines how much to penalize new tokens passed on their existing presence so far
pub presence_penalty: f32,
/// Determines how much to penalize new tokens based on their existing frequency so far
pub frequency_penalty: f32,
/// The maximum amount of replies
pub reply_count: u32,
/// URL of the /v1/chat/completions endpoint. Can be used to set a proxy
pub api_url: url::Url,
/// Timeout for the http requests sent to avoid potentially permanently hanging requests.
pub timeout: Duration,
/// Strategy for function validation strategy. Whenever ChatGPT fails to call a function correctly, this strategy is applied.
#[cfg(feature = "functions")]
pub function_validation: FunctionValidationStrategy,
}
impl Default for ModelConfiguration {
fn default() -> Self {
Self {
engine: Default::default(),
temperature: 0.5,
top_p: 1.0,
max_tokens: None,
presence_penalty: 0.0,
frequency_penalty: 0.0,
reply_count: 1,
api_url: url::Url::from_str("https://api.openai.com/v1/chat/completions").unwrap(),
timeout: Duration::from_secs(10),
#[cfg(feature = "functions")]
function_validation: FunctionValidationStrategy::default(),
}
}
}
/// The engine version for ChatGPT
#[derive(Serialize, Debug, Default, Copy, Clone, PartialEq, PartialOrd)]
#[allow(non_camel_case_types)]
pub enum ChatGPTEngine {
/// Standard engine: `gpt-3.5-turbo`
#[default]
Gpt35Turbo,
/// Different version of standard engine: `gpt-3.5-turbo-0301`
Gpt35Turbo_0301,
/// Base GPT-4 model: `gpt-4`
Gpt4,
/// Version of GPT-4, able to remember 32,000 tokens: `gpt-4-32k`
Gpt4_32k,
/// Different version of GPT-4: `gpt-4-0314`
Gpt4_0314,
/// Different version of GPT-4, able to remember 32,000 tokens: `gpt-4-32k-0314`
Gpt4_32k_0314,
/// Custom (or new/unimplemented) version of ChatGPT
Custom(&'static str),
}
impl Display for ChatGPTEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_ref())
}
}
impl AsRef<str> for ChatGPTEngine {
fn as_ref(&self) -> &'static str {
match self {
ChatGPTEngine::Gpt35Turbo => "gpt-3.5-turbo",
ChatGPTEngine::Gpt35Turbo_0301 => "gpt-3.5-turbo-0301",
ChatGPTEngine::Gpt4 => "gpt-4",
ChatGPTEngine::Gpt4_32k => "gpt-4-32k",
ChatGPTEngine::Gpt4_0314 => "gpt-4-0314",
ChatGPTEngine::Gpt4_32k_0314 => "gpt-4-32k-0314",
ChatGPTEngine::Custom(custom) => custom,
}
}
}