use crate::error::LlmError;
use crate::stream::ChatStream;
use crate::traits::*;
use crate::types::*;
pub trait LlmClient: ChatCapability + Send + Sync {
fn provider_name(&self) -> &'static str;
fn provider_type(&self) -> ProviderType {
ProviderType::from_name(self.provider_name())
}
fn supported_models(&self) -> Vec<String>;
fn capabilities(&self) -> ProviderCapabilities;
fn as_any(&self) -> &dyn std::any::Any;
fn clone_box(&self) -> Box<dyn LlmClient>;
fn as_embedding_capability(&self) -> Option<&dyn EmbeddingCapability> {
None
}
fn as_audio_capability(&self) -> Option<&dyn AudioCapability> {
None
}
fn as_vision_capability(&self) -> Option<&dyn VisionCapability> {
None
}
fn as_image_generation_capability(&self) -> Option<&dyn ImageGenerationCapability> {
None
}
}
pub enum ClientWrapper {
OpenAi(Box<dyn LlmClient>),
Anthropic(Box<dyn LlmClient>),
Gemini(Box<dyn LlmClient>),
Groq(Box<dyn LlmClient>),
XAI(Box<dyn LlmClient>),
Ollama(Box<dyn LlmClient>),
Custom(Box<dyn LlmClient>),
}
impl Clone for ClientWrapper {
fn clone(&self) -> Self {
match self {
ClientWrapper::OpenAi(client) => ClientWrapper::OpenAi(client.clone_box()),
ClientWrapper::Anthropic(client) => ClientWrapper::Anthropic(client.clone_box()),
ClientWrapper::Gemini(client) => ClientWrapper::Gemini(client.clone_box()),
ClientWrapper::Groq(client) => ClientWrapper::Groq(client.clone_box()),
ClientWrapper::XAI(client) => ClientWrapper::XAI(client.clone_box()),
ClientWrapper::Ollama(client) => ClientWrapper::Ollama(client.clone_box()),
ClientWrapper::Custom(client) => ClientWrapper::Custom(client.clone_box()),
}
}
}
impl std::fmt::Debug for ClientWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientWrapper::OpenAi(_) => f
.debug_tuple("ClientWrapper::OpenAi")
.field(&"[LlmClient]")
.finish(),
ClientWrapper::Anthropic(_) => f
.debug_tuple("ClientWrapper::Anthropic")
.field(&"[LlmClient]")
.finish(),
ClientWrapper::Gemini(_) => f
.debug_tuple("ClientWrapper::Gemini")
.field(&"[LlmClient]")
.finish(),
ClientWrapper::Groq(_) => f
.debug_tuple("ClientWrapper::Groq")
.field(&"[LlmClient]")
.finish(),
ClientWrapper::XAI(_) => f
.debug_tuple("ClientWrapper::XAI")
.field(&"[LlmClient]")
.finish(),
ClientWrapper::Ollama(_) => f
.debug_tuple("ClientWrapper::Ollama")
.field(&"[LlmClient]")
.finish(),
ClientWrapper::Custom(_) => f
.debug_tuple("ClientWrapper::Custom")
.field(&"[LlmClient]")
.finish(),
}
}
}
impl ClientWrapper {
pub fn openai(client: Box<dyn LlmClient>) -> Self {
Self::OpenAi(client)
}
pub fn anthropic(client: Box<dyn LlmClient>) -> Self {
Self::Anthropic(client)
}
pub fn gemini(client: Box<dyn LlmClient>) -> Self {
Self::Gemini(client)
}
pub fn groq(client: Box<dyn LlmClient>) -> Self {
Self::Groq(client)
}
pub fn xai(client: Box<dyn LlmClient>) -> Self {
Self::XAI(client)
}
pub fn ollama(client: Box<dyn LlmClient>) -> Self {
Self::Ollama(client)
}
pub fn custom(client: Box<dyn LlmClient>) -> Self {
Self::Custom(client)
}
pub fn client(&self) -> &dyn LlmClient {
match self {
Self::OpenAi(client) => client.as_ref(),
Self::Anthropic(client) => client.as_ref(),
Self::Gemini(client) => client.as_ref(),
Self::Groq(client) => client.as_ref(),
Self::XAI(client) => client.as_ref(),
Self::Ollama(client) => client.as_ref(),
Self::Custom(client) => client.as_ref(),
}
}
pub fn provider_type(&self) -> ProviderType {
match self {
Self::OpenAi(_) => ProviderType::OpenAi,
Self::Anthropic(_) => ProviderType::Anthropic,
Self::Gemini(_) => ProviderType::Gemini,
Self::Groq(_) => ProviderType::Groq,
Self::XAI(_) => ProviderType::XAI,
Self::Ollama(_) => ProviderType::Ollama,
Self::Custom(_) => ProviderType::Custom("unknown".to_string()),
}
}
pub fn supports_capability(&self, capability: &str) -> bool {
self.client().capabilities().supports(capability)
}
pub fn get_capabilities(&self) -> ProviderCapabilities {
self.client().capabilities()
}
}
#[async_trait::async_trait]
impl ChatCapability for ClientWrapper {
async fn chat_with_tools(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatResponse, LlmError> {
self.client().chat_with_tools(messages, tools).await
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
tools: Option<Vec<Tool>>,
) -> Result<ChatStream, LlmError> {
self.client().chat_stream(messages, tools).await
}
}
impl LlmClient for ClientWrapper {
fn provider_name(&self) -> &'static str {
self.client().provider_name()
}
fn supported_models(&self) -> Vec<String> {
self.client().supported_models()
}
fn capabilities(&self) -> ProviderCapabilities {
self.client().capabilities()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn LlmClient> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub api_key: String,
pub base_url: String,
pub http_config: HttpConfig,
pub common_params: CommonParams,
pub provider_params: ProviderParams,
}
impl ClientConfig {
pub fn new(api_key: String, base_url: String) -> Self {
Self {
api_key,
base_url,
http_config: HttpConfig::default(),
common_params: CommonParams::default(),
provider_params: ProviderParams::default(),
}
}
pub fn with_http_config(mut self, config: HttpConfig) -> Self {
self.http_config = config;
self
}
pub fn with_common_params(mut self, params: CommonParams) -> Self {
self.common_params = params;
self
}
pub fn with_provider_params(mut self, params: ProviderParams) -> Self {
self.provider_params = params;
self
}
}
pub struct ClientManager {
clients: std::collections::HashMap<String, ClientWrapper>,
}
impl ClientManager {
pub fn new() -> Self {
Self {
clients: std::collections::HashMap::new(),
}
}
pub fn add_client(&mut self, name: String, client: ClientWrapper) {
self.clients.insert(name, client);
}
pub fn get_client(&self, name: &str) -> Option<&ClientWrapper> {
self.clients.get(name)
}
pub fn remove_client(&mut self, name: &str) -> Option<ClientWrapper> {
self.clients.remove(name)
}
pub fn list_clients(&self) -> Vec<&String> {
self.clients.keys().collect()
}
pub fn default_client(&self) -> Option<&ClientWrapper> {
self.clients.values().next()
}
}
impl Default for ClientManager {
fn default() -> Self {
Self::new()
}
}
pub struct ClientPool {
pool: std::sync::Arc<std::sync::Mutex<Vec<ClientWrapper>>>,
max_size: usize,
}
impl ClientPool {
pub fn new(max_size: usize) -> Self {
Self {
pool: std::sync::Arc::new(std::sync::Mutex::new(Vec::new())),
max_size,
}
}
pub fn get_client(&self) -> Option<ClientWrapper> {
let mut pool = self.pool.lock().unwrap();
pool.pop()
}
pub fn return_client(&self, client: ClientWrapper) {
let mut pool = self.pool.lock().unwrap();
if pool.len() < self.max_size {
pool.push(client);
}
}
pub fn size(&self) -> usize {
let pool = self.pool.lock().unwrap();
pool.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_client_manager() {
let manager = ClientManager::new();
assert_eq!(manager.list_clients().len(), 0);
assert!(manager.default_client().is_none());
}
#[test]
fn test_client_pool() {
let pool = ClientPool::new(5);
assert_eq!(pool.size(), 0);
assert!(pool.get_client().is_none());
}
#[test]
fn test_client_config() {
let config = ClientConfig::new(
"test-key".to_string(),
"https://api.example.com".to_string(),
);
assert_eq!(config.api_key, "test-key");
assert_eq!(config.base_url, "https://api.example.com");
}
#[test]
fn test_client_types_are_send_sync() {
fn test_arc_usage() {
let _: Option<Arc<ClientWrapper>> = None;
let _: Option<Arc<ClientManager>> = None;
let _: Option<Arc<ClientPool>> = None;
}
test_arc_usage();
}
#[tokio::test]
async fn test_client_pool_multithreading() {
use std::sync::Arc;
use tokio::task;
let pool = Arc::new(ClientPool::new(5));
let mut handles = Vec::new();
for i in 0..10 {
let pool_clone = pool.clone();
let handle = task::spawn(async move {
let client = pool_clone.get_client();
assert!(client.is_none());
let size = pool_clone.size();
assert_eq!(size, 0);
i });
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
let result = handle.await.unwrap();
results.push(result);
}
assert_eq!(results.len(), 10);
for (i, result) in results.iter().enumerate() {
assert_eq!(*result, i);
}
}
}