use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::client::LlmClient;
use crate::error::LlmError;
use crate::stream::ChatStream;
use crate::traits::*;
use crate::types::*;
pub mod guide;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomProviderConfig {
pub name: String,
pub base_url: String,
pub api_key: String,
pub model: Option<String>,
pub headers: HashMap<String, String>,
pub timeout: Option<u64>,
pub custom_params: HashMap<String, serde_json::Value>,
}
impl CustomProviderConfig {
pub fn new<S: Into<String>>(name: S, base_url: S, api_key: S) -> Self {
Self {
name: name.into(),
base_url: base_url.into(),
api_key: api_key.into(),
model: None,
headers: HashMap::new(),
timeout: Some(30),
custom_params: HashMap::new(),
}
}
pub fn with_model<S: Into<String>>(mut self, model: S) -> Self {
self.model = Some(model.into());
self
}
pub fn with_header<K: Into<String>, V: Into<String>>(mut self, key: K, value: V) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub const fn with_timeout(mut self, timeout_seconds: u64) -> Self {
self.timeout = Some(timeout_seconds);
self
}
pub fn with_param<K: Into<String>, V: Serialize>(mut self, key: K, value: V) -> Self {
if let Ok(json_value) = serde_json::to_value(value) {
self.custom_params.insert(key.into(), json_value);
}
self
}
}
#[async_trait]
pub trait CustomProvider: Send + Sync {
fn name(&self) -> &str;
fn supported_models(&self) -> Vec<String>;
fn capabilities(&self) -> ProviderCapabilities;
async fn chat(&self, request: CustomChatRequest) -> Result<CustomChatResponse, LlmError>;
async fn chat_stream(&self, request: CustomChatRequest) -> Result<ChatStream, LlmError>;
fn validate_config(&self, config: &CustomProviderConfig) -> Result<(), LlmError> {
if config.name.is_empty() {
return Err(LlmError::InvalidParameter(
"Provider name cannot be empty".to_string(),
));
}
if config.base_url.is_empty() {
return Err(LlmError::InvalidParameter(
"Base URL cannot be empty".to_string(),
));
}
if config.api_key.is_empty() {
return Err(LlmError::InvalidParameter(
"API key cannot be empty".to_string(),
));
}
Ok(())
}
fn transform_request(&self, _request: &mut CustomChatRequest) -> Result<(), LlmError> {
Ok(())
}
fn transform_response(&self, _response: &mut CustomChatResponse) -> Result<(), LlmError> {
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomChatRequest {
pub messages: Vec<ChatMessage>,
pub model: String,
pub tools: Option<Vec<Tool>>,
pub stream: bool,
pub params: HashMap<String, serde_json::Value>,
}
impl CustomChatRequest {
pub fn new(messages: Vec<ChatMessage>, model: String) -> Self {
Self {
messages,
model,
tools: None,
stream: false,
params: HashMap::new(),
}
}
pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = Some(tools);
self
}
pub const fn with_stream(mut self, stream: bool) -> Self {
self.stream = stream;
self
}
pub fn with_param<K: Into<String>, V: Serialize>(mut self, key: K, value: V) -> Self {
if let Ok(json_value) = serde_json::to_value(value) {
self.params.insert(key.into(), json_value);
}
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomChatResponse {
pub content: String,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>,
pub finish_reason: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl CustomChatResponse {
pub fn new(content: String) -> Self {
Self {
content,
tool_calls: None,
usage: None,
finish_reason: None,
metadata: HashMap::new(),
}
}
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(tool_calls);
self
}
pub const fn with_usage(mut self, usage: Usage) -> Self {
self.usage = Some(usage);
self
}
pub fn with_finish_reason<S: Into<String>>(mut self, reason: S) -> Self {
self.finish_reason = Some(reason.into());
self
}
pub fn with_metadata<K: Into<String>, V: Serialize>(mut self, key: K, value: V) -> Self {
if let Ok(json_value) = serde_json::to_value(value) {
self.metadata.insert(key.into(), json_value);
}
self
}
pub fn to_chat_response(&self, _provider_name: &str) -> ChatResponse {
ChatResponse {
id: None,
content: MessageContent::Text(self.content.clone()),
model: None,
usage: self.usage.clone(),
finish_reason: self.finish_reason.as_ref().map(|r| match r.as_str() {
"stop" => FinishReason::Stop,
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCalls,
"content_filter" => FinishReason::ContentFilter,
_ => FinishReason::Other(r.clone()),
}),
tool_calls: self.tool_calls.clone(),
thinking: None,
metadata: self.metadata.clone(),
}
}
}
pub struct CustomProviderClient {
provider: Box<dyn CustomProvider>,
config: CustomProviderConfig,
http_client: reqwest::Client,
}
impl CustomProviderClient {
pub fn new(
provider: Box<dyn CustomProvider>,
config: CustomProviderConfig,
) -> Result<Self, LlmError> {
provider.validate_config(&config)?;
let mut client_builder = reqwest::Client::builder();
if let Some(timeout) = config.timeout {
client_builder = client_builder.timeout(std::time::Duration::from_secs(timeout));
}
let http_client = client_builder.build().map_err(|e| {
LlmError::ConfigurationError(format!("Failed to create HTTP client: {e}"))
})?;
Self::with_http_client(provider, config, http_client)
}
pub fn with_http_client(
provider: Box<dyn CustomProvider>,
config: CustomProviderConfig,
http_client: reqwest::Client,
) -> Result<Self, LlmError> {
provider.validate_config(&config)?;
Ok(Self {
provider,
config,
http_client,
})
}
pub fn provider(&self) -> &dyn CustomProvider {
self.provider.as_ref()
}
pub const fn config(&self) -> &CustomProviderConfig {
&self.config
}
pub const fn http_client(&self) -> &reqwest::Client {
&self.http_client
}
}
#[async_trait]
impl ChatCapability for CustomProviderClient {
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
let model = self
.config
.model
.clone()
.or_else(|| self.provider.supported_models().first().cloned())
.unwrap_or_else(|| "default".to_string());
let mut request = CustomChatRequest::new(messages, model);
if let Some(tools) = tools {
request = request.with_tools(tools);
}
for (key, value) in &self.config.custom_params {
request.params.insert(key.clone(), value.clone());
}
let mut response = self.provider.chat(request).await?;
self.provider.transform_response(&mut response)?;
Ok(response.to_chat_response(&self.config.name))
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError> {
let model = self
.config
.model
.clone()
.or_else(|| self.provider.supported_models().first().cloned())
.unwrap_or_else(|| "default".to_string());
let mut request = CustomChatRequest::new(messages, model).with_stream(true);
if let Some(tools) = tools {
request = request.with_tools(tools);
}
for (key, value) in &self.config.custom_params {
request.params.insert(key.clone(), value.clone());
}
self.provider.chat_stream(request).await
}
}
impl LlmClient for CustomProviderClient {
fn provider_name(&self) -> &'static str {
"custom"
}
fn supported_models(&self) -> Vec<String> {
self.provider.supported_models()
}
fn capabilities(&self) -> ProviderCapabilities {
self.provider.capabilities()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn LlmClient> {
panic!("Custom provider cloning not implemented")
}
}
pub trait CustomProviderBuilder {
fn build(self) -> Result<Box<dyn CustomProvider>, LlmError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_custom_provider_config_creation() {
let config = CustomProviderConfig::new("test-provider", "https://api.test.com", "test-key");
assert_eq!(config.name, "test-provider");
assert_eq!(config.base_url, "https://api.test.com");
assert_eq!(config.api_key, "test-key");
assert_eq!(config.model, None);
assert!(config.headers.is_empty());
assert_eq!(config.timeout, Some(30));
assert!(config.custom_params.is_empty());
}
#[test]
fn test_custom_provider_config_with_model() {
let config = CustomProviderConfig::new("test-provider", "https://api.test.com", "test-key")
.with_model("test-model-v1")
.with_header("Authorization", "Bearer token")
.with_timeout(60)
.with_param("temperature", 0.7);
assert_eq!(config.model, Some("test-model-v1".to_string()));
assert_eq!(
config.headers.get("Authorization"),
Some(&"Bearer token".to_string())
);
assert_eq!(config.timeout, Some(60));
assert!(config.custom_params.contains_key("temperature"));
}
}