llmsdk_provider/middleware/
context.rs1use serde::{Deserialize, Serialize};
16use serde_json::Value;
17
18use crate::language_model::CallOptions;
19use crate::shared::ProviderOptions;
20
21pub const LLMSDK_OPTIONS_KEY: &str = "llmsdk";
23
24#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
29pub struct MiddlewareContext {
30 #[serde(default, skip_serializing_if = "Option::is_none")]
32 pub request_id: Option<String>,
33 #[serde(default, skip_serializing_if = "Option::is_none")]
35 pub trace_id: Option<String>,
36 #[serde(default, skip_serializing_if = "Option::is_none")]
38 pub parent_span_id: Option<String>,
39 #[serde(default, skip_serializing_if = "Option::is_none")]
41 pub operation: Option<String>,
42}
43
44impl MiddlewareContext {
45 #[must_use]
47 pub fn with_request_id(id: impl Into<String>) -> Self {
48 Self {
49 request_id: Some(id.into()),
50 ..Default::default()
51 }
52 }
53
54 #[must_use]
59 pub fn read(options: &CallOptions) -> Option<Self> {
60 Self::read_from(options.provider_options.as_ref()?)
61 }
62
63 #[must_use]
65 pub fn read_from(options: &ProviderOptions) -> Option<Self> {
66 let bucket = options.get(LLMSDK_OPTIONS_KEY)?;
67 serde_json::from_value::<Self>(Value::Object(bucket.clone())).ok()
68 }
69
70 pub fn write(&self, options: &mut CallOptions) {
73 let bucket = options
74 .provider_options
75 .get_or_insert_with(ProviderOptions::default)
76 .entry(LLMSDK_OPTIONS_KEY.to_owned())
77 .or_default();
78 let value = serde_json::to_value(self).unwrap_or(Value::Null);
79 if let Value::Object(map) = value {
80 for (k, v) in map {
81 bucket.insert(k, v);
82 }
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn roundtrips_through_call_options() {
93 let ctx = MiddlewareContext {
94 request_id: Some("req-123".into()),
95 trace_id: Some("trace-abc".into()),
96 parent_span_id: None,
97 operation: Some("chat.completion".into()),
98 };
99 let mut opts = CallOptions::default();
100 ctx.write(&mut opts);
101
102 let read = MiddlewareContext::read(&opts).expect("present");
103 assert_eq!(read, ctx);
104 }
105
106 #[test]
107 fn write_preserves_existing_llmsdk_bucket_fields() {
108 let mut opts = CallOptions::default();
109 let mut po = ProviderOptions::default();
110 let mut bucket = serde_json::Map::new();
111 bucket.insert("custom".into(), Value::String("value".into()));
112 po.insert(LLMSDK_OPTIONS_KEY.into(), bucket);
113 opts.provider_options = Some(po);
114
115 MiddlewareContext::with_request_id("req-1").write(&mut opts);
116
117 let bucket = opts
118 .provider_options
119 .as_ref()
120 .unwrap()
121 .get(LLMSDK_OPTIONS_KEY)
122 .unwrap();
123 assert_eq!(bucket.get("custom"), Some(&Value::String("value".into())));
124 assert_eq!(
125 bucket.get("request_id"),
126 Some(&Value::String("req-1".into()))
127 );
128 }
129
130 #[test]
131 fn read_returns_none_when_no_bucket() {
132 let opts = CallOptions::default();
133 assert!(MiddlewareContext::read(&opts).is_none());
134 }
135}