Skip to main content

llmsdk_provider/middleware/
logging.rs

1//! Logging middleware that emits structured events around each model call.
2//!
3//! The middleware does not depend on any logging crate: callers implement
4//! [`Logger`] and route events wherever they like (`tracing`, `log`, an
5//! in-process channel, ...). A minimal [`StderrLogger`] is bundled as a
6//! quick-start option and as a test hook.
7//!
8//! By default the prompt is **not** included in any event (PII / size
9//! reasons); opt in with [`LoggingMiddleware::with_prompt`].
10// Rust guideline compliant 2026-02-21
11
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14
15use async_trait::async_trait;
16
17use crate::error::{ProviderError, Result};
18use crate::language_model::{
19    BoxStream, CallOptions, FinishReason, GenerateResult, LanguageModel, Prompt, StreamPart,
20    StreamResult, Usage,
21};
22
23use super::language_model::{CallKind, LanguageModelMiddleware};
24
25/// Sink for [`LoggingMiddleware`] events.
26///
27/// Implement this trait to forward middleware events to your logging system.
28/// All methods are synchronous on purpose: emitting a log line should never
29/// be the bottleneck on a model call. Buffer / async-dispatch in the impl
30/// if you need to.
31pub trait Logger: Send + Sync + std::fmt::Debug {
32    /// Called once per model call, after `transform_params` and before the
33    /// inner model runs.
34    fn log_call_start(&self, event: &LogCallStart<'_>);
35
36    /// Called once per *successful* model call.
37    ///
38    /// For streams this fires when the stream is **opened** (not when it
39    /// finishes); per-frame instrumentation is out of scope for the first
40    /// iteration.
41    fn log_call_end(&self, event: &LogCallEnd<'_>);
42
43    /// Called once per *failed* model call.
44    fn log_call_error(&self, event: &LogCallError<'_>);
45
46    /// Called once per emitted [`StreamPart`] when the middleware is built
47    /// with [`LoggingMiddleware::with_stream_parts`].
48    ///
49    /// Default no-op so existing [`Logger`] implementations keep working.
50    fn log_stream_part(&self, _event: &LogStreamPart<'_>) {}
51}
52
53/// Identity + call shape, common to every event.
54#[derive(Debug, Clone, Copy)]
55pub struct LogContext<'a> {
56    /// Wrapped model's provider id.
57    pub provider: &'a str,
58    /// Wrapped model's model id.
59    pub model_id: &'a str,
60    /// Whether this call is a generate or a stream.
61    pub call_kind: CallKind,
62}
63
64/// Event emitted before the inner model runs.
65#[derive(Debug, Clone, Copy)]
66pub struct LogCallStart<'a> {
67    /// Common identity / call shape.
68    pub context: LogContext<'a>,
69    /// Prompt — present only when [`LoggingMiddleware::with_prompt`] is set.
70    pub prompt: Option<&'a Prompt>,
71}
72
73/// Event emitted on success.
74#[derive(Debug, Clone, Copy)]
75pub struct LogCallEnd<'a> {
76    /// Common identity / call shape.
77    pub context: LogContext<'a>,
78    /// Wall-clock duration from `log_call_start` to call return.
79    pub elapsed: Duration,
80    /// Token usage — only meaningful for [`CallKind::Generate`] (stream
81    /// totals are not available until the stream drains).
82    pub usage: Option<&'a Usage>,
83    /// Why the model stopped — only meaningful for [`CallKind::Generate`].
84    pub finish_reason: Option<&'a FinishReason>,
85}
86
87/// Event emitted on failure.
88#[derive(Debug, Clone, Copy)]
89pub struct LogCallError<'a> {
90    /// Common identity / call shape.
91    pub context: LogContext<'a>,
92    /// Wall-clock duration from `log_call_start` to error return.
93    pub elapsed: Duration,
94    /// The error that was returned.
95    pub error: &'a ProviderError,
96}
97
98/// One per-frame event emitted while a stream is alive.
99#[derive(Debug, Clone, Copy)]
100pub struct LogStreamPart<'a> {
101    /// Common identity / call shape.
102    pub context: LogContext<'a>,
103    /// Wall-clock duration since `log_call_start`.
104    pub elapsed: Duration,
105    /// Either the part itself (`Ok`) or the per-frame transport error (`Err`).
106    pub item: std::result::Result<&'a StreamPart, &'a ProviderError>,
107    /// Zero-based index of the part within the stream.
108    pub index: usize,
109}
110
111/// Middleware that emits [`Logger`] events around every call.
112///
113/// Cheap to clone (just an `Arc` to the logger); safe to stack on top of
114/// retry / cache.
115#[derive(Debug, Clone)]
116pub struct LoggingMiddleware {
117    logger: Arc<dyn Logger>,
118    log_prompt: bool,
119    log_stream_parts: bool,
120}
121
122impl LoggingMiddleware {
123    /// Build a middleware that forwards events to `logger`.
124    #[must_use]
125    pub fn new(logger: Arc<dyn Logger>) -> Self {
126        Self {
127            logger,
128            log_prompt: false,
129            log_stream_parts: false,
130        }
131    }
132
133    /// Include the [`Prompt`] in [`LogCallStart`]. Off by default to avoid
134    /// accidentally logging PII or large payloads.
135    #[must_use]
136    pub fn with_prompt(mut self, include: bool) -> Self {
137        self.log_prompt = include;
138        self
139    }
140
141    /// Emit a [`Logger::log_stream_part`] event for every part yielded by
142    /// `do_stream`. Off by default — turning it on can be noisy.
143    #[must_use]
144    pub fn with_stream_parts(mut self, include: bool) -> Self {
145        self.log_stream_parts = include;
146        self
147    }
148}
149
150#[async_trait]
151impl LanguageModelMiddleware for LoggingMiddleware {
152    async fn wrap_generate(
153        &self,
154        next: &dyn LanguageModel,
155        params: CallOptions,
156    ) -> Result<GenerateResult> {
157        let context = LogContext {
158            provider: next.provider(),
159            model_id: next.model_id(),
160            call_kind: CallKind::Generate,
161        };
162        let started = Instant::now();
163        self.logger.log_call_start(&LogCallStart {
164            context,
165            prompt: self.log_prompt.then_some(&params.prompt),
166        });
167        match next.do_generate(params).await {
168            Ok(result) => {
169                self.logger.log_call_end(&LogCallEnd {
170                    context,
171                    elapsed: started.elapsed(),
172                    usage: Some(&result.usage),
173                    finish_reason: Some(&result.finish_reason),
174                });
175                Ok(result)
176            }
177            Err(err) => {
178                self.logger.log_call_error(&LogCallError {
179                    context,
180                    elapsed: started.elapsed(),
181                    error: &err,
182                });
183                Err(err)
184            }
185        }
186    }
187
188    async fn wrap_stream(
189        &self,
190        next: &dyn LanguageModel,
191        params: CallOptions,
192    ) -> Result<StreamResult> {
193        let context = LogContext {
194            provider: next.provider(),
195            model_id: next.model_id(),
196            call_kind: CallKind::Stream,
197        };
198        let started = Instant::now();
199        self.logger.log_call_start(&LogCallStart {
200            context,
201            prompt: self.log_prompt.then_some(&params.prompt),
202        });
203        match next.do_stream(params).await {
204            Ok(result) => {
205                self.logger.log_call_end(&LogCallEnd {
206                    context,
207                    elapsed: started.elapsed(),
208                    usage: None,
209                    finish_reason: None,
210                });
211                if self.log_stream_parts {
212                    let StreamResult {
213                        stream,
214                        request,
215                        response,
216                    } = result;
217                    let provider = context.provider.to_owned();
218                    let model_id = context.model_id.to_owned();
219                    let wrapped = wrap_stream_with_logger(
220                        stream,
221                        Arc::clone(&self.logger),
222                        provider,
223                        model_id,
224                        started,
225                    );
226                    return Ok(StreamResult {
227                        stream: wrapped,
228                        request,
229                        response,
230                    });
231                }
232                Ok(result)
233            }
234            Err(err) => {
235                self.logger.log_call_error(&LogCallError {
236                    context,
237                    elapsed: started.elapsed(),
238                    error: &err,
239                });
240                Err(err)
241            }
242        }
243    }
244}
245
246/// Wrap `inner` so every yielded `Result<StreamPart>` triggers
247/// [`Logger::log_stream_part`] before being forwarded.
248fn wrap_stream_with_logger(
249    inner: BoxStream<Result<StreamPart>>,
250    logger: Arc<dyn Logger>,
251    provider: String,
252    model_id: String,
253    started: Instant,
254) -> BoxStream<Result<StreamPart>> {
255    let stream = futures::stream::unfold(
256        (inner, 0_usize, logger, provider, model_id, started),
257        |(mut inner, idx, logger, provider, model_id, started)| async move {
258            use futures::StreamExt as _;
259            match inner.next().await {
260                None => None,
261                Some(item) => {
262                    let ctx = LogContext {
263                        provider: &provider,
264                        model_id: &model_id,
265                        call_kind: CallKind::Stream,
266                    };
267                    let event = LogStreamPart {
268                        context: ctx,
269                        elapsed: started.elapsed(),
270                        item: item.as_ref(),
271                        index: idx,
272                    };
273                    logger.log_stream_part(&event);
274                    Some((item, (inner, idx + 1, logger, provider, model_id, started)))
275                }
276            }
277        },
278    );
279    Box::pin(stream)
280}
281
282/// Minimal [`Logger`] that writes one line per event to stderr.
283///
284/// Useful as a quick-start and as a smoke-test hook. Not optimized for
285/// production throughput; route to `tracing` / `log` for real workloads.
286#[derive(Debug, Default)]
287pub struct StderrLogger;
288
289impl Logger for StderrLogger {
290    fn log_call_start(&self, event: &LogCallStart<'_>) {
291        eprintln!(
292            "[llmsdk:start] provider={} model={} kind={:?}",
293            event.context.provider, event.context.model_id, event.context.call_kind,
294        );
295    }
296
297    fn log_call_end(&self, event: &LogCallEnd<'_>) {
298        eprintln!(
299            "[llmsdk:end]   provider={} model={} kind={:?} elapsed_ms={} finish={:?}",
300            event.context.provider,
301            event.context.model_id,
302            event.context.call_kind,
303            event.elapsed.as_millis(),
304            event.finish_reason.map(|r| r.unified),
305        );
306    }
307
308    fn log_call_error(&self, event: &LogCallError<'_>) {
309        eprintln!(
310            "[llmsdk:error] provider={} model={} kind={:?} elapsed_ms={} error={}",
311            event.context.provider,
312            event.context.model_id,
313            event.context.call_kind,
314            event.elapsed.as_millis(),
315            event.error,
316        );
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use std::sync::Mutex;
323
324    use crate::language_model::FinishReasonKind;
325
326    use super::*;
327
328    #[derive(Debug, Default)]
329    struct RecordingLogger {
330        starts: Mutex<Vec<(String, String, CallKind, bool)>>,
331        ends: Mutex<Vec<(String, CallKind, bool, bool)>>,
332        errors: Mutex<Vec<(String, CallKind, String)>>,
333        parts: Mutex<Vec<(String, usize, bool)>>,
334    }
335
336    impl Logger for RecordingLogger {
337        fn log_call_start(&self, event: &LogCallStart<'_>) {
338            self.starts.lock().expect("starts mutex poisoned").push((
339                event.context.provider.to_owned(),
340                event.context.model_id.to_owned(),
341                event.context.call_kind,
342                event.prompt.is_some(),
343            ));
344        }
345
346        fn log_call_end(&self, event: &LogCallEnd<'_>) {
347            self.ends.lock().expect("ends mutex poisoned").push((
348                event.context.provider.to_owned(),
349                event.context.call_kind,
350                event.usage.is_some(),
351                event.finish_reason.is_some(),
352            ));
353        }
354
355        fn log_call_error(&self, event: &LogCallError<'_>) {
356            self.errors.lock().expect("errors mutex poisoned").push((
357                event.context.provider.to_owned(),
358                event.context.call_kind,
359                event.error.to_string(),
360            ));
361        }
362
363        fn log_stream_part(&self, event: &LogStreamPart<'_>) {
364            self.parts.lock().expect("parts mutex poisoned").push((
365                event.context.provider.to_owned(),
366                event.index,
367                event.item.is_ok(),
368            ));
369        }
370    }
371
372    #[derive(Debug)]
373    struct StubModel {
374        provider: String,
375        model_id: String,
376        should_fail: bool,
377    }
378
379    #[async_trait]
380    impl LanguageModel for StubModel {
381        fn provider(&self) -> &str {
382            &self.provider
383        }
384        fn model_id(&self) -> &str {
385            &self.model_id
386        }
387        async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
388            if self.should_fail {
389                return Err(ProviderError::invalid_prompt("nope"));
390            }
391            Ok(GenerateResult {
392                content: vec![],
393                finish_reason: FinishReason::new(FinishReasonKind::Stop),
394                usage: Usage::default(),
395                provider_metadata: None,
396                request: None,
397                response: None,
398                warnings: vec![],
399            })
400        }
401        async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
402            if self.should_fail {
403                return Err(ProviderError::invalid_prompt("nope"));
404            }
405            Ok(StreamResult {
406                stream: Box::pin(futures::stream::iter(Vec::new())),
407                request: None,
408                response: None,
409            })
410        }
411    }
412
413    #[tokio::test]
414    async fn success_emits_start_and_end_and_skips_prompt_by_default() {
415        let logger = Arc::new(RecordingLogger::default());
416        let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>);
417        let model = StubModel {
418            provider: "openai".to_owned(),
419            model_id: "gpt-foo".to_owned(),
420            should_fail: false,
421        };
422        mw.wrap_generate(&model, CallOptions::default())
423            .await
424            .expect("ok");
425        let starts = logger.starts.lock().expect("starts mutex poisoned");
426        assert_eq!(starts.len(), 1);
427        assert_eq!(starts[0].0, "openai");
428        assert_eq!(starts[0].1, "gpt-foo");
429        assert_eq!(starts[0].2, CallKind::Generate);
430        assert!(!starts[0].3, "prompt suppressed by default");
431        let ends = logger.ends.lock().expect("ends mutex poisoned");
432        assert_eq!(ends.len(), 1);
433        assert!(ends[0].2, "usage attached for generate");
434        assert!(ends[0].3, "finish_reason attached for generate");
435        assert!(
436            logger
437                .errors
438                .lock()
439                .expect("errors mutex poisoned")
440                .is_empty(),
441            "no error event on success"
442        );
443    }
444
445    #[tokio::test]
446    async fn with_prompt_attaches_prompt_to_start_event() {
447        let logger = Arc::new(RecordingLogger::default());
448        let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>).with_prompt(true);
449        let model = StubModel {
450            provider: "openai".to_owned(),
451            model_id: "gpt-foo".to_owned(),
452            should_fail: false,
453        };
454        mw.wrap_generate(&model, CallOptions::default())
455            .await
456            .expect("ok");
457        assert!(
458            logger.starts.lock().expect("starts mutex poisoned")[0].3,
459            "prompt attached when opt-in"
460        );
461    }
462
463    #[tokio::test]
464    async fn failure_emits_start_and_error_and_propagates() {
465        let logger = Arc::new(RecordingLogger::default());
466        let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>);
467        let model = StubModel {
468            provider: "openai".to_owned(),
469            model_id: "gpt-foo".to_owned(),
470            should_fail: true,
471        };
472        let err = mw
473            .wrap_generate(&model, CallOptions::default())
474            .await
475            .expect_err("propagates");
476        assert!(err.to_string().contains("nope"));
477        assert_eq!(
478            logger.errors.lock().expect("errors mutex poisoned").len(),
479            1
480        );
481        assert!(logger.ends.lock().expect("ends mutex poisoned").is_empty());
482    }
483
484    #[derive(Debug)]
485    struct ThreePartStream;
486
487    #[async_trait]
488    impl LanguageModel for ThreePartStream {
489        fn provider(&self) -> &'static str {
490            "openai"
491        }
492        fn model_id(&self) -> &'static str {
493            "gpt-foo"
494        }
495        async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
496            unimplemented!()
497        }
498        async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
499            let parts: Vec<Result<StreamPart>> = vec![
500                Ok(StreamPart::StreamStart { warnings: vec![] }),
501                Ok(StreamPart::TextStart {
502                    id: "b".into(),
503                    provider_metadata: None,
504                }),
505                Ok(StreamPart::Finish {
506                    usage: Usage::default(),
507                    finish_reason: FinishReason::new(FinishReasonKind::Stop),
508                    provider_metadata: None,
509                }),
510            ];
511            Ok(StreamResult {
512                stream: Box::pin(futures::stream::iter(parts)),
513                request: None,
514                response: None,
515            })
516        }
517    }
518
519    #[tokio::test]
520    async fn stream_parts_opt_in_emits_one_event_per_frame() {
521        use futures::StreamExt as _;
522
523        let logger = Arc::new(RecordingLogger::default());
524        let mw =
525            LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>).with_stream_parts(true);
526
527        let mut result = mw
528            .wrap_stream(&ThreePartStream, CallOptions::default())
529            .await
530            .expect("opens");
531        while result.stream.next().await.is_some() {}
532
533        let parts = logger.parts.lock().expect("parts mutex").clone();
534        assert_eq!(parts.len(), 3);
535        assert_eq!(parts[0].1, 0);
536        assert_eq!(parts[2].1, 2);
537        assert!(parts.iter().all(|(_, _, ok)| *ok));
538    }
539
540    #[tokio::test]
541    async fn stream_success_attaches_no_usage_or_finish_reason() {
542        let logger = Arc::new(RecordingLogger::default());
543        let mw = LoggingMiddleware::new(Arc::clone(&logger) as Arc<dyn Logger>);
544        let model = StubModel {
545            provider: "openai".to_owned(),
546            model_id: "gpt-foo".to_owned(),
547            should_fail: false,
548        };
549        mw.wrap_stream(&model, CallOptions::default())
550            .await
551            .expect("ok");
552        let ends = logger.ends.lock().expect("ends mutex poisoned");
553        assert_eq!(ends[0].1, CallKind::Stream);
554        assert!(!ends[0].2, "usage is None for stream");
555        assert!(!ends[0].3, "finish_reason is None for stream");
556    }
557}