use crate::completion::GetTokenUsage;
use serde::Serialize;
pub trait ProviderRequestExt {
type InputMessage: Serialize;
fn get_input_messages(&self) -> Vec<Self::InputMessage>;
fn get_system_prompt(&self) -> Option<String>;
fn get_model_name(&self) -> String;
fn get_prompt(&self) -> Option<String>;
}
pub trait ProviderResponseExt {
type OutputMessage: Serialize;
type Usage: Serialize;
fn get_response_id(&self) -> Option<String>;
fn get_response_model_name(&self) -> Option<String>;
fn get_output_messages(&self) -> Vec<Self::OutputMessage>;
fn get_text_response(&self) -> Option<String>;
fn get_usage(&self) -> Option<Self::Usage>;
}
pub trait SpanCombinator {
fn record_token_usage<U>(&self, usage: &U)
where
U: GetTokenUsage;
fn record_response_metadata<R>(&self, response: &R)
where
R: ProviderResponseExt;
fn record_model_input<T>(&self, messages: &T)
where
T: Serialize;
fn record_model_output<T>(&self, messages: &T)
where
T: Serialize;
}
impl SpanCombinator for tracing::Span {
fn record_token_usage<U>(&self, usage: &U)
where
U: GetTokenUsage,
{
if self.is_disabled() {
return;
}
if let Some(usage) = usage.token_usage() {
self.record("gen_ai.usage.input_tokens", usage.input_tokens);
self.record("gen_ai.usage.output_tokens", usage.output_tokens);
self.record(
"gen_ai.usage.cache_read.input_tokens",
usage.cached_input_tokens,
);
self.record(
"gen_ai.usage.cache_creation.input_tokens",
usage.cache_creation_input_tokens,
);
self.record(
"gen_ai.usage.tool_use_prompt_tokens",
usage.tool_use_prompt_tokens,
);
self.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
}
}
fn record_response_metadata<R>(&self, response: &R)
where
R: ProviderResponseExt,
{
if self.is_disabled() {
return;
}
if let Some(id) = response.get_response_id() {
self.record("gen_ai.response.id", id);
}
if let Some(model_name) = response.get_response_model_name() {
self.record("gen_ai.response.model", model_name);
}
}
fn record_model_input<T>(&self, input: &T)
where
T: Serialize,
{
if self.is_disabled() {
return;
}
if let Ok(input_as_json_string) = serde_json::to_string(input) {
self.record("gen_ai.input.messages", input_as_json_string);
}
}
fn record_model_output<T>(&self, output: &T)
where
T: Serialize,
{
if self.is_disabled() {
return;
}
if let Ok(output_as_json_string) = serde_json::to_string(output) {
self.record("gen_ai.output.messages", output_as_json_string);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::completion::{GetTokenUsage, Usage};
use std::sync::{Arc, Mutex};
use tracing::field::{Field, Visit};
use tracing::{Id, Subscriber};
use tracing_subscriber::layer::{Context, SubscriberExt};
use tracing_subscriber::{Layer, Registry, registry::LookupSpan};
#[derive(Clone)]
struct TestUsage(Usage);
impl GetTokenUsage for TestUsage {
fn token_usage(&self) -> Option<Usage> {
Some(self.0)
}
}
#[derive(Clone, Default)]
struct CapturedFields(Arc<Mutex<Vec<(String, u64)>>>);
impl CapturedFields {
fn push(&self, name: &str, value: u64) {
if let Ok(mut fields) = self.0.lock() {
fields.push((name.to_string(), value));
}
}
fn contains(&self, name: &str, value: u64) -> bool {
self.0.lock().is_ok_and(|fields| {
fields
.iter()
.any(|field| field == &(name.to_string(), value))
})
}
}
struct FieldCaptureLayer {
fields: CapturedFields,
}
impl<S> Layer<S> for FieldCaptureLayer
where
S: Subscriber,
S: for<'lookup> LookupSpan<'lookup>,
{
fn on_record(&self, _span: &Id, values: &tracing::span::Record<'_>, _ctx: Context<'_, S>) {
values.record(&mut FieldCaptureVisitor {
fields: self.fields.clone(),
});
}
}
struct FieldCaptureVisitor {
fields: CapturedFields,
}
impl Visit for FieldCaptureVisitor {
fn record_u64(&mut self, field: &Field, value: u64) {
self.fields.push(field.name(), value);
}
fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {}
}
#[test]
fn record_token_usage_records_tool_use_prompt_tokens() {
let fields = CapturedFields::default();
let subscriber = Registry::default().with(FieldCaptureLayer {
fields: fields.clone(),
});
let usage = TestUsage(Usage {
input_tokens: 1,
output_tokens: 2,
total_tokens: 15,
cached_input_tokens: 3,
cache_creation_input_tokens: 4,
tool_use_prompt_tokens: 12,
reasoning_tokens: 5,
});
tracing::subscriber::with_default(subscriber, || {
let span = tracing::info_span!(
"usage_recording",
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
gen_ai.usage.reasoning_tokens = tracing::field::Empty,
);
span.record_token_usage(&usage);
});
assert!(fields.contains("gen_ai.usage.tool_use_prompt_tokens", 12));
}
}