#[path = "remote_stream.rs"]
mod remote_stream;
use async_trait::async_trait;
use crate::agent::driver::{
CompletionRequest, CompletionResponse, LlmDriver, Message, StreamEvent, ToolCall,
};
use crate::agent::result::{AgentError, DriverError, StopReason, TokenUsage};
use crate::serve::backends::PrivacyTier;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ApiProvider {
Anthropic,
OpenAi,
}
#[derive(Debug, Clone)]
pub struct RemoteDriverConfig {
pub base_url: String,
pub api_key: String,
pub model: String,
pub provider: ApiProvider,
pub context_window: usize,
}
pub struct RemoteDriver {
config: RemoteDriverConfig,
}
impl RemoteDriver {
pub fn new(config: RemoteDriverConfig) -> Self {
Self { config }
}
fn build_request(&self, request: &CompletionRequest) -> (String, serde_json::Value) {
match self.config.provider {
ApiProvider::Anthropic => {
let url = format!("{}/v1/messages", self.config.base_url);
(url, self.build_anthropic_body(request))
}
ApiProvider::OpenAi => {
let url = format!("{}/v1/chat/completions", self.config.base_url);
(url, self.build_openai_body(request))
}
}
}
async fn send_http(
&self,
url: &str,
body: &serde_json::Value,
) -> Result<reqwest::Response, AgentError> {
let client = reqwest::Client::new();
let mut req = client.post(url);
req = match self.config.provider {
ApiProvider::Anthropic => req
.header("x-api-key", &self.config.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json"),
ApiProvider::OpenAi => req
.header("authorization", format!("Bearer {}", self.config.api_key))
.header("content-type", "application/json"),
};
let response = req.json(body).send().await.map_err(|e| {
AgentError::Driver(DriverError::Network(format!("HTTP request failed: {e}")))
})?;
let status = response.status().as_u16();
if status == 429 {
return Err(AgentError::Driver(DriverError::RateLimited { retry_after_ms: 1000 }));
}
if status == 529 || status == 503 {
return Err(AgentError::Driver(DriverError::Overloaded { retry_after_ms: 2000 }));
}
if !response.status().is_success() {
let text = response.text().await.unwrap_or_default();
return Err(AgentError::Driver(DriverError::Network(format!("HTTP {status}: {text}"))));
}
Ok(response)
}
fn build_anthropic_body(&self, request: &CompletionRequest) -> serde_json::Value {
let messages: Vec<serde_json::Value> = request
.messages
.iter()
.filter_map(|m| match m {
Message::User(text) => Some(serde_json::json!({
"role": "user",
"content": text
})),
Message::Assistant(text) => Some(serde_json::json!({
"role": "assistant",
"content": text
})),
Message::AssistantToolUse(call) => Some(serde_json::json!({
"role": "assistant",
"content": [{
"type": "tool_use",
"id": call.id,
"name": call.name,
"input": call.input
}]
})),
Message::ToolResult(result) => Some(serde_json::json!({
"role": "user",
"content": [{
"type": "tool_result",
"tool_use_id": result.tool_use_id,
"content": result.content,
"is_error": result.is_error
}]
})),
Message::System(_) => None,
})
.collect();
let mut body = serde_json::json!({
"model": self.config.model,
"messages": messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature
});
if let Some(ref system) = request.system {
body["system"] = serde_json::json!(system);
}
if !request.tools.is_empty() {
let tools: Vec<serde_json::Value> = request
.tools
.iter()
.map(|t| {
serde_json::json!({
"name": t.name,
"description": t.description,
"input_schema": t.input_schema
})
})
.collect();
body["tools"] = serde_json::json!(tools);
}
body
}
fn build_openai_body(&self, request: &CompletionRequest) -> serde_json::Value {
let mut messages: Vec<serde_json::Value> = Vec::new();
if let Some(ref system) = request.system {
messages.push(serde_json::json!({
"role": "system",
"content": system
}));
}
for m in &request.messages {
match m {
Message::System(text) => {
messages.push(serde_json::json!({
"role": "system",
"content": text
}));
}
Message::User(text) => {
messages.push(serde_json::json!({
"role": "user",
"content": text
}));
}
Message::Assistant(text) => {
messages.push(serde_json::json!({
"role": "assistant",
"content": text
}));
}
Message::AssistantToolUse(call) => {
messages.push(serde_json::json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": call.input.to_string()
}
}]
}));
}
Message::ToolResult(result) => {
messages.push(serde_json::json!({
"role": "tool",
"tool_call_id": result.tool_use_id,
"content": result.content
}));
}
}
}
let mut body = serde_json::json!({
"model": self.config.model,
"messages": messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature
});
if !request.tools.is_empty() {
let tools: Vec<serde_json::Value> = request
.tools
.iter()
.map(|t| {
serde_json::json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema
}
})
})
.collect();
body["tools"] = serde_json::json!(tools);
}
body
}
}
#[async_trait]
impl LlmDriver for RemoteDriver {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
let (url, body) = self.build_request(&request);
let response = self.send_http(&url, &body).await?;
let resp_body: serde_json::Value = response.json().await.map_err(|e| {
AgentError::Driver(DriverError::InferenceFailed(format!("JSON parse error: {e}")))
})?;
Ok(match self.config.provider {
ApiProvider::Anthropic => remote_stream::parse_anthropic_response(&resp_body),
ApiProvider::OpenAi => remote_stream::parse_openai_response(&resp_body),
})
}
async fn stream(
&self,
request: CompletionRequest,
tx: tokio::sync::mpsc::Sender<StreamEvent>,
) -> Result<CompletionResponse, AgentError> {
use futures_util::StreamExt;
let (url, mut body) = self.build_request(&request);
body["stream"] = serde_json::json!(true);
let response = self.send_http(&url, &body).await?;
let mut full_text = String::new();
let mut tool_calls = Vec::new();
let mut usage = TokenUsage { input_tokens: 0, output_tokens: 0 };
let mut stop_reason = StopReason::EndTurn;
let mut current_tool: Option<(String, String, String)> = None;
let mut stream = response.bytes_stream();
let mut buffer = String::new();
while let Some(chunk) = stream.next().await {
let bytes = chunk.map_err(|e| {
AgentError::Driver(DriverError::Network(format!("stream error: {e}")))
})?;
buffer.push_str(&String::from_utf8_lossy(&bytes));
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer = buffer[line_end + 1..].to_string();
if line.is_empty() || line.starts_with(':') {
continue;
}
let data = if let Some(stripped) = line.strip_prefix("data: ") {
stripped
} else {
continue;
};
if data == "[DONE]" {
break;
}
let Ok(event): Result<serde_json::Value, _> = serde_json::from_str(data) else {
continue;
};
match self.config.provider {
ApiProvider::Anthropic => {
remote_stream::process_anthropic_event(
&event,
&mut full_text,
&mut tool_calls,
&mut usage,
&mut stop_reason,
&mut current_tool,
&tx,
)
.await;
}
ApiProvider::OpenAi => {
remote_stream::process_openai_event(
&event,
&mut full_text,
&mut tool_calls,
&mut usage,
&mut stop_reason,
&tx,
)
.await;
}
}
}
}
let _ = tx
.send(StreamEvent::ContentComplete {
stop_reason: stop_reason.clone(),
usage: usage.clone(),
})
.await;
Ok(CompletionResponse { text: full_text, stop_reason, tool_calls, usage })
}
fn context_window(&self) -> usize {
self.config.context_window
}
fn privacy_tier(&self) -> PrivacyTier {
PrivacyTier::Standard }
#[allow(clippy::cast_precision_loss)] fn estimate_cost(&self, usage: &TokenUsage) -> f64 {
let input_cost = usage.input_tokens as f64 * 3.0 / 1_000_000.0;
let output_cost = usage.output_tokens as f64 * 15.0 / 1_000_000.0;
input_cost + output_cost
}
}
#[cfg(test)]
#[path = "remote_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "remote_tests_body.rs"]
mod tests_body;