use super::Provider;
use crate::error::AppError;
use crate::telemetry::provider_metrics::{MetricsExtractor, ProviderMetrics};
use async_trait::async_trait;
use axum::http::HeaderMap;
use serde_json::Value;
use std::time::Duration;
use tracing::{debug, error};
pub struct GroqProvider {
base_url: String,
}
impl GroqProvider {
pub fn new() -> Self {
Self {
base_url: "https://api.groq.com/openai".to_string(),
}
}
}
#[async_trait]
impl Provider for GroqProvider {
fn base_url(&self) -> String {
self.base_url.clone()
}
fn name(&self) -> &str {
"groq"
}
fn process_headers(&self, original_headers: &HeaderMap) -> Result<HeaderMap, AppError> {
debug!("Processing Groq request headers");
let mut headers = HeaderMap::new();
headers.insert(
http::header::CONTENT_TYPE,
http::header::HeaderValue::from_static("application/json"),
);
if let Some(auth) = original_headers
.get("authorization")
.and_then(|h| h.to_str().ok())
{
debug!("Using provided authorization header for Groq");
headers.insert(
http::header::AUTHORIZATION,
http::header::HeaderValue::from_str(auth).map_err(|_| {
error!("Failed to process Groq authorization header");
AppError::InvalidHeader
})?,
);
} else {
error!("No authorization header found for Groq request");
return Err(AppError::MissingApiKey);
}
Ok(headers)
}
}
pub struct GroqMetricsExtractor;
impl MetricsExtractor for GroqMetricsExtractor {
fn extract_metrics(&self, response_body: &Value) -> ProviderMetrics {
debug!("Extracting Groq metrics from response: {}", response_body);
let mut metrics = ProviderMetrics::default();
if let Some(x_groq) = response_body.get("x_groq") {
if let Some(usage) = x_groq.get("usage") {
debug!("Found Groq usage data in x_groq: {:?}", usage);
metrics.input_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.output_tokens = usage.get("completion_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
if let Some(total_time) = usage.get("total_time").and_then(|v| v.as_f64()) {
metrics.provider_latency = Duration::from_secs_f64(total_time);
}
}
}
if metrics.total_tokens.is_none() {
if let Some(usage) = response_body.get("usage") {
debug!("Found Groq usage data at root level: {:?}", usage);
metrics.input_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.output_tokens = usage.get("completion_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
if let Some(total_time) = usage.get("total_time").and_then(|v| v.as_f64()) {
metrics.provider_latency = Duration::from_secs_f64(total_time);
}
}
}
if let Some(model) = response_body.get("model").and_then(|v| v.as_str()) {
debug!("Found Groq model: {}", model);
metrics.model = model.to_string();
}
if let (Some(total_tokens), Some(model)) = (metrics.total_tokens, response_body.get("model")) {
metrics.cost = Some(calculate_groq_cost(model.as_str().unwrap_or(""), total_tokens));
debug!("Calculated Groq cost: {:?} for model {} and {} tokens",
metrics.cost, metrics.model, total_tokens);
}
debug!("Final extracted Groq metrics: {:?}", metrics);
metrics
}
fn try_extract_provider_specific_streaming_metrics(&self, chunk: &str) -> Option<ProviderMetrics> {
debug!("Attempting to extract metrics from Groq streaming chunk");
if let Ok(json) = serde_json::from_str::<Value>(chunk) {
if let Some(x_groq) = json.get("x_groq") {
debug!("Found x_groq field in direct JSON: {}", json);
if let Some(usage) = x_groq.get("usage") {
let mut metrics = ProviderMetrics::default();
metrics.input_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.output_tokens = usage.get("completion_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
if let Some(total_time) = usage.get("total_time").and_then(|v| v.as_f64()) {
metrics.provider_latency = Duration::from_secs_f64(total_time);
}
if let Some(model) = json.get("model").and_then(|v| v.as_str()) {
metrics.model = model.to_string();
}
if let Some(total_tokens) = metrics.total_tokens {
metrics.cost = Some(calculate_groq_cost(&metrics.model, total_tokens));
debug!("Calculated Groq cost: {:?} for model {} and {} tokens",
metrics.cost, metrics.model, total_tokens);
}
debug!("Extracted complete Groq metrics from streaming chunk: {:?}", metrics);
return Some(metrics);
}
}
let is_final_chunk = json.get("choices")
.and_then(|c| c.as_array())
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("finish_reason"))
.and_then(|f| f.as_str())
.map(|reason| reason == "stop")
.unwrap_or(false);
let model = json.get("model").and_then(|v| v.as_str()).unwrap_or("llama").to_string();
let is_groq_response =
model.contains("llama") ||
model.contains("gemma") ||
json.get("object").and_then(|o| o.as_str()).map(|obj| obj == "chat.completion.chunk").unwrap_or(false);
if is_groq_response {
debug!("Groq streaming chunk detected for model: {}, is_final: {}", model, is_final_chunk);
return Some(ProviderMetrics {
model,
provider_latency: Duration::from_millis(0),
..Default::default()
});
}
}
for line in chunk.lines() {
if !line.starts_with("data: ") {
continue;
}
let json_str = line.trim_start_matches("data: ");
if json_str == "[DONE]" {
continue;
}
if let Ok(json) = serde_json::from_str::<Value>(json_str) {
if let Some(x_groq) = json.get("x_groq") {
debug!("Found x_groq field in SSE data: {}", json_str);
if let Some(usage) = x_groq.get("usage") {
let mut metrics = ProviderMetrics::default();
metrics.input_tokens = usage.get("prompt_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.output_tokens = usage.get("completion_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
metrics.total_tokens = usage.get("total_tokens").and_then(|v| v.as_u64()).map(|v| v as u32);
if let Some(total_time) = usage.get("total_time").and_then(|v| v.as_f64()) {
metrics.provider_latency = Duration::from_secs_f64(total_time);
}
if let Some(model) = json.get("model").and_then(|v| v.as_str()) {
metrics.model = model.to_string();
}
if let Some(total_tokens) = metrics.total_tokens {
metrics.cost = Some(calculate_groq_cost(&metrics.model, total_tokens));
}
debug!("Extracted complete Groq metrics from SSE streaming chunk: {:?}", metrics);
return Some(metrics);
}
}
let model = json.get("model").and_then(|v| v.as_str()).unwrap_or("llama").to_string();
let is_groq_response =
model.contains("llama") ||
model.contains("gemma") ||
json.get("object").and_then(|o| o.as_str()).map(|obj| obj == "chat.completion.chunk").unwrap_or(false);
if is_groq_response {
debug!("Groq SSE streaming chunk detected for model: {}", model);
return Some(ProviderMetrics {
model,
provider_latency: Duration::from_millis(0),
..Default::default()
});
}
}
}
debug!("No usage data found in Groq streaming chunk");
None
}
}
fn calculate_groq_cost(model: &str, total_tokens: u32) -> f64 {
let tokens = total_tokens as f64;
match model {
m if m.contains("llama-3") && m.contains("70b") => tokens * 0.0009,
m if m.contains("llama-3") && m.contains("8b") => tokens * 0.0001,
m if m.contains("llama-3.1") && m.contains("70b") => tokens * 0.0009,
m if m.contains("llama-3.1") && m.contains("8b") => tokens * 0.0001,
m if m.contains("llama-2") && m.contains("70b") => tokens * 0.0007,
m if m.contains("llama-2") && m.contains("13b") => tokens * 0.0002,
m if m.contains("llama-2") && m.contains("7b") => tokens * 0.0001,
m if m.contains("mixtral-8x7b") => tokens * 0.0002,
m if m.contains("mixtral-8x22b") => tokens * 0.0006,
m if m.contains("gemma") && m.contains("7b") => tokens * 0.0001,
m if m.contains("gemma") && m.contains("27b") => tokens * 0.0004,
m if m.contains("mixtral") => tokens * 0.0002,
m if m.contains("llama") => tokens * 0.0001,
m if m.contains("gemma") => tokens * 0.0001,
_ => {
debug!("Unknown Groq model for cost calculation: {}", model);
tokens * 0.0001
},
}
}