use crate::language_models::{llm::LLM, options::CallOptions};
pub trait LLMConfig: Send + Sync {
fn model(&self) -> &str;
fn options(&self) -> &CallOptions;
fn set_model(&mut self, model: String);
fn set_options(&mut self, options: CallOptions);
}
pub trait LLMBuilder: Sized {
fn new() -> Self;
fn with_model<S: Into<String>>(self, model: S) -> Self;
fn with_options(self, options: CallOptions) -> Self;
}
pub struct LLMHelpers;
impl LLMHelpers {
pub fn validate_model_name(model: &str) -> Result<(), String> {
if model.is_empty() {
return Err("Model name cannot be empty".to_string());
}
if model.len() > 256 {
return Err("Model name too long (max 256 characters)".to_string());
}
Ok(())
}
pub fn get_api_key_from_env(env_var: &str, default: &str) -> String {
std::env::var(env_var).unwrap_or_else(|_| default.to_string())
}
pub fn merge_options(_base: CallOptions, override_opts: CallOptions) -> CallOptions {
override_opts
}
pub fn default_options() -> CallOptions {
CallOptions::default()
}
}
#[derive(Debug, Clone)]
pub struct LLMInitConfig {
pub model: Option<String>,
pub api_key: Option<String>,
pub base_url: Option<String>,
pub options: Option<CallOptions>,
}
impl LLMInitConfig {
pub fn new() -> Self {
Self {
model: None,
api_key: None,
base_url: None,
options: None,
}
}
pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
self.model = Some(model.into());
self
}
pub fn with_api_key<S: Into<String>>(mut self, api_key: S) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn with_base_url<S: Into<String>>(mut self, base_url: S) -> Self {
self.base_url = Some(base_url.into());
self
}
pub fn with_options(mut self, options: CallOptions) -> Self {
self.options = Some(options);
self
}
}
impl Default for LLMInitConfig {
fn default() -> Self {
Self::new()
}
}
pub trait StreamingLLM: LLM {
fn supports_streaming(&self) -> bool {
true }
fn default_streaming_config(&self) -> CallOptions {
CallOptions::default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_model_name() {
assert!(LLMHelpers::validate_model_name("gpt-4").is_ok());
assert!(LLMHelpers::validate_model_name("").is_err());
}
#[test]
fn test_llm_init_config() {
let config = LLMInitConfig::new()
.with_model("gpt-4")
.with_api_key("test-key");
assert_eq!(config.model, Some("gpt-4".to_string()));
assert_eq!(config.api_key, Some("test-key".to_string()));
}
}