1use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15use async_trait::async_trait;
16
17use crate::error::{ProviderError, Result};
18use crate::language_model::{
19 BoxStream, CallOptions, FinishReason, GenerateResult, LanguageModel, Prompt, StreamPart,
20 StreamResult, Usage,
21};
22
23use super::language_model::{CallKind, LanguageModelMiddleware};
24
25pub trait Logger: Send + Sync + std::fmt::Debug {
32 fn log_call_start(&self, event: &LogCallStart<'_>);
35
36 fn log_call_end(&self, event: &LogCallEnd<'_>);
42
43 fn log_call_error(&self, event: &LogCallError<'_>);
45
46 fn log_stream_part(&self, _event: &LogStreamPart<'_>) {}
51}
52
53#[derive(Debug, Clone, Copy)]
55pub struct LogContext<'a> {
56 pub provider: &'a str,
58 pub model_id: &'a str,
60 pub call_kind: CallKind,
62}
63
64#[derive(Debug, Clone, Copy)]
66pub struct LogCallStart<'a> {
67 pub context: LogContext<'a>,
69 pub prompt: Option<&'a Prompt>,
71}
72
73#[derive(Debug, Clone, Copy)]
75pub struct LogCallEnd<'a> {
76 pub context: LogContext<'a>,
78 pub elapsed: Duration,
80 pub usage: Option<&'a Usage>,
83 pub finish_reason: Option<&'a FinishReason>,
85}
86
87#[derive(Debug, Clone, Copy)]
89pub struct LogCallError<'a> {
90 pub context: LogContext<'a>,
92 pub elapsed: Duration,
94 pub error: &'a ProviderError,
96}
97
98#[derive(Debug, Clone, Copy)]
100pub struct LogStreamPart<'a> {
101 pub context: LogContext<'a>,
103 pub elapsed: Duration,
105 pub item: std::result::Result<&'a StreamPart, &'a ProviderError>,
107 pub index: usize,
109}
110
111#[derive(Debug, Clone)]
116pub struct LoggingMiddleware {
117 logger: Arc<dyn Logger>,
118 log_prompt: bool,
119 log_stream_parts: bool,
120}
121
122impl LoggingMiddleware {
123 #[must_use]
125 pub fn new(logger: Arc<dyn Logger>) -> Self {
126 Self {
127 logger,
128 log_prompt: false,
129 log_stream_parts: false,
130 }
131 }
132
133 #[must_use]
136 pub fn with_prompt(mut self, include: bool) -> Self {
137 self.log_prompt = include;
138 self
139 }
140
141 #[must_use]
144 pub fn with_stream_parts(mut self, include: bool) -> Self {
145 self.log_stream_parts = include;
146 self
147 }
148}
149
150#[async_trait]
151impl LanguageModelMiddleware for LoggingMiddleware {
152 async fn wrap_generate(
153 &self,
154 next: &dyn LanguageModel,
155 params: CallOptions,
156 ) -> Result<GenerateResult> {
157 let context = LogContext {
158 provider: next.provider(),
159 model_id: next.model_id(),
160 call_kind: CallKind::Generate,
161 };
162 let started = Instant::now();
163 self.logger.log_call_start(&LogCallStart {
164 context,
165 prompt: self.log_prompt.then_some(¶ms.prompt),
166 });
167 match next.do_generate(params).await {
168 Ok(result) => {
169 self.logger.log_call_end(&LogCallEnd {
170 context,
171 elapsed: started.elapsed(),
172 usage: Some(&result.usage),
173 finish_reason: Some(&result.finish_reason),
174 });
175 Ok(result)
176 }
177 Err(err) => {
178 self.logger.log_call_error(&LogCallError {
179 context,
180 elapsed: started.elapsed(),
181 error: &err,
182 });
183 Err(err)
184 }
185 }
186 }
187
188 async fn wrap_stream(
189 &self,
190 next: &dyn LanguageModel,
191 params: CallOptions,
192 ) -> Result<StreamResult> {
193 let context = LogContext {
194 provider: next.provider(),
195 model_id: next.model_id(),
196 call_kind: CallKind::Stream,
197 };
198 let started = Instant::now();
199 self.logger.log_call_start(&LogCallStart {
200 context,
201 prompt: self.log_prompt.then_some(¶ms.prompt),
202 });
203 match next.do_stream(params).await {
204 Ok(result) => {
205 self.logger.log_call_end(&LogCallEnd {
206 context,
207 elapsed: started.elapsed(),
208 usage: None,
209 finish_reason: None,
210 });
211 if self.log_stream_parts {
212 let StreamResult {
213 stream,
214 request,
215 response,
216 } = result;
217 let provider = context.provider.to_owned();
218 let model_id = context.model_id.to_owned();
219 let wrapped = wrap_stream_with_logger(
220 stream,
221 Arc::clone(&self.logger),
222 provider,
223 model_id,
224 started,
225 );
226 return Ok(StreamResult {
227 stream: wrapped,
228 request,
229 response,
230 });
231 }
232 Ok(result)
233 }
234 Err(err) => {
235 self.logger.log_call_error(&LogCallError {
236 context,
237 elapsed: started.elapsed(),
238 error: &err,
239 });
240 Err(err)
241 }
242 }
243 }
244}
245
246fn wrap_stream_with_logger(
249 inner: BoxStream<Result<StreamPart>>,
250 logger: Arc<dyn Logger>,
251 provider: String,
252 model_id: String,
253 started: Instant,
254) -> BoxStream<Result<StreamPart>> {
255 let stream = futures::stream::unfold(
256 (inner, 0_usize, logger, provider, model_id, started),
257 |(mut inner, idx, logger, provider, model_id, started)| async move {
258 use futures::StreamExt as _;
259 match inner.next().await {
260 None => None,
261 Some(item) => {
262 let ctx = LogContext {
263 provider: &provider,
264 model_id: &model_id,
265 call_kind: CallKind::Stream,
266 };
267 let event = LogStreamPart {
268 context: ctx,
269 elapsed: started.elapsed(),
270 item: item.as_ref(),
271 index: idx,
272 };
273 logger.log_stream_part(&event);
274 Some((item, (inner, idx + 1, logger, provider, model_id, started)))
275 }
276 }
277 },
278 );
279 Box::pin(stream)
280}
281
282#[derive(Debug, Default)]
287pub struct StderrLogger;
288
289impl Logger for StderrLogger {
290 fn log_call_start(&self, event: &LogCallStart<'_>) {
291 eprintln!(
292 "[llmsdk:start] provider={} model={} kind={:?}",
293 event.context.provider, event.context.model_id, event.context.call_kind,
294 );
295 }
296
297 fn log_call_end(&self, event: &LogCallEnd<'_>) {
298 eprintln!(
299 "[llmsdk:end] provider={} model={} kind={:?} elapsed_ms={} finish={:?}",
300 event.context.provider,
301 event.context.model_id,
302 event.context.call_kind,
303 event.elapsed.as_millis(),
304 event.finish_reason.map(|r| r.unified),
305 );
306 }
307
308 fn log_call_error(&self, event: &LogCallError<'_>) {
309 eprintln!(
310 "[llmsdk:error] provider={} model={} kind={:?} elapsed_ms={} error={}",
311 event.context.provider,
312 event.context.model_id,
313 event.context.call_kind,
314 event.elapsed.as_millis(),
315 event.error,
316 );
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use std::sync::Mutex;
323
324 use crate::language_model::FinishReasonKind;
325
326 use super::*;
327
328 #[derive(Debug, Default)]
329 struct RecordingLogger {
330 starts: Mutex<Vec<(String, String, CallKind, bool)>>,
331 ends: Mutex<Vec<(String, CallKind, bool, bool)>>,
332 errors: Mutex<Vec<(String, CallKind, String)>>,
333 parts: Mutex<Vec<(String, usize, bool)>>,
334 }
335
336 impl Logger for RecordingLogger {
337 fn log_call_start(&self, event: &LogCallStart<'_>) {
338 self.starts.lock().expect("starts mutex poisoned").push((
339 event.context.provider.to_owned(),
340 event.context.model_id.to_owned(),
341 event.context.call_kind,
342 event.prompt.is_some(),
343 ));
344 }
345
346 fn log_call_end(&self, event: &LogCallEnd<'_>) {
347 self.ends.lock().expect("ends mutex poisoned").push((
348 event.context.provider.to_owned(),
349 event.context.call_kind,
350 event.usage.is_some(),
351 event.finish_reason.is_some(),
352 ));
353 }
354
355 fn log_call_error(&self, event: &LogCallError<'_>) {
356 self.errors.lock().expect("errors mutex poisoned").push((
357 event.context.provider.to_owned(),
358 event.context.call_kind,
359 event.error.to_string(),
360 ));
361 }
362
363 fn log_stream_part(&self, event: &LogStreamPart<'_>) {
364 self.parts.lock().expect("parts mutex poisoned").push((
365 event.context.provider.to_owned(),
366 event.index,
367 event.item.is_ok(),
368 ));
369 }
370 }
371
372 #[derive(Debug)]
373 struct StubModel {
374 provider: String,
375 model_id: String,
376 should_fail: bool,
377 }
378
379 #[async_trait]
380 impl LanguageModel for StubModel {
381 fn provider(&self) -> &str {
382 &self.provider
383 }
384 fn model_id(&self) -> &str {
385 &self.model_id
386 }
387 async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
388 if self.should_fail {
389 return Err(ProviderError::invalid_prompt("nope"));
390 }
391 Ok(GenerateResult {
392 content: vec![],
393 finish_reason: FinishReason::new(FinishReasonKind::Stop),
394 usage: Usage::default(),
395 provider_metadata: None,
396 request: None,
397 response: None,
398 warnings: vec![],
399 })
400 }
401 async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
402 if self.should_fail {
403 return Err(ProviderError::invalid_prompt("nope"));
404 }
405 Ok(StreamResult {
406 stream: Box::pin(futures::stream::iter(Vec::new())),
407 request: None,
408 response: None,
409 })
410 }
411 }
412
413 #[tokio::test]
414 async fn success_emits_start_and_end_and_skips_prompt_by_default() {
415 let logger = Arc::new(RecordingLogger::default());
416 let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>);
417 let model = StubModel {
418 provider: "openai".to_owned(),
419 model_id: "gpt-foo".to_owned(),
420 should_fail: false,
421 };
422 mw.wrap_generate(&model, CallOptions::default())
423 .await
424 .expect("ok");
425 let starts = logger.starts.lock().expect("starts mutex poisoned");
426 assert_eq!(starts.len(), 1);
427 assert_eq!(starts[0].0, "openai");
428 assert_eq!(starts[0].1, "gpt-foo");
429 assert_eq!(starts[0].2, CallKind::Generate);
430 assert!(!starts[0].3, "prompt suppressed by default");
431 let ends = logger.ends.lock().expect("ends mutex poisoned");
432 assert_eq!(ends.len(), 1);
433 assert!(ends[0].2, "usage attached for generate");
434 assert!(ends[0].3, "finish_reason attached for generate");
435 assert!(
436 logger
437 .errors
438 .lock()
439 .expect("errors mutex poisoned")
440 .is_empty(),
441 "no error event on success"
442 );
443 }
444
445 #[tokio::test]
446 async fn with_prompt_attaches_prompt_to_start_event() {
447 let logger = Arc::new(RecordingLogger::default());
448 let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>).with_prompt(true);
449 let model = StubModel {
450 provider: "openai".to_owned(),
451 model_id: "gpt-foo".to_owned(),
452 should_fail: false,
453 };
454 mw.wrap_generate(&model, CallOptions::default())
455 .await
456 .expect("ok");
457 assert!(
458 logger.starts.lock().expect("starts mutex poisoned")[0].3,
459 "prompt attached when opt-in"
460 );
461 }
462
463 #[tokio::test]
464 async fn failure_emits_start_and_error_and_propagates() {
465 let logger = Arc::new(RecordingLogger::default());
466 let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>);
467 let model = StubModel {
468 provider: "openai".to_owned(),
469 model_id: "gpt-foo".to_owned(),
470 should_fail: true,
471 };
472 let err = mw
473 .wrap_generate(&model, CallOptions::default())
474 .await
475 .expect_err("propagates");
476 assert!(err.to_string().contains("nope"));
477 assert_eq!(
478 logger.errors.lock().expect("errors mutex poisoned").len(),
479 1
480 );
481 assert!(logger.ends.lock().expect("ends mutex poisoned").is_empty());
482 }
483
484 #[derive(Debug)]
485 struct ThreePartStream;
486
487 #[async_trait]
488 impl LanguageModel for ThreePartStream {
489 fn provider(&self) -> &'static str {
490 "openai"
491 }
492 fn model_id(&self) -> &'static str {
493 "gpt-foo"
494 }
495 async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
496 unimplemented!()
497 }
498 async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
499 let parts: Vec<Result<StreamPart>> = vec![
500 Ok(StreamPart::StreamStart { warnings: vec![] }),
501 Ok(StreamPart::TextStart {
502 id: "b".into(),
503 provider_metadata: None,
504 }),
505 Ok(StreamPart::Finish {
506 usage: Usage::default(),
507 finish_reason: FinishReason::new(FinishReasonKind::Stop),
508 provider_metadata: None,
509 }),
510 ];
511 Ok(StreamResult {
512 stream: Box::pin(futures::stream::iter(parts)),
513 request: None,
514 response: None,
515 })
516 }
517 }
518
519 #[tokio::test]
520 async fn stream_parts_opt_in_emits_one_event_per_frame() {
521 use futures::StreamExt as _;
522
523 let logger = Arc::new(RecordingLogger::default());
524 let mw =
525 LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>).with_stream_parts(true);
526
527 let mut result = mw
528 .wrap_stream(&ThreePartStream, CallOptions::default())
529 .await
530 .expect("opens");
531 while result.stream.next().await.is_some() {}
532
533 let parts = logger.parts.lock().expect("parts mutex").clone();
534 assert_eq!(parts.len(), 3);
535 assert_eq!(parts[0].1, 0);
536 assert_eq!(parts[2].1, 2);
537 assert!(parts.iter().all(|(_, _, ok)| *ok));
538 }
539
540 #[tokio::test]
541 async fn stream_success_attaches_no_usage_or_finish_reason() {
542 let logger = Arc::new(RecordingLogger::default());
543 let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>);
544 let model = StubModel {
545 provider: "openai".to_owned(),
546 model_id: "gpt-foo".to_owned(),
547 should_fail: false,
548 };
549 mw.wrap_stream(&model, CallOptions::default())
550 .await
551 .expect("ok");
552 let ends = logger.ends.lock().expect("ends mutex poisoned");
553 assert_eq!(ends[0].1, CallKind::Stream);
554 assert!(!ends[0].2, "usage is None for stream");
555 assert!(!ends[0].3, "finish_reason is None for stream");
556 }
557}