1#[path = "remote_stream.rs"]
13mod remote_stream;
14
15use async_trait::async_trait;
16
17use crate::agent::driver::{
18 CompletionRequest, CompletionResponse, LlmDriver, Message, StreamEvent, ToolCall,
19};
20use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
21use crate::serve::backends::PrivacyTier;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum ApiProvider {
26 Anthropic,
28 OpenAi,
30}
31
32#[derive(Debug, Clone)]
34pub struct RemoteDriverConfig {
35 pub base_url: String,
37 pub api_key: String,
39 pub model: String,
41 pub provider: ApiProvider,
43 pub context_window: usize,
45}
46
47pub struct RemoteDriver {
54 config: RemoteDriverConfig,
55}
56
57impl RemoteDriver {
58 pub fn new(config: RemoteDriverConfig) -> Self {
60 Self { config }
61 }
62
63 fn build_request(&self, request: &CompletionRequest) -> (String, serde_json::Value) {
65 match self.config.provider {
66 ApiProvider::Anthropic => {
67 let url = format!("{}/v1/messages", self.config.base_url);
68 (url, self.build_anthropic_body(request))
69 }
70 ApiProvider::OpenAi => {
71 let url = format!("{}/v1/chat/completions", self.config.base_url);
72 (url, self.build_openai_body(request))
73 }
74 }
75 }
76
77 async fn send_http(
81 &self,
82 url: &str,
83 body: &serde_json::Value,
84 ) -> Result<reqwest::Response, AgentError> {
85 let client = reqwest::Client::new();
86 let mut req = client.post(url);
87 req = match self.config.provider {
88 ApiProvider::Anthropic => req
89 .header("x-api-key", &self.config.api_key)
90 .header("anthropic-version", "2023-06-01")
91 .header("content-type", "application/json"),
92 ApiProvider::OpenAi => req
93 .header("authorization", format!("Bearer {}", self.config.api_key))
94 .header("content-type", "application/json"),
95 };
96
97 let response = req.json(body).send().await.map_err(|e| {
98 AgentError::Driver(DriverError::Network(format!("HTTP request failed: {e}")))
99 })?;
100
101 let status = response.status().as_u16();
102 if status == 429 {
103 return Err(AgentError::Driver(DriverError::RateLimited { retry_after_ms: 1000 }));
104 }
105 if status == 529 || status == 503 {
106 return Err(AgentError::Driver(DriverError::Overloaded { retry_after_ms: 2000 }));
107 }
108 if !response.status().is_success() {
109 let text = response.text().await.unwrap_or_default();
110 return Err(AgentError::Driver(DriverError::Network(format!("HTTP {status}: {text}"))));
111 }
112
113 Ok(response)
114 }
115
116 fn build_anthropic_body(&self, request: &CompletionRequest) -> serde_json::Value {
118 let messages: Vec<serde_json::Value> = request
119 .messages
120 .iter()
121 .filter_map(|m| match m {
122 Message::User(text) => Some(serde_json::json!({
123 "role": "user",
124 "content": text
125 })),
126 Message::Assistant(text) => Some(serde_json::json!({
127 "role": "assistant",
128 "content": text
129 })),
130 Message::AssistantToolUse(call) => Some(serde_json::json!({
131 "role": "assistant",
132 "content": [{
133 "type": "tool_use",
134 "id": call.id,
135 "name": call.name,
136 "input": call.input
137 }]
138 })),
139 Message::ToolResult(result) => Some(serde_json::json!({
140 "role": "user",
141 "content": [{
142 "type": "tool_result",
143 "tool_use_id": result.tool_use_id,
144 "content": result.content,
145 "is_error": result.is_error
146 }]
147 })),
148 Message::System(_) => None,
149 })
150 .collect();
151
152 let mut body = serde_json::json!({
153 "model": self.config.model,
154 "messages": messages,
155 "max_tokens": request.max_tokens,
156 "temperature": request.temperature
157 });
158
159 if let Some(ref system) = request.system {
160 body["system"] = serde_json::json!(system);
161 }
162
163 if !request.tools.is_empty() {
164 let tools: Vec<serde_json::Value> = request
165 .tools
166 .iter()
167 .map(|t| {
168 serde_json::json!({
169 "name": t.name,
170 "description": t.description,
171 "input_schema": t.input_schema
172 })
173 })
174 .collect();
175 body["tools"] = serde_json::json!(tools);
176 }
177
178 body
179 }
180
181 fn build_openai_body(&self, request: &CompletionRequest) -> serde_json::Value {
183 let mut messages: Vec<serde_json::Value> = Vec::new();
184
185 if let Some(ref system) = request.system {
186 messages.push(serde_json::json!({
187 "role": "system",
188 "content": system
189 }));
190 }
191
192 for m in &request.messages {
193 match m {
194 Message::System(text) => {
195 messages.push(serde_json::json!({
196 "role": "system",
197 "content": text
198 }));
199 }
200 Message::User(text) => {
201 messages.push(serde_json::json!({
202 "role": "user",
203 "content": text
204 }));
205 }
206 Message::Assistant(text) => {
207 messages.push(serde_json::json!({
208 "role": "assistant",
209 "content": text
210 }));
211 }
212 Message::AssistantToolUse(call) => {
213 messages.push(serde_json::json!({
214 "role": "assistant",
215 "content": null,
216 "tool_calls": [{
217 "id": call.id,
218 "type": "function",
219 "function": {
220 "name": call.name,
221 "arguments": call.input.to_string()
222 }
223 }]
224 }));
225 }
226 Message::ToolResult(result) => {
227 messages.push(serde_json::json!({
228 "role": "tool",
229 "tool_call_id": result.tool_use_id,
230 "content": result.content
231 }));
232 }
233 }
234 }
235
236 let mut body = serde_json::json!({
237 "model": self.config.model,
238 "messages": messages,
239 "max_tokens": request.max_tokens,
240 "temperature": request.temperature
241 });
242
243 if !request.tools.is_empty() {
244 let tools: Vec<serde_json::Value> = request
245 .tools
246 .iter()
247 .map(|t| {
248 serde_json::json!({
249 "type": "function",
250 "function": {
251 "name": t.name,
252 "description": t.description,
253 "parameters": t.input_schema
254 }
255 })
256 })
257 .collect();
258 body["tools"] = serde_json::json!(tools);
259 }
260
261 body
262 }
263}
264
265#[async_trait]
266impl LlmDriver for RemoteDriver {
267 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
268 let (url, body) = self.build_request(&request);
269 let response = self.send_http(&url, &body).await?;
270
271 let resp_body: serde_json::Value = response.json().await.map_err(|e| {
272 AgentError::Driver(DriverError::InferenceFailed(format!("JSON parse error: {e}")))
273 })?;
274
275 Ok(match self.config.provider {
276 ApiProvider::Anthropic => remote_stream::parse_anthropic_response(&resp_body),
277 ApiProvider::OpenAi => remote_stream::parse_openai_response(&resp_body),
278 })
279 }
280
281 async fn stream(
287 &self,
288 request: CompletionRequest,
289 tx: tokio::sync::mpsc::Sender<StreamEvent>,
290 ) -> Result<CompletionResponse, AgentError> {
291 use futures_util::StreamExt;
292
293 let (url, mut body) = self.build_request(&request);
294 body["stream"] = serde_json::json!(true);
295
296 let response = self.send_http(&url, &body).await?;
297
298 let mut full_text = String::new();
299 let mut tool_calls = Vec::new();
300 let mut usage = TokenUsage { input_tokens: 0, output_tokens: 0 };
301 let mut stop_reason = StopReason::EndTurn;
302 let mut current_tool: Option<(String, String, String)> = None; let mut stream = response.bytes_stream();
305 let mut buffer = String::new();
306
307 while let Some(chunk) = stream.next().await {
308 let bytes = chunk.map_err(|e| {
309 AgentError::Driver(DriverError::Network(format!("stream error: {e}")))
310 })?;
311 buffer.push_str(&String::from_utf8_lossy(&bytes));
312
313 while let Some(line_end) = buffer.find('\n') {
314 let line = buffer[..line_end].trim().to_string();
315 buffer = buffer[line_end + 1..].to_string();
316
317 if line.is_empty() || line.starts_with(':') {
318 continue;
319 }
320
321 let data = if let Some(stripped) = line.strip_prefix("data: ") {
322 stripped
323 } else {
324 continue;
325 };
326
327 if data == "[DONE]" {
328 break;
329 }
330
331 let Ok(event): Result<serde_json::Value, _> = serde_json::from_str(data) else {
332 continue;
333 };
334
335 match self.config.provider {
336 ApiProvider::Anthropic => {
337 remote_stream::process_anthropic_event(
338 &event,
339 &mut full_text,
340 &mut tool_calls,
341 &mut usage,
342 &mut stop_reason,
343 &mut current_tool,
344 &tx,
345 )
346 .await;
347 }
348 ApiProvider::OpenAi => {
349 remote_stream::process_openai_event(
350 &event,
351 &mut full_text,
352 &mut tool_calls,
353 &mut usage,
354 &mut stop_reason,
355 &tx,
356 )
357 .await;
358 }
359 }
360 }
361 }
362
363 let _ = tx
364 .send(StreamEvent::ContentComplete {
365 stop_reason: stop_reason.clone(),
366 usage: usage.clone(),
367 })
368 .await;
369
370 Ok(CompletionResponse { text: full_text, stop_reason, tool_calls, usage })
371 }
372
373 fn context_window(&self) -> usize {
374 self.config.context_window
375 }
376
377 fn privacy_tier(&self) -> PrivacyTier {
378 PrivacyTier::Standard }
380
381 #[allow(clippy::cast_precision_loss)] fn estimate_cost(&self, usage: &TokenUsage) -> f64 {
388 let input_cost = usage.input_tokens as f64 * 3.0 / 1_000_000.0;
389 let output_cost = usage.output_tokens as f64 * 15.0 / 1_000_000.0;
390 input_cost + output_cost
391 }
392}
393
394#[cfg(test)]
395#[path = "remote_tests.rs"]
396mod tests;
397
398#[cfg(test)]
399#[path = "remote_tests_body.rs"]
400mod tests_body;