Skip to main content

llmsdk_provider/middleware/
context.rs

1//! Shared context that travels with a call across the middleware chain.
2//!
3//! Middleware stages can pass structured metadata (request id, trace span id,
4//! parent operation id, ...) to later stages and to the provider impl by
5//! stashing it in `CallOptions.provider_options["llmsdk"]`. The `"llmsdk"`
6//! bucket is reserved for this purpose — provider crates ignore it on the
7//! wire.
8//!
9//! Why not extend the trait? Adding a `&mut Context` argument to every method
10//! would be a viral breaking change. Reusing the existing
11//! [`crate::shared::ProviderOptions`] surface keeps the trait stable and gives
12//! callers a place to drop their own fields too.
13// Rust guideline compliant 2026-02-21
14
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17
18use crate::language_model::CallOptions;
19use crate::shared::ProviderOptions;
20
21/// Reserved provider id for cross-middleware metadata.
22pub const LLMSDK_OPTIONS_KEY: &str = "llmsdk";
23
24/// Structured fields carried across the middleware chain.
25///
26/// Round-trips through `serde_json` so the bag stays JSON-compatible with the
27/// rest of `provider_options`.
28#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
29pub struct MiddlewareContext {
30    /// Unique id for this end-to-end request.
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub request_id: Option<String>,
33    /// Trace id for distributed tracing (W3C trace-context style).
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    pub trace_id: Option<String>,
36    /// Parent span id within the trace.
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub parent_span_id: Option<String>,
39    /// Logical operation name (e.g. `"chat.completion"`, `"embed.query"`).
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub operation: Option<String>,
42}
43
44impl MiddlewareContext {
45    /// Build with a fresh request id.
46    #[must_use]
47    pub fn with_request_id(id: impl Into<String>) -> Self {
48        Self {
49            request_id: Some(id.into()),
50            ..Default::default()
51        }
52    }
53
54    /// Read the context (if any) from a [`CallOptions`].
55    ///
56    /// Returns `None` when no `llmsdk` bucket exists or when its contents
57    /// don't deserialize as [`MiddlewareContext`].
58    #[must_use]
59    pub fn read(options: &CallOptions) -> Option<Self> {
60        Self::read_from(options.provider_options.as_ref()?)
61    }
62
63    /// Read from a raw `ProviderOptions` map.
64    #[must_use]
65    pub fn read_from(options: &ProviderOptions) -> Option<Self> {
66        let bucket = options.get(LLMSDK_OPTIONS_KEY)?;
67        serde_json::from_value::<Self>(Value::Object(bucket.clone())).ok()
68    }
69
70    /// Write `self` into a [`CallOptions`], merging onto any existing
71    /// `llmsdk` bucket (caller fields win).
72    pub fn write(&self, options: &mut CallOptions) {
73        let bucket = options
74            .provider_options
75            .get_or_insert_with(ProviderOptions::default)
76            .entry(LLMSDK_OPTIONS_KEY.to_owned())
77            .or_default();
78        let value = serde_json::to_value(self).unwrap_or(Value::Null);
79        if let Value::Object(map) = value {
80            for (k, v) in map {
81                bucket.insert(k, v);
82            }
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn roundtrips_through_call_options() {
93        let ctx = MiddlewareContext {
94            request_id: Some("req-123".into()),
95            trace_id: Some("trace-abc".into()),
96            parent_span_id: None,
97            operation: Some("chat.completion".into()),
98        };
99        let mut opts = CallOptions::default();
100        ctx.write(&mut opts);
101
102        let read = MiddlewareContext::read(&opts).expect("present");
103        assert_eq!(read, ctx);
104    }
105
106    #[test]
107    fn write_preserves_existing_llmsdk_bucket_fields() {
108        let mut opts = CallOptions::default();
109        let mut po = ProviderOptions::default();
110        let mut bucket = serde_json::Map::new();
111        bucket.insert("custom".into(), Value::String("value".into()));
112        po.insert(LLMSDK_OPTIONS_KEY.into(), bucket);
113        opts.provider_options = Some(po);
114
115        MiddlewareContext::with_request_id("req-1").write(&mut opts);
116
117        let bucket = opts
118            .provider_options
119            .as_ref()
120            .unwrap()
121            .get(LLMSDK_OPTIONS_KEY)
122            .unwrap();
123        assert_eq!(bucket.get("custom"), Some(&Value::String("value".into())));
124        assert_eq!(
125            bucket.get("request_id"),
126            Some(&Value::String("req-1".into()))
127        );
128    }
129
130    #[test]
131    fn read_returns_none_when_no_bucket() {
132        let opts = CallOptions::default();
133        assert!(MiddlewareContext::read(&opts).is_none());
134    }
135}