use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::Client;
use std::time::Duration;
use tokio::sync::mpsc;
pub mod clients;
pub mod error;
pub mod http;
pub mod metrics;
pub mod middleware;
pub mod observability;
pub mod utils;
mod sse;
#[cfg(feature = "orchestration")]
pub mod orchestration;
#[cfg(feature = "prompt-optimization")]
pub mod prompt_optimizer;
#[cfg(any(feature = "mock", test))]
pub mod mock;
pub use clients::*;
pub use error::*;
pub use http::{HttpConfig, get_provider_client, SHARED_CLIENT};
pub use metrics::{ClientMetrics, MetricsSnapshot, RequestTimer};
pub use utils::{execute_with_retry, RetryStrategy};
#[cfg(feature = "orchestration")]
pub use orchestration::{AiOrchestrator, FusedResponse, OrchestrationStrategy, ModelCapabilities};
#[cfg(feature = "prompt-optimization")]
pub use prompt_optimizer::{PromptOptimizer, OptimizedPrompt};
#[cfg(feature = "mock")]
pub use mock::MockClient;
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub timeout: Duration,
pub retries: u32,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub top_p: Option<f32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub system_message: Option<String>,
pub base_url: Option<String>,
pub retry_strategy: RetryStrategy,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
retries: 0,
temperature: None,
max_tokens: Some(1024),
top_p: None,
frequency_penalty: None,
presence_penalty: None,
system_message: None,
base_url: None,
retry_strategy: RetryStrategy::default(),
}
}
}
impl ClientConfig {
pub fn builder() -> ClientConfigBuilder {
ClientConfigBuilder::default()
}
}
#[derive(Debug, Default)]
pub struct ClientConfigBuilder {
timeout: Option<Duration>,
retries: Option<u32>,
temperature: Option<f32>,
max_tokens: Option<u32>,
top_p: Option<f32>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
system_message: Option<String>,
base_url: Option<String>,
retry_strategy: Option<RetryStrategy>,
}
impl ClientConfigBuilder {
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.retries = Some(retries);
self
}
pub fn temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
pub fn frequency_penalty(mut self, penalty: f32) -> Self {
self.frequency_penalty = Some(penalty);
self
}
pub fn presence_penalty(mut self, penalty: f32) -> Self {
self.presence_penalty = Some(penalty);
self
}
pub fn system_message<S: Into<String>>(mut self, message: S) -> Self {
self.system_message = Some(message.into());
self
}
pub fn base_url<S: Into<String>>(mut self, url: S) -> Self {
self.base_url = Some(url.into());
self
}
pub fn retry_strategy(mut self, strategy: RetryStrategy) -> Self {
self.retry_strategy = Some(strategy);
self
}
pub fn build(self) -> ClientConfig {
ClientConfig {
timeout: self.timeout.unwrap_or(Duration::from_secs(30)),
retries: self.retries.unwrap_or(0),
temperature: self.temperature,
max_tokens: self.max_tokens.or(Some(1024)),
top_p: self.top_p,
frequency_penalty: self.frequency_penalty,
presence_penalty: self.presence_penalty,
system_message: self.system_message,
base_url: self.base_url,
retry_strategy: self.retry_strategy.unwrap_or_default(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
impl Message {
pub fn system<S: Into<String>>(content: S) -> Self {
Self {
role: "system".to_string(),
content: content.into(),
}
}
pub fn user<S: Into<String>>(content: S) -> Self {
Self {
role: "user".to_string(),
content: content.into(),
}
}
pub fn assistant<S: Into<String>>(content: S) -> Self {
Self {
role: "assistant".to_string(),
content: content.into(),
}
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct Conversation {
pub messages: Vec<Message>,
}
impl Conversation {
pub fn new() -> Self {
Self::default()
}
pub fn with_system<S: Into<String>>(system_message: S) -> Self {
Self {
messages: vec![Message::system(system_message)],
}
}
pub fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
pub fn add_user<S: Into<String>>(&mut self, content: S) {
self.add_message(Message::user(content));
}
pub fn add_assistant<S: Into<String>>(&mut self, content: S) {
self.add_message(Message::assistant(content));
}
pub fn last_message(&self) -> Option<&Message> {
self.messages.last()
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct ResponseMetadata {
pub model_used: Option<String>,
pub prompt_tokens: Option<u32>,
pub completion_tokens: Option<u32>,
pub total_tokens: Option<u32>,
pub finish_reason: Option<String>,
pub safety_ratings: Option<serde_json::Value>,
pub request_id: Option<String>,
pub latency_ms: Option<u64>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct AiResponse {
pub content: String,
pub metadata: ResponseMetadata,
}
impl AiResponse {
pub fn new(content: String) -> Self {
Self {
content,
metadata: ResponseMetadata::default(),
}
}
pub fn with_metadata(content: String, metadata: ResponseMetadata) -> Self {
Self { content, metadata }
}
}
#[derive(Debug, Clone)]
pub struct StreamChunk {
pub content: String,
pub finished: bool,
pub metadata: Option<ResponseMetadata>,
}
pub struct ChatSession {
client: Box<dyn AiClient>,
conversation: Conversation,
}
impl ChatSession {
pub fn new(client: Box<dyn AiClient>) -> Self {
Self {
client,
conversation: Conversation::new(),
}
}
pub fn with_system_message<S: Into<String>>(client: Box<dyn AiClient>, message: S) -> Self {
Self {
client,
conversation: Conversation::with_system(message),
}
}
pub async fn send<S: Into<String>>(&mut self, message: S) -> Result<String, ClientError> {
let user_msg = message.into();
self.conversation.add_user(user_msg);
let response = self.client.send_conversation(&self.conversation).await?;
self.conversation.add_assistant(&response);
Ok(response)
}
pub async fn send_with_metadata<S: Into<String>>(
&mut self,
message: S,
) -> Result<AiResponse, ClientError> {
let user_msg = message.into();
self.conversation.add_user(user_msg);
let response = self
.client
.send_conversation_with_metadata(&self.conversation)
.await?;
self.conversation.add_assistant(&response.content);
Ok(response)
}
pub async fn stream<S: Into<String>>(
&mut self,
message: S,
) -> Result<BoxStream<'_, Result<StreamChunk, ClientError>>, ClientError> {
let user_msg = message.into();
self.conversation.add_user(user_msg);
self.client.stream_conversation(&self.conversation).await
}
pub fn add_message(&mut self, message: Message) {
self.conversation.add_message(message);
}
pub fn history(&self) -> &Conversation {
&self.conversation
}
pub fn history_mut(&mut self) -> &mut Conversation {
&mut self.conversation
}
pub fn clear(&mut self) {
self.conversation.clear();
}
pub fn reset_with_system<S: Into<String>>(&mut self, message: S) {
self.conversation = Conversation::with_system(message);
}
pub fn load_history(&mut self, messages: Vec<Message>) {
self.conversation.messages = messages;
}
}
#[async_trait]
pub trait AiClient: Send + Sync {
async fn send_prompt(&self, prompt: &str) -> Result<String, ClientError>;
async fn send_prompt_with_metadata(&self, prompt: &str) -> Result<AiResponse, ClientError> {
let content = self.send_prompt(prompt).await?;
Ok(AiResponse::new(content))
}
async fn send_conversation(&self, conversation: &Conversation) -> Result<String, ClientError> {
let prompt = if conversation.messages.is_empty() {
return Err(ClientError::config("Empty conversation", None));
} else if conversation.messages.len() == 1 {
&conversation.messages[0].content
} else {
conversation
.messages
.iter()
.rev()
.find(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or(&conversation.messages.last().unwrap().content)
};
self.send_prompt(prompt).await
}
async fn send_prompt_streaming(
&self,
prompt: &str,
tx: mpsc::UnboundedSender<StreamChunk>,
) -> Result<(), ClientError> {
let response = self.send_prompt(prompt).await?;
tx.send(StreamChunk {
content: response,
finished: true,
metadata: None,
}).map_err(|_| ClientError::Stream(crate::StreamError {
message: "Failed to send stream chunk".into(),
error_type: crate::StreamErrorType::Other,
}))?;
Ok(())
}
async fn send_conversation_with_metadata(
&self,
conversation: &Conversation,
) -> Result<AiResponse, ClientError> {
let content = self.send_conversation(conversation).await?;
Ok(AiResponse::new(content))
}
async fn stream_prompt(
&self,
_prompt: &str,
) -> Result<BoxStream<'_, Result<StreamChunk, ClientError>>, ClientError> {
let response = self.send_prompt(_prompt).await?;
let chunk = StreamChunk {
content: response,
finished: true,
metadata: None,
};
Ok(Box::pin(futures::stream::once(async { Ok(chunk) })))
}
async fn stream_conversation(
&self,
conversation: &Conversation,
) -> Result<BoxStream<'_, Result<StreamChunk, ClientError>>, ClientError> {
let response = self.send_conversation(conversation).await?;
let chunk = StreamChunk {
content: response,
finished: true,
metadata: None,
};
Ok(Box::pin(futures::stream::once(async { Ok(chunk) })))
}
fn supports_streaming(&self) -> bool {
false
}
fn supports_conversations(&self) -> bool {
false
}
fn name(&self) -> &str;
fn model(&self) -> &str;
}
pub fn create_client(
provider: &str,
api_key: &str,
model: &str,
config: ClientConfig,
) -> Result<Box<dyn AiClient>, ClientError> {
let http_client = Client::builder()
.timeout(config.timeout)
.build()
.map_err(|e| ClientError::config(format!("Failed to create HTTP client: {e}"), None))?;
match provider.to_lowercase().as_str() {
"openai" | "gpt" | "chatgpt" => Ok(Box::new(ChatGpt::new(
http_client,
api_key.to_string(),
model.to_string(),
config,
))),
"google" | "gemini" => Ok(Box::new(Gemini::new(
http_client,
api_key.to_string(),
model.to_string(),
config,
))),
"anthropic" | "claude" => Ok(Box::new(Claude::new(
http_client,
api_key.to_string(),
model.to_string(),
config,
))),
_ => Err(ClientError::config(
format!("Unknown provider: {provider}. Supported providers: openai, google, anthropic"),
Some("provider".to_string()),
)),
}
}
pub async fn execute_parallel(
clients: Vec<Box<dyn AiClient>>,
prompt: &str,
) -> Vec<(String, Result<String, ClientError>)> {
use futures::future;
let futures: Vec<_> = clients
.iter()
.map(|client| {
let name = client.name().to_string();
let prompt = prompt.to_string();
async move {
let result = client.send_prompt(&prompt).await;
(name, result)
}
})
.collect();
future::join_all(futures).await
}
pub async fn execute_parallel_conversation(
clients: Vec<Box<dyn AiClient>>,
conversation: &Conversation,
) -> Vec<(String, Result<String, ClientError>)> {
use futures::future;
let futures: Vec<_> = clients
.iter()
.map(|client| {
let name = client.name().to_string();
let conversation = conversation.clone();
async move {
let result = client.send_conversation(&conversation).await;
(name, result)
}
})
.collect();
future::join_all(futures).await
}
pub async fn generate_summary(
client: &dyn AiClient,
responses: &[(String, String)],
) -> Result<String, ClientError> {
let mut summary_prompt = "Given these AI model responses:\n".to_string();
for (name, response) in responses {
summary_prompt.push_str(&format!("{name}:\n{response}\n---\n"));
}
summary_prompt.push_str("Summarize the key differences and commonalities.");
client.send_prompt(&summary_prompt).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock::MockClient;
#[test]
fn test_client_config_default() {
let config = ClientConfig::default();
assert_eq!(config.timeout, Duration::from_secs(30));
assert_eq!(config.retries, 0);
assert_eq!(config.temperature, None);
assert_eq!(config.max_tokens, Some(1024));
}
#[tokio::test]
async fn test_execute_parallel() {
let clients: Vec<Box<dyn AiClient>> = vec![
Box::new(MockClient::new(
"client1",
vec![Ok("response1".to_string())],
)),
Box::new(MockClient::new(
"client2",
vec![Ok("response2".to_string())],
)),
];
let results = execute_parallel(clients, "test prompt").await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "client1");
assert!(results[0].1.is_ok());
assert_eq!(results[1].0, "client2");
assert!(results[1].1.is_ok());
}
#[tokio::test]
async fn test_generate_summary() {
let client = MockClient::new("summarizer", vec![Ok("summary response".to_string())]);
let responses = vec![
("AI1".to_string(), "response1".to_string()),
("AI2".to_string(), "response2".to_string()),
];
let summary = generate_summary(&client, &responses).await;
assert!(summary.is_ok());
assert_eq!(summary.unwrap(), "summary response");
}
#[tokio::test]
async fn test_execute_parallel_conversation() {
let clients: Vec<Box<dyn AiClient>> = vec![
Box::new(MockClient::new(
"client1",
vec![Ok("conversation response1".to_string())],
)),
Box::new(MockClient::new(
"client2",
vec![Ok("conversation response2".to_string())],
)),
];
let mut conversation = Conversation::new();
conversation.add_user("Hello");
conversation.add_assistant("Hi there!");
conversation.add_user("How are you?");
let results = execute_parallel_conversation(clients, &conversation).await;
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "client1");
assert!(results[0].1.is_ok());
assert_eq!(results[1].0, "client2");
assert!(results[1].1.is_ok());
}
#[tokio::test]
async fn test_mock_client_conversation_support() {
let client = MockClient::new("test", vec![Ok("conversation test".to_string())]);
assert!(client.supports_conversations());
assert!(!client.supports_streaming());
let mut conversation = Conversation::new();
conversation.add_user("Test message");
let result = client.send_conversation(&conversation).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "conversation test");
}
}