use std::pin::Pin;
use futures::stream::Stream;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use std::sync::Arc;
use std::collections::HashMap;
use crate::{UbiquityError, Result, retry::{with_retry, RetryConfig}, config::{LLMConfig, LLMProvider}};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMRequest {
pub messages: Vec<Message>,
pub temperature: Option<f32>,
pub max_tokens: Option<usize>,
pub stop_sequences: Option<Vec<String>>,
pub stream: bool,
pub system_prompt: Option<String>,
pub extra_params: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMResponse {
pub content: String,
pub usage: Option<TokenUsage>,
pub model: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamChunk {
pub delta: String,
pub is_final: bool,
pub usage: Option<TokenUsage>,
}
pub type LLMStream = Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send>>;
#[async_trait]
pub trait LLMService: Send + Sync {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse>;
async fn stream(&self, request: LLMRequest) -> Result<LLMStream>;
fn provider(&self) -> LLMProvider;
async fn health_check(&self) -> Result<()>;
}
pub struct LLMServiceFactory;
impl LLMServiceFactory {
pub async fn create(config: &LLMConfig) -> Result<Arc<dyn LLMService>> {
match config.provider {
LLMProvider::Claude => {
let service = ClaudeLLMService::new(config.clone()).await?;
Ok(Arc::new(service))
}
LLMProvider::OpenAI => {
let service = OpenAILLMService::new(config.clone()).await?;
Ok(Arc::new(service))
}
LLMProvider::Local => {
let service = LocalLLMService::new(config.clone()).await?;
Ok(Arc::new(service))
}
LLMProvider::Mock => {
let service = MockLLMService::new(config.clone());
Ok(Arc::new(service))
}
}
}
}
pub struct ClaudeLLMService {
config: LLMConfig,
client: reqwest::Client,
retry_config: RetryConfig,
}
impl ClaudeLLMService {
pub async fn new(config: LLMConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(config.timeout)
.build()
.map_err(|e| UbiquityError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
let retry_config = RetryConfig {
max_attempts: config.retry_attempts,
initial_delay: config.retry_delay,
..Default::default()
};
Ok(Self {
config,
client,
retry_config,
})
}
fn build_request_body(&self, request: &LLMRequest) -> serde_json::Value {
let mut body = serde_json::json!({
"model": self.config.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content
})
}).collect::<Vec<_>>(),
"max_tokens": request.max_tokens.unwrap_or(self.config.max_tokens),
"temperature": request.temperature.unwrap_or(self.config.temperature),
"stream": request.stream
});
if let Some(stop) = &request.stop_sequences {
body["stop_sequences"] = serde_json::json!(stop);
}
if let Some(system) = &request.system_prompt {
body["system"] = serde_json::json!(system);
}
if let Some(extra) = &request.extra_params {
if let serde_json::Value::Object(map) = extra {
if let serde_json::Value::Object(body_map) = &mut body {
for (k, v) in map {
body_map.insert(k.clone(), v.clone());
}
}
}
}
body
}
}
#[async_trait]
impl LLMService for ClaudeLLMService {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
let body = self.build_request_body(&request);
with_retry(&self.retry_config, "claude_complete", || async {
let response = self.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &self.config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| UbiquityError::MeshError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(UbiquityError::MeshError(format!(
"Claude API error: {}",
error_text
)));
}
let response_body: serde_json::Value = response.json().await
.map_err(|e| UbiquityError::MeshError(format!("Failed to parse response: {}", e)))?;
let content = response_body["content"]
.as_array()
.and_then(|arr| arr.first())
.and_then(|c| c["text"].as_str())
.unwrap_or("")
.to_string();
let usage = response_body["usage"].as_object().map(|u| {
TokenUsage {
prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0) as usize,
completion_tokens: u["output_tokens"].as_u64().unwrap_or(0) as usize,
total_tokens: (u["input_tokens"].as_u64().unwrap_or(0) +
u["output_tokens"].as_u64().unwrap_or(0)) as usize,
}
});
Ok(LLMResponse {
content,
usage,
model: self.config.model.clone(),
metadata: Some(response_body),
})
}).await
}
async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream> {
request.stream = true;
let body = self.build_request_body(&request);
let response = self.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", &self.config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| UbiquityError::MeshError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(UbiquityError::MeshError(format!(
"Claude API error: {}",
error_text
)));
}
use futures::stream;
let mock_chunks = vec![
Ok(StreamChunk {
delta: "Streaming response implementation ".to_string(),
is_final: false,
usage: None,
}),
Ok(StreamChunk {
delta: "would parse SSE format here.".to_string(),
is_final: true,
usage: Some(TokenUsage {
prompt_tokens: 10,
completion_tokens: 8,
total_tokens: 18,
}),
}),
];
Ok(Box::pin(stream::iter(mock_chunks)))
}
fn provider(&self) -> LLMProvider {
LLMProvider::Claude
}
async fn health_check(&self) -> Result<()> {
if self.config.api_key.is_empty() {
return Err(UbiquityError::ConfigError("API key is empty".to_string()));
}
Ok(())
}
}
pub struct OpenAILLMService {
config: LLMConfig,
client: reqwest::Client,
retry_config: RetryConfig,
}
impl OpenAILLMService {
pub async fn new(config: LLMConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(config.timeout)
.build()
.map_err(|e| UbiquityError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
let retry_config = RetryConfig {
max_attempts: config.retry_attempts,
initial_delay: config.retry_delay,
..Default::default()
};
Ok(Self {
config,
client,
retry_config,
})
}
}
#[async_trait]
impl LLMService for OpenAILLMService {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
let body = serde_json::json!({
"model": self.config.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": match m.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
},
"content": m.content
})
}).collect::<Vec<_>>(),
"temperature": request.temperature.unwrap_or(self.config.temperature),
"max_tokens": request.max_tokens.unwrap_or(self.config.max_tokens),
"stream": request.stream
});
with_retry(&self.retry_config, "openai_complete", || async {
let response = self.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {}", self.config.api_key))
.json(&body)
.send()
.await
.map_err(|e| UbiquityError::MeshError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(UbiquityError::MeshError(format!(
"OpenAI API error: {}",
error_text
)));
}
let response_body: serde_json::Value = response.json().await
.map_err(|e| UbiquityError::MeshError(format!("Failed to parse response: {}", e)))?;
let content = response_body["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
let usage = response_body["usage"].as_object().map(|u| {
TokenUsage {
prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0) as usize,
completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0) as usize,
total_tokens: u["total_tokens"].as_u64().unwrap_or(0) as usize,
}
});
Ok(LLMResponse {
content,
usage,
model: self.config.model.clone(),
metadata: Some(response_body),
})
}).await
}
async fn stream(&self, mut request: LLMRequest) -> Result<LLMStream> {
request.stream = true;
use futures::stream;
let mock_chunks = vec![
Ok(StreamChunk {
delta: "OpenAI streaming ".to_string(),
is_final: false,
usage: None,
}),
Ok(StreamChunk {
delta: "response.".to_string(),
is_final: true,
usage: Some(TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
}),
}),
];
Ok(Box::pin(stream::iter(mock_chunks)))
}
fn provider(&self) -> LLMProvider {
LLMProvider::OpenAI
}
async fn health_check(&self) -> Result<()> {
if self.config.api_key.is_empty() {
return Err(UbiquityError::ConfigError("API key is empty".to_string()));
}
Ok(())
}
}
pub struct LocalLLMService {
config: LLMConfig,
client: reqwest::Client,
base_url: String,
}
impl LocalLLMService {
pub async fn new(config: LLMConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(config.timeout)
.build()
.map_err(|e| UbiquityError::ConfigError(format!("Failed to create HTTP client: {}", e)))?;
let base_url = config.api_key.clone(); let base_url = if base_url.is_empty() {
"http://localhost:11434".to_string()
} else {
base_url
};
Ok(Self {
config,
client,
base_url,
})
}
}
#[async_trait]
impl LLMService for LocalLLMService {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
let prompt = request.messages.iter()
.map(|m| format!("{}: {}",
match m.role {
MessageRole::System => "System",
MessageRole::User => "User",
MessageRole::Assistant => "Assistant",
},
m.content
))
.collect::<Vec<_>>()
.join("\n");
let body = serde_json::json!({
"model": self.config.model,
"prompt": prompt,
"temperature": request.temperature.unwrap_or(self.config.temperature),
"stream": false
});
let response = self.client
.post(format!("{}/api/generate", self.base_url))
.json(&body)
.send()
.await
.map_err(|e| UbiquityError::MeshError(format!("Local LLM request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(UbiquityError::MeshError(format!(
"Local LLM error: {}",
error_text
)));
}
let response_body: serde_json::Value = response.json().await
.map_err(|e| UbiquityError::MeshError(format!("Failed to parse response: {}", e)))?;
let content = response_body["response"].as_str().unwrap_or("").to_string();
Ok(LLMResponse {
content,
usage: None, model: self.config.model.clone(),
metadata: Some(response_body),
})
}
async fn stream(&self, _request: LLMRequest) -> Result<LLMStream> {
use futures::stream;
let mock_chunks = vec![
Ok(StreamChunk {
delta: "Local model streaming ".to_string(),
is_final: false,
usage: None,
}),
Ok(StreamChunk {
delta: "response.".to_string(),
is_final: true,
usage: None, }),
];
Ok(Box::pin(stream::iter(mock_chunks)))
}
fn provider(&self) -> LLMProvider {
LLMProvider::Local
}
async fn health_check(&self) -> Result<()> {
let response = self.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await
.map_err(|e| UbiquityError::MeshError(format!("Health check failed: {}", e)))?;
if !response.status().is_success() {
return Err(UbiquityError::MeshError("Local LLM server not available".to_string()));
}
Ok(())
}
}
pub struct MockLLMService {
config: LLMConfig,
responses: Arc<RwLock<HashMap<String, String>>>,
}
impl MockLLMService {
pub fn new(config: LLMConfig) -> Self {
Self {
config,
responses: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_response(&self, input: String, response: String) {
let mut responses = self.responses.write().await;
responses.insert(input, response);
}
}
#[async_trait]
impl LLMService for MockLLMService {
async fn complete(&self, request: LLMRequest) -> Result<LLMResponse> {
let last_message = request.messages.iter()
.filter(|m| m.role == MessageRole::User)
.last()
.map(|m| m.content.clone())
.unwrap_or_default();
let responses = self.responses.read().await;
let content = responses.get(&last_message)
.cloned()
.unwrap_or_else(|| format!("Mock response for: {}", last_message));
Ok(LLMResponse {
content,
usage: Some(TokenUsage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
}),
model: "mock-model".to_string(),
metadata: None,
})
}
async fn stream(&self, request: LLMRequest) -> Result<LLMStream> {
use futures::stream;
let response = self.complete(request).await?;
let chunks: Vec<Result<StreamChunk>> = vec![
Ok(StreamChunk {
delta: response.content,
is_final: true,
usage: response.usage,
}),
];
Ok(Box::pin(stream::iter(chunks)))
}
fn provider(&self) -> LLMProvider {
LLMProvider::Mock
}
async fn health_check(&self) -> Result<()> {
Ok(())
}
}
pub struct LLMServiceManager {
services: HashMap<LLMProvider, Arc<dyn LLMService>>,
default_provider: LLMProvider,
}
impl LLMServiceManager {
pub fn new() -> Self {
Self {
services: HashMap::new(),
default_provider: LLMProvider::Claude,
}
}
pub fn add_service(&mut self, service: Arc<dyn LLMService>) {
let provider = service.provider();
self.services.insert(provider, service);
}
pub fn set_default_provider(&mut self, provider: LLMProvider) {
self.default_provider = provider;
}
pub fn get_service(&self, provider: LLMProvider) -> Option<Arc<dyn LLMService>> {
self.services.get(&provider).cloned()
}
pub fn get_default_service(&self) -> Option<Arc<dyn LLMService>> {
self.services.get(&self.default_provider).cloned()
}
pub async fn from_config(configs: Vec<LLMConfig>) -> Result<Self> {
let mut manager = Self::new();
for config in configs {
let service = LLMServiceFactory::create(&config).await?;
manager.add_service(service);
}
Ok(manager)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_llm_service() {
let config = LLMConfig {
provider: LLMProvider::Mock,
api_key: String::new(),
model: "test-model".to_string(),
temperature: 0.7,
max_tokens: 100,
timeout: Duration::from_secs(30),
retry_attempts: 3,
retry_delay: Duration::from_secs(1),
};
let service = MockLLMService::new(config);
service.add_response("Hello".to_string(), "Hi there!".to_string()).await;
let request = LLMRequest {
messages: vec![
Message {
role: MessageRole::User,
content: "Hello".to_string(),
},
],
temperature: None,
max_tokens: None,
stop_sequences: None,
stream: false,
system_prompt: None,
extra_params: None,
};
let response = service.complete(request).await.unwrap();
assert_eq!(response.content, "Hi there!");
assert_eq!(response.model, "mock-model");
}
#[tokio::test]
async fn test_llm_service_factory() {
let config = LLMConfig {
provider: LLMProvider::Mock,
api_key: String::new(),
model: "test-model".to_string(),
temperature: 0.7,
max_tokens: 100,
timeout: Duration::from_secs(30),
retry_attempts: 3,
retry_delay: Duration::from_secs(1),
};
let service = LLMServiceFactory::create(&config).await.unwrap();
assert_eq!(service.provider(), LLMProvider::Mock);
}
}