batuta/agent/driver/
remote_stream.rs1use crate::agent::driver::{CompletionResponse, StreamEvent, ToolCall};
8use crate::agent::result::{StopReason, TokenUsage};
9
10pub(super) fn parse_anthropic_response(body: &serde_json::Value) -> CompletionResponse {
12 let stop_reason = match body["stop_reason"].as_str().unwrap_or("end_turn") {
13 "tool_use" => StopReason::ToolUse,
14 "max_tokens" => StopReason::MaxTokens,
15 "stop_sequence" => StopReason::StopSequence,
16 _ => StopReason::EndTurn,
17 };
18
19 let mut text = String::new();
20 let mut tool_calls = Vec::new();
21
22 if let Some(content) = body["content"].as_array() {
23 for block in content {
24 match block["type"].as_str() {
25 Some("text") => {
26 if let Some(t) = block["text"].as_str() {
27 text.push_str(t);
28 }
29 }
30 Some("tool_use") => {
31 tool_calls.push(ToolCall {
32 id: block["id"].as_str().unwrap_or("unknown").to_string(),
33 name: block["name"].as_str().unwrap_or("").to_string(),
34 input: block["input"].clone(),
35 });
36 }
37 _ => {}
38 }
39 }
40 }
41
42 let usage = TokenUsage {
43 input_tokens: body["usage"]["input_tokens"].as_u64().unwrap_or(0),
44 output_tokens: body["usage"]["output_tokens"].as_u64().unwrap_or(0),
45 };
46
47 CompletionResponse { text, stop_reason, tool_calls, usage }
48}
49
50pub(super) fn parse_openai_response(body: &serde_json::Value) -> CompletionResponse {
52 let choice = &body["choices"][0];
53 let message = &choice["message"];
54
55 let stop_reason = match choice["finish_reason"].as_str().unwrap_or("stop") {
56 "tool_calls" => StopReason::ToolUse,
57 "length" => StopReason::MaxTokens,
58 _ => StopReason::EndTurn,
59 };
60
61 let text = message["content"].as_str().unwrap_or("").to_string();
62
63 let mut tool_calls = Vec::new();
64 if let Some(calls) = message["tool_calls"].as_array() {
65 for call in calls {
66 let input: serde_json::Value = call["function"]["arguments"]
67 .as_str()
68 .and_then(|s| serde_json::from_str(s).ok())
69 .unwrap_or(serde_json::json!({}));
70
71 tool_calls.push(ToolCall {
72 id: call["id"].as_str().unwrap_or("unknown").to_string(),
73 name: call["function"]["name"].as_str().unwrap_or("").to_string(),
74 input,
75 });
76 }
77 }
78
79 let usage = TokenUsage {
80 input_tokens: body["usage"]["prompt_tokens"].as_u64().unwrap_or(0),
81 output_tokens: body["usage"]["completion_tokens"].as_u64().unwrap_or(0),
82 };
83
84 CompletionResponse { text, stop_reason, tool_calls, usage }
85}
86
87pub(super) async fn process_anthropic_event(
92 event: &serde_json::Value,
93 full_text: &mut String,
94 tool_calls: &mut Vec<ToolCall>,
95 usage: &mut TokenUsage,
96 stop_reason: &mut StopReason,
97 current_tool: &mut Option<(String, String, String)>,
98 tx: &tokio::sync::mpsc::Sender<StreamEvent>,
99) {
100 let event_type = event["type"].as_str().unwrap_or("");
101 match event_type {
102 "content_block_start" => {
103 let block = &event["content_block"];
104 if block["type"].as_str() == Some("tool_use") {
105 let id = block["id"].as_str().unwrap_or("").to_string();
106 let name = block["name"].as_str().unwrap_or("").to_string();
107 *current_tool = Some((id, name, String::new()));
108 }
109 }
110 "content_block_delta" => {
111 let delta = &event["delta"];
112 if let Some(text) = delta["text"].as_str() {
113 full_text.push_str(text);
114 let _ = tx.send(StreamEvent::TextDelta { text: text.to_string() }).await;
115 }
116 if let Some(json) = delta["partial_json"].as_str() {
117 if let Some((_, _, ref mut accum)) = current_tool {
118 accum.push_str(json);
119 }
120 }
121 }
122 "content_block_stop" => {
123 if let Some((id, name, json_str)) = current_tool.take() {
124 let input = serde_json::from_str(&json_str).unwrap_or(serde_json::json!({}));
125 tool_calls.push(ToolCall { id, name, input });
126 }
127 }
128 "message_delta" => {
129 if let Some(sr) = event["delta"]["stop_reason"].as_str() {
130 *stop_reason = match sr {
131 "tool_use" => StopReason::ToolUse,
132 "max_tokens" => StopReason::MaxTokens,
133 "stop_sequence" => StopReason::StopSequence,
134 _ => StopReason::EndTurn,
135 };
136 }
137 if let Some(out) = event["usage"]["output_tokens"].as_u64() {
138 usage.output_tokens = out;
139 }
140 }
141 "message_start" => {
142 if let Some(inp) = event["message"]["usage"]["input_tokens"].as_u64() {
143 usage.input_tokens = inp;
144 }
145 }
146 _ => {}
147 }
148}
149
150pub(super) async fn process_openai_event(
156 event: &serde_json::Value,
157 full_text: &mut String,
158 tool_calls: &mut Vec<ToolCall>,
159 usage: &mut TokenUsage,
160 stop_reason: &mut StopReason,
161 tx: &tokio::sync::mpsc::Sender<StreamEvent>,
162) {
163 let choice = &event["choices"][0];
164 let delta = &choice["delta"];
165
166 if let Some(text) = delta["content"].as_str() {
167 full_text.push_str(text);
168 let _ = tx.send(StreamEvent::TextDelta { text: text.to_string() }).await;
169 }
170
171 if let Some(calls) = delta["tool_calls"].as_array() {
172 for call in calls {
173 accumulate_openai_tool_call(call, tool_calls);
174 }
175 }
176
177 if let Some(fr) = choice["finish_reason"].as_str() {
178 *stop_reason = match fr {
179 "tool_calls" => StopReason::ToolUse,
180 "length" => StopReason::MaxTokens,
181 _ => StopReason::EndTurn,
182 };
183 }
184
185 if let Some(u) = event.get("usage") {
186 if let Some(inp) = u["prompt_tokens"].as_u64() {
187 usage.input_tokens = inp;
188 }
189 if let Some(out) = u["completion_tokens"].as_u64() {
190 usage.output_tokens = out;
191 }
192 }
193}
194
195fn accumulate_openai_tool_call(call: &serde_json::Value, tool_calls: &mut Vec<ToolCall>) {
197 let idx = call["index"].as_u64().unwrap_or(0) as usize;
198 while tool_calls.len() <= idx {
199 tool_calls.push(ToolCall {
200 id: String::new(),
201 name: String::new(),
202 input: serde_json::json!({}),
203 });
204 }
205 if let Some(id) = call["id"].as_str() {
206 tool_calls[idx].id = id.to_string();
207 }
208 if let Some(name) = call["function"]["name"].as_str() {
209 tool_calls[idx].name = name.to_string();
210 }
211 if let Some(args) = call["function"]["arguments"].as_str() {
212 let existing = tool_calls[idx].input.as_str().unwrap_or("");
213 let combined = format!("{existing}{args}");
214 tool_calls[idx].input =
215 serde_json::from_str(&combined).unwrap_or(serde_json::json!(combined));
216 }
217}
218
219#[cfg(test)]
220#[path = "remote_stream_tests.rs"]
221mod tests;