1use futures::stream::{self, BoxStream, Stream, StreamExt};
8use std::collections::BTreeMap;
9use std::time::{Duration, Instant};
10use tokio::sync::mpsc;
11
12const SSE_IDLE_TIMEOUT: Duration = Duration::from_secs(120);
18const SSE_CHUNK_BUFFER: usize = 128;
22
23pub fn ensure_event_stream(resp: reqwest::Response) -> anyhow::Result<reqwest::Response> {
31 if let Some(ct) = resp.headers().get(reqwest::header::CONTENT_TYPE) {
32 if let Ok(s) = ct.to_str() {
33 let s_lower = s.to_ascii_lowercase();
34 if !s_lower.contains("text/event-stream") {
35 anyhow::bail!(
36 "expected SSE response (text/event-stream), got content-type `{s}` — upstream is likely an error page"
37 );
38 }
39 }
40 }
41 Ok(resp)
42}
43
44use crate::client::LlmClient;
45use crate::rate_limiter::RateLimiter;
46use crate::telemetry::{inc_stream_chunks_total, observe_stream_ttft_ms};
47use crate::types::{
48 ChatRequest, ChatResponse, FinishReason, ResponseContent, TokenUsage, ToolCall,
49};
50use std::sync::Arc;
51
52pub fn record_usage_tap<S>(
57 stream: S,
58 rate_limiter: Arc<RateLimiter>,
59) -> BoxStream<'static, anyhow::Result<StreamChunk>>
60where
61 S: Stream<Item = anyhow::Result<StreamChunk>> + Send + 'static,
62{
63 stream
64 .inspect(move |item| {
65 if let Ok(StreamChunk::Usage(u)) = item {
66 if let Some(t) = rate_limiter.quota_tracker() {
67 t.record_usage(u.prompt_tokens, u.completion_tokens);
68 }
69 }
70 })
71 .boxed()
72}
73
74pub fn stream_metrics_tap<S>(
80 stream: S,
81 provider: &str,
82) -> BoxStream<'static, anyhow::Result<StreamChunk>>
83where
84 S: Stream<Item = anyhow::Result<StreamChunk>> + Send + 'static,
85{
86 let provider = provider.to_string();
87 let started = Instant::now();
88 let mut observed_ttft = false;
89 stream
90 .inspect(move |item| {
91 if let Ok(chunk) = item {
92 inc_stream_chunks_total(&provider, chunk.kind_label());
93 if !observed_ttft && chunk.is_contentful() {
94 observed_ttft = true;
95 let elapsed_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
96 observe_stream_ttft_ms(&provider, elapsed_ms);
97 }
98 }
99 })
100 .boxed()
101}
102
103#[derive(Debug, Clone)]
112pub enum StreamChunk {
113 TextDelta { delta: String },
114 ToolCallStart { id: String, name: String },
115 ToolCallArgsDelta { id: String, delta: String },
116 ToolCallEnd { id: String },
117 Usage(TokenUsage),
118 End { finish_reason: FinishReason },
119}
120
121impl StreamChunk {
122 pub fn kind_label(&self) -> &'static str {
123 match self {
124 StreamChunk::TextDelta { .. } => "text_delta",
125 StreamChunk::ToolCallStart { .. } => "tool_call_start",
126 StreamChunk::ToolCallArgsDelta { .. } => "tool_call_args_delta",
127 StreamChunk::ToolCallEnd { .. } => "tool_call_end",
128 StreamChunk::Usage(_) => "usage",
129 StreamChunk::End { .. } => "end",
130 }
131 }
132
133 fn is_contentful(&self) -> bool {
134 matches!(
135 self,
136 StreamChunk::TextDelta { .. }
137 | StreamChunk::ToolCallStart { .. }
138 | StreamChunk::ToolCallArgsDelta { .. }
139 | StreamChunk::ToolCallEnd { .. }
140 )
141 }
142}
143
144const MAX_TEXT_BYTES: usize = 8 * 1024 * 1024;
150const MAX_TOOL_ARGS_BYTES: usize = 4 * 1024 * 1024;
153
154fn receiver_stream(
155 rx: mpsc::Receiver<anyhow::Result<StreamChunk>>,
156) -> BoxStream<'static, anyhow::Result<StreamChunk>> {
157 futures::stream::unfold(rx, |mut rx| async move {
158 rx.recv().await.map(|item| (item, rx))
159 })
160 .boxed()
161}
162
163pub async fn collect_stream<S>(mut s: S) -> anyhow::Result<ChatResponse>
170where
171 S: Stream<Item = anyhow::Result<StreamChunk>> + Unpin,
172{
173 let mut text = String::new();
174 let mut tool_order: Vec<String> = Vec::new();
176 let mut tool_buf: BTreeMap<String, (String, String)> = BTreeMap::new(); let mut usage = TokenUsage::default();
178 let mut finish: Option<FinishReason> = None;
179
180 while let Some(item) = s.next().await {
181 match item? {
182 StreamChunk::TextDelta { delta } => {
183 if text.len().saturating_add(delta.len()) > MAX_TEXT_BYTES {
184 anyhow::bail!(
185 "stream text exceeded {} bytes — refusing to buffer further",
186 MAX_TEXT_BYTES
187 );
188 }
189 text.push_str(&delta);
190 }
191 StreamChunk::ToolCallStart { id, name } => {
192 if !tool_buf.contains_key(&id) {
193 tool_order.push(id.clone());
194 }
195 tool_buf.insert(id, (name, String::new()));
196 }
197 StreamChunk::ToolCallArgsDelta { id, delta } => {
198 let entry = tool_buf
199 .entry(id.clone())
200 .or_insert_with(|| (String::new(), String::new()));
201 if entry.1.len().saturating_add(delta.len()) > MAX_TOOL_ARGS_BYTES {
202 anyhow::bail!(
203 "tool `{}` args exceeded {} bytes — refusing to buffer further",
204 entry.0,
205 MAX_TOOL_ARGS_BYTES
206 );
207 }
208 entry.1.push_str(&delta);
209 if !tool_order.iter().any(|x| x == &id) {
210 tool_order.push(id);
211 }
212 }
213 StreamChunk::ToolCallEnd { .. } => {}
214 StreamChunk::Usage(u) => usage = u,
215 StreamChunk::End { finish_reason } => {
216 finish = Some(finish_reason);
217 break;
218 }
219 }
220 }
221
222 let finish_reason = finish.ok_or_else(|| anyhow::anyhow!("stream ended without End chunk"))?;
223
224 let content = if !tool_order.is_empty() {
225 let calls: Vec<ToolCall> = tool_order
226 .into_iter()
227 .filter_map(|id| {
228 tool_buf.remove(&id).map(|(name, args)| {
229 let arguments = if args.trim().is_empty() {
230 serde_json::json!({})
231 } else {
232 serde_json::from_str(&args)
233 .unwrap_or_else(|_| serde_json::Value::String(args.clone()))
234 };
235 ToolCall {
236 id,
237 name,
238 arguments,
239 }
240 })
241 })
242 .collect();
243 ResponseContent::ToolCalls(calls)
244 } else {
245 ResponseContent::Text(text)
246 };
247
248 Ok(ChatResponse {
249 content,
250 usage,
251 finish_reason,
252
253 cache_usage: None,
254 })
255}
256
257pub async fn default_stream_from_chat<'a, C>(
262 client: &'a C,
263 req: ChatRequest,
264) -> anyhow::Result<BoxStream<'a, anyhow::Result<StreamChunk>>>
265where
266 C: LlmClient + ?Sized,
267{
268 let resp = client.chat(req).await?;
269 Ok(stream_metrics_tap(
270 synth_chunks_from_response(resp),
271 client.provider(),
272 ))
273}
274
275fn synth_chunks_from_response(
276 resp: ChatResponse,
277) -> impl Stream<Item = anyhow::Result<StreamChunk>> + Send + 'static {
278 let ChatResponse {
279 content,
280 usage,
281 finish_reason,
282 cache_usage: _,
283 } = resp;
284 let mut chunks: Vec<anyhow::Result<StreamChunk>> = Vec::new();
285 match content {
286 ResponseContent::Text(t) => {
287 if !t.is_empty() {
288 chunks.push(Ok(StreamChunk::TextDelta { delta: t }));
289 }
290 }
291 ResponseContent::ToolCalls(calls) => {
292 for c in calls {
293 chunks.push(Ok(StreamChunk::ToolCallStart {
294 id: c.id.clone(),
295 name: c.name.clone(),
296 }));
297 let args = serde_json::to_string(&c.arguments).unwrap_or_else(|_| "{}".into());
298 chunks.push(Ok(StreamChunk::ToolCallArgsDelta {
299 id: c.id.clone(),
300 delta: args,
301 }));
302 chunks.push(Ok(StreamChunk::ToolCallEnd { id: c.id }));
303 }
304 }
305 }
306 chunks.push(Ok(StreamChunk::Usage(usage)));
307 chunks.push(Ok(StreamChunk::End { finish_reason }));
308 stream::iter(chunks)
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::types::{ChatMessage, ToolCall};
315 use async_trait::async_trait;
316 use futures::stream::iter;
317
318 fn ok_chunks(v: Vec<StreamChunk>) -> BoxStream<'static, anyhow::Result<StreamChunk>> {
319 iter(v.into_iter().map(Ok)).boxed()
320 }
321
322 #[tokio::test]
323 async fn collect_text_only() {
324 let s = ok_chunks(vec![
325 StreamChunk::TextDelta {
326 delta: "hola ".into(),
327 },
328 StreamChunk::TextDelta {
329 delta: "mundo".into(),
330 },
331 StreamChunk::Usage(TokenUsage {
332 prompt_tokens: 3,
333 completion_tokens: 2,
334 }),
335 StreamChunk::End {
336 finish_reason: FinishReason::Stop,
337 },
338 ]);
339 let r = collect_stream(s).await.unwrap();
340 match r.content {
341 ResponseContent::Text(t) => assert_eq!(t, "hola mundo"),
342 _ => panic!("expected text"),
343 }
344 assert_eq!(r.usage.prompt_tokens, 3);
345 assert_eq!(r.finish_reason, FinishReason::Stop);
346 }
347
348 #[tokio::test]
349 async fn collect_tool_calls() {
350 let s = ok_chunks(vec![
351 StreamChunk::ToolCallStart {
352 id: "call_1".into(),
353 name: "weather".into(),
354 },
355 StreamChunk::ToolCallArgsDelta {
356 id: "call_1".into(),
357 delta: "{\"city\":".into(),
358 },
359 StreamChunk::ToolCallArgsDelta {
360 id: "call_1".into(),
361 delta: "\"Bogota\"}".into(),
362 },
363 StreamChunk::ToolCallEnd {
364 id: "call_1".into(),
365 },
366 StreamChunk::Usage(TokenUsage::default()),
367 StreamChunk::End {
368 finish_reason: FinishReason::ToolUse,
369 },
370 ]);
371 let r = collect_stream(s).await.unwrap();
372 match r.content {
373 ResponseContent::ToolCalls(calls) => {
374 assert_eq!(calls.len(), 1);
375 assert_eq!(calls[0].name, "weather");
376 assert_eq!(calls[0].arguments["city"], "Bogota");
377 }
378 _ => panic!("expected tool calls"),
379 }
380 }
381
382 #[tokio::test]
383 async fn collect_propagates_err() {
384 let s: BoxStream<'static, anyhow::Result<StreamChunk>> = iter(vec![
385 Ok(StreamChunk::TextDelta { delta: "x".into() }),
386 Err(anyhow::anyhow!("boom")),
387 ])
388 .boxed();
389 let r = collect_stream(s).await;
390 assert!(r.is_err());
391 }
392
393 #[tokio::test]
394 async fn collect_missing_end_fails() {
395 let s = ok_chunks(vec![StreamChunk::TextDelta { delta: "x".into() }]);
396 assert!(collect_stream(s).await.is_err());
397 }
398
399 struct FakeClient {
400 resp: ChatResponse,
401 }
402
403 #[async_trait]
404 impl LlmClient for FakeClient {
405 async fn chat(&self, _req: ChatRequest) -> anyhow::Result<ChatResponse> {
406 Ok(self.resp.clone())
407 }
408 fn model_id(&self) -> &str {
409 "fake"
410 }
411 fn provider(&self) -> &str {
412 "fake"
413 }
414 }
415
416 #[tokio::test]
417 async fn default_stream_synthesizes_text() {
418 let client = FakeClient {
419 resp: ChatResponse {
420 content: ResponseContent::Text("hi".into()),
421 usage: TokenUsage {
422 prompt_tokens: 1,
423 completion_tokens: 2,
424 },
425 finish_reason: FinishReason::Stop,
426
427 cache_usage: None,
428 },
429 };
430 let stream = default_stream_from_chat(
431 &client,
432 ChatRequest::new("fake", vec![ChatMessage::user("hola")]),
433 )
434 .await
435 .unwrap();
436 let collected = collect_stream(stream).await.unwrap();
437 match collected.content {
438 ResponseContent::Text(t) => assert_eq!(t, "hi"),
439 _ => panic!(),
440 }
441 assert_eq!(collected.usage.completion_tokens, 2);
442 }
443
444 #[tokio::test]
445 async fn default_stream_synthesizes_tool_calls() {
446 let client = FakeClient {
447 resp: ChatResponse {
448 content: ResponseContent::ToolCalls(vec![ToolCall {
449 id: "c1".into(),
450 name: "search".into(),
451 arguments: serde_json::json!({"q":"rust"}),
452 }]),
453 usage: TokenUsage::default(),
454 finish_reason: FinishReason::ToolUse,
455
456 cache_usage: None,
457 },
458 };
459 let stream = default_stream_from_chat(
460 &client,
461 ChatRequest::new("fake", vec![ChatMessage::user("x")]),
462 )
463 .await
464 .unwrap();
465 let collected = collect_stream(stream).await.unwrap();
466 match collected.content {
467 ResponseContent::ToolCalls(calls) => {
468 assert_eq!(calls[0].arguments["q"], "rust");
469 }
470 _ => panic!(),
471 }
472 }
473
474 #[allow(clippy::await_holding_lock)]
476 #[tokio::test]
477 async fn metrics_tap_records_ttft_and_chunk_kinds() {
478 let _guard = crate::telemetry::TEST_LOCK
487 .lock()
488 .unwrap_or_else(|p| p.into_inner());
489 crate::telemetry::reset_for_test();
490 let provider = "zz_stream_metrics_probe";
491 let stream = stream_metrics_tap(
492 ok_chunks(vec![
493 StreamChunk::TextDelta {
494 delta: "hola".into(),
495 },
496 StreamChunk::Usage(TokenUsage::default()),
497 StreamChunk::End {
498 finish_reason: FinishReason::Stop,
499 },
500 ]),
501 provider,
502 );
503 let _ = collect_stream(stream).await.unwrap();
504 let body = crate::telemetry::render_prometheus();
505 assert!(body.contains(
506 "nexo_llm_stream_chunks_total{provider=\"zz_stream_metrics_probe\",kind=\"text_delta\"} 1"
507 ));
508 assert!(body.contains(
509 "nexo_llm_stream_chunks_total{provider=\"zz_stream_metrics_probe\",kind=\"usage\"} 1"
510 ));
511 assert!(body.contains(
512 "nexo_llm_stream_ttft_seconds_count{provider=\"zz_stream_metrics_probe\"} 1"
513 ));
514 }
515}
516
517use futures::Stream as FStream;
525use serde_json::Value;
526
527pub(crate) fn parse_openai_line(
531 line: &str,
532 acc: &mut OpenAiAcc,
533 out: &mut Vec<anyhow::Result<StreamChunk>>,
534) {
535 if line.trim() == "[DONE]" {
536 return;
538 }
539 let v: Value = match serde_json::from_str(line) {
540 Ok(v) => v,
541 Err(e) => {
542 tracing::warn!(error = %e, line = %line, "openai SSE: skip malformed data");
543 return;
544 }
545 };
546
547 if let Some(u) = v.get("usage") {
549 acc.usage = Some(TokenUsage {
550 prompt_tokens: u.get("prompt_tokens").and_then(Value::as_u64).unwrap_or(0) as u32,
551 completion_tokens: u
552 .get("completion_tokens")
553 .and_then(Value::as_u64)
554 .unwrap_or(0) as u32,
555 });
556 }
557
558 let choice = match v.get("choices").and_then(|c| c.get(0)) {
559 Some(c) => c,
560 None => return,
561 };
562 let delta = choice.get("delta").cloned().unwrap_or(Value::Null);
563
564 if let Some(content) = delta.get("content").and_then(Value::as_str) {
565 if !content.is_empty() {
566 out.push(Ok(StreamChunk::TextDelta {
567 delta: content.to_string(),
568 }));
569 }
570 }
571
572 if let Some(tcs) = delta.get("tool_calls").and_then(Value::as_array) {
573 for tc in tcs {
574 let index = tc.get("index").and_then(Value::as_u64).unwrap_or(0) as usize;
575 let id_opt = tc.get("id").and_then(Value::as_str).map(str::to_string);
576 let name_opt = tc
577 .get("function")
578 .and_then(|f| f.get("name"))
579 .and_then(Value::as_str)
580 .map(str::to_string);
581 let args_delta = tc
582 .get("function")
583 .and_then(|f| f.get("arguments"))
584 .and_then(Value::as_str)
585 .unwrap_or("");
586
587 let slot = acc.tool_by_index.entry(index).or_default();
588 if let Some(id) = id_opt {
589 if slot.id.is_empty() {
590 slot.id = id;
591 }
592 }
593 if let Some(name) = name_opt {
594 if !name.is_empty() {
595 slot.name_buf.push_str(&name);
596 }
597 }
598 if !slot.started && !slot.id.is_empty() && !slot.name_buf.is_empty() {
599 slot.started = true;
600 out.push(Ok(StreamChunk::ToolCallStart {
601 id: slot.id.clone(),
602 name: slot.name_buf.clone(),
603 }));
604 }
605 if slot.started && !args_delta.is_empty() {
606 out.push(Ok(StreamChunk::ToolCallArgsDelta {
607 id: slot.id.clone(),
608 delta: args_delta.to_string(),
609 }));
610 } else if !args_delta.is_empty() {
611 slot.pending_args.push_str(args_delta);
612 }
613 }
614 }
615
616 if let Some(finish) = choice.get("finish_reason").and_then(Value::as_str) {
617 acc.finish_reason = Some(match finish {
618 "stop" => FinishReason::Stop,
619 "tool_calls" => FinishReason::ToolUse,
620 "length" => FinishReason::Length,
621 other => FinishReason::Other(other.to_string()),
622 });
623 for (_, slot) in acc.tool_by_index.iter_mut() {
625 if !slot.started && !slot.id.is_empty() && !slot.name_buf.is_empty() {
626 slot.started = true;
627 out.push(Ok(StreamChunk::ToolCallStart {
628 id: slot.id.clone(),
629 name: slot.name_buf.clone(),
630 }));
631 if !slot.pending_args.is_empty() {
632 out.push(Ok(StreamChunk::ToolCallArgsDelta {
633 id: slot.id.clone(),
634 delta: std::mem::take(&mut slot.pending_args),
635 }));
636 }
637 }
638 if slot.started && !slot.ended {
639 slot.ended = true;
640 out.push(Ok(StreamChunk::ToolCallEnd {
641 id: slot.id.clone(),
642 }));
643 }
644 }
645 }
646}
647
648#[derive(Default)]
649pub(crate) struct OpenAiAcc {
650 pub tool_by_index: BTreeMap<usize, OpenAiToolSlot>,
651 pub usage: Option<TokenUsage>,
652 pub finish_reason: Option<FinishReason>,
653}
654
655#[derive(Default)]
656pub(crate) struct OpenAiToolSlot {
657 pub id: String,
658 pub name_buf: String,
659 pub pending_args: String,
660 pub started: bool,
661 pub ended: bool,
662}
663
664pub fn parse_openai_sse<S, E>(byte_stream: S) -> BoxStream<'static, anyhow::Result<StreamChunk>>
668where
669 S: FStream<Item = Result<bytes::Bytes, E>> + Send + 'static,
670 E: std::fmt::Display + Send + 'static,
671{
672 use eventsource_stream::Eventsource;
673 let mut events = Box::pin(
674 byte_stream
675 .map(|r| r.map_err(|e| std::io::Error::other(e.to_string())))
676 .eventsource(),
677 );
678 let (tx, rx) = mpsc::channel::<anyhow::Result<StreamChunk>>(SSE_CHUNK_BUFFER);
679 tokio::spawn(async move {
680 let mut acc = OpenAiAcc::default();
681 loop {
682 match tokio::time::timeout(SSE_IDLE_TIMEOUT, events.next()).await {
683 Ok(Some(Ok(ev))) => {
684 let mut out = Vec::<anyhow::Result<StreamChunk>>::new();
685 parse_openai_line(&ev.data, &mut acc, &mut out);
686 for chunk in out {
687 if tx.send(chunk).await.is_err() {
688 return;
689 }
690 }
691 }
692 Ok(Some(Err(e))) => {
693 let _ = tx.send(Err(anyhow::anyhow!("sse error: {e}"))).await;
694 return;
695 }
696 Ok(None) => {
697 if let Some(u) = acc.usage.take() {
698 if tx.send(Ok(StreamChunk::Usage(u))).await.is_err() {
699 return;
700 }
701 }
702 let finish = acc.finish_reason.take().unwrap_or(FinishReason::Stop);
703 let _ = tx
704 .send(Ok(StreamChunk::End {
705 finish_reason: finish,
706 }))
707 .await;
708 return;
709 }
710 Err(_) => {
711 let _ = tx
712 .send(Err(anyhow::anyhow!(
713 "sse idle timeout after {}s",
714 SSE_IDLE_TIMEOUT.as_secs()
715 )))
716 .await;
717 return;
718 }
719 }
720 }
721 });
722 receiver_stream(rx)
723}
724
725#[derive(Default)]
728pub(crate) struct AnthropicAcc {
729 pub blocks: BTreeMap<u64, AnthropicBlockSlot>,
731 pub usage: TokenUsage,
732 pub finish_reason: Option<FinishReason>,
733}
734
735#[derive(Default)]
736pub(crate) struct AnthropicBlockSlot {
737 pub id: String,
738 pub name: String,
739 pub kind: String, pub started: bool,
741}
742
743pub(crate) fn parse_anthropic_event(
744 event_type: &str,
745 data: &str,
746 acc: &mut AnthropicAcc,
747 out: &mut Vec<anyhow::Result<StreamChunk>>,
748) {
749 let v: Value = match serde_json::from_str(data) {
750 Ok(v) => v,
751 Err(e) => {
752 tracing::warn!(error = %e, event = %event_type, "anthropic SSE: skip malformed data");
753 return;
754 }
755 };
756
757 match event_type {
758 "message_start" => {
759 if let Some(u) = v.pointer("/message/usage") {
760 acc.usage.prompt_tokens =
761 u.get("input_tokens").and_then(Value::as_u64).unwrap_or(0) as u32;
762 }
763 }
764 "content_block_start" => {
765 let index = v.get("index").and_then(Value::as_u64).unwrap_or(0);
766 let block = v.get("content_block").cloned().unwrap_or(Value::Null);
767 let kind = block
768 .get("type")
769 .and_then(Value::as_str)
770 .unwrap_or("")
771 .to_string();
772 let slot = acc.blocks.entry(index).or_default();
773 slot.kind = kind.clone();
774 if kind == "tool_use" {
775 slot.id = block
776 .get("id")
777 .and_then(Value::as_str)
778 .unwrap_or("")
779 .to_string();
780 slot.name = block
781 .get("name")
782 .and_then(Value::as_str)
783 .unwrap_or("")
784 .to_string();
785 if !slot.id.is_empty() && !slot.name.is_empty() && !slot.started {
786 slot.started = true;
787 out.push(Ok(StreamChunk::ToolCallStart {
788 id: slot.id.clone(),
789 name: slot.name.clone(),
790 }));
791 }
792 }
793 }
794 "content_block_delta" => {
795 let index = v.get("index").and_then(Value::as_u64).unwrap_or(0);
796 let delta = v.get("delta").cloned().unwrap_or(Value::Null);
797 let dtype = delta.get("type").and_then(Value::as_str).unwrap_or("");
798 let slot = acc.blocks.entry(index).or_default();
799 match dtype {
800 "text_delta" => {
801 if let Some(t) = delta.get("text").and_then(Value::as_str) {
802 if !t.is_empty() {
803 out.push(Ok(StreamChunk::TextDelta {
804 delta: t.to_string(),
805 }));
806 }
807 }
808 }
809 "input_json_delta" => {
810 if let Some(t) = delta.get("partial_json").and_then(Value::as_str) {
811 if !t.is_empty() && slot.started {
812 out.push(Ok(StreamChunk::ToolCallArgsDelta {
813 id: slot.id.clone(),
814 delta: t.to_string(),
815 }));
816 }
817 }
818 }
819 _ => {}
820 }
821 }
822 "content_block_stop" => {
823 let index = v.get("index").and_then(Value::as_u64).unwrap_or(0);
824 if let Some(slot) = acc.blocks.get_mut(&index) {
825 if slot.kind == "tool_use" && slot.started {
826 out.push(Ok(StreamChunk::ToolCallEnd {
827 id: slot.id.clone(),
828 }));
829 }
830 }
831 }
832 "message_delta" => {
833 if let Some(stop) = v.pointer("/delta/stop_reason").and_then(Value::as_str) {
834 acc.finish_reason = Some(match stop {
835 "end_turn" => FinishReason::Stop,
836 "tool_use" => FinishReason::ToolUse,
837 "max_tokens" => FinishReason::Length,
838 other => FinishReason::Other(other.to_string()),
839 });
840 }
841 if let Some(u) = v.get("usage") {
842 if let Some(ot) = u.get("output_tokens").and_then(Value::as_u64) {
843 acc.usage.completion_tokens = ot as u32;
844 }
845 if let Some(it) = u.get("input_tokens").and_then(Value::as_u64) {
846 if acc.usage.prompt_tokens == 0 {
847 acc.usage.prompt_tokens = it as u32;
848 }
849 }
850 }
851 }
852 "message_stop" => {}
853 _ => {}
854 }
855}
856
857pub fn parse_anthropic_sse<S, E>(byte_stream: S) -> BoxStream<'static, anyhow::Result<StreamChunk>>
858where
859 S: FStream<Item = Result<bytes::Bytes, E>> + Send + Unpin + 'static,
860 E: std::fmt::Display + Send + 'static,
861{
862 use eventsource_stream::Eventsource;
863 let mut events = Box::pin(
864 byte_stream
865 .map(|r| r.map_err(|e| std::io::Error::other(e.to_string())))
866 .eventsource(),
867 );
868 let (tx, rx) = mpsc::channel::<anyhow::Result<StreamChunk>>(SSE_CHUNK_BUFFER);
869 tokio::spawn(async move {
870 let mut acc = AnthropicAcc::default();
871 loop {
872 match tokio::time::timeout(SSE_IDLE_TIMEOUT, events.next()).await {
873 Ok(Some(Ok(ev))) => {
874 let etype = if ev.event.is_empty() {
875 "message".to_string()
876 } else {
877 ev.event.clone()
878 };
879 let mut out = Vec::<anyhow::Result<StreamChunk>>::new();
880 parse_anthropic_event(&etype, &ev.data, &mut acc, &mut out);
881 for chunk in out {
882 if tx.send(chunk).await.is_err() {
883 return;
884 }
885 }
886 }
887 Ok(Some(Err(e))) => {
888 let _ = tx.send(Err(anyhow::anyhow!("sse error: {e}"))).await;
889 return;
890 }
891 Ok(None) => {
892 if tx
893 .send(Ok(StreamChunk::Usage(acc.usage.clone())))
894 .await
895 .is_err()
896 {
897 return;
898 }
899 let finish = acc.finish_reason.take().unwrap_or(FinishReason::Stop);
900 let _ = tx
901 .send(Ok(StreamChunk::End {
902 finish_reason: finish,
903 }))
904 .await;
905 return;
906 }
907 Err(_) => {
908 let _ = tx
909 .send(Err(anyhow::anyhow!(
910 "sse idle timeout after {}s",
911 SSE_IDLE_TIMEOUT.as_secs()
912 )))
913 .await;
914 return;
915 }
916 }
917 }
918 });
919 receiver_stream(rx)
920}
921
922#[derive(Default)]
932struct GeminiAcc {
933 usage: TokenUsage,
934 finish_reason: Option<FinishReason>,
935 tool_call_counter: usize,
936}
937
938fn parse_gemini_event(data: &str, acc: &mut GeminiAcc, out: &mut Vec<anyhow::Result<StreamChunk>>) {
939 let v: serde_json::Value = match serde_json::from_str(data) {
940 Ok(v) => v,
941 Err(e) => {
942 out.push(Err(anyhow::anyhow!("gemini sse json: {e}")));
943 return;
944 }
945 };
946 if let Some(cand) = v.pointer("/candidates/0") {
947 if let Some(parts) = cand.pointer("/content/parts").and_then(|p| p.as_array()) {
948 for part in parts {
949 if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
950 if !t.is_empty() {
951 out.push(Ok(StreamChunk::TextDelta {
952 delta: t.to_string(),
953 }));
954 }
955 }
956 if let Some(fc) = part.get("functionCall") {
957 let name = fc
958 .get("name")
959 .and_then(|n| n.as_str())
960 .unwrap_or("")
961 .to_string();
962 let args = fc.get("args").cloned().unwrap_or(serde_json::json!({}));
963 let id = format!("gemini_call_{}", acc.tool_call_counter);
964 acc.tool_call_counter += 1;
965 out.push(Ok(StreamChunk::ToolCallStart {
966 id: id.clone(),
967 name,
968 }));
969 out.push(Ok(StreamChunk::ToolCallArgsDelta {
970 id: id.clone(),
971 delta: serde_json::to_string(&args).unwrap_or_default(),
972 }));
973 out.push(Ok(StreamChunk::ToolCallEnd { id }));
974 }
975 }
976 }
977 if let Some(fr) = cand.get("finishReason").and_then(|f| f.as_str()) {
978 acc.finish_reason = Some(match fr {
979 "STOP" => FinishReason::Stop,
980 "MAX_TOKENS" => FinishReason::Length,
981 other => FinishReason::Other(other.to_string()),
982 });
983 }
984 }
985 if let Some(u) = v.get("usageMetadata") {
986 if let Some(p) = u.get("promptTokenCount").and_then(|v| v.as_u64()) {
987 acc.usage.prompt_tokens = p as u32;
988 }
989 if let Some(o) = u.get("candidatesTokenCount").and_then(|v| v.as_u64()) {
990 acc.usage.completion_tokens = o as u32;
991 }
992 }
993}
994
995pub fn parse_gemini_sse<S, E>(byte_stream: S) -> BoxStream<'static, anyhow::Result<StreamChunk>>
996where
997 S: FStream<Item = Result<bytes::Bytes, E>> + Send + Unpin + 'static,
998 E: std::fmt::Display + Send + 'static,
999{
1000 use eventsource_stream::Eventsource;
1001 let mut events = Box::pin(
1002 byte_stream
1003 .map(|r| r.map_err(|e| std::io::Error::other(e.to_string())))
1004 .eventsource(),
1005 );
1006 let (tx, rx) = mpsc::channel::<anyhow::Result<StreamChunk>>(SSE_CHUNK_BUFFER);
1007 tokio::spawn(async move {
1008 let mut acc = GeminiAcc::default();
1009 loop {
1010 match tokio::time::timeout(SSE_IDLE_TIMEOUT, events.next()).await {
1011 Ok(Some(Ok(ev))) => {
1012 if ev.data.trim().is_empty() {
1013 continue;
1014 }
1015 let mut out = Vec::<anyhow::Result<StreamChunk>>::new();
1016 parse_gemini_event(&ev.data, &mut acc, &mut out);
1017 for chunk in out {
1018 if tx.send(chunk).await.is_err() {
1019 return;
1020 }
1021 }
1022 }
1023 Ok(Some(Err(e))) => {
1024 let _ = tx.send(Err(anyhow::anyhow!("sse error: {e}"))).await;
1025 return;
1026 }
1027 Ok(None) => {
1028 if tx
1029 .send(Ok(StreamChunk::Usage(acc.usage.clone())))
1030 .await
1031 .is_err()
1032 {
1033 return;
1034 }
1035 let finish = acc.finish_reason.take().unwrap_or(FinishReason::Stop);
1036 let _ = tx
1037 .send(Ok(StreamChunk::End {
1038 finish_reason: finish,
1039 }))
1040 .await;
1041 return;
1042 }
1043 Err(_) => {
1044 let _ = tx
1045 .send(Err(anyhow::anyhow!(
1046 "sse idle timeout after {}s",
1047 SSE_IDLE_TIMEOUT.as_secs()
1048 )))
1049 .await;
1050 return;
1051 }
1052 }
1053 }
1054 });
1055 receiver_stream(rx)
1056}
1057
1058#[cfg(test)]
1059mod parser_tests {
1060 use super::*;
1061 use bytes::Bytes;
1062 use futures::stream;
1063
1064 fn bstream(
1065 chunks: Vec<&'static str>,
1066 ) -> impl FStream<Item = Result<Bytes, std::io::Error>> + Send + 'static {
1067 stream::iter(
1068 chunks
1069 .into_iter()
1070 .map(|s| Ok(Bytes::from_static(s.as_bytes()))),
1071 )
1072 }
1073
1074 #[tokio::test]
1075 async fn openai_text_stream() {
1076 let raw = "data: {\"choices\":[{\"delta\":{\"content\":\"Hola \"}}]}\n\n\
1077data: {\"choices\":[{\"delta\":{\"content\":\"mundo\"}}]}\n\n\
1078data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2}}\n\n\
1079data: [DONE]\n\n";
1080 let s = parse_openai_sse(bstream(vec![raw]));
1081 let r = collect_stream(s).await.unwrap();
1082 match r.content {
1083 ResponseContent::Text(t) => assert_eq!(t, "Hola mundo"),
1084 _ => panic!(),
1085 }
1086 assert_eq!(r.usage.completion_tokens, 2);
1087 assert_eq!(r.finish_reason, FinishReason::Stop);
1088 }
1089
1090 #[tokio::test]
1091 async fn openai_tool_call_stream() {
1092 let raw = "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"\"}}]}}]}\n\n\
1093data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"city\\\":\"}}]}}]}\n\n\
1094data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"Bogota\\\"}\"}}]}}]}\n\n\
1095data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n\
1096data: [DONE]\n\n";
1097 let s = parse_openai_sse(bstream(vec![raw]));
1098 let r = collect_stream(s).await.unwrap();
1099 match r.content {
1100 ResponseContent::ToolCalls(calls) => {
1101 assert_eq!(calls.len(), 1);
1102 assert_eq!(calls[0].id, "call_1");
1103 assert_eq!(calls[0].name, "weather");
1104 assert_eq!(calls[0].arguments["city"], "Bogota");
1105 }
1106 _ => panic!("expected tool calls"),
1107 }
1108 assert_eq!(r.finish_reason, FinishReason::ToolUse);
1109 }
1110
1111 #[tokio::test]
1112 async fn openai_malformed_line_is_skipped() {
1113 let raw = "data: {broken\n\n\
1114data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":\"stop\"}]}\n\n\
1115data: [DONE]\n\n";
1116 let s = parse_openai_sse(bstream(vec![raw]));
1117 let r = collect_stream(s).await.unwrap();
1118 match r.content {
1119 ResponseContent::Text(t) => assert_eq!(t, "ok"),
1120 _ => panic!(),
1121 }
1122 }
1123
1124 #[tokio::test]
1125 async fn anthropic_text_stream() {
1126 let raw = "event: message_start\n\
1127data: {\"message\":{\"usage\":{\"input_tokens\":4}}}\n\n\
1128event: content_block_start\n\
1129data: {\"index\":0,\"content_block\":{\"type\":\"text\"}}\n\n\
1130event: content_block_delta\n\
1131data: {\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hola \"}}\n\n\
1132event: content_block_delta\n\
1133data: {\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"mundo\"}}\n\n\
1134event: content_block_stop\n\
1135data: {\"index\":0}\n\n\
1136event: message_delta\n\
1137data: {\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":2}}\n\n\
1138event: message_stop\n\
1139data: {}\n\n";
1140 let s = parse_anthropic_sse(bstream(vec![raw]));
1141 let r = collect_stream(s).await.unwrap();
1142 match r.content {
1143 ResponseContent::Text(t) => assert_eq!(t, "Hola mundo"),
1144 _ => panic!(),
1145 }
1146 assert_eq!(r.usage.prompt_tokens, 4);
1147 assert_eq!(r.usage.completion_tokens, 2);
1148 assert_eq!(r.finish_reason, FinishReason::Stop);
1149 }
1150
1151 #[tokio::test]
1152 async fn anthropic_tool_use_stream() {
1153 let raw = "event: message_start\n\
1154data: {\"message\":{\"usage\":{\"input_tokens\":10}}}\n\n\
1155event: content_block_start\n\
1156data: {\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_01\",\"name\":\"weather\"}}\n\n\
1157event: content_block_delta\n\
1158data: {\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\"}}\n\n\
1159event: content_block_delta\n\
1160data: {\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"\\\"Bogota\\\"}\"}}\n\n\
1161event: content_block_stop\n\
1162data: {\"index\":0}\n\n\
1163event: message_delta\n\
1164data: {\"delta\":{\"stop_reason\":\"tool_use\"},\"usage\":{\"output_tokens\":7}}\n\n\
1165event: message_stop\n\
1166data: {}\n\n";
1167 let s = parse_anthropic_sse(bstream(vec![raw]));
1168 let r = collect_stream(s).await.unwrap();
1169 match r.content {
1170 ResponseContent::ToolCalls(calls) => {
1171 assert_eq!(calls[0].id, "toolu_01");
1172 assert_eq!(calls[0].name, "weather");
1173 assert_eq!(calls[0].arguments["city"], "Bogota");
1174 }
1175 _ => panic!("expected tool calls"),
1176 }
1177 assert_eq!(r.finish_reason, FinishReason::ToolUse);
1178 }
1179
1180 #[tokio::test]
1181 async fn gemini_text_stream() {
1182 let raw = "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hola \"}]}}]}\n\n\
1183data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"mundo\"}]}}]}\n\n\
1184data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":4,\"candidatesTokenCount\":2}}\n\n";
1185 let s = parse_gemini_sse(bstream(vec![raw]));
1186 let r = collect_stream(s).await.unwrap();
1187 match r.content {
1188 ResponseContent::Text(t) => assert_eq!(t, "Hola mundo"),
1189 _ => panic!(),
1190 }
1191 assert_eq!(r.usage.prompt_tokens, 4);
1192 assert_eq!(r.usage.completion_tokens, 2);
1193 assert_eq!(r.finish_reason, FinishReason::Stop);
1194 }
1195
1196 #[tokio::test]
1197 async fn gemini_function_call_stream() {
1198 let raw = "data: {\"candidates\":[{\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"weather\",\"args\":{\"city\":\"Bogota\"}}}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":3}}\n\n";
1199 let s = parse_gemini_sse(bstream(vec![raw]));
1200 let r = collect_stream(s).await.unwrap();
1201 match r.content {
1202 ResponseContent::ToolCalls(calls) => {
1203 assert_eq!(calls.len(), 1);
1204 assert_eq!(calls[0].name, "weather");
1205 assert_eq!(calls[0].arguments["city"], "Bogota");
1206 assert!(calls[0].id.starts_with("gemini_call_"));
1207 }
1208 _ => panic!("expected tool calls"),
1209 }
1210 assert!(matches!(
1215 r.finish_reason,
1216 FinishReason::ToolUse | FinishReason::Stop
1217 ));
1218 }
1219}