use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use async_trait::async_trait;
use tracing::info;
use crate::types::{
ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider, RunnerError,
};
const CHARS_PER_TOKEN_ESTIMATE: u32 = 4;
#[derive(Debug, Clone)]
pub struct TokenPricing {
pub prompt_price_per_1k: f64,
pub completion_price_per_1k: f64,
}
pub type PricingTable = HashMap<String, TokenPricing>;
pub fn default_pricing_table() -> PricingTable {
let mut table = PricingTable::new();
table.insert(
"opus".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.015,
completion_price_per_1k: 0.075,
},
);
table.insert(
"sonnet".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.003,
completion_price_per_1k: 0.015,
},
);
table.insert(
"haiku".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.00025,
completion_price_per_1k: 0.00125,
},
);
table.insert(
"gpt-5.4".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.005,
completion_price_per_1k: 0.015,
},
);
table.insert(
"gpt-4o".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.005,
completion_price_per_1k: 0.015,
},
);
table.insert(
"gemini-2.5-pro".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.00125,
completion_price_per_1k: 0.005,
},
);
table.insert(
"gemini-2.5-flash".to_owned(),
TokenPricing {
prompt_price_per_1k: 0.000_075,
completion_price_per_1k: 0.0003,
},
);
table
}
#[derive(Debug, Default)]
struct MetricsState {
call_count: u64,
total_latency_ms: u64,
total_prompt_tokens: u64,
total_completion_tokens: u64,
total_tokens: u64,
errors_count: u64,
total_cost: f64,
}
#[derive(Debug, Clone)]
pub struct MetricsReport {
pub provider_name: String,
pub call_count: u64,
pub total_latency_ms: u64,
pub avg_latency_ms: u64,
pub total_prompt_tokens: u64,
pub total_completion_tokens: u64,
pub total_tokens: u64,
pub errors_count: u64,
pub total_cost: f64,
}
#[cfg(feature = "otel")]
struct OtelInstruments {
requests_total: opentelemetry::metrics::Counter<u64>,
requests_duration_ms: opentelemetry::metrics::Histogram<f64>,
tokens_prompt: opentelemetry::metrics::Counter<u64>,
tokens_completion: opentelemetry::metrics::Counter<u64>,
errors_total: opentelemetry::metrics::Counter<u64>,
cost_total: opentelemetry::metrics::Counter<f64>,
}
#[cfg(feature = "otel")]
impl OtelInstruments {
fn new() -> Self {
let meter = opentelemetry::global::meter("embacle");
Self {
requests_total: meter
.u64_counter("embacle.requests.total")
.with_description("Total LLM requests")
.build(),
requests_duration_ms: meter
.f64_histogram("embacle.requests.duration_ms")
.with_description("Request duration in milliseconds")
.build(),
tokens_prompt: meter
.u64_counter("embacle.tokens.prompt")
.with_description("Total prompt tokens consumed")
.build(),
tokens_completion: meter
.u64_counter("embacle.tokens.completion")
.with_description("Total completion tokens generated")
.build(),
errors_total: meter
.u64_counter("embacle.errors.total")
.with_description("Total error count")
.build(),
cost_total: meter
.f64_counter("embacle.cost.total")
.with_description("Total cost in USD")
.build(),
}
}
}
pub struct MetricsProvider {
inner: Box<dyn LlmProvider>,
state: Arc<Mutex<MetricsState>>,
pricing: Option<PricingTable>,
#[cfg(feature = "otel")]
otel: OtelInstruments,
}
impl MetricsProvider {
pub fn new(inner: Box<dyn LlmProvider>) -> Self {
Self {
inner,
state: Arc::new(Mutex::new(MetricsState::default())),
pricing: None,
#[cfg(feature = "otel")]
otel: OtelInstruments::new(),
}
}
pub fn with_pricing(mut self, pricing: PricingTable) -> Self {
self.pricing = Some(pricing);
self
}
pub fn with_default_pricing(self) -> Self {
self.with_pricing(default_pricing_table())
}
pub fn report(&self) -> MetricsReport {
let state = self.state.lock().expect("metrics lock poisoned");
let divisor = state.call_count.max(1);
MetricsReport {
provider_name: self.inner.name().to_owned(),
call_count: state.call_count,
total_latency_ms: state.total_latency_ms,
avg_latency_ms: state.total_latency_ms / divisor,
total_prompt_tokens: state.total_prompt_tokens,
total_completion_tokens: state.total_completion_tokens,
total_tokens: state.total_tokens,
errors_count: state.errors_count,
total_cost: state.total_cost,
}
}
pub fn reset(&self) {
let mut state = self.state.lock().expect("metrics lock poisoned");
*state = MetricsState::default();
}
fn compute_cost(&self, model: &str, prompt_tokens: u64, completion_tokens: u64) -> f64 {
let Some(table) = &self.pricing else {
return 0.0;
};
let pricing = table.get(model).or_else(|| {
table
.iter()
.find(|(key, _)| model.contains(key.as_str()))
.map(|(_, v)| v)
});
let Some(pricing) = pricing else {
return 0.0;
};
#[allow(clippy::cast_precision_loss)]
let cost = (prompt_tokens as f64 * pricing.prompt_price_per_1k / 1000.0)
+ (completion_tokens as f64 * pricing.completion_price_per_1k / 1000.0);
cost
}
}
fn estimate_tokens(text: &str) -> u32 {
#[allow(clippy::cast_possible_truncation)]
let len = text.len() as u32;
len / CHARS_PER_TOKEN_ESTIMATE.max(1)
}
#[async_trait]
impl LlmProvider for MetricsProvider {
fn name(&self) -> &'static str {
self.inner.name()
}
fn display_name(&self) -> &str {
self.inner.display_name()
}
fn capabilities(&self) -> LlmCapabilities {
self.inner.capabilities()
}
fn default_model(&self) -> &str {
self.inner.default_model()
}
fn available_models(&self) -> &[String] {
self.inner.available_models()
}
async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
let start = Instant::now();
let result = self.inner.complete(request).await;
#[allow(clippy::cast_possible_truncation)]
let elapsed_ms = start.elapsed().as_millis() as u64;
let mut state = self.state.lock().expect("metrics lock poisoned");
state.call_count += 1;
state.total_latency_ms += elapsed_ms;
#[cfg(feature = "otel")]
let provider_attr = opentelemetry::KeyValue::new("provider", self.inner.name());
if let Ok(response) = &result {
let usage = response.usage.as_ref();
let prompt_tokens = u64::from(
usage.map_or_else(|| estimate_prompt_tokens(request), |u| u.prompt_tokens),
);
let completion_tokens = u64::from(usage.map_or_else(
|| estimate_tokens(&response.content),
|u| u.completion_tokens,
));
let total = prompt_tokens + completion_tokens;
state.total_prompt_tokens += prompt_tokens;
state.total_completion_tokens += completion_tokens;
state.total_tokens += total;
let cost = self.compute_cost(&response.model, prompt_tokens, completion_tokens);
state.total_cost += cost;
info!(
provider = self.inner.name(),
elapsed_ms, prompt_tokens, completion_tokens, cost, "metrics: complete() succeeded"
);
#[cfg(feature = "otel")]
{
let attrs = std::slice::from_ref(&provider_attr);
self.otel.requests_total.add(1, attrs);
#[allow(clippy::cast_precision_loss)]
self.otel
.requests_duration_ms
.record(elapsed_ms as f64, attrs);
self.otel.tokens_prompt.add(prompt_tokens, attrs);
self.otel.tokens_completion.add(completion_tokens, attrs);
if cost > 0.0 {
self.otel.cost_total.add(cost, attrs);
}
}
} else {
state.errors_count += 1;
info!(
provider = self.inner.name(),
elapsed_ms, "metrics: complete() failed"
);
#[cfg(feature = "otel")]
{
self.otel
.errors_total
.add(1, std::slice::from_ref(&provider_attr));
}
}
drop(state);
result
}
async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError> {
self.inner.complete_stream(request).await
}
async fn health_check(&self) -> Result<bool, RunnerError> {
self.inner.health_check().await
}
}
fn estimate_prompt_tokens(request: &ChatRequest) -> u32 {
let total_chars: usize = request.messages.iter().map(|m| m.content.len()).sum();
#[allow(clippy::cast_possible_truncation)]
let len = total_chars as u32;
len / CHARS_PER_TOKEN_ESTIMATE.max(1)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{
ChatMessage, ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider,
RunnerError, TokenUsage,
};
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct TestProvider {
responses: Mutex<Vec<Result<ChatResponse, RunnerError>>>,
call_count: AtomicU32,
}
impl TestProvider {
fn new(responses: Vec<Result<ChatResponse, RunnerError>>) -> Self {
Self {
responses: Mutex::new(responses),
call_count: AtomicU32::new(0),
}
}
}
#[async_trait]
impl LlmProvider for TestProvider {
fn name(&self) -> &'static str {
"test"
}
fn display_name(&self) -> &str {
"Test Provider"
}
fn capabilities(&self) -> LlmCapabilities {
LlmCapabilities::text_only()
}
fn default_model(&self) -> &'static str {
"test-model"
}
fn available_models(&self) -> &[String] {
&[]
}
async fn complete(&self, _request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let mut responses = self.responses.lock().expect("test lock");
if responses.is_empty() {
Ok(ChatResponse {
content: "default".to_owned(),
model: "test-model".to_owned(),
usage: None,
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})
} else {
responses.remove(0)
}
}
async fn complete_stream(&self, _request: &ChatRequest) -> Result<ChatStream, RunnerError> {
Err(RunnerError::internal("streaming not supported in test"))
}
async fn health_check(&self) -> Result<bool, RunnerError> {
Ok(true)
}
}
#[test]
fn fresh_report_is_zeroed() {
let provider = TestProvider::new(vec![]);
let metered = MetricsProvider::new(Box::new(provider));
let report = metered.report();
assert_eq!(report.call_count, 0);
assert_eq!(report.total_latency_ms, 0);
assert_eq!(report.avg_latency_ms, 0);
assert_eq!(report.total_prompt_tokens, 0);
assert_eq!(report.total_completion_tokens, 0);
assert_eq!(report.total_tokens, 0);
assert_eq!(report.errors_count, 0);
assert!(report.total_cost == 0.0);
assert_eq!(report.provider_name, "test");
}
#[tokio::test]
async fn call_count_increments() {
let provider = TestProvider::new(vec![
Ok(ChatResponse {
content: "hello world".to_owned(),
model: "test-model".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
}),
Ok(ChatResponse {
content: "second".to_owned(),
model: "test-model".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 8,
completion_tokens: 3,
total_tokens: 11,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
}),
]);
let metered = MetricsProvider::new(Box::new(provider));
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("first call");
metered.complete(&request).await.expect("second call");
let report = metered.report();
assert_eq!(report.call_count, 2);
assert_eq!(report.total_prompt_tokens, 18);
assert_eq!(report.total_completion_tokens, 8);
assert_eq!(report.total_tokens, 26);
assert_eq!(report.errors_count, 0);
}
#[tokio::test]
async fn errors_count_on_failure() {
let provider = TestProvider::new(vec![Err(RunnerError::external_service("test", "boom"))]);
let metered = MetricsProvider::new(Box::new(provider));
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
let result = metered.complete(&request).await;
assert!(result.is_err());
let report = metered.report();
assert_eq!(report.call_count, 1);
assert_eq!(report.errors_count, 1);
}
#[tokio::test]
async fn token_estimation_when_no_usage() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "abcdefghijklmnop".to_owned(), model: "test-model".to_owned(),
usage: None,
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider));
let request = ChatRequest::new(vec![ChatMessage::user("12345678")]);
metered.complete(&request).await.expect("call");
let report = metered.report();
assert_eq!(report.total_prompt_tokens, 2);
assert_eq!(report.total_completion_tokens, 4);
assert_eq!(report.total_tokens, 6);
}
#[test]
fn div_by_zero_guard_on_avg_latency() {
let provider = TestProvider::new(vec![]);
let metered = MetricsProvider::new(Box::new(provider));
let report = metered.report();
assert_eq!(report.avg_latency_ms, 0);
}
#[tokio::test]
async fn reset_zeroes_counters() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "hello".to_owned(),
model: "test-model".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 5,
completion_tokens: 2,
total_tokens: 7,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider));
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("call");
assert_eq!(metered.report().call_count, 1);
metered.reset();
let report = metered.report();
assert_eq!(report.call_count, 0);
assert_eq!(report.total_tokens, 0);
assert_eq!(report.errors_count, 0);
assert!(report.total_cost == 0.0);
}
#[tokio::test]
async fn cost_with_known_model() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "response".to_owned(),
model: "opus".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
total_tokens: 1500,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider)).with_default_pricing();
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("call");
let report = metered.report();
assert!((report.total_cost - 0.0525).abs() < 1e-10);
}
#[tokio::test]
async fn cost_with_unknown_model() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "response".to_owned(),
model: "some-unknown-model".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
total_tokens: 1500,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider)).with_default_pricing();
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("call");
let report = metered.report();
assert!(report.total_cost == 0.0);
}
#[tokio::test]
async fn cost_accumulates() {
let provider = TestProvider::new(vec![
Ok(ChatResponse {
content: "r1".to_owned(),
model: "opus".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
total_tokens: 1500,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
}),
Ok(ChatResponse {
content: "r2".to_owned(),
model: "opus".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 2000,
completion_tokens: 1000,
total_tokens: 3000,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
}),
]);
let metered = MetricsProvider::new(Box::new(provider)).with_default_pricing();
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("call 1");
metered.complete(&request).await.expect("call 2");
let report = metered.report();
assert!((report.total_cost - 0.1575).abs() < 1e-10);
}
#[tokio::test]
async fn cost_without_pricing() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "response".to_owned(),
model: "opus".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
total_tokens: 1500,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider));
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("call");
let report = metered.report();
assert!(report.total_cost == 0.0);
}
#[tokio::test]
async fn cost_with_estimated_tokens() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "abcdefghijklmnop".to_owned(), model: "opus".to_owned(),
usage: None,
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider)).with_default_pricing();
let request = ChatRequest::new(vec![ChatMessage::user("12345678")]); metered.complete(&request).await.expect("call");
let report = metered.report();
assert!(report.total_cost > 0.0);
assert!((report.total_cost - 0.00033).abs() < 1e-10);
}
#[test]
fn default_pricing_populated() {
let table = default_pricing_table();
assert!(table.contains_key("opus"));
assert!(table.contains_key("sonnet"));
assert!(table.contains_key("haiku"));
assert!(table.contains_key("gpt-5.4"));
assert!(table.contains_key("gemini-2.5-pro"));
assert!(table.contains_key("gemini-2.5-flash"));
assert!(table.len() >= 7);
}
#[tokio::test]
async fn reset_zeroes_cost() {
let provider = TestProvider::new(vec![Ok(ChatResponse {
content: "response".to_owned(),
model: "opus".to_owned(),
usage: Some(TokenUsage {
prompt_tokens: 1000,
completion_tokens: 500,
total_tokens: 1500,
}),
finish_reason: Some("stop".to_owned()),
warnings: None,
tool_calls: None,
})]);
let metered = MetricsProvider::new(Box::new(provider)).with_default_pricing();
let request = ChatRequest::new(vec![ChatMessage::user("hi")]);
metered.complete(&request).await.expect("call");
assert!(metered.report().total_cost > 0.0);
metered.reset();
assert!(metered.report().total_cost == 0.0);
}
}