use crate::{
language_models::{llm::LLM, model_parser::parse_model_string, options::CallOptions, LLMError},
llm::{
claude::Claude,
deepseek::Deepseek,
openai::{AzureConfig, OpenAI, OpenAIConfig},
qwen::Qwen,
},
};
#[cfg(feature = "ollama")]
use crate::llm::Ollama;
#[cfg(feature = "mistralai")]
use crate::llm::MistralAI;
#[cfg(feature = "gemini")]
use crate::llm::Gemini;
#[cfg(feature = "bedrock")]
use crate::llm::Bedrock;
use crate::llm::HuggingFace;
pub async fn init_chat_model(
model: &str,
temperature: Option<f32>,
max_tokens: Option<u32>,
timeout: Option<u64>,
max_retries: Option<u32>,
api_key: Option<String>,
base_url: Option<String>,
azure_deployment: Option<String>,
) -> Result<Box<dyn LLM>, LLMError> {
let parsed = parse_model_string(model)?;
let provider = parsed.provider.as_deref();
let mut options = CallOptions::new();
if let Some(temp) = temperature {
options = options.with_temperature(temp);
}
if let Some(max) = max_tokens {
options = options.with_max_tokens(max);
}
if let Some(to) = timeout {
options = options.with_timeout(to);
}
if let Some(retries) = max_retries {
options = options.with_max_retries(retries);
}
if let Some(key) = api_key {
options = options.with_api_key(key);
}
if let Some(url) = base_url {
options = options.with_base_url(url);
}
let model_name = parsed.model.as_str();
let model_lower = model_name.to_lowercase();
if provider == Some("azure_openai")
|| (provider.is_none() && model_lower.starts_with("gpt-") && azure_deployment.is_some())
{
let mut config = AzureConfig::default();
if let Some(key) = &options.api_key {
config = config.with_api_key(key.clone());
}
if let Some(url) = &options.base_url {
config = config.with_api_base(url.clone());
}
if let Some(deployment) = azure_deployment {
config = config.with_deployment_id(deployment);
}
let llm: OpenAI<AzureConfig> = OpenAI::new(config)
.with_model(model_name)
.with_options(options);
return Ok(Box::new(llm));
}
if provider == Some("openai") || (provider.is_none() && model_lower.starts_with("gpt-")) {
let mut config = OpenAIConfig::default();
if let Some(key) = &options.api_key {
config = config.with_api_key(key.clone());
}
let llm: OpenAI<OpenAIConfig> = OpenAI::new(config)
.with_model(model_name)
.with_options(options);
return Ok(Box::new(llm));
}
if provider == Some("anthropic")
|| provider == Some("claude")
|| (provider.is_none()
&& (model_lower.starts_with("claude-") || model_lower.starts_with("claude")))
{
let mut llm = Claude::default().with_model(model_name);
if let Some(key) = &options.api_key {
llm = llm.with_api_key(key.clone());
}
llm = llm.with_options(options);
return Ok(Box::new(llm));
}
if provider == Some("qwen")
|| (provider.is_none() && (model_lower.starts_with("qwen-") || model_lower == "qwen"))
{
let mut llm = Qwen::default().with_model(model_name);
if let Some(key) = &options.api_key {
llm = llm.with_api_key(key.clone());
}
if let Some(url) = &options.base_url {
llm = llm.with_base_url(url.clone());
}
llm = llm.with_options(options);
return Ok(Box::new(llm));
}
if provider == Some("deepseek")
|| (provider.is_none()
&& (model_lower.starts_with("deepseek-") || model_lower == "deepseek"))
{
let mut llm = Deepseek::default().with_model(model_name);
if let Some(key) = &options.api_key {
llm = llm.with_api_key(key.clone());
}
if let Some(url) = &options.base_url {
llm = llm.with_base_url(url.clone());
}
llm = llm.with_options(options);
return Ok(Box::new(llm));
}
#[cfg(feature = "mistralai")]
{
if provider == Some("mistralai")
|| provider == Some("mistral")
|| (provider.is_none()
&& (model_lower.starts_with("mistral-")
|| model_lower.starts_with("mixtral-")
|| model_lower.starts_with("pixtral-")))
{
let mut llm = MistralAI::default().with_model(model_name);
if let Some(key) = &options.api_key {
llm = llm.with_api_key(key.clone());
}
if let Some(url) = &options.base_url {
llm = llm.with_base_url(url.clone());
}
llm = llm.with_options(options);
return Ok(Box::new(llm));
}
}
#[cfg(feature = "gemini")]
{
if provider == Some("gemini")
|| provider == Some("google")
|| provider == Some("google_genai")
|| (provider.is_none()
&& (model_lower.starts_with("gemini-") || model_lower.starts_with("gemini")))
{
let mut llm = Gemini::default().with_model(model_name);
if let Some(key) = &options.api_key {
llm = llm.with_api_key(key.clone());
}
if let Some(url) = &options.base_url {
llm = llm.with_base_url(url.clone());
}
llm = llm.with_options(options);
return Ok(Box::new(llm));
}
}
#[cfg(feature = "bedrock")]
{
if provider == Some("bedrock")
|| provider == Some("aws_bedrock")
|| (provider.is_none()
&& (model_lower.contains("anthropic.claude")
|| model_lower.contains("meta.llama")
|| model_lower.contains("amazon.titan")))
{
let bedrock = Bedrock::new().await?;
let bedrock = bedrock.with_model(model_name);
let bedrock = bedrock.with_options(options);
return Ok(Box::new(bedrock));
}
}
{
if provider == Some("huggingface")
|| provider == Some("hf")
|| (provider.is_none() && model_lower.contains("/"))
{
let mut llm = HuggingFace::default().with_model(model_name);
if let Some(key) = &options.api_key {
llm = llm.with_api_key(key.clone());
}
if let Some(url) = &options.base_url {
llm = llm.with_base_url(url.clone());
}
llm = llm.with_options(options);
return Ok(Box::new(llm));
}
}
#[cfg(feature = "ollama")]
{
let mut gen_options = ollama_rs::generation::options::GenerationOptions::default();
if let Some(temp) = options.temperature {
gen_options = gen_options.temperature(temp);
}
if let Some(max_tokens) = options.max_tokens {
gen_options = gen_options.num_predict(max_tokens as i32);
}
if let Some(top_p) = options.top_p {
gen_options = gen_options.top_p(top_p);
}
if let Some(top_k) = options.top_k {
gen_options = gen_options.top_k(top_k as u32);
}
if let Some(seed) = options.seed {
gen_options = gen_options.seed(seed as i32);
}
let llm = Ollama::default()
.with_model(model_name)
.with_options(gen_options);
return Ok(Box::new(llm));
}
#[cfg(not(feature = "ollama"))]
{
#[cfg(feature = "mistralai")]
{
Err(LLMError::OtherError(format!(
"Unsupported model: {}. Supported providers: openai, azure_openai, anthropic, qwen, deepseek, mistralai. Enable 'ollama' feature for Ollama models.",
model
)))
}
#[cfg(not(feature = "mistralai"))]
{
Err(LLMError::OtherError(format!(
"Unsupported model: {}. Supported providers: openai, azure_openai, anthropic, qwen, deepseek. Enable 'mistralai' feature for MistralAI models. Enable 'ollama' feature for Ollama models.",
model
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_init_chat_model_openai() {
let result = init_chat_model("gpt-4o-mini", None, None, None, None, None, None, None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_init_chat_model_with_provider() {
let result = init_chat_model(
"openai:gpt-4o-mini",
None,
None,
None,
None,
None,
None,
None,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_init_chat_model_with_params() {
let result = init_chat_model(
"gpt-4o-mini",
Some(0.7),
Some(1000),
Some(30),
Some(3),
None,
None,
None,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_init_chat_model_claude() {
let result = init_chat_model(
"claude-3-5-sonnet-20240620",
None,
None,
None,
None,
None,
None,
None,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
#[ignore = "init_chat_model may accept unknown model names"]
async fn test_init_chat_model_invalid() {
let result = init_chat_model(
"invalid-model-xyz",
None,
None,
None,
None,
None,
None,
None,
)
.await;
assert!(result.is_err());
}
}