use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use tracing::{debug, info};
pub mod chat;
pub mod common_utils;
use crate::core::cost::CostCalculator;
use crate::core::cost::providers::generic::StubCostCalculator;
pub use chat::{LlamaChatHandler, LlamaChatTransformation};
pub use common_utils::{LlamaClient, LlamaConfig, LlamaUtils};
use crate::core::providers::base::HttpErrorMapper;
use crate::core::providers::unified_provider::ProviderError;
use crate::core::traits::{
error_mapper::trait_def::ErrorMapper, provider::ProviderConfig,
provider::llm_provider::trait_definition::LLMProvider,
};
use crate::core::types::{
chat::ChatRequest,
context::RequestContext,
embedding::EmbeddingRequest,
health::HealthStatus,
image::ImageGenerationRequest,
model::ModelInfo,
model::ProviderCapability,
responses::{ChatChunk, ChatResponse, EmbeddingResponse, ImageGenerationResponse},
};
const LLAMA_CAPABILITIES: &[ProviderCapability] = &[
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
];
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlamaProviderConfig {
pub api_key: String,
pub api_base: Option<String>,
pub organization_id: Option<String>,
pub timeout: Option<u64>,
pub max_retries: Option<u32>,
pub headers: Option<HashMap<String, String>>,
pub supported_models: Vec<String>,
pub metadata: HashMap<String, String>,
#[serde(skip)]
pub cost_calculator: Option<StubCostCalculator>,
}
impl Default for LlamaProviderConfig {
fn default() -> Self {
Self {
api_key: String::new(),
api_base: Some("https://api.llama.com/compat/v1".to_string()),
organization_id: None,
timeout: Some(30),
max_retries: Some(3),
headers: Some(HashMap::new()),
supported_models: vec![
"llama4-scout".to_string(),
"llama4-maverick".to_string(),
"llama3.3-70b".to_string(),
"llama3.2-1b".to_string(),
"llama3.2-3b".to_string(),
"llama3.2-11b-vision".to_string(),
"llama3.2-90b-vision".to_string(),
"llama3.1-8b".to_string(),
"llama3.1-70b".to_string(),
"llama3.1-405b".to_string(),
],
metadata: HashMap::new(),
cost_calculator: None,
}
}
}
#[derive(Debug, Clone)]
pub struct LlamaProvider {
config: Arc<LlamaProviderConfig>,
client: Arc<LlamaClient>,
chat_handler: Arc<LlamaChatHandler>,
cost_calculator: StubCostCalculator,
models: Vec<ModelInfo>,
}
impl LlamaProvider {
pub fn new(config: LlamaProviderConfig) -> Result<Self, ProviderError> {
let llama_config = LlamaConfig::from_provider_config(&config)?;
let client = LlamaClient::new(llama_config.clone())?;
let chat_handler = LlamaChatHandler::new(llama_config.clone())?;
let cost_calculator = config
.cost_calculator
.clone()
.unwrap_or_else(|| StubCostCalculator::new("meta_llama".to_string()));
let models = vec![
ModelInfo {
id: "llama4-scout".to_string(),
name: "Llama 4 Scout".to_string(),
provider: "meta".to_string(),
max_context_length: 10_000_000, max_output_length: Some(128000),
supports_streaming: true,
supports_tools: true,
supports_multimodal: true, input_cost_per_1k_tokens: Some(0.00008), output_cost_per_1k_tokens: Some(0.0003), currency: "USD".to_string(),
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "llama4-maverick".to_string(),
name: "Llama 4 Maverick".to_string(),
provider: "meta".to_string(),
max_context_length: 1_000_000, max_output_length: Some(128000),
supports_streaming: true,
supports_tools: true,
supports_multimodal: true,
input_cost_per_1k_tokens: Some(0.00020), output_cost_per_1k_tokens: Some(0.0006), currency: "USD".to_string(),
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "llama3.3-70b".to_string(),
name: "Llama 3.3 70B".to_string(),
provider: "meta".to_string(),
max_context_length: 128000,
max_output_length: Some(32000),
supports_streaming: true,
supports_tools: true,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.0006),
output_cost_per_1k_tokens: Some(0.0006),
currency: "USD".to_string(),
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "llama3.1-405b".to_string(),
name: "Llama 3.1 405B".to_string(),
provider: "meta".to_string(),
max_context_length: 128000,
max_output_length: None,
supports_streaming: true,
supports_tools: true,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.002),
output_cost_per_1k_tokens: Some(0.002),
currency: "USD".to_string(),
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "llama3.1-70b".to_string(),
name: "Llama 3.1 70B".to_string(),
provider: "meta".to_string(),
max_context_length: 128000,
max_output_length: None,
supports_streaming: true,
supports_tools: true,
supports_multimodal: false,
input_cost_per_1k_tokens: Some(0.001),
output_cost_per_1k_tokens: Some(0.001),
currency: "USD".to_string(),
capabilities: vec![
ProviderCapability::ChatCompletion,
ProviderCapability::ChatCompletionStream,
ProviderCapability::ToolCalling,
],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
];
Ok(Self {
config: Arc::new(config),
client: Arc::new(client),
chat_handler: Arc::new(chat_handler),
cost_calculator,
models,
})
}
pub fn get_api_base(&self) -> String {
self.config
.api_base
.clone()
.unwrap_or_else(|| "https://api.llama.com/compat/v1".to_string())
}
pub fn is_model_supported(&self, model: &str) -> bool {
self.config
.supported_models
.iter()
.any(|m| m == model || model.contains(m))
}
pub fn get_capabilities(&self) -> &'static [ProviderCapability] {
LLAMA_CAPABILITIES
}
}
impl LLMProvider for LlamaProvider {
fn name(&self) -> &'static str {
"meta"
}
async fn chat_completion(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<ChatResponse, ProviderError> {
debug!("Llama chat completion request: model={}", request.model);
if !self.is_model_supported(&request.model) {
return Err(ProviderError::model_not_found(
"meta",
request.model.clone(),
));
}
let llama_request = self.chat_handler.transform_request(request)?;
let response = self
.client
.chat_completion(
llama_request,
Some(&self.config.api_key),
self.config.api_base.as_deref(),
self.config.headers.clone(),
)
.await?;
let chat_response = self.chat_handler.transform_response(response)?;
info!(
"Llama chat completion successful: model={}",
chat_response.model
);
Ok(chat_response)
}
async fn chat_completion_stream(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<Pin<Box<dyn Stream<Item = Result<ChatChunk, ProviderError>> + Send>>, ProviderError>
{
debug!("Llama streaming chat request: model={}", request.model);
if !self.is_model_supported(&request.model) {
return Err(ProviderError::model_not_found(
"meta",
request.model.clone(),
));
}
let client = Arc::clone(&self.client);
let config = Arc::clone(&self.config);
let chat_handler = Arc::clone(&self.chat_handler);
let mut llama_request = chat_handler.transform_request(request)?;
if let serde_json::Value::Object(ref mut obj) = llama_request {
obj.insert("stream".to_string(), serde_json::Value::Bool(true));
}
let api_key = Some(config.api_key.clone());
let api_base = config.api_base.clone();
let headers = config.headers.clone();
let json_stream = client
.chat_completion_stream(
llama_request,
api_key.as_deref(),
api_base.as_deref(),
headers,
)
.await?;
use futures::stream::StreamExt;
let chunk_stream = json_stream.map(|result| {
result.map(|json| crate::core::types::responses::ChatChunk {
id: json
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
object: "chat.completion.chunk".to_string(),
created: json.get("created").and_then(|v| v.as_i64()).unwrap_or(0),
model: json
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
choices: vec![],
usage: None,
system_fingerprint: None,
})
});
Ok(Box::pin(chunk_stream))
}
async fn embeddings(
&self,
_request: EmbeddingRequest,
_context: RequestContext,
) -> Result<EmbeddingResponse, ProviderError> {
Err(ProviderError::not_implemented("meta", "embeddings"))
}
async fn image_generation(
&self,
_request: ImageGenerationRequest,
_context: RequestContext,
) -> Result<ImageGenerationResponse, ProviderError> {
Err(ProviderError::not_implemented("meta", "image generation"))
}
async fn health_check(&self) -> HealthStatus {
match self.client.check_health().await {
Ok(status) => status,
Err(_) => HealthStatus::Unhealthy,
}
}
fn capabilities(&self) -> &'static [ProviderCapability] {
LLAMA_CAPABILITIES
}
fn models(&self) -> &[ModelInfo] {
&self.models
}
fn get_supported_openai_params(&self, _model: &str) -> &'static [&'static str] {
&[
"messages",
"model",
"max_tokens",
"temperature",
"top_p",
"n",
"stream",
"stop",
"presence_penalty",
"frequency_penalty",
"user",
"seed",
"response_format",
"tools",
"tool_choice",
]
}
async fn map_openai_params(
&self,
params: HashMap<String, Value>,
_model: &str,
) -> Result<HashMap<String, Value>, ProviderError> {
Ok(params)
}
async fn transform_request(
&self,
request: ChatRequest,
_context: RequestContext,
) -> Result<Value, ProviderError> {
self.chat_handler.transform_request(request)
}
async fn transform_response(
&self,
raw_response: &[u8],
_model: &str,
_request_id: &str,
) -> Result<ChatResponse, ProviderError> {
let response: Value = serde_json::from_slice(raw_response)?;
self.chat_handler
.transform_response(response)
.map_err(|e| ProviderError::serialization("meta", e.to_string()))
}
fn get_error_mapper(&self) -> Box<dyn ErrorMapper<ProviderError>> {
Box::new(UnifiedErrorMapper)
}
async fn calculate_cost(
&self,
_model: &str,
input_tokens: u32,
output_tokens: u32,
) -> Result<f64, ProviderError> {
use crate::core::cost::types::UsageTokens;
let usage = UsageTokens::new(input_tokens, output_tokens);
let cost = self.cost_calculator.calculate_cost("", &usage).await?;
Ok(cost.total_cost)
}
}
impl ProviderConfig for LlamaProviderConfig {
fn validate(&self) -> Result<(), String> {
if self.api_key.is_empty() {
return Err("API key is required for Llama provider".to_string());
}
if let Some(timeout) = self.timeout
&& timeout == 0
{
return Err("Timeout must be greater than 0".to_string());
}
Ok(())
}
fn api_key(&self) -> Option<&str> {
if self.api_key.is_empty() {
None
} else {
Some(&self.api_key)
}
}
fn api_base(&self) -> Option<&str> {
self.api_base.as_deref()
}
fn timeout(&self) -> std::time::Duration {
std::time::Duration::from_secs(self.timeout.unwrap_or(30))
}
fn max_retries(&self) -> u32 {
self.max_retries.unwrap_or(3)
}
}
#[derive(Debug, Clone)]
pub struct UnifiedErrorMapper;
impl ErrorMapper<ProviderError> for UnifiedErrorMapper {
fn map_http_error(&self, status_code: u16, response_body: &str) -> ProviderError {
HttpErrorMapper::map_status_code("meta", status_code, response_body)
}
fn map_json_error(&self, _error_response: &serde_json::Value) -> ProviderError {
ProviderError::response_parsing("meta", "Failed to parse JSON error response")
}
fn map_network_error(&self, error: &dyn std::error::Error) -> ProviderError {
ProviderError::network("meta", error.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = LlamaProviderConfig::default();
assert_eq!(config.api_base.unwrap(), "https://api.llama.com/compat/v1");
assert_eq!(config.timeout.unwrap(), 30);
assert!(!config.supported_models.is_empty());
}
#[test]
fn test_model_support() {
let config = LlamaProviderConfig {
api_key: "test-api-key-1234567890123456".to_string(),
..Default::default()
};
let provider = LlamaProvider::new(config).unwrap();
assert!(provider.is_model_supported("llama4-scout"));
assert!(provider.is_model_supported("llama4-maverick"));
assert!(provider.is_model_supported("llama3.3-70b"));
assert!(provider.is_model_supported("llama3.1-8b"));
assert!(provider.is_model_supported("llama3.2-11b-vision"));
assert!(!provider.is_model_supported("gpt-4"));
}
#[test]
fn test_capabilities() {
let config = LlamaProviderConfig {
api_key: "test-api-key-1234567890123456".to_string(),
..Default::default()
};
let provider = LlamaProvider::new(config).unwrap();
let capabilities = provider.get_capabilities();
assert!(capabilities.contains(&ProviderCapability::ChatCompletion));
assert!(capabilities.contains(&ProviderCapability::ChatCompletionStream));
assert!(capabilities.contains(&ProviderCapability::ToolCalling));
assert!(!capabilities.contains(&ProviderCapability::ImageGeneration));
}
}