Skip to main content

batuta/agent/driver/
remote_stream.rs

1//! SSE streaming parsers for Anthropic and OpenAI APIs.
2//!
3//! Extracted from `remote.rs` for QA-002 (≤500 lines).
4//! Handles `content_block_start/delta/stop` (Anthropic) and
5//! `choices[0].delta` (OpenAI) Server-Sent Event formats.
6
7use crate::agent::driver::{CompletionResponse, StreamEvent, ToolCall};
8use crate::agent::result::{StopReason, TokenUsage};
9
10/// Parse Anthropic Messages API response.
11pub(super) fn parse_anthropic_response(body: &serde_json::Value) -> CompletionResponse {
12    let stop_reason = match body["stop_reason"].as_str().unwrap_or("end_turn") {
13        "tool_use" => StopReason::ToolUse,
14        "max_tokens" => StopReason::MaxTokens,
15        "stop_sequence" => StopReason::StopSequence,
16        _ => StopReason::EndTurn,
17    };
18
19    let mut text = String::new();
20    let mut tool_calls = Vec::new();
21
22    if let Some(content) = body["content"].as_array() {
23        for block in content {
24            match block["type"].as_str() {
25                Some("text") => {
26                    if let Some(t) = block["text"].as_str() {
27                        text.push_str(t);
28                    }
29                }
30                Some("tool_use") => {
31                    tool_calls.push(ToolCall {
32                        id: block["id"].as_str().unwrap_or("unknown").to_string(),
33                        name: block["name"].as_str().unwrap_or("").to_string(),
34                        input: block["input"].clone(),
35                    });
36                }
37                _ => {}
38            }
39        }
40    }
41
42    let usage = TokenUsage {
43        input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
44        output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
45    };
46
47    CompletionResponse { text, stop_reason, tool_calls, usage }
48}
49
50/// Parse `OpenAI` Chat Completions response.
51pub(super) fn parse_openai_response(body: &serde_json::Value) -> CompletionResponse {
52    let choice = &body["choices"][0];
53    let message = &choice["message"];
54
55    let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
56        "tool_calls" => StopReason::ToolUse,
57        "length" => StopReason::MaxTokens,
58        _ => StopReason::EndTurn,
59    };
60
61    let text = message["content"].as_str().unwrap_or("").to_string();
62
63    let mut tool_calls = Vec::new();
64    if let Some(calls) = message["tool_calls"].as_array() {
65        for call in calls {
66            let input: serde_json::Value = call["function"]["arguments"]
67                .as_str()
68                .and_then(|s| serde_json::from_str(s).ok())
69                .unwrap_or(serde_json::json!({}));
70
71            tool_calls.push(ToolCall {
72                id: call["id"].as_str().unwrap_or("unknown").to_string(),
73                name: call["function"]["name"].as_str().unwrap_or("").to_string(),
74                input,
75            });
76        }
77    }
78
79    let usage = TokenUsage {
80        input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
81        output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
82    };
83
84    CompletionResponse { text, stop_reason, tool_calls, usage }
85}
86
87/// Process a single Anthropic SSE event.
88///
89/// Accumulates text deltas, tool calls (partial JSON), usage,
90/// and stop reason from the Anthropic Messages streaming API.
91pub(super) async fn process_anthropic_event(
92    event: &serde_json::Value,
93    full_text: &mut String,
94    tool_calls: &mut Vec<ToolCall>,
95    usage: &mut TokenUsage,
96    stop_reason: &mut StopReason,
97    current_tool: &mut Option<(String, String, String)>,
98    tx: &tokio::sync::mpsc::Sender<StreamEvent>,
99) {
100    let event_type = event["type"].as_str().unwrap_or("");
101    match event_type {
102        "content_block_start" => {
103            let block = &event["content_block"];
104            if block["type"].as_str() == Some("tool_use") {
105                let id = block["id"].as_str().unwrap_or("").to_string();
106                let name = block["name"].as_str().unwrap_or("").to_string();
107                *current_tool = Some((id, name, String::new()));
108            }
109        }
110        "content_block_delta" => {
111            let delta = &event["delta"];
112            if let Some(text) = delta["text"].as_str() {
113                full_text.push_str(text);
114                let _ = tx.send(StreamEvent::TextDelta { text: text.to_string() }).await;
115            }
116            if let Some(json) = delta["partial_json"].as_str() {
117                if let Some((_, _, ref mut accum)) = current_tool {
118                    accum.push_str(json);
119                }
120            }
121        }
122        "content_block_stop" => {
123            if let Some((id, name, json_str)) = current_tool.take() {
124                let input = serde_json::from_str(&json_str).unwrap_or(serde_json::json!({}));
125                tool_calls.push(ToolCall { id, name, input });
126            }
127        }
128        "message_delta" => {
129            if let Some(sr) = event["delta"]["stop_reason"].as_str() {
130                *stop_reason = match sr {
131                    "tool_use" => StopReason::ToolUse,
132                    "max_tokens" => StopReason::MaxTokens,
133                    "stop_sequence" => StopReason::StopSequence,
134                    _ => StopReason::EndTurn,
135                };
136            }
137            if let Some(out) = event["usage"]["output_tokens"].as_u64() {
138                usage.output_tokens = out;
139            }
140        }
141        "message_start" => {
142            if let Some(inp) = event["message"]["usage"]["input_tokens"].as_u64() {
143                usage.input_tokens = inp;
144            }
145        }
146        _ => {}
147    }
148}
149
150/// Process a single OpenAI SSE event.
151///
152/// Accumulates text deltas, tool calls (indexed partial JSON),
153/// usage, and stop reason from the OpenAI Chat Completions
154/// streaming API.
155pub(super) async fn process_openai_event(
156    event: &serde_json::Value,
157    full_text: &mut String,
158    tool_calls: &mut Vec<ToolCall>,
159    usage: &mut TokenUsage,
160    stop_reason: &mut StopReason,
161    tx: &tokio::sync::mpsc::Sender<StreamEvent>,
162) {
163    let choice = &event["choices"][0];
164    let delta = &choice["delta"];
165
166    if let Some(text) = delta["content"].as_str() {
167        full_text.push_str(text);
168        let _ = tx.send(StreamEvent::TextDelta { text: text.to_string() }).await;
169    }
170
171    if let Some(calls) = delta["tool_calls"].as_array() {
172        for call in calls {
173            accumulate_openai_tool_call(call, tool_calls);
174        }
175    }
176
177    if let Some(fr) = choice["finish_reason"].as_str() {
178        *stop_reason = match fr {
179            "tool_calls" => StopReason::ToolUse,
180            "length" => StopReason::MaxTokens,
181            _ => StopReason::EndTurn,
182        };
183    }
184
185    if let Some(u) = event.get("usage") {
186        if let Some(inp) = u["prompt_tokens"].as_u64() {
187            usage.input_tokens = inp;
188        }
189        if let Some(out) = u["completion_tokens"].as_u64() {
190            usage.output_tokens = out;
191        }
192    }
193}
194
195/// Accumulate a single OpenAI tool call delta into the tool list.
196fn accumulate_openai_tool_call(call: &serde_json::Value, tool_calls: &mut Vec<ToolCall>) {
197    let idx = call["index"].as_u64().unwrap_or(0) as usize;
198    while tool_calls.len() <= idx {
199        tool_calls.push(ToolCall {
200            id: String::new(),
201            name: String::new(),
202            input: serde_json::json!({}),
203        });
204    }
205    if let Some(id) = call["id"].as_str() {
206        tool_calls[idx].id = id.to_string();
207    }
208    if let Some(name) = call["function"]["name"].as_str() {
209        tool_calls[idx].name = name.to_string();
210    }
211    if let Some(args) = call["function"]["arguments"].as_str() {
212        let existing = tool_calls[idx].input.as_str().unwrap_or("");
213        let combined = format!("{existing}{args}");
214        tool_calls[idx].input =
215            serde_json::from_str(&combined).unwrap_or(serde_json::json!(combined));
216    }
217}
218
219#[cfg(test)]
220#[path = "remote_stream_tests.rs"]
221mod tests;