use crate::load_balancer::tasks::TaskDefinition;
use crate::providers::types::{LlmRequest, LlmResponse, LlmStream, ProviderType, StreamChunk};
use crate::providers::anthropic::AnthropicInstance;
use crate::providers::openai::OpenAIInstance;
use crate::providers::ollama::OllamaInstance;
use crate::providers::google::GoogleInstance;
use crate::providers::mistral::MistralInstance;
use crate::providers::lmstudio::LMStudioInstance;
use crate::providers::groq::GroqInstance;
use crate::providers::cohere::CohereInstance;
use crate::providers::togetherai::TogetherAIInstance;
use crate::providers::perplexity::PerplexityInstance;
use crate::errors::{LlmResult};
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use std::time::Duration;
use reqwest::Client;
use futures::stream;
#[async_trait]
pub trait LlmInstance {
async fn generate(&self, request: &LlmRequest) -> LlmResult<LlmResponse>;
async fn generate_stream(&self, request: &LlmRequest) -> LlmResult<LlmStream> {
let response = self.generate(request).await?;
let chunk = StreamChunk {
content: response.content,
model: Some(response.model),
is_final: true,
usage: response.usage,
};
Ok(Box::pin(stream::once(async move { Ok(chunk) })))
}
fn supports_streaming(&self) -> bool {
false }
fn get_name(&self) -> &str;
fn get_model(&self) -> &str;
fn get_supported_tasks(&self) -> &HashMap<String, TaskDefinition>;
fn is_enabled(&self) -> bool;
}
pub struct BaseInstance {
name: String,
client: Client,
api_key: String,
model: String,
supported_tasks: HashMap<String, TaskDefinition>,
enabled: bool,
}
impl BaseInstance {
pub fn new(name: String, api_key: String, model: String, supported_tasks: HashMap<String, TaskDefinition>, enabled: bool) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(120))
.build()
.expect("Failed to create HTTP client");
Self { name, client, api_key, model, supported_tasks, enabled }
}
pub fn client(&self) -> &Client {
&self.client
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub fn model(&self) -> &str {
&self.model
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn name(&self) -> &str {
&self.name
}
pub fn supported_tasks(&self) -> &HashMap<String, TaskDefinition> {
&self.supported_tasks
}
}
pub fn create_instance(instance_type: ProviderType, api_key: String, model: String, supported_tasks: Vec<TaskDefinition>, enabled: bool, endpoint_url: Option<String>) -> Arc<dyn LlmInstance + Send + Sync> {
let supported_tasks: HashMap<String, TaskDefinition> = supported_tasks
.into_iter()
.map(|task| (task.name.clone(), task))
.collect();
match instance_type {
ProviderType::Anthropic => Arc::new(AnthropicInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::OpenAI => Arc::new(OpenAIInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::Mistral => Arc::new(MistralInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::Google => Arc::new(GoogleInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::Ollama => Arc::new(OllamaInstance::new(api_key, model, supported_tasks, enabled, endpoint_url)),
ProviderType::LMStudio => Arc::new(LMStudioInstance::new(api_key, model, supported_tasks, enabled, endpoint_url)),
ProviderType::Groq => Arc::new(GroqInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::Cohere => Arc::new(CohereInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::TogetherAI => Arc::new(TogetherAIInstance::new(api_key, model, supported_tasks, enabled)),
ProviderType::Perplexity => Arc::new(PerplexityInstance::new(api_key, model, supported_tasks, enabled)),
}
}