1use crate::{
46 EmbeddingProvider, EmbeddingRequest, EmbeddingResponse, LlmProvider, LlmRequest, LlmResponse,
47 Result,
48};
49use async_trait::async_trait;
50use std::time::Instant;
51
52#[derive(Debug, Clone)]
54pub struct SpanAttributes {
55 pub provider: String,
57 pub model: String,
59 pub prompt: Option<String>,
61 pub system_prompt: Option<String>,
63 pub temperature: Option<f64>,
65 pub max_tokens: Option<u32>,
67 pub tools_count: usize,
69 pub images_count: usize,
71}
72
73impl SpanAttributes {
74 pub fn from_request(provider: &str, model: &str, request: &LlmRequest) -> Self {
76 Self {
77 provider: provider.to_string(),
78 model: model.to_string(),
79 prompt: Some(Self::truncate(&request.prompt, 500)),
80 system_prompt: request
81 .system_prompt
82 .as_ref()
83 .map(|s| Self::truncate(s, 500)),
84 temperature: request.temperature,
85 max_tokens: request.max_tokens,
86 tools_count: request.tools.len(),
87 images_count: request.images.len(),
88 }
89 }
90
91 pub fn from_embedding_request(provider: &str, model: &str, request: &EmbeddingRequest) -> Self {
93 Self {
94 provider: provider.to_string(),
95 model: model.to_string(),
96 prompt: request.texts.first().map(|t| Self::truncate(t, 500)),
97 system_prompt: None,
98 temperature: None,
99 max_tokens: None,
100 tools_count: 0,
101 images_count: 0,
102 }
103 }
104
105 fn truncate(s: &str, max_len: usize) -> String {
106 if s.len() <= max_len {
107 s.to_string()
108 } else {
109 format!("{}...", &s[..max_len])
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
116pub struct ResponseAttributes {
117 pub content: String,
119 pub prompt_tokens: Option<u32>,
121 pub completion_tokens: Option<u32>,
123 pub total_tokens: Option<u32>,
125 pub tool_calls_count: usize,
127 pub latency_ms: u64,
129 pub success: bool,
131 pub error: Option<String>,
133}
134
135impl ResponseAttributes {
136 pub fn from_response(response: &LlmResponse, latency_ms: u64) -> Self {
138 Self {
139 content: Self::truncate(&response.content, 500),
140 prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens),
141 completion_tokens: response.usage.as_ref().map(|u| u.completion_tokens),
142 total_tokens: response.usage.as_ref().map(|u| u.total_tokens),
143 tool_calls_count: response.tool_calls.len(),
144 latency_ms,
145 success: true,
146 error: None,
147 }
148 }
149
150 pub fn from_embedding_response(response: &EmbeddingResponse, latency_ms: u64) -> Self {
152 Self {
153 content: format!("{} embeddings", response.embeddings.len()),
154 prompt_tokens: response.usage.as_ref().map(|u| u.prompt_tokens),
155 completion_tokens: None,
156 total_tokens: response.usage.as_ref().map(|u| u.total_tokens),
157 tool_calls_count: 0,
158 latency_ms,
159 success: true,
160 error: None,
161 }
162 }
163
164 pub fn from_error(error: &str, latency_ms: u64) -> Self {
166 Self {
167 content: String::new(),
168 prompt_tokens: None,
169 completion_tokens: None,
170 total_tokens: None,
171 tool_calls_count: 0,
172 latency_ms,
173 success: false,
174 error: Some(error.to_string()),
175 }
176 }
177
178 fn truncate(s: &str, max_len: usize) -> String {
179 if s.len() <= max_len {
180 s.to_string()
181 } else {
182 format!("{}...", &s[..max_len])
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
189pub struct TraceEvent {
190 pub span_name: String,
192 pub request_attrs: SpanAttributes,
194 pub response_attrs: Option<ResponseAttributes>,
196}
197
198impl TraceEvent {
199 pub fn new(span_name: String, request_attrs: SpanAttributes) -> Self {
201 Self {
202 span_name,
203 request_attrs,
204 response_attrs: None,
205 }
206 }
207
208 pub fn with_response(mut self, response_attrs: ResponseAttributes) -> Self {
210 self.response_attrs = Some(response_attrs);
211 self
212 }
213}
214
215pub struct OtelProvider {
219 inner: Box<dyn LlmProvider>,
220 provider_name: String,
221 model_name: String,
222 trace_callback: Option<Box<dyn Fn(TraceEvent) + Send + Sync>>,
224}
225
226impl OtelProvider {
227 pub fn new(inner: Box<dyn LlmProvider>, provider_name: String, model_name: String) -> Self {
229 Self {
230 inner,
231 provider_name,
232 model_name,
233 trace_callback: None,
234 }
235 }
236
237 pub fn with_trace_callback<F>(mut self, callback: F) -> Self
241 where
242 F: Fn(TraceEvent) + Send + Sync + 'static,
243 {
244 self.trace_callback = Some(Box::new(callback));
245 self
246 }
247
248 fn emit_trace(&self, event: TraceEvent) {
249 tracing::info!(
251 provider = %event.request_attrs.provider,
252 model = %event.request_attrs.model,
253 success = event.response_attrs.as_ref().map(|r| r.success).unwrap_or(false),
254 latency_ms = event.response_attrs.as_ref().map(|r| r.latency_ms).unwrap_or(0),
255 "LLM request trace"
256 );
257
258 if let Some(callback) = &self.trace_callback {
259 callback(event);
260 }
261 }
262}
263
264#[async_trait]
265impl LlmProvider for OtelProvider {
266 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
267 let start = Instant::now();
268 let span_attrs =
269 SpanAttributes::from_request(&self.provider_name, &self.model_name, &request);
270
271 match self.inner.complete(request).await {
272 Ok(response) => {
273 let latency_ms = start.elapsed().as_millis() as u64;
274 let response_attrs = ResponseAttributes::from_response(&response, latency_ms);
275
276 let trace_event = TraceEvent::new("llm.complete".to_string(), span_attrs)
277 .with_response(response_attrs);
278
279 self.emit_trace(trace_event);
280
281 Ok(response)
282 }
283 Err(e) => {
284 let latency_ms = start.elapsed().as_millis() as u64;
285 let response_attrs = ResponseAttributes::from_error(&e.to_string(), latency_ms);
286
287 let trace_event = TraceEvent::new("llm.complete".to_string(), span_attrs)
288 .with_response(response_attrs);
289
290 self.emit_trace(trace_event);
291
292 Err(e)
293 }
294 }
295 }
296}
297
298pub struct OtelEmbeddingProvider {
300 inner: Box<dyn EmbeddingProvider>,
301 provider_name: String,
302 model_name: String,
303 trace_callback: Option<Box<dyn Fn(TraceEvent) + Send + Sync>>,
304}
305
306impl OtelEmbeddingProvider {
307 pub fn new(
309 inner: Box<dyn EmbeddingProvider>,
310 provider_name: String,
311 model_name: String,
312 ) -> Self {
313 Self {
314 inner,
315 provider_name,
316 model_name,
317 trace_callback: None,
318 }
319 }
320
321 pub fn with_trace_callback<F>(mut self, callback: F) -> Self
323 where
324 F: Fn(TraceEvent) + Send + Sync + 'static,
325 {
326 self.trace_callback = Some(Box::new(callback));
327 self
328 }
329
330 fn emit_trace(&self, event: TraceEvent) {
331 tracing::info!(
332 provider = %event.request_attrs.provider,
333 model = %event.request_attrs.model,
334 success = event.response_attrs.as_ref().map(|r| r.success).unwrap_or(false),
335 latency_ms = event.response_attrs.as_ref().map(|r| r.latency_ms).unwrap_or(0),
336 "Embedding request trace"
337 );
338
339 if let Some(callback) = &self.trace_callback {
340 callback(event);
341 }
342 }
343}
344
345#[async_trait]
346impl EmbeddingProvider for OtelEmbeddingProvider {
347 async fn embed(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
348 let start = Instant::now();
349 let span_attrs =
350 SpanAttributes::from_embedding_request(&self.provider_name, &self.model_name, &request);
351
352 match self.inner.embed(request).await {
353 Ok(response) => {
354 let latency_ms = start.elapsed().as_millis() as u64;
355 let response_attrs =
356 ResponseAttributes::from_embedding_response(&response, latency_ms);
357
358 let trace_event = TraceEvent::new("embedding.embed".to_string(), span_attrs)
359 .with_response(response_attrs);
360
361 self.emit_trace(trace_event);
362
363 Ok(response)
364 }
365 Err(e) => {
366 let latency_ms = start.elapsed().as_millis() as u64;
367 let response_attrs = ResponseAttributes::from_error(&e.to_string(), latency_ms);
368
369 let trace_event = TraceEvent::new("embedding.embed".to_string(), span_attrs)
370 .with_response(response_attrs);
371
372 self.emit_trace(trace_event);
373
374 Err(e)
375 }
376 }
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use crate::{OpenAIProvider, Usage};
384
385 #[test]
386 fn test_span_attributes_from_request() {
387 let request = LlmRequest {
388 prompt: "Test prompt".to_string(),
389 system_prompt: Some("System prompt".to_string()),
390 temperature: Some(0.7),
391 max_tokens: Some(100),
392 tools: vec![],
393 images: vec![],
394 };
395
396 let attrs = SpanAttributes::from_request("openai", "gpt-4", &request);
397
398 assert_eq!(attrs.provider, "openai");
399 assert_eq!(attrs.model, "gpt-4");
400 assert_eq!(attrs.prompt, Some("Test prompt".to_string()));
401 assert_eq!(attrs.system_prompt, Some("System prompt".to_string()));
402 assert_eq!(attrs.temperature, Some(0.7));
403 assert_eq!(attrs.max_tokens, Some(100));
404 assert_eq!(attrs.tools_count, 0);
405 assert_eq!(attrs.images_count, 0);
406 }
407
408 #[test]
409 fn test_span_attributes_truncation() {
410 let long_prompt = "a".repeat(1000);
411 let request = LlmRequest {
412 prompt: long_prompt,
413 system_prompt: None,
414 temperature: None,
415 max_tokens: None,
416 tools: vec![],
417 images: vec![],
418 };
419
420 let attrs = SpanAttributes::from_request("openai", "gpt-4", &request);
421
422 assert!(attrs.prompt.as_ref().unwrap().len() <= 503); assert!(attrs.prompt.as_ref().unwrap().ends_with("..."));
424 }
425
426 #[test]
427 fn test_response_attributes_from_response() {
428 let response = LlmResponse {
429 content: "Test response".to_string(),
430 model: "gpt-4".to_string(),
431 usage: Some(Usage {
432 prompt_tokens: 10,
433 completion_tokens: 20,
434 total_tokens: 30,
435 }),
436 tool_calls: vec![],
437 };
438
439 let attrs = ResponseAttributes::from_response(&response, 100);
440
441 assert_eq!(attrs.content, "Test response");
442 assert_eq!(attrs.prompt_tokens, Some(10));
443 assert_eq!(attrs.completion_tokens, Some(20));
444 assert_eq!(attrs.total_tokens, Some(30));
445 assert_eq!(attrs.latency_ms, 100);
446 assert!(attrs.success);
447 assert!(attrs.error.is_none());
448 }
449
450 #[test]
451 fn test_response_attributes_from_error() {
452 let attrs = ResponseAttributes::from_error("Rate limited", 50);
453
454 assert_eq!(attrs.content, "");
455 assert_eq!(attrs.latency_ms, 50);
456 assert!(!attrs.success);
457 assert_eq!(attrs.error, Some("Rate limited".to_string()));
458 }
459
460 #[test]
461 fn test_trace_event_creation() {
462 let request = LlmRequest {
463 prompt: "Test".to_string(),
464 system_prompt: None,
465 temperature: None,
466 max_tokens: None,
467 tools: vec![],
468 images: vec![],
469 };
470
471 let span_attrs = SpanAttributes::from_request("openai", "gpt-4", &request);
472 let trace_event = TraceEvent::new("llm.complete".to_string(), span_attrs);
473
474 assert_eq!(trace_event.span_name, "llm.complete");
475 assert!(trace_event.response_attrs.is_none());
476 }
477
478 #[tokio::test]
479 async fn test_otel_provider_success() {
480 let provider = OpenAIProvider::new("test_key".to_string(), "gpt-4".to_string());
481 let otel_provider = OtelProvider::new(
482 Box::new(provider),
483 "openai".to_string(),
484 "gpt-4".to_string(),
485 );
486
487 let request = LlmRequest {
488 prompt: "Test".to_string(),
489 system_prompt: None,
490 temperature: None,
491 max_tokens: None,
492 tools: vec![],
493 images: vec![],
494 };
495
496 let result = otel_provider.complete(request).await;
498 assert!(result.is_err());
499 }
500
501 #[tokio::test]
502 async fn test_otel_provider_with_callback() {
503 use std::sync::Arc;
504 use tokio::sync::Mutex;
505
506 let provider = OpenAIProvider::new("test_key".to_string(), "gpt-4".to_string());
507 let trace_events = Arc::new(Mutex::new(Vec::new()));
508 let trace_events_clone = Arc::clone(&trace_events);
509
510 let otel_provider = OtelProvider::new(
511 Box::new(provider),
512 "openai".to_string(),
513 "gpt-4".to_string(),
514 )
515 .with_trace_callback(move |event| {
516 let events = trace_events_clone.clone();
517 tokio::spawn(async move {
518 events.lock().await.push(event);
519 });
520 });
521
522 let request = LlmRequest {
523 prompt: "Test".to_string(),
524 system_prompt: None,
525 temperature: None,
526 max_tokens: None,
527 tools: vec![],
528 images: vec![],
529 };
530
531 let _ = otel_provider.complete(request).await;
532
533 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
535
536 let events = trace_events.lock().await;
537 assert!(!events.is_empty());
538 }
539}