Skip to main content

agentrs_core/
streaming.rs

1//! Streaming helpers.
2
3use bytes::Bytes;
4
5use crate::{AgentError, Result, StopReason, StreamChunk, ToolCallDelta};
6
7/// Parses a single SSE data chunk into a unified [`StreamChunk`].
8pub fn parse_sse_chunk(bytes: Bytes) -> Result<Option<StreamChunk>> {
9    let text = std::str::from_utf8(&bytes).map_err(|_| AgentError::InvalidStream)?;
10
11    for line in text.lines() {
12        let Some(payload) = line.strip_prefix("data:") else {
13            continue;
14        };
15
16        let payload = payload.trim();
17        if payload.is_empty() {
18            continue;
19        }
20
21        if payload == "[DONE]" {
22            return Ok(None);
23        }
24
25        let value: serde_json::Value = serde_json::from_str(payload)?;
26        return Ok(Some(map_openai_like_chunk(&value)));
27    }
28
29    Ok(None)
30}
31
32fn map_openai_like_chunk(value: &serde_json::Value) -> StreamChunk {
33    let choice = value
34        .get("choices")
35        .and_then(|choices| choices.as_array())
36        .and_then(|choices| choices.first())
37        .cloned()
38        .unwrap_or_default();
39
40    let delta_value = choice.get("delta").cloned().unwrap_or_default();
41    let delta = delta_value
42        .get("content")
43        .and_then(serde_json::Value::as_str)
44        .unwrap_or_default()
45        .to_string();
46
47    let tool_call_delta = delta_value
48        .get("tool_calls")
49        .and_then(serde_json::Value::as_array)
50        .map(|items| {
51            items
52                .iter()
53                .map(|item| ToolCallDelta {
54                    index: item
55                        .get("index")
56                        .and_then(serde_json::Value::as_u64)
57                        .unwrap_or_default() as usize,
58                    id: item
59                        .get("id")
60                        .and_then(serde_json::Value::as_str)
61                        .map(ToOwned::to_owned),
62                    name: item
63                        .get("function")
64                        .and_then(|function| function.get("name"))
65                        .and_then(serde_json::Value::as_str)
66                        .map(ToOwned::to_owned),
67                    arguments_delta: item
68                        .get("function")
69                        .and_then(|function| function.get("arguments"))
70                        .and_then(serde_json::Value::as_str)
71                        .map(ToOwned::to_owned),
72                })
73                .collect::<Vec<_>>()
74        })
75        .filter(|items| !items.is_empty());
76
77    let finish_reason = choice
78        .get("finish_reason")
79        .and_then(serde_json::Value::as_str)
80        .map(map_stop_reason);
81
82    StreamChunk {
83        delta,
84        tool_call_delta,
85        finish_reason,
86    }
87}
88
89/// Maps a provider stop-reason string into the unified enum.
90pub fn map_stop_reason(value: &str) -> StopReason {
91    match value {
92        "stop" | "end_turn" => StopReason::Stop,
93        "tool_use" | "tool_calls" => StopReason::ToolUse,
94        "length" | "max_tokens" => StopReason::MaxTokens,
95        other => StopReason::Other(other.to_string()),
96    }
97}