1use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12use std::task::{Context, Poll};
13use std::time::Instant;
14
15use anyhow::Result;
16use futures::Stream;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use tokio::sync::Mutex;
20
21use crate::span::{SpanHandle, SpanLog, SpanSubmitter};
22use crate::types::{usage_metrics_to_map, UsageMetrics};
23
24#[derive(Clone, Debug, Default, Serialize)]
26pub struct ToolCall {
27 pub id: String,
28 #[serde(rename = "type")]
29 pub call_type: String, pub function: FunctionCall,
31}
32
33#[derive(Clone, Debug, Default, Serialize)]
35pub struct FunctionCall {
36 pub name: String,
37 pub arguments: String,
38}
39
40#[derive(Clone, Debug, Default, Serialize)]
42pub struct ChatMessage {
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub role: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub content: Option<String>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub tool_calls: Option<Vec<ToolCall>>,
49}
50
51#[derive(Clone, Debug, Serialize)]
53pub struct OutputChoice {
54 pub index: usize,
55 pub message: ChatMessage,
56 pub logprobs: Option<()>, #[serde(skip_serializing_if = "Option::is_none")]
58 pub finish_reason: Option<String>,
59}
60
61#[derive(Clone, Debug, Default, Serialize)]
63pub struct StreamMetadata {
64 #[serde(skip_serializing_if = "Option::is_none")]
65 pub model: Option<String>,
66 #[serde(flatten, skip_serializing_if = "HashMap::is_empty")]
68 pub extra: HashMap<String, Value>,
69}
70
71impl StreamMetadata {
72 pub fn is_empty(&self) -> bool {
74 self.model.is_none() && self.extra.is_empty()
75 }
76
77 pub fn to_map(&self) -> Option<serde_json::Map<String, Value>> {
79 if self.is_empty() {
80 return None;
81 }
82 match serde_json::to_value(self) {
84 Ok(Value::Object(map)) => Some(map),
85 _ => None,
86 }
87 }
88}
89
90#[derive(Clone)]
92pub struct FinalizedStream {
93 pub output: Vec<OutputChoice>,
95 pub usage: Option<UsageMetrics>,
97 pub metadata: StreamMetadata,
99}
100
101#[derive(Clone, Default)]
109pub struct BraintrustStream {
110 raw_chunks: Vec<Value>,
111 finalized: Option<FinalizedStream>,
112}
113
114#[derive(Debug, Clone, Deserialize, Serialize)]
116struct StreamChunk {
117 #[serde(default)]
118 model: Option<String>,
119 #[serde(default)]
120 choices: Vec<StreamChoice>,
121 #[serde(default)]
122 usage: Option<StreamUsage>,
123}
124
125#[derive(Debug, Clone, Default, Deserialize, Serialize)]
127struct StreamDelta {
128 #[serde(default)]
129 role: Option<String>,
130 #[serde(default)]
131 content: Option<String>,
132}
133
134#[derive(Debug, Clone, Deserialize, Serialize)]
135struct StreamChoice {
136 #[serde(default)]
137 delta: Option<StreamDelta>,
138 #[serde(default)]
139 finish_reason: Option<String>,
140}
141
142#[derive(Debug, Clone, Deserialize, Serialize)]
143struct StreamUsage {
144 #[serde(default)]
145 prompt_tokens: Option<i64>,
146 #[serde(default, alias = "input_tokens")]
147 completion_tokens: Option<i64>,
148 #[serde(default, alias = "cache_read_input_tokens")]
149 prompt_cached_tokens: Option<i64>,
150 #[serde(default, alias = "cache_creation_input_tokens")]
151 prompt_cache_creation_tokens: Option<i64>,
152}
153
154impl BraintrustStream {
155 pub fn new() -> Self {
157 Self {
158 raw_chunks: Vec::new(),
159 finalized: None,
160 }
161 }
162
163 pub fn push(&mut self, value: Value) {
169 if value.get("_keep_alive").is_some() {
171 return;
172 }
173 self.raw_chunks.push(value);
174 }
175
176 pub fn final_value(&mut self) -> Result<&FinalizedStream> {
181 if self.finalized.is_none() {
182 self.finalized = Some(self.aggregate()?);
183 }
184 Ok(self.finalized.as_ref().unwrap())
185 }
186
187 pub fn is_empty(&self) -> bool {
189 self.raw_chunks.is_empty()
190 }
191
192 fn aggregate(&self) -> Result<FinalizedStream> {
193 let mut usage: Option<UsageMetrics> = None;
194 let mut model: Option<String> = None;
195 let mut finish_reason: Option<String> = None;
196
197 let mut aggregated_content = String::new();
199 let mut role: Option<String> = None;
200
201 for raw in &self.raw_chunks {
202 let chunk: StreamChunk = match serde_json::from_value(raw.clone()) {
204 Ok(c) => c,
205 Err(_) => continue, };
207
208 if model.is_none() {
210 model = chunk.model;
211 }
212
213 if let Some(ref u) = chunk.usage {
215 usage = Some(UsageMetrics {
216 prompt_tokens: u.prompt_tokens.and_then(|v| u32::try_from(v).ok()),
217 completion_tokens: u.completion_tokens.and_then(|v| u32::try_from(v).ok()),
218 total_tokens: match (u.prompt_tokens, u.completion_tokens) {
219 (Some(p), Some(c)) => u32::try_from(p + c).ok(),
220 _ => None,
221 },
222 reasoning_tokens: None,
223 prompt_cached_tokens: u
224 .prompt_cached_tokens
225 .and_then(|v| u32::try_from(v).ok()),
226 prompt_cache_creation_tokens: u
227 .prompt_cache_creation_tokens
228 .and_then(|v| u32::try_from(v).ok()),
229 completion_reasoning_tokens: None,
230 prompt_tokens_details: None,
231 completion_tokens_details: None,
232 });
233 }
234
235 for choice in &chunk.choices {
237 if let Some(ref reason) = choice.finish_reason {
239 if !reason.is_empty() {
240 finish_reason = Some(reason.clone());
241 }
242 }
243
244 if let Some(ref delta) = choice.delta {
246 if role.is_none() {
248 role = delta.role.clone();
249 }
250
251 if let Some(ref content) = delta.content {
253 aggregated_content.push_str(content);
254 }
255 }
256 }
257 }
258
259 let metadata = StreamMetadata {
261 model,
262 extra: HashMap::new(),
263 };
264
265 let message = ChatMessage {
267 role: Some(role.unwrap_or_else(|| "assistant".to_string())),
268 content: Some(aggregated_content),
269 tool_calls: None, };
271
272 let choice = OutputChoice {
273 index: 0,
274 message,
275 logprobs: None,
276 finish_reason,
277 };
278
279 Ok(FinalizedStream {
280 output: vec![choice],
281 usage,
282 metadata,
283 })
284 }
285}
286
287#[allow(private_bounds)]
300pub fn wrap_stream_with_span<S, E, Sub>(
301 stream: S,
302 span: SpanHandle<Sub>,
303) -> Pin<Box<dyn Stream<Item = std::result::Result<Value, E>> + Send>>
304where
305 S: Stream<Item = std::result::Result<Value, E>> + Send + Unpin + 'static,
306 E: Send + 'static,
307 Sub: SpanSubmitter + 'static,
308{
309 use futures::StreamExt;
310
311 let start_time = Instant::now();
312 let ttft_recorded = Arc::new(AtomicBool::new(false));
313 let aggregator = Arc::new(Mutex::new(BraintrustStream::new()));
314 let span_for_complete = span.clone();
315 let aggregator_for_complete = Arc::clone(&aggregator);
316
317 let logged_stream = stream.then(move |result| {
318 let span = span.clone();
319 let ttft_recorded = ttft_recorded.clone();
320 let aggregator = aggregator.clone();
321 async move {
322 if let Ok(ref value) = result {
323 if value.get("_keep_alive").is_none() {
325 if !ttft_recorded.swap(true, Ordering::SeqCst) && value_has_content(value) {
327 let ttft_secs = start_time.elapsed().as_secs_f64();
328 span.log(SpanLog {
329 metrics: Some(
330 [("time_to_first_token".to_string(), ttft_secs)]
331 .into_iter()
332 .collect(),
333 ),
334 ..Default::default()
335 })
336 .await;
337 }
338 aggregator.lock().await.push(value.clone());
340 }
341 }
342 result
343 }
344 });
345
346 Box::pin(SpanCompleteWrapper {
348 inner: Box::pin(logged_stream),
349 span: Some(span_for_complete),
350 aggregator: Some(aggregator_for_complete),
351 finalize_state: FinalizeState::Idle,
352 })
353}
354
355fn value_has_content(value: &Value) -> bool {
357 if let Some(choices) = value.get("choices").and_then(|c| c.as_array()) {
359 if !choices.is_empty() {
360 return true;
361 }
362 }
363 if let Some(usage) = value.get("usage").and_then(|u| u.as_object()) {
365 let has_tokens = usage
366 .get("completion_tokens")
367 .and_then(|v| v.as_i64())
368 .map(|t| t > 0)
369 .unwrap_or(false)
370 || usage
371 .get("prompt_tokens")
372 .and_then(|v| v.as_i64())
373 .map(|t| t > 0)
374 .unwrap_or(false);
375 if has_tokens {
376 return true;
377 }
378 }
379 false
380}
381
382enum FinalizeState {
384 Idle,
386 Finalizing(Pin<Box<dyn std::future::Future<Output = ()> + Send>>),
388 Done,
390}
391
392struct SpanCompleteWrapper<S, Sub: SpanSubmitter> {
394 inner: S,
395 span: Option<SpanHandle<Sub>>,
396 aggregator: Option<Arc<Mutex<BraintrustStream>>>,
397 finalize_state: FinalizeState,
398}
399
400async fn finalize_span<Sub: SpanSubmitter>(
402 span: SpanHandle<Sub>,
403 aggregator: Arc<Mutex<BraintrustStream>>,
404) {
405 let mut agg = aggregator.lock().await;
406 if !agg.is_empty() {
407 match agg.final_value() {
408 Ok(finalized) => {
409 let metrics = finalized
411 .usage
412 .as_ref()
413 .map(|u| usage_metrics_to_map(u.clone()));
414
415 let metadata = finalized.metadata.to_map();
417
418 let output = serde_json::to_value(&finalized.output).ok();
420
421 span.log(SpanLog {
422 output,
423 metadata,
424 metrics,
425 ..Default::default()
426 })
427 .await;
428 }
429 Err(e) => {
430 tracing::warn!("Failed to finalize stream: {}", e);
431 }
432 }
433 }
434 if let Err(e) = span.flush().await {
436 tracing::warn!("Failed to flush span: {}", e);
437 }
438}
439
440impl<S, E, Sub> Stream for SpanCompleteWrapper<S, Sub>
441where
442 S: Stream<Item = std::result::Result<Value, E>> + Unpin,
443 E: Send + 'static,
444 Sub: SpanSubmitter + 'static,
445{
446 type Item = std::result::Result<Value, E>;
447
448 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
449 let this = unsafe { self.get_unchecked_mut() };
451
452 match &mut this.finalize_state {
454 FinalizeState::Idle => {
455 }
457 FinalizeState::Finalizing(fut) => {
458 match fut.as_mut().poll(cx) {
460 Poll::Ready(()) => {
461 this.finalize_state = FinalizeState::Done;
463 return Poll::Ready(None);
464 }
465 Poll::Pending => {
466 return Poll::Pending;
468 }
469 }
470 }
471 FinalizeState::Done => {
472 return Poll::Ready(None);
473 }
474 }
475
476 let result = Pin::new(&mut this.inner).poll_next(cx);
478
479 if matches!(result, Poll::Ready(None)) {
481 if let (Some(span), Some(aggregator)) = (this.span.take(), this.aggregator.take()) {
482 let fut = Box::pin(finalize_span(span, aggregator));
484 this.finalize_state = FinalizeState::Finalizing(fut);
485
486 return unsafe { Pin::new_unchecked(this) }.poll_next(cx);
489 }
490 }
491
492 result
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use serde_json::json;
500
501 #[test]
502 fn aggregates_content_from_streaming_values() {
503 let chunks = vec![
504 json!({
505 "id": "chunk1",
506 "model": "gpt-4",
507 "choices": [{
508 "index": 0,
509 "delta": { "role": "assistant", "content": "Hello" }
510 }],
511 "created": 1
512 }),
513 json!({
514 "id": "chunk2",
515 "model": "gpt-4",
516 "choices": [{
517 "index": 0,
518 "delta": { "content": " world" }
519 }],
520 "created": 1
521 }),
522 json!({
523 "id": "chunk3",
524 "model": "gpt-4",
525 "choices": [{
526 "index": 0,
527 "delta": { "content": "!" },
528 "finish_reason": "stop"
529 }],
530 "created": 1
531 }),
532 ];
533
534 let mut stream = BraintrustStream::new();
535 for chunk in chunks {
536 stream.push(chunk);
537 }
538
539 let finalized = stream.final_value().expect("should finalize");
540
541 assert_eq!(finalized.output.len(), 1);
543
544 let choice = &finalized.output[0];
545 assert_eq!(choice.index, 0);
546 assert_eq!(choice.message.role.as_deref(), Some("assistant"));
547 assert_eq!(choice.message.content.as_deref(), Some("Hello world!"));
548 assert_eq!(choice.finish_reason.as_deref(), Some("stop"));
549
550 assert_eq!(finalized.metadata.model.as_deref(), Some("gpt-4"));
552 }
553
554 #[test]
555 fn aggregates_usage_from_final_chunk() {
556 let chunks = vec![
557 json!({
558 "id": "chunk1",
559 "model": "gpt-4",
560 "choices": [{
561 "index": 0,
562 "delta": { "role": "assistant", "content": "Hi" },
563 "finish_reason": "stop"
564 }],
565 "created": 1
566 }),
567 json!({
568 "id": "chunk2",
569 "model": "gpt-4",
570 "choices": [],
571 "created": 1,
572 "usage": {
573 "prompt_tokens": 10,
574 "completion_tokens": 5
575 }
576 }),
577 ];
578
579 let mut stream = BraintrustStream::new();
580 for chunk in chunks {
581 stream.push(chunk);
582 }
583
584 let finalized = stream.final_value().expect("should finalize");
585
586 let usage = finalized.usage.as_ref().expect("should have usage");
587 assert_eq!(usage.prompt_tokens, Some(10));
588 assert_eq!(usage.completion_tokens, Some(5));
589 assert_eq!(usage.total_tokens, Some(15));
590 }
591
592 #[test]
593 fn skips_keep_alive_markers() {
594 let mut stream = BraintrustStream::new();
595
596 stream.push(json!({"_keep_alive": true}));
598
599 assert!(stream.is_empty());
600 }
601
602 #[test]
603 fn caches_finalized_result() {
604 let chunk = json!({
605 "id": "chunk1",
606 "model": "gpt-4",
607 "choices": [{
608 "index": 0,
609 "delta": { "role": "assistant", "content": "test" }
610 }],
611 "created": 1
612 });
613
614 let mut stream = BraintrustStream::new();
615 stream.push(chunk);
616
617 let first_content = {
619 let first = stream.final_value().expect("should finalize");
620 first.output.first().and_then(|c| c.message.content.clone())
621 };
622
623 let second_content = {
625 let second = stream.final_value().expect("should finalize");
626 second
627 .output
628 .first()
629 .and_then(|c| c.message.content.clone())
630 };
631
632 assert_eq!(first_content, second_content);
633 assert_eq!(first_content, Some("test".to_string()));
634 }
635}