llmsdk_provider/middleware/builtin/
simulate_streaming.rs1use async_trait::async_trait;
13use futures::stream;
14
15use crate::error::Result;
16use crate::language_model::{
17 CallOptions, Content, LanguageModel, ResponseMetadata, StreamPart, StreamResult,
18};
19use crate::middleware::language_model::LanguageModelMiddleware;
20
21#[derive(Debug, Default, Clone, Copy)]
23pub struct SimulateStreamingMiddleware;
24
25impl SimulateStreamingMiddleware {
26 #[must_use]
28 pub fn new() -> Self {
29 Self
30 }
31}
32
33#[async_trait]
34impl LanguageModelMiddleware for SimulateStreamingMiddleware {
35 async fn wrap_stream(
36 &self,
37 next: &dyn LanguageModel,
38 params: CallOptions,
39 ) -> Result<StreamResult> {
40 let result = next.do_generate(params).await?;
41
42 let mut parts: Vec<Result<StreamPart>> = Vec::new();
43 parts.push(Ok(StreamPart::StreamStart {
44 warnings: result.warnings.clone(),
45 }));
46 let resp_metadata = result
52 .response
53 .as_ref()
54 .map(|resp| ResponseMetadata {
55 id: resp.metadata.id.clone(),
56 timestamp: resp.metadata.timestamp.clone(),
57 model_id: resp.metadata.model_id.clone(),
58 headers: resp.metadata.headers.clone(),
59 })
60 .unwrap_or_default();
61 parts.push(Ok(StreamPart::ResponseMetadata(resp_metadata)));
62
63 for (idx, content) in result.content.iter().enumerate() {
64 let block_id = format!("sim-{idx}");
65 match content {
66 Content::Text(t) => {
67 if t.text.is_empty() {
72 continue;
73 }
74 parts.push(Ok(StreamPart::TextStart {
81 id: block_id.clone(),
82 provider_metadata: None,
83 }));
84 parts.push(Ok(StreamPart::TextDelta {
85 id: block_id.clone(),
86 delta: t.text.clone(),
87 provider_metadata: None,
88 }));
89 parts.push(Ok(StreamPart::TextEnd {
90 id: block_id,
91 provider_metadata: None,
92 }));
93 }
94 Content::Reasoning(r) => {
95 parts.push(Ok(StreamPart::ReasoningStart {
100 id: block_id.clone(),
101 provider_metadata: r.provider_options.clone().map(into_metadata),
102 }));
103 parts.push(Ok(StreamPart::ReasoningDelta {
104 id: block_id.clone(),
105 delta: r.text.clone(),
106 provider_metadata: None,
107 }));
108 parts.push(Ok(StreamPart::ReasoningEnd {
109 id: block_id,
110 provider_metadata: None,
111 }));
112 }
113 Content::ToolCall(tc) => {
114 parts.push(Ok(StreamPart::ToolCall(tc.clone())));
115 }
116 Content::ToolResult(tr) => {
117 parts.push(Ok(StreamPart::ToolResult(tr.clone())));
118 }
119 Content::ToolApprovalRequest(req) => {
120 parts.push(Ok(StreamPart::ToolApprovalRequest(req.clone())));
121 }
122 Content::Source(s) => {
123 parts.push(Ok(StreamPart::Source(s.clone())));
124 }
125 Content::File(_) | Content::ReasoningFile { .. } => {
126 parts.push(Ok(StreamPart::Custom {
128 kind: "llmsdk.simulate.file".into(),
129 provider_metadata: None,
130 }));
131 }
132 Content::Custom {
133 kind,
134 provider_options,
135 } => {
136 parts.push(Ok(StreamPart::Custom {
137 kind: kind.clone(),
138 provider_metadata: provider_options.clone().map(into_metadata),
139 }));
140 }
141 }
142 }
143
144 parts.push(Ok(StreamPart::Finish {
145 usage: result.usage,
146 finish_reason: result.finish_reason,
147 provider_metadata: result.provider_metadata,
148 }));
149
150 Ok(StreamResult {
151 stream: Box::pin(stream::iter(parts)),
152 request: result.request,
153 response: None,
154 })
155 }
156}
157
158fn into_metadata(opts: crate::shared::ProviderOptions) -> crate::shared::ProviderMetadata {
161 opts
162}
163
164#[cfg(test)]
165mod tests {
166 use std::sync::Arc;
167
168 use futures::StreamExt;
169
170 use super::*;
171 use crate::language_model::{FinishReason, FinishReasonKind, TextPart, Usage};
172 use crate::middleware::wrap_language_model;
173
174 #[derive(Debug)]
175 struct Gen {
176 text: String,
177 }
178
179 #[async_trait]
180 impl LanguageModel for Gen {
181 fn provider(&self) -> &'static str {
182 "g"
183 }
184 fn model_id(&self) -> &'static str {
185 "g"
186 }
187 async fn do_generate(
188 &self,
189 _opts: CallOptions,
190 ) -> Result<crate::language_model::GenerateResult> {
191 Ok(crate::language_model::GenerateResult {
192 content: vec![Content::Text(TextPart {
193 text: self.text.clone(),
194 provider_options: None,
195 })],
196 finish_reason: FinishReason::new(FinishReasonKind::Stop),
197 usage: Usage::default(),
198 provider_metadata: None,
199 request: None,
200 response: None,
201 warnings: vec![],
202 })
203 }
204 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
205 unimplemented!("middleware should bypass do_stream")
206 }
207 }
208
209 #[tokio::test]
210 async fn emits_block_level_stream_from_generate() {
211 let inner: Arc<dyn LanguageModel> = Arc::new(Gen {
212 text: "hello".into(),
213 });
214 let wrapped = wrap_language_model(
215 inner,
216 [Arc::new(SimulateStreamingMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
217 );
218 let mut s = wrapped.do_stream(CallOptions::default()).await.unwrap();
219 let mut tags: Vec<&'static str> = Vec::new();
220 while let Some(item) = s.stream.next().await {
221 tags.push(match item.unwrap() {
222 StreamPart::StreamStart { .. } => "start",
223 StreamPart::ResponseMetadata(_) => "response-metadata",
224 StreamPart::TextStart { .. } => "text-start",
225 StreamPart::TextDelta { .. } => "text-delta",
226 StreamPart::TextEnd { .. } => "text-end",
227 StreamPart::Finish { .. } => "finish",
228 _ => "other",
229 });
230 }
231 assert_eq!(
235 tags,
236 vec![
237 "start",
238 "response-metadata",
239 "text-start",
240 "text-delta",
241 "text-end",
242 "finish"
243 ]
244 );
245 }
246
247 #[tokio::test]
248 async fn empty_text_block_is_skipped() {
249 let inner: Arc<dyn LanguageModel> = Arc::new(Gen {
252 text: String::new(),
253 });
254 let wrapped = wrap_language_model(
255 inner,
256 [Arc::new(SimulateStreamingMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
257 );
258 let mut s = wrapped.do_stream(CallOptions::default()).await.unwrap();
259 let mut tags: Vec<&'static str> = Vec::new();
260 while let Some(item) = s.stream.next().await {
261 tags.push(match item.unwrap() {
262 StreamPart::StreamStart { .. } => "start",
263 StreamPart::ResponseMetadata(_) => "response-metadata",
264 StreamPart::TextStart { .. } => "text-start",
265 StreamPart::TextDelta { .. } => "text-delta",
266 StreamPart::TextEnd { .. } => "text-end",
267 StreamPart::Finish { .. } => "finish",
268 _ => "other",
269 });
270 }
271 assert_eq!(tags, vec!["start", "response-metadata", "finish"]);
274 }
275
276 #[tokio::test]
277 async fn reasoning_provider_metadata_rides_on_start_not_delta() {
278 use crate::language_model::ReasoningPart;
283 use crate::shared::ProviderOptions;
284
285 #[derive(Debug)]
286 struct ReasoningGen;
287
288 #[async_trait]
289 impl LanguageModel for ReasoningGen {
290 fn provider(&self) -> &'static str {
291 "r"
292 }
293 fn model_id(&self) -> &'static str {
294 "r"
295 }
296 async fn do_generate(
297 &self,
298 _opts: CallOptions,
299 ) -> Result<crate::language_model::GenerateResult> {
300 let mut opts = ProviderOptions::new();
301 opts.insert(
302 "anthropic".into(),
303 serde_json::json!({ "signature": "sig" })
304 .as_object()
305 .cloned()
306 .unwrap(),
307 );
308 Ok(crate::language_model::GenerateResult {
309 content: vec![Content::Reasoning(ReasoningPart {
310 text: "thinking…".into(),
311 provider_options: Some(opts),
312 })],
313 finish_reason: FinishReason::new(FinishReasonKind::Stop),
314 usage: Usage::default(),
315 provider_metadata: None,
316 request: None,
317 response: None,
318 warnings: vec![],
319 })
320 }
321 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResult> {
322 unimplemented!()
323 }
324 }
325
326 let inner: Arc<dyn LanguageModel> = Arc::new(ReasoningGen);
327 let wrapped = wrap_language_model(
328 inner,
329 [Arc::new(SimulateStreamingMiddleware::new()) as Arc<dyn LanguageModelMiddleware>],
330 );
331 let mut s = wrapped.do_stream(CallOptions::default()).await.unwrap();
332 let mut start_meta: Option<crate::shared::ProviderMetadata> = None;
333 let mut delta_meta: Option<crate::shared::ProviderMetadata> = None;
334 while let Some(item) = s.stream.next().await {
335 match item.unwrap() {
336 StreamPart::ReasoningStart {
337 provider_metadata, ..
338 } => start_meta = provider_metadata,
339 StreamPart::ReasoningDelta {
340 provider_metadata, ..
341 } => delta_meta = provider_metadata,
342 _ => {}
343 }
344 }
345 assert!(
346 start_meta.is_some(),
347 "reasoning-start must carry provider_metadata (upstream parity)"
348 );
349 assert!(
350 delta_meta.is_none(),
351 "reasoning-delta must NOT carry provider_metadata (upstream parity)"
352 );
353 }
354}