rig_core/telemetry/
mod.rs1use crate::completion::GetTokenUsage;
7use serde::Serialize;
8
9pub trait ProviderRequestExt {
11 type InputMessage: Serialize;
13
14 fn get_input_messages(&self) -> Vec<Self::InputMessage>;
16 fn get_system_prompt(&self) -> Option<String>;
18 fn get_model_name(&self) -> String;
20 fn get_prompt(&self) -> Option<String>;
22}
23
24pub trait ProviderResponseExt {
26 type OutputMessage: Serialize;
28 type Usage: Serialize;
30
31 fn get_response_id(&self) -> Option<String>;
33
34 fn get_response_model_name(&self) -> Option<String>;
36
37 fn get_output_messages(&self) -> Vec<Self::OutputMessage>;
39
40 fn get_text_response(&self) -> Option<String>;
42
43 fn get_usage(&self) -> Option<Self::Usage>;
45}
46
47pub trait SpanCombinator {
50 fn record_token_usage<U>(&self, usage: &U)
52 where
53 U: GetTokenUsage;
54
55 fn record_response_metadata<R>(&self, response: &R)
57 where
58 R: ProviderResponseExt;
59
60 fn record_model_input<T>(&self, messages: &T)
62 where
63 T: Serialize;
64
65 fn record_model_output<T>(&self, messages: &T)
67 where
68 T: Serialize;
69}
70
71impl SpanCombinator for tracing::Span {
72 fn record_token_usage<U>(&self, usage: &U)
73 where
74 U: GetTokenUsage,
75 {
76 if self.is_disabled() {
77 return;
78 }
79
80 if let Some(usage) = usage.token_usage() {
81 self.record("gen_ai.usage.input_tokens", usage.input_tokens);
82 self.record("gen_ai.usage.output_tokens", usage.output_tokens);
83 self.record(
84 "gen_ai.usage.cache_read.input_tokens",
85 usage.cached_input_tokens,
86 );
87 self.record(
88 "gen_ai.usage.cache_creation.input_tokens",
89 usage.cache_creation_input_tokens,
90 );
91 self.record(
92 "gen_ai.usage.tool_use_prompt_tokens",
93 usage.tool_use_prompt_tokens,
94 );
95 self.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
96 }
97 }
98
99 fn record_response_metadata<R>(&self, response: &R)
100 where
101 R: ProviderResponseExt,
102 {
103 if self.is_disabled() {
104 return;
105 }
106
107 if let Some(id) = response.get_response_id() {
108 self.record("gen_ai.response.id", id);
109 }
110
111 if let Some(model_name) = response.get_response_model_name() {
112 self.record("gen_ai.response.model", model_name);
113 }
114 }
115
116 fn record_model_input<T>(&self, input: &T)
117 where
118 T: Serialize,
119 {
120 if self.is_disabled() {
121 return;
122 }
123
124 if let Ok(input_as_json_string) = serde_json::to_string(input) {
125 self.record("gen_ai.input.messages", input_as_json_string);
126 }
127 }
128
129 fn record_model_output<T>(&self, output: &T)
130 where
131 T: Serialize,
132 {
133 if self.is_disabled() {
134 return;
135 }
136
137 if let Ok(output_as_json_string) = serde_json::to_string(output) {
138 self.record("gen_ai.output.messages", output_as_json_string);
139 }
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146 use crate::completion::{GetTokenUsage, Usage};
147 use std::sync::{Arc, Mutex};
148 use tracing::field::{Field, Visit};
149 use tracing::{Id, Subscriber};
150 use tracing_subscriber::layer::{Context, SubscriberExt};
151 use tracing_subscriber::{Layer, Registry, registry::LookupSpan};
152
153 #[derive(Clone)]
154 struct TestUsage(Usage);
155
156 impl GetTokenUsage for TestUsage {
157 fn token_usage(&self) -> Option<Usage> {
158 Some(self.0)
159 }
160 }
161
162 #[derive(Clone, Default)]
163 struct CapturedFields(Arc<Mutex<Vec<(String, u64)>>>);
164
165 impl CapturedFields {
166 fn push(&self, name: &str, value: u64) {
167 if let Ok(mut fields) = self.0.lock() {
168 fields.push((name.to_string(), value));
169 }
170 }
171
172 fn contains(&self, name: &str, value: u64) -> bool {
173 self.0.lock().is_ok_and(|fields| {
174 fields
175 .iter()
176 .any(|field| field == &(name.to_string(), value))
177 })
178 }
179 }
180
181 struct FieldCaptureLayer {
182 fields: CapturedFields,
183 }
184
185 impl<S> Layer<S> for FieldCaptureLayer
186 where
187 S: Subscriber,
188 S: for<'lookup> LookupSpan<'lookup>,
189 {
190 fn on_record(&self, _span: &Id, values: &tracing::span::Record<'_>, _ctx: Context<'_, S>) {
191 values.record(&mut FieldCaptureVisitor {
192 fields: self.fields.clone(),
193 });
194 }
195 }
196
197 struct FieldCaptureVisitor {
198 fields: CapturedFields,
199 }
200
201 impl Visit for FieldCaptureVisitor {
202 fn record_u64(&mut self, field: &Field, value: u64) {
203 self.fields.push(field.name(), value);
204 }
205
206 fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {}
207 }
208
209 #[test]
210 fn record_token_usage_records_tool_use_prompt_tokens() {
211 let fields = CapturedFields::default();
212 let subscriber = Registry::default().with(FieldCaptureLayer {
213 fields: fields.clone(),
214 });
215 let usage = TestUsage(Usage {
216 input_tokens: 1,
217 output_tokens: 2,
218 total_tokens: 15,
219 cached_input_tokens: 3,
220 cache_creation_input_tokens: 4,
221 tool_use_prompt_tokens: 12,
222 reasoning_tokens: 5,
223 });
224
225 tracing::subscriber::with_default(subscriber, || {
226 let span = tracing::info_span!(
227 "usage_recording",
228 gen_ai.usage.input_tokens = tracing::field::Empty,
229 gen_ai.usage.output_tokens = tracing::field::Empty,
230 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
231 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
232 gen_ai.usage.tool_use_prompt_tokens = tracing::field::Empty,
233 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
234 );
235
236 span.record_token_usage(&usage);
237 });
238
239 assert!(fields.contains("gen_ai.usage.tool_use_prompt_tokens", 12));
240 }
241}