use reqwest::Client;
use serde::Deserialize;
use tracing::{debug, instrument};
use crate::key_pool::{KeyLease, KeyPool};
use crate::schema::{ApiProtocol, ModelSchema, ModelSource};
use crate::tasks::ContentBlock;
use crate::InferenceError;
pub struct RemoteBackend {
pub(crate) client: Client,
pub key_pool: KeyPool,
}
pub fn estimate_tokens(text: &str) -> usize {
std::cmp::max(1, text.len() / 4)
}
fn truncate_prompt_to_fit(
prompt: &str,
context: Option<&str>,
tools_json: Option<&[serde_json::Value]>,
max_tokens: usize,
context_window: usize,
) -> String {
let context_tokens = context.map(|c| estimate_tokens(c)).unwrap_or(0);
let tools_tokens = tools_json
.map(|t| estimate_tokens(&serde_json::to_string(t).unwrap_or_default()))
.unwrap_or(0);
let overhead = 100; let reserved = context_tokens + tools_tokens + max_tokens + overhead;
let available = context_window.saturating_sub(reserved);
let prompt_tokens = estimate_tokens(prompt);
if prompt_tokens <= available {
return prompt.to_string();
}
tracing::warn!(
prompt_tokens = prompt_tokens,
available_tokens = available,
context_window = context_window,
"truncating prompt to fit context window (prefer compaction via car-memgine)"
);
let chars_to_keep = available * 4;
if chars_to_keep >= prompt.len() {
return prompt.to_string();
}
let start = prompt.len().saturating_sub(chars_to_keep);
let safe_start = prompt.ceil_char_boundary(start);
let truncated = &prompt[safe_start..];
let break_point = truncated
.find('\n')
.or_else(|| truncated.find(' '))
.unwrap_or(0);
format!(
"[...truncated...]\n{}",
&truncated[break_point..].trim_start()
)
}
impl RemoteBackend {
pub fn new() -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.read_timeout(std::time::Duration::from_secs(90))
.build()
.unwrap_or_default();
Self {
client,
key_pool: KeyPool::new(),
}
}
pub async fn register_model_keys(&self, schema: &ModelSchema) {
if let ModelSource::RemoteApi { ref endpoint, .. } = schema.source {
let env_vars = schema.all_api_key_envs();
if !env_vars.is_empty() {
self.key_pool.register_endpoint(endpoint, env_vars).await;
}
}
}
#[instrument(
name = "inference.remote_call",
skip_all,
fields(
model = %schema.name,
provider = %schema.provider,
),
)]
async fn execute_request(
&self,
schema: &ModelSchema,
req: crate::protocol::ApiRequest,
) -> Result<crate::protocol::ApiResponse, InferenceError> {
let (endpoint, protocol) = extract_remote_endpoint(schema)?;
let lease = self.lease_key(schema, &endpoint).await?;
let handler = crate::protocol::handler_for(protocol);
let start = std::time::Instant::now();
let api_version = match &schema.source {
ModelSource::RemoteApi { api_version, .. } => api_version.clone(),
_ => None,
};
let url = if matches!(protocol, ApiProtocol::Google) {
crate::protocol::google_url(&endpoint, &req.model, &lease.api_key)
} else if matches!(protocol, ApiProtocol::AzureOpenAi) {
let version = api_version.as_deref().unwrap_or("2024-10-21");
format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
endpoint.trim_end_matches('/'),
req.model,
version
)
} else {
format_endpoint(&endpoint, handler.endpoint_path())
};
let headers = handler.auth_headers(&lease.api_key);
let body = handler.build_request_body(&req);
debug!(url = %url, model = %req.model, "protocol handler request");
let mut builder = self.client.post(&url);
for (name, value) in &headers {
builder = builder.header(name.as_str(), value.as_str());
}
let send_fut = builder.json(&body).send();
let resp = tokio::time::timeout(std::time::Duration::from_secs(150), send_fut)
.await
.map_err(|_| {
InferenceError::InferenceFailed(
"request timed out after 150s (tokio safety timeout)".to_string(),
)
})?
.map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
let status = resp.status();
let body_fut = resp.text();
let resp_text = tokio::time::timeout(std::time::Duration::from_secs(120), body_fut)
.await
.map_err(|_| {
InferenceError::InferenceFailed(
"response body read timed out after 120s".to_string(),
)
})?
.map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
let latency_ms = start.elapsed().as_millis() as u64;
if !status.is_success() {
let is_rl = resp_text.contains("429") || resp_text.contains("RESOURCE_EXHAUSTED");
self.key_pool
.report_failure(&endpoint, &lease.env_var, is_rl)
.await;
return Err(InferenceError::InferenceFailed(format!(
"API returned {status}: {resp_text}"
)));
}
let est_tokens = req
.messages
.iter()
.filter_map(|m| m.get("content").and_then(|c| c.as_str()))
.map(|s| s.len() / 4)
.sum::<usize>() as u64;
self.key_pool
.report_success(&endpoint, &lease.env_var, latency_ms, est_tokens, 0)
.await;
let mut response = handler.parse_response(&resp_text)?;
if let Some(ref mut usage) = response.usage {
usage.context_window = schema.context_length as u64;
}
Ok(response)
}
pub async fn generate(
&self,
schema: &ModelSchema,
prompt: &str,
context: Option<&str>,
temperature: f64,
max_tokens: usize,
images: Option<&[ContentBlock]>,
) -> Result<String, InferenceError> {
let resp = self
.generate_with_tools_multi(
schema,
prompt,
context,
temperature,
max_tokens,
None,
images,
None,
None,
None,
0,
false,
None,
)
.await?;
Ok(resp.0)
}
pub async fn generate_with_tools(
&self,
schema: &ModelSchema,
prompt: &str,
context: Option<&str>,
temperature: f64,
max_tokens: usize,
tools: Option<&[serde_json::Value]>,
images: Option<&[ContentBlock]>,
) -> Result<(String, Vec<crate::tasks::generate::ToolCall>), InferenceError> {
let (text, calls, _usage) = self
.generate_with_tools_multi(
schema,
prompt,
context,
temperature,
max_tokens,
tools,
images,
None,
None,
None,
0,
false,
None,
)
.await?;
Ok((text, calls))
}
pub async fn generate_with_tools_multi(
&self,
schema: &ModelSchema,
prompt: &str,
context: Option<&str>,
temperature: f64,
max_tokens: usize,
tools: Option<&[serde_json::Value]>,
images: Option<&[ContentBlock]>,
messages: Option<&[crate::tasks::generate::Message]>,
tool_choice: Option<&str>,
parallel_tool_calls: Option<bool>,
budget_tokens: usize,
cache_control: bool,
response_format: Option<&crate::tasks::generate::ResponseFormat>,
) -> Result<
(
String,
Vec<crate::tasks::generate::ToolCall>,
Option<crate::TokenUsage>,
),
InferenceError,
> {
let (_, protocol) = extract_remote_endpoint(schema)?;
let handler = crate::protocol::handler_for(protocol);
if !handler.supports_video() {
let has_video_in_images =
images.is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video));
let has_video_in_messages = messages.is_some_and(|msgs| {
msgs.iter().any(|msg| match msg {
crate::tasks::generate::Message::UserMultimodal { content } => {
content.iter().any(ContentBlock::is_video)
}
_ => false,
})
});
if has_video_in_images || has_video_in_messages {
return Err(InferenceError::UnsupportedMode {
mode: "video-content-block",
backend: handler.protocol_name(),
reason: "this remote protocol has no native video input path; route to \
a provider that implements ProtocolHandler::supports_video() (Gemini)",
});
}
}
if !handler.supports_audio() {
let has_audio_in_images =
images.is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio));
let has_audio_in_messages = messages.is_some_and(|msgs| {
msgs.iter().any(|msg| match msg {
crate::tasks::generate::Message::UserMultimodal { content } => {
content.iter().any(ContentBlock::is_audio)
}
_ => false,
})
});
if has_audio_in_images || has_audio_in_messages {
return Err(InferenceError::UnsupportedMode {
mode: "audio-content-block",
backend: handler.protocol_name(),
reason: "this remote protocol has no native audio input path; route to \
a provider that implements ProtocolHandler::supports_audio() (Gemini)",
});
}
}
let prompt = if schema.context_length > 0 {
truncate_prompt_to_fit(prompt, context, tools, max_tokens, schema.context_length)
} else {
prompt.to_string()
};
let (api_messages, system) =
handler.build_messages(messages.unwrap_or(&[]), &prompt, context, images);
let api_tools = tools.map(|t| handler.build_tools(t));
let req = crate::protocol::ApiRequest {
model: request_model_name(schema),
messages: api_messages,
system,
temperature,
max_tokens,
tools: api_tools,
tool_choice: tool_choice.map(str::to_string),
parallel_tool_calls,
stream: false,
budget_tokens,
cache_control,
response_format: response_format.cloned(),
};
let response = self.execute_request(schema, req).await?;
let text = response.text;
let mut calls = response.tool_calls;
let usage = response.usage;
if !text.is_empty() {
for call in &mut calls {
if call.name == "done" {
let result_val = call
.arguments
.get("result")
.and_then(|v| v.as_str())
.unwrap_or("");
if result_val.len() < 50 && text.len() > result_val.len() {
call.arguments.insert(
"result".to_string(),
serde_json::Value::String(text.clone()),
);
}
}
}
}
Ok((text, calls, usage))
}
pub async fn embed(
&self,
schema: &ModelSchema,
texts: &[String],
) -> Result<Vec<Vec<f32>>, InferenceError> {
let (endpoint, protocol) = extract_remote_endpoint(schema)?;
let lease = self.lease_key(schema, &endpoint).await?;
let start = std::time::Instant::now();
let result = match protocol {
ApiProtocol::OpenAiCompat => {
self.embed_openai(&endpoint, &lease.api_key, &schema.name, texts)
.await
}
_ => Err(InferenceError::InferenceFailed(format!(
"embedding not supported for {:?} protocol",
protocol
))),
};
let latency_ms = start.elapsed().as_millis() as u64;
match &result {
Ok(_) => {
let est_tokens = texts
.iter()
.map(|t| t.split_whitespace().count() as u64)
.sum();
self.key_pool
.report_success(&endpoint, &lease.env_var, latency_ms, est_tokens, 0)
.await;
}
Err(e) => {
let is_rl =
e.to_string().contains("429") || e.to_string().contains("RESOURCE_EXHAUSTED");
self.key_pool
.report_failure(&endpoint, &lease.env_var, is_rl)
.await;
}
}
result
}
async fn lease_key(
&self,
schema: &ModelSchema,
endpoint: &str,
) -> Result<KeyLease, InferenceError> {
let fallback_env = match &schema.source {
ModelSource::RemoteApi { api_key_env, .. } => api_key_env.as_str(),
ModelSource::Ollama { .. } | ModelSource::VllmMlx { .. } => {
return Ok(KeyLease {
api_key: String::new(),
env_var: String::new(),
index: 0,
})
}
_ => {
return Err(InferenceError::InferenceFailed(format!(
"model {} is not remote",
schema.id
)))
}
};
self.register_model_keys(schema).await;
self.key_pool
.lease_or_env(endpoint, fallback_env)
.await
.ok_or_else(|| {
InferenceError::InferenceFailed(format!(
"no API keys available for endpoint {} (checked env vars: {:?})",
endpoint,
schema.all_api_key_envs()
))
})
}
async fn embed_openai(
&self,
endpoint: &str,
api_key: &str,
model: &str,
texts: &[String],
) -> Result<Vec<Vec<f32>>, InferenceError> {
let url = format_endpoint(endpoint, "/v1/embeddings");
let body = serde_json::json!({
"model": model,
"input": texts,
});
let resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {api_key}"))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| InferenceError::InferenceFailed(format!("read body: {e}")))?;
if !status.is_success() {
return Err(InferenceError::InferenceFailed(format!(
"API returned {status}: {text}"
)));
}
let parsed: OpenAiEmbedResponse = serde_json::from_str(&text)
.map_err(|e| InferenceError::InferenceFailed(format!("parse response: {e}")))?;
Ok(parsed.data.into_iter().map(|d| d.embedding).collect())
}
pub async fn generate_stream(
&self,
schema: &ModelSchema,
prompt: &str,
context: Option<&str>,
temperature: f64,
max_tokens: usize,
tools: Option<&[serde_json::Value]>,
images: Option<&[ContentBlock]>,
tool_choice: Option<&str>,
parallel_tool_calls: Option<bool>,
response_format: Option<&crate::tasks::generate::ResponseFormat>,
) -> Result<tokio::sync::mpsc::Receiver<crate::stream::StreamEvent>, InferenceError> {
let (endpoint, protocol) = extract_remote_endpoint(schema)?;
let lease = self.lease_key(schema, &endpoint).await?;
let api_key = lease.api_key;
let model = request_model_name(schema);
let handler = crate::protocol::handler_for(protocol);
if !handler.supports_video()
&& images.is_some_and(|blocks| blocks.iter().any(ContentBlock::is_video))
{
return Err(InferenceError::UnsupportedMode {
mode: "video-content-block",
backend: handler.protocol_name(),
reason: "this remote protocol has no native video input path; route to \
a provider that implements ProtocolHandler::supports_video() (Gemini)",
});
}
if !handler.supports_audio()
&& images.is_some_and(|blocks| blocks.iter().any(ContentBlock::is_audio))
{
return Err(InferenceError::UnsupportedMode {
mode: "audio-content-block",
backend: handler.protocol_name(),
reason: "this remote protocol has no native audio input path; route to \
a provider that implements ProtocolHandler::supports_audio() (Gemini)",
});
}
if matches!(protocol, ApiProtocol::Google) {
return Err(InferenceError::InferenceFailed(
"streaming not supported for Google protocol".to_string(),
));
}
let (messages, system) = handler.build_messages(&[], prompt, context, images);
let built_tools = tools.map(|t| handler.build_tools(t));
let req = crate::protocol::ApiRequest {
model: model.clone(),
messages,
system,
temperature,
max_tokens,
tools: built_tools,
tool_choice: tool_choice.map(str::to_string),
parallel_tool_calls,
stream: true,
budget_tokens: 0,
cache_control: false,
response_format: response_format.cloned(),
};
let body = handler.build_request_body(&req);
let url = if matches!(protocol, ApiProtocol::AzureOpenAi) {
let api_version = match &schema.source {
ModelSource::RemoteApi { api_version, .. } => api_version.clone(),
_ => None,
};
let version = api_version.as_deref().unwrap_or("2024-10-21");
format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
endpoint.trim_end_matches('/'),
model,
version
)
} else {
format_endpoint(&endpoint, handler.endpoint_path())
};
let mut headers = reqwest::header::HeaderMap::new();
for (name, value) in handler.auth_headers(&api_key) {
headers.insert(
reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
InferenceError::InferenceFailed(format!("auth header name: {e}"))
})?,
value.parse().map_err(|e| {
InferenceError::InferenceFailed(format!("auth header value: {e}"))
})?,
);
}
let send_fut = self.client.post(&url).headers(headers).json(&body).send();
let resp = tokio::time::timeout(std::time::Duration::from_secs(150), send_fut)
.await
.map_err(|_| {
InferenceError::InferenceFailed(
"stream request timed out after 150s (tokio safety timeout)".to_string(),
)
})?
.map_err(|e| InferenceError::InferenceFailed(format!("HTTP error: {e}")))?;
let status = resp.status();
if !status.is_success() {
let err_text = resp.text().await.unwrap_or_default();
return Err(InferenceError::InferenceFailed(format!(
"API returned {status}: {err_text}"
)));
}
let is_anthropic = matches!(protocol, ApiProtocol::Anthropic);
let (tx, rx) = tokio::sync::mpsc::channel::<crate::stream::StreamEvent>(64);
tokio::spawn(async move {
use futures::StreamExt;
let mut byte_stream = resp.bytes_stream();
let mut buffer = String::new();
while let Some(chunk_result) = byte_stream.next().await {
let chunk = match chunk_result {
Ok(bytes) => String::from_utf8_lossy(&bytes).to_string(),
Err(_) => break,
};
buffer.push_str(&chunk);
while let Some(pos) = buffer.find("\n\n") {
let event_block = buffer[..pos].to_string();
buffer = buffer[pos + 2..].to_string();
let sse_events = crate::stream::parse_sse_lines(&event_block);
for (event_type, data) in sse_events {
if data == "[DONE]" {
continue;
}
let stream_events = if is_anthropic {
crate::stream::parse_anthropic_sse_line(&event_type, &data)
} else {
crate::stream::parse_openai_sse_line(&format!("data: {data}"))
};
for evt in stream_events {
if tx.send(evt).await.is_err() {
return; }
}
}
}
}
});
Ok(rx)
}
}
impl Default for RemoteBackend {
fn default() -> Self {
Self::new()
}
}
fn extract_remote_endpoint(schema: &ModelSchema) -> Result<(String, ApiProtocol), InferenceError> {
match &schema.source {
ModelSource::RemoteApi {
endpoint, protocol, ..
} => Ok((endpoint.clone(), *protocol)),
ModelSource::Ollama { host, .. } => Ok((host.clone(), ApiProtocol::OpenAiCompat)),
ModelSource::VllmMlx { endpoint, .. } => Ok((endpoint.clone(), ApiProtocol::OpenAiCompat)),
_ => Err(InferenceError::InferenceFailed(format!(
"model {} is not remote",
schema.id
))),
}
}
fn request_model_name(schema: &ModelSchema) -> String {
match &schema.source {
ModelSource::VllmMlx { model_name, .. } => model_name.clone(),
_ => schema.name.clone(),
}
}
fn format_endpoint(base: &str, path: &str) -> String {
let base = base.trim_end_matches('/');
if base.ends_with(path.trim_start_matches('/')) {
base.to_string()
} else {
format!("{}{}", base, path)
}
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedResponse {
data: Vec<OpenAiEmbedData>,
}
#[derive(Debug, Deserialize)]
struct OpenAiEmbedData {
embedding: Vec<f32>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn format_endpoint_no_dup() {
assert_eq!(
format_endpoint("https://api.openai.com", "/v1/chat/completions"),
"https://api.openai.com/v1/chat/completions"
);
assert_eq!(
format_endpoint(
"https://api.openai.com/v1/chat/completions",
"/v1/chat/completions"
),
"https://api.openai.com/v1/chat/completions"
);
assert_eq!(
format_endpoint("https://api.openai.com/", "/v1/chat/completions"),
"https://api.openai.com/v1/chat/completions"
);
}
#[test]
fn extract_endpoint_from_remote() {
let schema = ModelSchema {
id: "test/model:v1".into(),
name: "Test".into(),
provider: "test".into(),
family: "test".into(),
version: "1".into(),
capabilities: vec![],
context_length: 4096,
param_count: String::new(),
quantization: None,
performance: Default::default(),
cost: Default::default(),
source: ModelSource::RemoteApi {
endpoint: "https://api.test.com".into(),
api_key_env: "NONEXISTENT_TEST_KEY_12345".into(),
api_key_envs: vec![],
api_version: None,
protocol: ApiProtocol::OpenAiCompat,
},
tags: vec![],
supported_params: vec![],
public_benchmarks: vec![],
available: false,
};
let (endpoint, protocol) = extract_remote_endpoint(&schema).unwrap();
assert_eq!(endpoint, "https://api.test.com");
assert_eq!(protocol, ApiProtocol::OpenAiCompat);
}
#[test]
fn extract_endpoint_non_remote_fails() {
let schema = ModelSchema {
id: "local/model:v1".into(),
name: "Local".into(),
provider: "test".into(),
family: "test".into(),
version: "1".into(),
capabilities: vec![],
context_length: 4096,
param_count: String::new(),
quantization: None,
performance: Default::default(),
cost: Default::default(),
source: ModelSource::Local {
hf_repo: "test".into(),
hf_filename: "test".into(),
tokenizer_repo: "test".into(),
},
tags: vec![],
supported_params: vec![],
public_benchmarks: vec![],
available: false,
};
assert!(extract_remote_endpoint(&schema).is_err());
}
#[test]
fn parse_openai_embed_response() {
let json = r#"{"data":[{"embedding":[0.1,0.2,0.3]}]}"#;
let resp: OpenAiEmbedResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.data[0].embedding, vec![0.1, 0.2, 0.3]);
}
#[test]
fn request_model_name_uses_vllm_server_model() {
let schema = ModelSchema {
id: "vllm-mlx/test".into(),
name: "Display Name".into(),
provider: "test".into(),
family: "test".into(),
version: "1".into(),
capabilities: vec![],
context_length: 4096,
param_count: String::new(),
quantization: None,
performance: Default::default(),
cost: Default::default(),
source: ModelSource::VllmMlx {
endpoint: "http://localhost:8000".into(),
model_name: "mlx-community/Actual-Model".into(),
},
tags: vec![],
supported_params: vec![],
public_benchmarks: vec![],
available: true,
};
assert_eq!(request_model_name(&schema), "mlx-community/Actual-Model");
}
#[test]
fn truncate_prompt_fits_returns_unchanged() {
let prompt = "short prompt";
let result = truncate_prompt_to_fit(prompt, None, None, 16, 256);
assert_eq!(result, prompt);
}
#[test]
fn truncate_prompt_cjk_mid_codepoint_does_not_panic() {
let prompt: String = std::iter::repeat('\u{4E16}').take(200).collect();
let result = truncate_prompt_to_fit(&prompt, None, None, 20, 209);
assert!(result.starts_with("[...truncated...]"));
let kept = result.strip_prefix("[...truncated...]\n").unwrap();
assert!(!kept.is_empty());
}
#[test]
fn truncate_prompt_accounts_for_context_and_tools() {
let prompt = "line one\nline two\nline three\n".repeat(50);
let tools = vec![serde_json::json!({"name": "demo_tool"})];
let result = truncate_prompt_to_fit(&prompt, Some("ctx"), Some(&tools), 20, 240);
assert!(result.starts_with("[...truncated...]"));
}
}