Skip to main content

braintrust_sdk_rust/
stream.rs

1//! Stream aggregation for LLM streaming responses.
2//!
3//! This module provides `BraintrustStream`, a wrapper that aggregates streaming
4//! chunks into a final response value, following the JS/Python SDK pattern.
5//!
6//! It also provides `wrap_stream_with_span` for wrapping streams with span logging.
7
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use std::time::Instant;
14
15use anyhow::Result;
16use futures::Stream;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use tokio::sync::Mutex;
20
21use crate::span::{SpanHandle, SpanLog, SpanSubmitter};
22use crate::types::{usage_metrics_to_map, UsageMetrics};
23
24/// A tool call in a chat message.
25#[derive(Clone, Debug, Default, Serialize)]
26pub struct ToolCall {
27    pub id: String,
28    #[serde(rename = "type")]
29    pub call_type: String, // Always "function"
30    pub function: FunctionCall,
31}
32
33/// Function details in a tool call.
34#[derive(Clone, Debug, Default, Serialize)]
35pub struct FunctionCall {
36    pub name: String,
37    pub arguments: String,
38}
39
40/// A chat message in the output.
41#[derive(Clone, Debug, Default, Serialize)]
42pub struct ChatMessage {
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub role: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub content: Option<String>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub tool_calls: Option<Vec<ToolCall>>,
49}
50
51/// A choice in the output array (matches OpenAI response format).
52#[derive(Clone, Debug, Serialize)]
53pub struct OutputChoice {
54    pub index: usize,
55    pub message: ChatMessage,
56    pub logprobs: Option<()>, // Always None
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub finish_reason: Option<String>,
59}
60
61/// Stream metadata with typed known fields and passthrough for extras.
62#[derive(Clone, Debug, Default, Serialize)]
63pub struct StreamMetadata {
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub model: Option<String>,
66    /// Catch-all for additional fields (passthrough behavior).
67    #[serde(flatten, skip_serializing_if = "HashMap::is_empty")]
68    pub extra: HashMap<String, Value>,
69}
70
71impl StreamMetadata {
72    /// Returns true if the metadata has no content.
73    pub fn is_empty(&self) -> bool {
74        self.model.is_none() && self.extra.is_empty()
75    }
76
77    /// Convert to a serde_json Map for logging.
78    pub fn to_map(&self) -> Option<serde_json::Map<String, Value>> {
79        if self.is_empty() {
80            return None;
81        }
82        // Serialize to Value and extract the Map
83        match serde_json::to_value(self) {
84            Ok(Value::Object(map)) => Some(map),
85            _ => None,
86        }
87    }
88}
89
90/// Aggregated result from a streaming response.
91#[derive(Clone)]
92pub struct FinalizedStream {
93    /// The output choices (matches OpenAI response format)
94    pub output: Vec<OutputChoice>,
95    /// Usage metrics extracted from the stream
96    pub usage: Option<UsageMetrics>,
97    /// Metadata (model and any extras)
98    pub metadata: StreamMetadata,
99}
100
101/// A stream aggregator that collects streaming chunks and produces a final value.
102///
103/// This follows the JS/Python SDK pattern where streaming responses are
104/// collected and aggregated lazily when `final_value()` is called.
105///
106/// Raw chunks are stored as-is during streaming (non-blocking), and transformation
107/// to universal format happens during aggregation (which runs async in a spawned task).
108#[derive(Clone, Default)]
109pub struct BraintrustStream {
110    raw_chunks: Vec<Value>,
111    finalized: Option<FinalizedStream>,
112}
113
114/// OpenAI-style streaming chunk structure for deserialization.
115#[derive(Debug, Clone, Deserialize, Serialize)]
116struct StreamChunk {
117    #[serde(default)]
118    model: Option<String>,
119    #[serde(default)]
120    choices: Vec<StreamChoice>,
121    #[serde(default)]
122    usage: Option<StreamUsage>,
123}
124
125/// Delta from a streaming chunk (typed for role/content).
126#[derive(Debug, Clone, Default, Deserialize, Serialize)]
127struct StreamDelta {
128    #[serde(default)]
129    role: Option<String>,
130    #[serde(default)]
131    content: Option<String>,
132}
133
134#[derive(Debug, Clone, Deserialize, Serialize)]
135struct StreamChoice {
136    #[serde(default)]
137    delta: Option<StreamDelta>,
138    #[serde(default)]
139    finish_reason: Option<String>,
140}
141
142#[derive(Debug, Clone, Deserialize, Serialize)]
143struct StreamUsage {
144    #[serde(default)]
145    prompt_tokens: Option<i64>,
146    #[serde(default, alias = "input_tokens")]
147    completion_tokens: Option<i64>,
148    #[serde(default, alias = "cache_read_input_tokens")]
149    prompt_cached_tokens: Option<i64>,
150    #[serde(default, alias = "cache_creation_input_tokens")]
151    prompt_cache_creation_tokens: Option<i64>,
152}
153
154impl BraintrustStream {
155    /// Create a new empty stream.
156    pub fn new() -> Self {
157        Self {
158            raw_chunks: Vec::new(),
159            finalized: None,
160        }
161    }
162
163    /// Add a raw JSON value to the stream.
164    ///
165    /// Stores the raw chunk as-is for later aggregation. This is non-blocking
166    /// to avoid adding latency to the streaming hot path. Transformation to
167    /// universal format happens lazily in `aggregate()`.
168    pub fn push(&mut self, value: Value) {
169        // Skip keep-alive markers
170        if value.get("_keep_alive").is_some() {
171            return;
172        }
173        self.raw_chunks.push(value);
174    }
175
176    /// Get the final aggregated value.
177    ///
178    /// This aggregates all chunks into a final response. The result is cached,
179    /// so subsequent calls return the same value.
180    pub fn final_value(&mut self) -> Result<&FinalizedStream> {
181        if self.finalized.is_none() {
182            self.finalized = Some(self.aggregate()?);
183        }
184        Ok(self.finalized.as_ref().unwrap())
185    }
186
187    /// Check if the stream has any chunks.
188    pub fn is_empty(&self) -> bool {
189        self.raw_chunks.is_empty()
190    }
191
192    fn aggregate(&self) -> Result<FinalizedStream> {
193        let mut usage: Option<UsageMetrics> = None;
194        let mut model: Option<String> = None;
195        let mut finish_reason: Option<String> = None;
196
197        // Aggregate content from all chunks
198        let mut aggregated_content = String::new();
199        let mut role: Option<String> = None;
200
201        for raw in &self.raw_chunks {
202            // Try to parse as OpenAI-style streaming chunk
203            let chunk: StreamChunk = match serde_json::from_value(raw.clone()) {
204                Ok(c) => c,
205                Err(_) => continue, // Skip unparseable chunks
206            };
207
208            // Extract model (take first non-None)
209            if model.is_none() {
210                model = chunk.model;
211            }
212
213            // Extract usage (take last non-None)
214            if let Some(ref u) = chunk.usage {
215                usage = Some(UsageMetrics {
216                    prompt_tokens: u.prompt_tokens.and_then(|v| u32::try_from(v).ok()),
217                    completion_tokens: u.completion_tokens.and_then(|v| u32::try_from(v).ok()),
218                    total_tokens: match (u.prompt_tokens, u.completion_tokens) {
219                        (Some(p), Some(c)) => u32::try_from(p + c).ok(),
220                        _ => None,
221                    },
222                    reasoning_tokens: None,
223                    prompt_cached_tokens: u
224                        .prompt_cached_tokens
225                        .and_then(|v| u32::try_from(v).ok()),
226                    prompt_cache_creation_tokens: u
227                        .prompt_cache_creation_tokens
228                        .and_then(|v| u32::try_from(v).ok()),
229                    completion_reasoning_tokens: None,
230                    prompt_tokens_details: None,
231                    completion_tokens_details: None,
232                });
233            }
234
235            // Process choices
236            for choice in &chunk.choices {
237                // Extract finish_reason (take last non-None)
238                if let Some(ref reason) = choice.finish_reason {
239                    if !reason.is_empty() {
240                        finish_reason = Some(reason.clone());
241                    }
242                }
243
244                // Extract content from delta
245                if let Some(ref delta) = choice.delta {
246                    // Extract role (take first)
247                    if role.is_none() {
248                        role = delta.role.clone();
249                    }
250
251                    // Append content
252                    if let Some(ref content) = delta.content {
253                        aggregated_content.push_str(content);
254                    }
255                }
256            }
257        }
258
259        // Build metadata (finish_reason moved to OutputChoice)
260        let metadata = StreamMetadata {
261            model,
262            extra: HashMap::new(),
263        };
264
265        // Build typed output (matches OpenAI response format)
266        let message = ChatMessage {
267            role: Some(role.unwrap_or_else(|| "assistant".to_string())),
268            content: Some(aggregated_content),
269            tool_calls: None, // TODO: implement tool call aggregation
270        };
271
272        let choice = OutputChoice {
273            index: 0,
274            message,
275            logprobs: None,
276            finish_reason,
277        };
278
279        Ok(FinalizedStream {
280            output: vec![choice],
281            usage,
282            metadata,
283        })
284    }
285}
286
287/// Wrap a stream with span logging.
288///
289/// This creates a new stream that yields the same chunks as the original,
290/// but also:
291/// - Records time-to-first-token on first meaningful content
292/// - Accumulates chunks for aggregation
293/// - On stream completion, logs the aggregated output/usage/metadata via `span.log()`
294///
295/// # Type Parameters
296/// - `S`: The stream type yielding `Result<Value, E>`
297/// - `E`: The error type (allows use with any error type)
298/// - `Sub`: The span submitter type
299#[allow(private_bounds)]
300pub fn wrap_stream_with_span<S, E, Sub>(
301    stream: S,
302    span: SpanHandle<Sub>,
303) -> Pin<Box<dyn Stream<Item = std::result::Result<Value, E>> + Send>>
304where
305    S: Stream<Item = std::result::Result<Value, E>> + Send + Unpin + 'static,
306    E: Send + 'static,
307    Sub: SpanSubmitter + 'static,
308{
309    use futures::StreamExt;
310
311    let start_time = Instant::now();
312    let ttft_recorded = Arc::new(AtomicBool::new(false));
313    let aggregator = Arc::new(Mutex::new(BraintrustStream::new()));
314    let span_for_complete = span.clone();
315    let aggregator_for_complete = Arc::clone(&aggregator);
316
317    let logged_stream = stream.then(move |result| {
318        let span = span.clone();
319        let ttft_recorded = ttft_recorded.clone();
320        let aggregator = aggregator.clone();
321        async move {
322            if let Ok(ref value) = result {
323                // Skip keep-alive markers
324                if value.get("_keep_alive").is_none() {
325                    // Record TTFT on first meaningful chunk
326                    if !ttft_recorded.swap(true, Ordering::SeqCst) && value_has_content(value) {
327                        let ttft_secs = start_time.elapsed().as_secs_f64();
328                        span.log(SpanLog {
329                            metrics: Some(
330                                [("time_to_first_token".to_string(), ttft_secs)]
331                                    .into_iter()
332                                    .collect(),
333                            ),
334                            ..Default::default()
335                        })
336                        .await;
337                    }
338                    // Accumulate chunk for final aggregation
339                    aggregator.lock().await.push(value.clone());
340                }
341            }
342            result
343        }
344    });
345
346    // Wrap in a stream that finalizes on completion
347    Box::pin(SpanCompleteWrapper {
348        inner: Box::pin(logged_stream),
349        span: Some(span_for_complete),
350        aggregator: Some(aggregator_for_complete),
351        finalize_state: FinalizeState::Idle,
352    })
353}
354
355/// Check if a JSON value contains meaningful output (for TTFT detection).
356fn value_has_content(value: &Value) -> bool {
357    // Check for choices array with content
358    if let Some(choices) = value.get("choices").and_then(|c| c.as_array()) {
359        if !choices.is_empty() {
360            return true;
361        }
362    }
363    // Check for usage with tokens
364    if let Some(usage) = value.get("usage").and_then(|u| u.as_object()) {
365        let has_tokens = usage
366            .get("completion_tokens")
367            .and_then(|v| v.as_i64())
368            .map(|t| t > 0)
369            .unwrap_or(false)
370            || usage
371                .get("prompt_tokens")
372                .and_then(|v| v.as_i64())
373                .map(|t| t > 0)
374                .unwrap_or(false);
375        if has_tokens {
376            return true;
377        }
378    }
379    false
380}
381
382/// State for stream finalization.
383enum FinalizeState {
384    /// Not yet finalizing
385    Idle,
386    /// Finalizing in progress
387    Finalizing(Pin<Box<dyn std::future::Future<Output = ()> + Send>>),
388    /// Finalization complete
389    Done,
390}
391
392/// A wrapper stream that logs aggregated output when the stream is exhausted.
393struct SpanCompleteWrapper<S, Sub: SpanSubmitter> {
394    inner: S,
395    span: Option<SpanHandle<Sub>>,
396    aggregator: Option<Arc<Mutex<BraintrustStream>>>,
397    finalize_state: FinalizeState,
398}
399
400/// Finalize the stream by logging and flushing the span.
401async fn finalize_span<Sub: SpanSubmitter>(
402    span: SpanHandle<Sub>,
403    aggregator: Arc<Mutex<BraintrustStream>>,
404) {
405    let mut agg = aggregator.lock().await;
406    if !agg.is_empty() {
407        match agg.final_value() {
408            Ok(finalized) => {
409                // Build metrics from usage
410                let metrics = finalized
411                    .usage
412                    .as_ref()
413                    .map(|u| usage_metrics_to_map(u.clone()));
414
415                // Convert StreamMetadata to Option<Map>
416                let metadata = finalized.metadata.to_map();
417
418                // Serialize typed output to Value for SpanLog
419                let output = serde_json::to_value(&finalized.output).ok();
420
421                span.log(SpanLog {
422                    output,
423                    metadata,
424                    metrics,
425                    ..Default::default()
426                })
427                .await;
428            }
429            Err(e) => {
430                tracing::warn!("Failed to finalize stream: {}", e);
431            }
432        }
433    }
434    // Flush span with aggregated output
435    if let Err(e) = span.flush().await {
436        tracing::warn!("Failed to flush span: {}", e);
437    }
438}
439
440impl<S, E, Sub> Stream for SpanCompleteWrapper<S, Sub>
441where
442    S: Stream<Item = std::result::Result<Value, E>> + Unpin,
443    E: Send + 'static,
444    Sub: SpanSubmitter + 'static,
445{
446    type Item = std::result::Result<Value, E>;
447
448    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
449        // SAFETY: We never move the inner stream or finalize future after pinning
450        let this = unsafe { self.get_unchecked_mut() };
451
452        // First, check if we're in the middle of finalizing
453        match &mut this.finalize_state {
454            FinalizeState::Idle => {
455                // Not finalizing yet, poll the inner stream
456            }
457            FinalizeState::Finalizing(fut) => {
458                // Poll the finalization future
459                match fut.as_mut().poll(cx) {
460                    Poll::Ready(()) => {
461                        // Finalization complete
462                        this.finalize_state = FinalizeState::Done;
463                        return Poll::Ready(None);
464                    }
465                    Poll::Pending => {
466                        // Still finalizing
467                        return Poll::Pending;
468                    }
469                }
470            }
471            FinalizeState::Done => {
472                return Poll::Ready(None);
473            }
474        }
475
476        // Poll the inner stream
477        let result = Pin::new(&mut this.inner).poll_next(cx);
478
479        // If stream is done, start finalization
480        if matches!(result, Poll::Ready(None)) {
481            if let (Some(span), Some(aggregator)) = (this.span.take(), this.aggregator.take()) {
482                // Create the finalization future
483                let fut = Box::pin(finalize_span(span, aggregator));
484                this.finalize_state = FinalizeState::Finalizing(fut);
485
486                // Poll it immediately by recursing
487                // SAFETY: self is still valid and pinned
488                return unsafe { Pin::new_unchecked(this) }.poll_next(cx);
489            }
490        }
491
492        result
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use serde_json::json;
500
501    #[test]
502    fn aggregates_content_from_streaming_values() {
503        let chunks = vec![
504            json!({
505                "id": "chunk1",
506                "model": "gpt-4",
507                "choices": [{
508                    "index": 0,
509                    "delta": { "role": "assistant", "content": "Hello" }
510                }],
511                "created": 1
512            }),
513            json!({
514                "id": "chunk2",
515                "model": "gpt-4",
516                "choices": [{
517                    "index": 0,
518                    "delta": { "content": " world" }
519                }],
520                "created": 1
521            }),
522            json!({
523                "id": "chunk3",
524                "model": "gpt-4",
525                "choices": [{
526                    "index": 0,
527                    "delta": { "content": "!" },
528                    "finish_reason": "stop"
529                }],
530                "created": 1
531            }),
532        ];
533
534        let mut stream = BraintrustStream::new();
535        for chunk in chunks {
536            stream.push(chunk);
537        }
538
539        let finalized = stream.final_value().expect("should finalize");
540
541        // Check output is array of choices
542        assert_eq!(finalized.output.len(), 1);
543
544        let choice = &finalized.output[0];
545        assert_eq!(choice.index, 0);
546        assert_eq!(choice.message.role.as_deref(), Some("assistant"));
547        assert_eq!(choice.message.content.as_deref(), Some("Hello world!"));
548        assert_eq!(choice.finish_reason.as_deref(), Some("stop"));
549
550        // Check metadata
551        assert_eq!(finalized.metadata.model.as_deref(), Some("gpt-4"));
552    }
553
554    #[test]
555    fn aggregates_usage_from_final_chunk() {
556        let chunks = vec![
557            json!({
558                "id": "chunk1",
559                "model": "gpt-4",
560                "choices": [{
561                    "index": 0,
562                    "delta": { "role": "assistant", "content": "Hi" },
563                    "finish_reason": "stop"
564                }],
565                "created": 1
566            }),
567            json!({
568                "id": "chunk2",
569                "model": "gpt-4",
570                "choices": [],
571                "created": 1,
572                "usage": {
573                    "prompt_tokens": 10,
574                    "completion_tokens": 5
575                }
576            }),
577        ];
578
579        let mut stream = BraintrustStream::new();
580        for chunk in chunks {
581            stream.push(chunk);
582        }
583
584        let finalized = stream.final_value().expect("should finalize");
585
586        let usage = finalized.usage.as_ref().expect("should have usage");
587        assert_eq!(usage.prompt_tokens, Some(10));
588        assert_eq!(usage.completion_tokens, Some(5));
589        assert_eq!(usage.total_tokens, Some(15));
590    }
591
592    #[test]
593    fn skips_keep_alive_markers() {
594        let mut stream = BraintrustStream::new();
595
596        // Push a keep-alive marker
597        stream.push(json!({"_keep_alive": true}));
598
599        assert!(stream.is_empty());
600    }
601
602    #[test]
603    fn caches_finalized_result() {
604        let chunk = json!({
605            "id": "chunk1",
606            "model": "gpt-4",
607            "choices": [{
608                "index": 0,
609                "delta": { "role": "assistant", "content": "test" }
610            }],
611            "created": 1
612        });
613
614        let mut stream = BraintrustStream::new();
615        stream.push(chunk);
616
617        // First call computes - extract content and drop borrow
618        let first_content = {
619            let first = stream.final_value().expect("should finalize");
620            first.output.first().and_then(|c| c.message.content.clone())
621        };
622
623        // Second call returns cached
624        let second_content = {
625            let second = stream.final_value().expect("should finalize");
626            second
627                .output
628                .first()
629                .and_then(|c| c.message.content.clone())
630        };
631
632        assert_eq!(first_content, second_content);
633        assert_eq!(first_content, Some("test".to_string()));
634    }
635}