mod types;
#[cfg(feature = "anthropic")]
mod anthropic_compat;
#[cfg(feature = "google")]
mod google_gemini;
#[cfg(any(
feature = "openai",
feature = "aliyun",
feature = "ollama",
feature = "zhipu"
))]
mod openai_compat;
#[cfg(feature = "anthropic")]
use anthropic_compat::AnthropicCompatChat;
use async_trait::async_trait;
#[cfg(feature = "google")]
use google_gemini::GoogleGeminiChat;
#[cfg(any(
feature = "openai",
feature = "aliyun",
feature = "ollama",
feature = "zhipu"
))]
use openai_compat::OpenaiCompatChat;
pub use types::{
ChatEvent, ChatMessage, ChatRequest, ChatResponse, FinishReason, FunctionCallResult,
FunctionDefinition, RequestPreset, ResponseFormat, Role, ToolCall, ToolCallDelta, ToolChoice,
ToolDefinition,
};
use std::pin::Pin;
use futures::Stream;
use crate::config::Provider;
use crate::config::ProviderConfig;
use crate::error::{Error, Result};
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatEvent>> + Send>>;
#[async_trait]
pub trait ChatProvider: Send + Sync {
async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse>;
async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream>;
async fn chat(&self, prompt: &str) -> Result<String> {
let resp = self.complete(&ChatRequest::single_user(prompt)).await?;
resp.content
.filter(|s| !s.is_empty())
.ok_or(Error::MissingField("response content"))
}
async fn chat_stream(&self, prompt: &str) -> Result<ChatStream> {
self.complete_stream(&ChatRequest::single_user(prompt))
.await
}
}
pub(crate) fn create(config: &ProviderConfig) -> Result<Box<dyn ChatProvider>> {
match config.provider {
#[cfg(feature = "openai")]
Provider::OpenAI => Ok(Box::new(OpenaiCompatChat::new(config)?)),
#[cfg(not(feature = "openai"))]
Provider::OpenAI => Err(crate::error::Error::ProviderDisabled("openai".to_string())),
#[cfg(feature = "aliyun")]
Provider::Aliyun => Ok(Box::new(OpenaiCompatChat::new(config)?)),
#[cfg(not(feature = "aliyun"))]
Provider::Aliyun => Err(crate::error::Error::ProviderDisabled("aliyun".to_string())),
#[cfg(feature = "anthropic")]
Provider::Anthropic => Ok(Box::new(AnthropicCompatChat::new(config)?)),
#[cfg(not(feature = "anthropic"))]
Provider::Anthropic => Err(crate::error::Error::ProviderDisabled(
"anthropic".to_string(),
)),
#[cfg(feature = "google")]
Provider::Google => Ok(Box::new(GoogleGeminiChat::new(config)?)),
#[cfg(not(feature = "google"))]
Provider::Google => Err(crate::error::Error::ProviderDisabled("google".to_string())),
#[cfg(feature = "ollama")]
Provider::Ollama => Ok(Box::new(OpenaiCompatChat::new(config)?)),
#[cfg(not(feature = "ollama"))]
Provider::Ollama => Err(crate::error::Error::ProviderDisabled("ollama".to_string())),
#[cfg(feature = "zhipu")]
Provider::Zhipu => Ok(Box::new(OpenaiCompatChat::new(config)?)),
#[cfg(not(feature = "zhipu"))]
Provider::Zhipu => Err(crate::error::Error::ProviderDisabled("zhipu".to_string())),
}
}
pub fn merge_tool_call_deltas(deltas: &[ToolCallDelta]) -> Vec<ToolCall> {
use std::collections::BTreeMap;
let mut map: BTreeMap<u32, (Option<String>, Option<String>, String)> = BTreeMap::new();
for d in deltas {
let entry = map.entry(d.index).or_insert((None, None, String::new()));
if entry.0.is_none() {
entry.0 = d.id.clone();
}
if entry.1.is_none() {
entry.1 = d.function_name.clone();
}
if let Some(args) = &d.function_arguments {
entry.2.push_str(args);
}
}
map.into_values()
.filter_map(|(id, name, args)| {
let name = name?;
Some(ToolCall {
id: id.unwrap_or_else(|| format!("tool_call_{}", name)),
function: crate::chat::FunctionCallResult {
name,
arguments: args,
},
})
})
.collect()
}