use reqwest::Client;
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::Semaphore;
use super::{
best_effort_parse_json_object, effective_thinking_payload, extract_assistant_content,
extract_thinking_content, extract_usage, is_anthropic_endpoint, reasoning_payload,
truncate_utf8_tail, AutomationResult, AutomationUsage, ContentAnalysis, EngineError,
EngineResult, ExtractionSchema, PromptUrlGate, RemoteMultimodalConfig, DEFAULT_SYSTEM_PROMPT,
EXTRACTION_ONLY_SYSTEM_PROMPT,
};
static CLIENT: std::sync::LazyLock<Client> = std::sync::LazyLock::new(Client::new);
#[derive(Debug, Clone)]
pub struct RemoteMultimodalEngine {
pub api_url: String,
pub api_key: Option<String>,
pub model_name: String,
pub system_prompt: Option<String>,
pub system_prompt_extra: Option<String>,
pub user_message_extra: Option<String>,
pub cfg: RemoteMultimodalConfig,
pub prompt_url_gate: Option<PromptUrlGate>,
pub semaphore: Option<Arc<Semaphore>>,
pub vision_model: Option<super::ModelEndpoint>,
pub text_model: Option<super::ModelEndpoint>,
pub vision_route_mode: super::VisionRouteMode,
#[cfg(feature = "skills")]
pub skill_registry: Option<super::skills::SkillRegistry>,
#[cfg(feature = "memvid")]
pub experience_memory:
Option<std::sync::Arc<tokio::sync::RwLock<super::long_term_memory::ExperienceMemory>>>,
pub use_chrome_ai: bool,
pub chrome_ai_max_user_chars: usize,
pub model_router: Option<super::router::ModelRouter>,
pub model_pool: Vec<super::ModelEndpoint>,
pub client: Option<Client>,
}
impl RemoteMultimodalEngine {
pub fn new<S: Into<String>>(api_url: S, model_name: S, system_prompt: Option<String>) -> Self {
Self {
api_url: api_url.into(),
api_key: None,
model_name: model_name.into(),
system_prompt,
system_prompt_extra: None,
user_message_extra: None,
cfg: RemoteMultimodalConfig::default(),
prompt_url_gate: None,
semaphore: None,
vision_model: None,
text_model: None,
vision_route_mode: super::VisionRouteMode::default(),
#[cfg(feature = "skills")]
skill_registry: None,
#[cfg(feature = "memvid")]
experience_memory: None,
use_chrome_ai: false,
chrome_ai_max_user_chars: 6000,
model_router: None,
model_pool: Vec::new(),
client: None,
}
}
pub fn with_api_key(mut self, key: Option<&str>) -> Self {
self.api_key = key.map(|k| k.to_string());
self
}
pub fn with_config(mut self, cfg: RemoteMultimodalConfig) -> Self {
self.cfg = cfg;
self
}
pub fn with_max_inflight_requests(&mut self, n: usize) -> &mut Self {
if n > 0 {
self.semaphore = Some(Arc::new(Semaphore::new(n)));
} else {
self.semaphore = None;
}
self
}
pub fn with_semaphore(&mut self, sem: Option<Arc<Semaphore>>) -> &mut Self {
self.semaphore = sem;
self
}
pub fn with_system_prompt_extra(&mut self, extra: Option<&str>) -> &mut Self {
self.system_prompt_extra = extra.map(|s| s.to_string());
self
}
pub fn with_user_message_extra(&mut self, extra: Option<&str>) -> &mut Self {
self.user_message_extra = extra.map(|s| s.to_string());
self
}
pub fn with_prompt_url_gate(&mut self, gate: Option<PromptUrlGate>) -> &mut Self {
self.prompt_url_gate = gate;
self
}
#[cfg(feature = "skills")]
pub fn with_skill_registry(
&mut self,
registry: Option<super::skills::SkillRegistry>,
) -> &mut Self {
self.skill_registry = registry;
self
}
#[cfg(feature = "memvid")]
pub fn with_experience_memory(
&mut self,
memory: Option<
std::sync::Arc<tokio::sync::RwLock<super::long_term_memory::ExperienceMemory>>,
>,
) -> &mut Self {
self.experience_memory = memory;
self
}
pub fn with_chrome_ai(&mut self, enabled: bool) -> &mut Self {
self.use_chrome_ai = enabled;
self
}
pub fn with_chrome_ai_max_user_chars(&mut self, chars: usize) -> &mut Self {
self.chrome_ai_max_user_chars = chars;
self
}
pub fn should_use_chrome_ai(&self) -> bool {
self.use_chrome_ai || (self.api_url.is_empty() && self.api_key.is_none())
}
pub fn with_client(&mut self, client: Option<Client>) -> &mut Self {
self.client = client;
self
}
pub fn with_proxies(&mut self, proxies: Option<&[String]>) -> &mut Self {
self.client = proxies.and_then(|urls| {
if urls.is_empty() {
return None;
}
let mut builder = Client::builder().timeout(std::time::Duration::from_secs(120));
for url in urls {
if let Ok(proxy) = reqwest::Proxy::all(url) {
builder = builder.proxy(proxy);
}
}
builder.build().ok()
});
self
}
pub(crate) fn http_client(&self) -> &Client {
self.client.as_ref().unwrap_or(&CLIENT)
}
pub fn with_remote_multimodal_config(&mut self, cfg: RemoteMultimodalConfig) -> &mut Self {
self.cfg = cfg;
self
}
pub fn with_extra_ai_data(&mut self, enabled: bool) -> &mut Self {
self.cfg.extra_ai_data = enabled;
self
}
pub fn with_extraction_prompt(&mut self, prompt: Option<&str>) -> &mut Self {
self.cfg.extraction_prompt = prompt.map(|s| s.to_string());
self
}
pub fn with_screenshot(&mut self, enabled: bool) -> &mut Self {
self.cfg.screenshot = enabled;
self
}
pub fn with_extraction_schema(&mut self, schema: Option<ExtractionSchema>) -> &mut Self {
self.cfg.extraction_schema = schema;
self
}
pub fn config(&self) -> &RemoteMultimodalConfig {
&self.cfg
}
pub fn prompt_url_gate(&self) -> Option<&PromptUrlGate> {
self.prompt_url_gate.as_ref()
}
pub fn clone_with_cfg(&self, cfg: RemoteMultimodalConfig) -> Self {
Self {
api_url: self.api_url.clone(),
api_key: self.api_key.clone(),
model_name: self.model_name.clone(),
system_prompt: self.system_prompt.clone(),
system_prompt_extra: self.system_prompt_extra.clone(),
user_message_extra: self.user_message_extra.clone(),
cfg,
prompt_url_gate: self.prompt_url_gate.clone(),
semaphore: self.semaphore.clone(),
vision_model: self.vision_model.clone(),
text_model: self.text_model.clone(),
vision_route_mode: self.vision_route_mode,
#[cfg(feature = "skills")]
skill_registry: self.skill_registry.clone(),
#[cfg(feature = "memvid")]
experience_memory: self.experience_memory.clone(),
use_chrome_ai: self.use_chrome_ai,
chrome_ai_max_user_chars: self.chrome_ai_max_user_chars,
model_router: self.model_router.clone(),
model_pool: self.model_pool.clone(),
client: self.client.clone(),
}
}
pub async fn acquire_llm_permit(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
match &self.semaphore {
Some(sem) => Some(sem.clone().acquire_owned().await.ok()?),
None => None,
}
}
pub fn analyze_content(&self, html: &str) -> ContentAnalysis {
ContentAnalysis::analyze(html)
}
pub fn needs_screenshot(&self, html: &str) -> bool {
ContentAnalysis::quick_needs_screenshot(html)
}
#[allow(clippy::type_complexity)]
fn resolve_runtime_for_url(
&self,
url: &str,
) -> Option<(
RemoteMultimodalConfig,
Option<String>,
Option<String>,
Option<String>,
)> {
let mut effective_cfg = self.cfg.clone();
let mut effective_system_prompt = self.system_prompt.clone();
let mut effective_system_prompt_extra = self.system_prompt_extra.clone();
let mut effective_user_message_extra = self.user_message_extra.clone();
if let Some(gate) = &self.prompt_url_gate {
let gate_match = gate.match_url(url)?;
if let Some(override_cfg) = gate_match {
let defaults = super::AutomationConfig::default();
if override_cfg.max_steps != defaults.max_steps {
effective_cfg.max_rounds = override_cfg.max_steps.max(1);
}
if override_cfg.max_retries != defaults.max_retries {
effective_cfg.retry.max_attempts = override_cfg.max_retries.max(1);
}
if override_cfg.capture_screenshots != defaults.capture_screenshots {
effective_cfg.screenshot = override_cfg.capture_screenshots;
}
if override_cfg.capture_profile != defaults.capture_profile {
effective_cfg.capture_profiles = vec![override_cfg.capture_profile.clone()];
}
if override_cfg.extract_on_success || override_cfg.extraction_prompt.is_some() {
effective_cfg.extra_ai_data = true;
}
if let Some(extraction_prompt) = &override_cfg.extraction_prompt {
if !extraction_prompt.trim().is_empty() {
effective_cfg.extraction_prompt = Some(extraction_prompt.clone());
}
}
if let Some(system_prompt) = &override_cfg.system_prompt {
if !system_prompt.trim().is_empty() {
effective_system_prompt = Some(system_prompt.clone());
}
}
if let Some(system_prompt_extra) = &override_cfg.system_prompt_extra {
if !system_prompt_extra.trim().is_empty() {
effective_system_prompt_extra = Some(system_prompt_extra.clone());
}
}
if let Some(user_message_extra) = &override_cfg.user_message_extra {
if !user_message_extra.trim().is_empty() {
effective_user_message_extra = Some(user_message_extra.clone());
}
}
}
}
Some((
effective_cfg,
effective_system_prompt,
effective_system_prompt_extra,
effective_user_message_extra,
))
}
pub fn system_prompt_compiled(&self, effective_cfg: &RemoteMultimodalConfig) -> String {
let mut s = if effective_cfg.is_extraction_only() {
EXTRACTION_ONLY_SYSTEM_PROMPT.to_string()
} else {
DEFAULT_SYSTEM_PROMPT.to_string()
};
if let Some(base) = &self.system_prompt {
if !base.trim().is_empty() {
s.push_str("\n\n---\nCONFIGURED SYSTEM INSTRUCTIONS:\n");
s.push_str(base.trim());
}
}
if let Some(extra) = &self.system_prompt_extra {
if !extra.trim().is_empty() {
s.push_str("\n\n---\nADDITIONAL INSTRUCTIONS:\n");
s.push_str(extra.trim());
}
}
if effective_cfg.extra_ai_data {
s.push_str("\n\n---\nEXTRACTION MODE ENABLED:\n");
s.push_str("Include an \"extracted\" field in your JSON response containing structured data extracted from the page.\n");
if let Some(schema) = &effective_cfg.extraction_schema {
s.push_str("\nExtraction Schema: ");
s.push_str(&schema.name);
s.push('\n');
if let Some(desc) = &schema.description {
s.push_str("Description: ");
s.push_str(desc.trim());
s.push('\n');
}
s.push_str("\nThe \"extracted\" field MUST conform to this JSON Schema:\n");
s.push_str(&schema.schema);
s.push('\n');
if schema.strict {
s.push_str("\nSTRICT MODE: You MUST follow the schema exactly. Do not add extra fields or omit required fields.\n");
}
} else {
s.push_str("The \"extracted\" field should be a JSON object or array with the relevant data.\n");
}
if let Some(extraction_prompt) = &effective_cfg.extraction_prompt {
s.push_str("\nExtraction instructions: ");
s.push_str(extraction_prompt.trim());
s.push('\n');
}
s.push_str("\nExample response with extraction:\n");
s.push_str("{\n \"label\": \"extract_products\",\n \"done\": true,\n \"steps\": [],\n \"extracted\": {\"products\": [{\"name\": \"Product A\", \"price\": 19.99}]}\n}\n");
}
if effective_cfg.relevance_gate {
s.push_str("\n\n---\nRELEVANCE GATE ENABLED:\n");
s.push_str("Include a \"relevant\": true|false field in your JSON response.\n");
s.push_str("Set true if the page content is relevant to the extraction/crawl goals.\n");
s.push_str(
"Set false if the page is off-topic, a 404, login wall, or otherwise not useful.\n",
);
if let Some(prompt) = &effective_cfg.relevance_prompt {
s.push_str("\nRelevance criteria: ");
s.push_str(prompt.trim());
s.push('\n');
} else if let Some(ep) = &effective_cfg.extraction_prompt {
s.push_str("\nJudge relevance against: ");
s.push_str(ep.trim());
s.push('\n');
}
}
s.push_str("\n\n---\nRUNTIME CONFIG (read-only):\n");
s.push_str(&format!(
"- include_url: {}\n- include_title: {}\n- include_html: {}\n- html_max_bytes: {}\n- temperature: {}\n- max_tokens: {}\n- request_json_object: {}\n- best_effort_json_extract: {}\n- max_rounds: {}\n- extra_ai_data: {}\n- relevance_gate: {}\n",
effective_cfg.include_url,
effective_cfg.include_title,
effective_cfg.include_html,
effective_cfg.html_max_bytes,
effective_cfg.temperature,
effective_cfg.max_tokens,
effective_cfg.request_json_object,
effective_cfg.best_effort_json_extract,
effective_cfg.max_rounds,
effective_cfg.extra_ai_data,
effective_cfg.relevance_gate,
));
s
}
pub fn with_vision_model(&mut self, endpoint: Option<super::ModelEndpoint>) -> &mut Self {
self.vision_model = endpoint;
self
}
pub fn with_text_model(&mut self, endpoint: Option<super::ModelEndpoint>) -> &mut Self {
self.text_model = endpoint;
self
}
pub fn with_vision_route_mode(&mut self, mode: super::VisionRouteMode) -> &mut Self {
self.vision_route_mode = mode;
self
}
pub fn has_dual_model_routing(&self) -> bool {
self.vision_model.is_some() || self.text_model.is_some()
}
pub fn resolve_model_for_round(&self, use_vision: bool) -> (&str, &str, Option<&str>) {
let endpoint = if use_vision {
self.vision_model.as_ref()
} else {
self.text_model.as_ref()
};
match endpoint {
Some(ep) => {
let url = ep.api_url.as_deref().unwrap_or(&self.api_url);
let key = ep.api_key.as_deref().or(self.api_key.as_deref());
(url, &ep.model_name, key)
}
None => (&self.api_url, &self.model_name, self.api_key.as_deref()),
}
}
pub fn resolve_model_for_round_with_complexity(
&self,
use_vision: bool,
user_prompt: &str,
html_len: usize,
round_idx: usize,
stagnated: bool,
) -> (&str, &str, Option<&str>) {
let router = match &self.model_router {
Some(r) => r,
None => return self.resolve_model_for_round(use_vision),
};
let analysis =
super::router::classify_round_complexity(user_prompt, html_len, round_idx, stagnated);
let decision = router.route(&analysis);
let chosen_model = &decision.model;
if let Some(ep) = self
.model_pool
.iter()
.find(|ep| ep.model_name == *chosen_model)
{
if use_vision && !super::supports_vision(&ep.model_name) {
if let Some(fallback) = self.find_vision_fallback_in_pool(&decision.tier) {
let url = fallback.api_url.as_deref().unwrap_or(&self.api_url);
let key = fallback.api_key.as_deref().or(self.api_key.as_deref());
return (url, &fallback.model_name, key);
}
}
let url = ep.api_url.as_deref().unwrap_or(&self.api_url);
let key = ep.api_key.as_deref().or(self.api_key.as_deref());
return (url, &ep.model_name, key);
}
self.resolve_model_for_round(use_vision)
}
fn find_vision_fallback_in_pool(
&self,
starting_tier: &super::CostTier,
) -> Option<&super::ModelEndpoint> {
let router = self.model_router.as_ref()?;
let policy = router.policy();
let tiers_to_try: &[super::CostTier] = match starting_tier {
super::CostTier::Low => &[super::CostTier::Medium, super::CostTier::High],
super::CostTier::Medium => &[super::CostTier::High],
super::CostTier::High => &[],
};
for &tier in tiers_to_try {
let model_name = policy.model_for_tier(tier);
if let Some(ep) = self
.model_pool
.iter()
.find(|ep| ep.model_name == model_name)
{
if super::supports_vision(&ep.model_name) {
return Some(ep);
}
}
}
self.model_pool
.iter()
.find(|ep| super::supports_vision(&ep.model_name))
}
pub fn pick_fallback_model(
&self,
tried: &[String],
use_vision: bool,
) -> Option<(String, String, Option<String>)> {
let mut candidates: Vec<_> = self
.model_pool
.iter()
.filter(|ep| !tried.iter().any(|t| t == &ep.model_name))
.collect();
if candidates.is_empty() {
return None;
}
if use_vision {
candidates.sort_by_key(|ep| {
if super::supports_vision(&ep.model_name) {
0
} else {
1
}
});
}
let ep = candidates.first()?;
let url = ep.api_url.as_deref().unwrap_or(&self.api_url).to_string();
let key = ep
.api_key
.as_deref()
.or(self.api_key.as_deref())
.map(|s| s.to_string());
Some((url, ep.model_name.clone(), key))
}
pub fn should_use_vision_this_round(
&self,
round_idx: usize,
stagnated: bool,
action_stuck_rounds: usize,
force_vision: bool,
) -> bool {
if !self.has_dual_model_routing() {
return true;
}
if force_vision {
return true;
}
match self.vision_route_mode {
super::VisionRouteMode::AlwaysPrimary => true,
super::VisionRouteMode::TextFirst => {
round_idx == 0 || stagnated || action_stuck_rounds >= 3
}
super::VisionRouteMode::VisionFirst => {
round_idx < 2 || stagnated || action_stuck_rounds >= 3
}
super::VisionRouteMode::AgentDriven => false,
}
}
pub async fn extract_from_html(
&self,
html: &str,
url: &str,
title: Option<&str>,
) -> EngineResult<AutomationResult> {
#[derive(Serialize)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
}
#[derive(Serialize)]
struct Message {
role: String,
content: Vec<ContentBlock>,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String,
}
#[derive(Serialize)]
struct InferenceRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "response_format")]
response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
let Some((
effective_cfg,
effective_system_prompt,
effective_system_prompt_extra,
effective_user_message_extra,
)) = self.resolve_runtime_for_url(url)
else {
return Ok(AutomationResult {
label: "url_not_allowed".into(),
steps_executed: 0,
success: true,
error: None,
usage: AutomationUsage::default(),
extracted: None,
screenshot: None,
spawn_pages: Vec::new(),
relevant: None,
reasoning: None,
});
};
let mut prompt_engine = self.clone();
prompt_engine.system_prompt = effective_system_prompt;
prompt_engine.system_prompt_extra = effective_system_prompt_extra;
let is_anthropic = is_anthropic_endpoint(&self.api_url);
let mut user_text =
String::with_capacity(256 + html.len().min(effective_cfg.html_max_bytes));
user_text.push_str("EXTRACTION CONTEXT:\n");
user_text.push_str("- url: ");
user_text.push_str(url);
user_text.push('\n');
if let Some(t) = title {
user_text.push_str("- title: ");
user_text.push_str(t);
user_text.push('\n');
}
user_text.push_str("\nHTML CONTENT:\n");
let html_truncated = truncate_utf8_tail(html, effective_cfg.html_max_bytes);
user_text.push_str(&html_truncated);
user_text.push_str(
"\n\nTASK:\nExtract structured data from the HTML above. Return a JSON object with:\n",
);
user_text.push_str("- \"label\": short description of what was extracted\n");
user_text.push_str("- \"done\": true\n");
user_text.push_str("- \"steps\": [] (empty, no browser automation)\n");
user_text.push_str("- \"extracted\": the structured data extracted from the page\n");
if effective_cfg.relevance_gate {
user_text.push_str(
"- \"relevant\": true if page is relevant to the goal, false otherwise\n",
);
}
if let Some(extra) = &effective_user_message_extra {
if !extra.trim().is_empty() {
user_text.push_str("\n---\nUSER INSTRUCTIONS:\n");
user_text.push_str(extra.trim());
user_text.push('\n');
}
}
let system_text = prompt_engine.system_prompt_compiled(&effective_cfg);
let thinking_pl = if is_anthropic {
effective_thinking_payload(&effective_cfg)
} else {
None
};
let has_thinking = thinking_pl.is_some();
let messages = if is_anthropic {
vec![Message {
role: "user".into(),
content: vec![ContentBlock {
content_type: "text".into(),
text: Some(user_text),
}],
}]
} else {
vec![
Message {
role: "system".into(),
content: vec![ContentBlock {
content_type: "text".into(),
text: Some(system_text.clone()),
}],
},
Message {
role: "user".into(),
content: vec![ContentBlock {
content_type: "text".into(),
text: Some(user_text),
}],
},
]
};
let max_tokens = if let Some(budget) = thinking_pl
.as_ref()
.and_then(|v| v.get("budget_tokens"))
.and_then(|v| v.as_u64())
{
effective_cfg.max_tokens as u32 + budget as u32
} else {
effective_cfg.max_tokens as u32
};
let request_body = InferenceRequest {
model: self.model_name.clone(),
messages,
temperature: if has_thinking {
None
} else {
Some(effective_cfg.temperature)
},
max_tokens,
response_format: if is_anthropic || has_thinking {
None
} else if effective_cfg.request_json_object {
Some(ResponseFormat {
format_type: "json_object".into(),
})
} else {
None
},
reasoning: if is_anthropic {
None
} else {
reasoning_payload(&effective_cfg)
},
thinking: thinking_pl,
system: if is_anthropic {
Some(system_text)
} else {
None
},
};
let _permit = self.acquire_llm_permit().await;
let mut req = self.http_client().post(&self.api_url).json(&request_body);
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let start = std::time::Instant::now();
let http_resp = req.send().await?;
let status = http_resp.status();
let raw_body = http_resp.text().await?;
log::debug!(
"remote_multimodal extract_from_html: status={} latency={:?} body_len={}",
status,
start.elapsed(),
raw_body.len()
);
if !status.is_success() {
return Err(EngineError::RemoteStatus(
status.as_u16(),
format!("non-success status {status}: {raw_body}"),
));
}
let root: serde_json::Value = serde_json::from_str(&raw_body)
.map_err(|e| EngineError::Remote(format!("JSON parse error: {e}")))?;
let content = extract_assistant_content(&root)
.ok_or(EngineError::MissingField("assistant content"))?;
let usage = extract_usage(&root);
let plan_value = if effective_cfg.best_effort_json_extract {
best_effort_parse_json_object(&content)?
} else {
serde_json::from_str::<serde_json::Value>(&content)
.map_err(|e| EngineError::Remote(format!("JSON parse error: {e}")))?
};
let label = plan_value
.get("label")
.and_then(|v| v.as_str())
.unwrap_or("extraction")
.to_string();
let relevant = if effective_cfg.relevance_gate {
Some(
plan_value
.get("relevant")
.and_then(|v| v.as_bool())
.unwrap_or(true),
)
} else {
None
};
let reasoning = extract_thinking_content(&root).or_else(|| {
plan_value.get("reasoning").and_then(|v| {
if let Some(s) = v.as_str() {
let trimmed = s.trim();
return if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
};
}
if v.is_null() {
None
} else {
Some(v.to_string())
}
})
});
let extracted = plan_value.get("extracted").cloned().or_else(|| {
if plan_value.get("label").is_none()
&& plan_value.get("done").is_none()
&& plan_value.get("steps").is_none()
{
Some(plan_value.clone())
} else {
let mut extracted_data = serde_json::Map::new();
if let Some(obj) = plan_value.as_object() {
for (key, value) in obj {
if !matches!(
key.as_str(),
"label"
| "done"
| "steps"
| "memory_ops"
| "extracted"
| "relevant"
| "reasoning"
) {
extracted_data.insert(key.clone(), value.clone());
}
}
}
if !extracted_data.is_empty() {
Some(serde_json::Value::Object(extracted_data))
} else {
None
}
}
});
Ok(AutomationResult {
label,
steps_executed: 0,
success: true,
error: None,
usage,
extracted,
screenshot: None,
spawn_pages: Vec::new(),
relevant,
reasoning,
})
}
pub async fn extract_with_screenshot(
&self,
html: &str,
url: &str,
title: Option<&str>,
screenshot_base64: Option<&str>,
) -> EngineResult<AutomationResult> {
#[derive(Serialize)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
image_url: Option<ImageUrlBlock>,
}
#[derive(Serialize)]
struct ImageUrlBlock {
url: String,
}
#[derive(Serialize)]
struct Message {
role: String,
content: Vec<ContentBlock>,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String,
}
#[derive(Serialize)]
struct InferenceRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
let Some((
effective_cfg,
effective_system_prompt,
effective_system_prompt_extra,
effective_user_message_extra,
)) = self.resolve_runtime_for_url(url)
else {
return Ok(AutomationResult {
label: "url_not_allowed".into(),
steps_executed: 0,
success: true,
error: None,
usage: AutomationUsage::default(),
extracted: None,
screenshot: None,
spawn_pages: Vec::new(),
relevant: None,
reasoning: None,
});
};
let mut prompt_engine = self.clone();
prompt_engine.system_prompt = effective_system_prompt;
prompt_engine.system_prompt_extra = effective_system_prompt_extra;
let is_anthropic = is_anthropic_endpoint(&self.api_url);
let mut user_text =
String::with_capacity(256 + html.len().min(effective_cfg.html_max_bytes));
user_text.push_str("EXTRACTION CONTEXT:\n");
user_text.push_str("- url: ");
user_text.push_str(url);
user_text.push('\n');
if let Some(t) = title {
user_text.push_str("- title: ");
user_text.push_str(t);
user_text.push('\n');
}
if screenshot_base64.is_some() {
user_text.push_str("- screenshot: provided (use for visual content not in HTML)\n");
}
user_text.push_str("\nHTML CONTENT:\n");
let html_truncated = truncate_utf8_tail(html, effective_cfg.html_max_bytes);
user_text.push_str(&html_truncated);
user_text.push_str("\n\nTASK:\nExtract structured data from the page. Use both the HTML and screenshot (if provided) to extract information. Return a JSON object with:\n");
user_text.push_str("- \"label\": short description of what was extracted\n");
user_text.push_str("- \"done\": true\n");
user_text.push_str("- \"steps\": [] (empty, no browser automation)\n");
user_text.push_str("- \"extracted\": the structured data extracted from the page\n");
if effective_cfg.relevance_gate {
user_text.push_str(
"- \"relevant\": true if page is relevant to the goal, false otherwise\n",
);
}
if screenshot_base64.is_some() {
user_text.push_str("\nIMPORTANT: The screenshot may contain visual information not present in the HTML (iframe content, videos, canvas drawings, dynamically rendered content). Examine the screenshot carefully.\n");
}
if let Some(extra) = &effective_user_message_extra {
if !extra.trim().is_empty() {
user_text.push_str("\n---\nUSER INSTRUCTIONS:\n");
user_text.push_str(extra.trim());
user_text.push('\n');
}
}
let mut user_content = vec![ContentBlock {
content_type: "text".into(),
text: Some(user_text),
image_url: None,
}];
if let Some(screenshot) = screenshot_base64 {
let image_url = if screenshot.starts_with("data:") {
screenshot.to_string()
} else {
format!("data:image/png;base64,{}", screenshot)
};
user_content.push(ContentBlock {
content_type: "image_url".into(),
text: None,
image_url: Some(ImageUrlBlock { url: image_url }),
});
}
let system_text = prompt_engine.system_prompt_compiled(&effective_cfg);
let thinking_pl = if is_anthropic {
effective_thinking_payload(&effective_cfg)
} else {
None
};
let has_thinking = thinking_pl.is_some();
let messages = if is_anthropic {
vec![Message {
role: "user".into(),
content: user_content,
}]
} else {
vec![
Message {
role: "system".into(),
content: vec![ContentBlock {
content_type: "text".into(),
text: Some(system_text.clone()),
image_url: None,
}],
},
Message {
role: "user".into(),
content: user_content,
},
]
};
let max_tokens = if let Some(budget) = thinking_pl
.as_ref()
.and_then(|v| v.get("budget_tokens"))
.and_then(|v| v.as_u64())
{
effective_cfg.max_tokens as u32 + budget as u32
} else {
effective_cfg.max_tokens as u32
};
let request_body = InferenceRequest {
model: self.model_name.clone(),
messages,
temperature: if has_thinking {
None
} else {
Some(effective_cfg.temperature)
},
max_tokens,
response_format: if is_anthropic || has_thinking {
None
} else if effective_cfg.request_json_object {
Some(ResponseFormat {
format_type: "json_object".into(),
})
} else {
None
},
reasoning: if is_anthropic {
None
} else {
reasoning_payload(&effective_cfg)
},
thinking: thinking_pl,
system: if is_anthropic {
Some(system_text)
} else {
None
},
};
let _permit = self.acquire_llm_permit().await;
let mut req = self.http_client().post(&self.api_url).json(&request_body);
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let start = std::time::Instant::now();
let http_resp = req.send().await?;
let status = http_resp.status();
let raw_body = http_resp.text().await?;
log::debug!(
"remote_multimodal extract_with_screenshot: status={} latency={:?} body_len={}",
status,
start.elapsed(),
raw_body.len()
);
if !status.is_success() {
return Err(EngineError::RemoteStatus(
status.as_u16(),
format!("non-success status {status}: {raw_body}"),
));
}
let root: serde_json::Value = serde_json::from_str(&raw_body)
.map_err(|e| EngineError::Remote(format!("JSON parse error: {e}")))?;
let content = extract_assistant_content(&root)
.ok_or(EngineError::MissingField("assistant content"))?;
let usage = extract_usage(&root);
let plan_value = if effective_cfg.best_effort_json_extract {
best_effort_parse_json_object(&content)?
} else {
serde_json::from_str::<serde_json::Value>(&content)
.map_err(|e| EngineError::Remote(format!("JSON parse error: {e}")))?
};
let label = plan_value
.get("label")
.and_then(|v| v.as_str())
.unwrap_or("extraction")
.to_string();
let relevant = if effective_cfg.relevance_gate {
Some(
plan_value
.get("relevant")
.and_then(|v| v.as_bool())
.unwrap_or(true),
)
} else {
None
};
let reasoning = extract_thinking_content(&root).or_else(|| {
plan_value.get("reasoning").and_then(|v| {
if let Some(s) = v.as_str() {
let trimmed = s.trim();
return if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
};
}
if v.is_null() {
None
} else {
Some(v.to_string())
}
})
});
let extracted = plan_value.get("extracted").cloned().or_else(|| {
if plan_value.get("label").is_none()
&& plan_value.get("done").is_none()
&& plan_value.get("steps").is_none()
{
Some(plan_value.clone())
} else {
let mut extracted_data = serde_json::Map::new();
if let Some(obj) = plan_value.as_object() {
for (key, value) in obj {
if !matches!(
key.as_str(),
"label"
| "done"
| "steps"
| "memory_ops"
| "extracted"
| "relevant"
| "reasoning"
) {
extracted_data.insert(key.clone(), value.clone());
}
}
}
if !extracted_data.is_empty() {
Some(serde_json::Value::Object(extracted_data))
} else {
None
}
}
});
Ok(AutomationResult {
label,
steps_executed: 0,
success: true,
error: None,
usage,
extracted,
screenshot: None,
spawn_pages: Vec::new(),
relevant,
reasoning,
})
}
pub async fn chat_completion(
&self,
system_prompt: &str,
user_message: &str,
) -> EngineResult<(String, AutomationUsage)> {
#[derive(Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String,
}
#[derive(Serialize)]
struct InferenceRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
let is_anthropic = is_anthropic_endpoint(&self.api_url);
let thinking_pl = if is_anthropic {
effective_thinking_payload(&self.cfg)
} else {
None
};
let has_thinking = thinking_pl.is_some();
let messages = if is_anthropic {
vec![Message {
role: "user".into(),
content: user_message.to_string(),
}]
} else {
vec![
Message {
role: "system".into(),
content: system_prompt.to_string(),
},
Message {
role: "user".into(),
content: user_message.to_string(),
},
]
};
let max_tokens = if let Some(budget) = thinking_pl
.as_ref()
.and_then(|v| v.get("budget_tokens"))
.and_then(|v| v.as_u64())
{
self.cfg.max_tokens as u32 + budget as u32
} else {
self.cfg.max_tokens as u32
};
let request_body = InferenceRequest {
model: self.model_name.clone(),
messages,
temperature: if has_thinking {
None
} else {
Some(self.cfg.temperature)
},
max_tokens,
response_format: if is_anthropic || has_thinking {
None
} else if self.cfg.request_json_object {
Some(ResponseFormat {
format_type: "json_object".into(),
})
} else {
None
},
reasoning: if is_anthropic {
None
} else {
reasoning_payload(&self.cfg)
},
thinking: thinking_pl,
system: if is_anthropic {
Some(system_prompt.to_string())
} else {
None
},
};
let _permit = self.acquire_llm_permit().await;
let mut req = self.http_client().post(&self.api_url).json(&request_body);
if let Some(key) = &self.api_key {
req = req.bearer_auth(key);
}
let http_resp = req.send().await?;
let status = http_resp.status();
let raw_body = http_resp.text().await?;
if !status.is_success() {
return Err(EngineError::RemoteStatus(
status.as_u16(),
format!("non-success status {status}: {raw_body}"),
));
}
let root: serde_json::Value = serde_json::from_str(&raw_body)
.map_err(|e| EngineError::Remote(format!("JSON parse error: {e}")))?;
let content = extract_assistant_content(&root)
.ok_or(EngineError::MissingField("assistant content"))?;
let usage = extract_usage(&root);
Ok((content, usage))
}
pub async fn classify_urls(
&self,
urls: &[&str],
relevance_prompt: Option<&str>,
extraction_prompt: Option<&str>,
max_tokens: u16,
) -> EngineResult<Vec<bool>> {
if urls.is_empty() {
return Ok(Vec::new());
}
let criteria = relevance_prompt
.or(extraction_prompt)
.unwrap_or("General web crawling");
let system = format!(
"You are a URL relevance classifier. Given a list of URLs, determine which are relevant to the crawl goal.\nGoal: {}\n\nRespond ONLY with a JSON array of 1s and 0s, one per URL. 1=relevant, 0=irrelevant.\nExample: [1,0,1,1,0]",
criteria
);
let mut user_msg = String::with_capacity(urls.len() * 80);
user_msg.push_str("Classify these URLs:\n");
for (i, url) in urls.iter().enumerate() {
user_msg.push_str(&format!("{}. {}\n", i + 1, url));
}
let (api_url, model_name, api_key) = self.resolve_model_for_round(false);
let is_anthropic = is_anthropic_endpoint(api_url);
#[derive(Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Serialize)]
struct InferenceRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
}
let thinking_pl = if is_anthropic {
effective_thinking_payload(&self.cfg)
} else {
None
};
let has_thinking = thinking_pl.is_some();
let classify_max_tokens = if let Some(budget) = thinking_pl
.as_ref()
.and_then(|v| v.get("budget_tokens"))
.and_then(|v| v.as_u64())
{
max_tokens as u32 + budget as u32
} else {
max_tokens as u32
};
let messages = if is_anthropic {
vec![Message {
role: "user".into(),
content: user_msg,
}]
} else {
vec![
Message {
role: "system".into(),
content: system.clone(),
},
Message {
role: "user".into(),
content: user_msg,
},
]
};
let request_body = InferenceRequest {
model: model_name.to_string(),
messages,
temperature: if has_thinking { None } else { Some(0.0) },
max_tokens: classify_max_tokens,
reasoning: if is_anthropic {
None
} else {
reasoning_payload(&self.cfg)
},
thinking: thinking_pl,
system: if is_anthropic { Some(system) } else { None },
};
let _permit = self.acquire_llm_permit().await;
let mut req = self.http_client().post(api_url).json(&request_body);
if let Some(key) = api_key {
req = req.bearer_auth(key);
}
let http_resp = match req.send().await {
Ok(r) => r,
Err(e) => {
log::warn!("url_prefilter: HTTP error, assuming all relevant: {e}");
return Ok(vec![true; urls.len()]);
}
};
if !http_resp.status().is_success() {
log::warn!(
"url_prefilter: non-success status {}, assuming all relevant",
http_resp.status()
);
return Ok(vec![true; urls.len()]);
}
let raw_body = http_resp.text().await.unwrap_or_default();
let root: serde_json::Value = match serde_json::from_str(&raw_body) {
Ok(v) => v,
Err(_) => return Ok(vec![true; urls.len()]),
};
let content = match extract_assistant_content(&root) {
Some(c) => c,
None => return Ok(vec![true; urls.len()]),
};
Ok(parse_url_classifications(&content, urls.len()))
}
pub fn generate_schema_from_examples(
&self,
examples: &[serde_json::Value],
name: Option<&str>,
description: Option<&str>,
) -> super::GeneratedSchema {
let request = super::SchemaGenerationRequest {
examples: examples.to_vec(),
description: description.map(|s| s.to_string()),
strict: false,
name: name.map(|s| s.to_string()),
};
super::generate_schema(&request)
}
pub fn infer_schema(&self, example: &serde_json::Value) -> serde_json::Value {
super::infer_schema(example)
}
pub fn build_schema_prompt(
&self,
examples: &[serde_json::Value],
description: Option<&str>,
) -> String {
let request = super::SchemaGenerationRequest {
examples: examples.to_vec(),
description: description.map(|s| s.to_string()),
strict: false,
name: None,
};
super::build_schema_generation_prompt(&request)
}
pub fn parse_tool_calls(&self, response: &serde_json::Value) -> Vec<super::ToolCall> {
super::parse_tool_calls(response)
}
pub fn tool_calls_to_steps(&self, calls: &[super::ToolCall]) -> Vec<serde_json::Value> {
super::tool_calls_to_steps(calls)
}
pub fn action_tool_schemas(&self) -> Vec<super::ToolDefinition> {
super::ActionToolSchemas::all()
}
pub fn extract_html_context(&self, html: &str, max_bytes: usize) -> String {
super::extract_html_context(html, max_bytes)
}
pub fn create_dependency_graph(
&self,
steps: Vec<super::DependentStep>,
) -> Result<super::DependencyGraph, String> {
super::DependencyGraph::new(steps)
}
pub async fn execute_dependency_graph<F, Fut>(
&self,
graph: &mut super::DependencyGraph,
config: &super::ConcurrentChainConfig,
executor: F,
) -> super::ConcurrentChainResult
where
F: Fn(super::DependentStep) -> Fut + Clone + Send + Sync + 'static,
Fut: std::future::Future<Output = super::StepResult> + Send + 'static,
{
super::execute_graph(graph, config, executor).await
}
}
fn parse_url_classifications(response: &str, expected_len: usize) -> Vec<bool> {
let trimmed = response.trim();
let start = match trimmed.find('[') {
Some(i) => i,
None => return vec![true; expected_len],
};
let end = match trimmed.rfind(']') {
Some(i) => i + 1,
None => return vec![true; expected_len],
};
let arr_str = &trimmed[start..end];
let arr: Vec<serde_json::Value> = match serde_json::from_str(arr_str) {
Ok(v) => v,
Err(_) => return vec![true; expected_len],
};
if arr.len() != expected_len {
log::warn!(
"url_prefilter: classification length mismatch (got {}, expected {}), assuming all relevant",
arr.len(),
expected_len
);
return vec![true; expected_len];
}
arr.iter()
.map(|v| {
v.as_i64().map(|n| n != 0).unwrap_or_else(|| {
v.as_bool().unwrap_or(true) })
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_url_classifications_valid() {
assert_eq!(
parse_url_classifications("[1,0,1]", 3),
vec![true, false, true]
);
}
#[test]
fn test_parse_url_classifications_booleans() {
assert_eq!(
parse_url_classifications("[true,false,true]", 3),
vec![true, false, true]
);
}
#[test]
fn test_parse_url_classifications_length_mismatch() {
assert_eq!(
parse_url_classifications("[1,0]", 3),
vec![true, true, true]
);
}
#[test]
fn test_parse_url_classifications_invalid_json() {
assert_eq!(parse_url_classifications("not json", 2), vec![true, true]);
}
#[test]
fn test_parse_url_classifications_embedded_array() {
assert_eq!(
parse_url_classifications("Here are the results: [1,0,1,0]", 4),
vec![true, false, true, false]
);
}
#[test]
fn test_parse_url_classifications_empty() {
assert_eq!(parse_url_classifications("[]", 0), Vec::<bool>::new());
}
#[test]
fn test_engine_new() {
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
);
assert_eq!(engine.api_url, "https://api.openai.com/v1/chat/completions");
assert_eq!(engine.model_name, "gpt-4o");
assert!(engine.api_key.is_none());
assert!(engine.system_prompt.is_none());
}
#[test]
fn test_engine_with_api_key() {
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
)
.with_api_key(Some("sk-test"));
assert_eq!(engine.api_key, Some("sk-test".to_string()));
}
#[test]
fn test_engine_system_prompt_compiled() {
let mut engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
);
engine.with_system_prompt_extra(Some("Custom instructions"));
let compiled = engine.system_prompt_compiled(&RemoteMultimodalConfig::default());
assert!(compiled.starts_with(super::DEFAULT_SYSTEM_PROMPT));
assert!(compiled.contains("Custom instructions"));
assert!(compiled.contains("RUNTIME CONFIG"));
}
#[test]
fn test_engine_system_prompt_with_extraction() {
let cfg = RemoteMultimodalConfig {
extra_ai_data: true,
extraction_schema: Some(ExtractionSchema::new("products", r#"{"type":"array"}"#)),
..Default::default()
};
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
);
let compiled = engine.system_prompt_compiled(&cfg);
assert!(compiled.contains("EXTRACTION MODE ENABLED"));
assert!(compiled.contains("products"));
}
#[test]
fn test_engine_system_prompt_extraction_only() {
let cfg = RemoteMultimodalConfig::new()
.with_extraction(true)
.with_max_rounds(1);
assert!(cfg.is_extraction_only());
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
);
let compiled = engine.system_prompt_compiled(&cfg);
assert!(compiled.contains("data extraction assistant"));
assert!(!compiled.contains("ClickPoint"));
assert!(!compiled.contains("SetViewport"));
}
#[test]
fn test_engine_analyze_content() {
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
);
let html = "<html><body><p>Test content</p></body></html>";
let analysis = engine.analyze_content(html);
assert!(!analysis.has_visual_elements);
}
#[test]
fn test_engine_needs_screenshot() {
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
);
assert!(engine.needs_screenshot("<iframe src='x'></iframe>"));
assert!(!engine.needs_screenshot(&"a".repeat(2000)));
}
#[test]
fn test_engine_clone_with_cfg() {
let engine = RemoteMultimodalEngine::new(
"https://api.openai.com/v1/chat/completions",
"gpt-4o",
None,
)
.with_api_key(Some("sk-test"));
let new_cfg = RemoteMultimodalConfig {
max_rounds: 10,
..Default::default()
};
let cloned = engine.clone_with_cfg(new_cfg);
assert_eq!(cloned.api_key, Some("sk-test".to_string()));
assert_eq!(cloned.cfg.max_rounds, 10);
}
#[test]
fn test_engine_dual_model_routing_setup() {
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
assert!(!engine.has_dual_model_routing());
engine.with_vision_model(Some(crate::automation::ModelEndpoint::new("gpt-4o")));
engine.with_text_model(Some(crate::automation::ModelEndpoint::new("gpt-4o-mini")));
assert!(engine.has_dual_model_routing());
}
#[test]
fn test_engine_resolve_model_for_round() {
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "primary", None);
engine.api_key = Some("sk-parent".to_string());
engine.with_vision_model(Some(crate::automation::ModelEndpoint::new("vision-model")));
engine.with_text_model(Some(
crate::automation::ModelEndpoint::new("text-model")
.with_api_url("https://text.api.com")
.with_api_key("sk-text"),
));
let (url, model, key) = engine.resolve_model_for_round(true);
assert_eq!(model, "vision-model");
assert_eq!(url, "https://api.example.com");
assert_eq!(key, Some("sk-parent"));
let (url, model, key) = engine.resolve_model_for_round(false);
assert_eq!(model, "text-model");
assert_eq!(url, "https://text.api.com");
assert_eq!(key, Some("sk-text"));
}
#[test]
fn test_engine_should_use_vision_this_round() {
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.with_vision_model(Some(crate::automation::ModelEndpoint::new("gpt-4o")));
engine.with_text_model(Some(crate::automation::ModelEndpoint::new("gpt-4o-mini")));
engine.with_vision_route_mode(crate::automation::VisionRouteMode::TextFirst);
assert!(engine.should_use_vision_this_round(0, false, 0, false));
assert!(!engine.should_use_vision_this_round(1, false, 0, false));
assert!(engine.should_use_vision_this_round(3, true, 0, false));
assert!(engine.should_use_vision_this_round(3, false, 3, false));
assert!(engine.should_use_vision_this_round(5, false, 0, true));
}
#[test]
fn test_engine_no_routing_always_vision() {
let engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
assert!(!engine.has_dual_model_routing());
assert!(engine.should_use_vision_this_round(0, false, 0, false));
assert!(engine.should_use_vision_this_round(99, false, 0, false));
}
#[test]
fn test_engine_system_prompt_extraction_only_with_schema() {
let mut cfg = RemoteMultimodalConfig::new()
.with_extraction(true)
.with_max_rounds(1);
cfg.extraction_schema = Some(ExtractionSchema::new("products", r#"{"type":"array"}"#));
cfg.extraction_prompt = Some("Extract all products".to_string());
let engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
let compiled = engine.system_prompt_compiled(&cfg);
assert!(compiled.contains("data extraction assistant"));
assert!(compiled.contains("EXTRACTION MODE ENABLED"));
assert!(compiled.contains("products"));
assert!(compiled.contains("Extract all products"));
}
#[test]
fn test_engine_system_prompt_multi_round_extraction_uses_default() {
let cfg = RemoteMultimodalConfig {
extra_ai_data: true,
..Default::default()
};
assert!(!cfg.is_extraction_only());
let engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
let compiled = engine.system_prompt_compiled(&cfg);
assert!(compiled.contains("ClickPoint"));
assert!(compiled.contains("SetViewport"));
assert!(compiled.contains("EXTRACTION MODE ENABLED"));
}
#[test]
fn test_engine_resolve_runtime_for_url_override() {
let mut url_map = std::collections::HashMap::new();
let override_cfg = crate::automation::AutomationConfig::new("override goal")
.with_max_steps(2)
.with_retries(5)
.with_system_prompt("override system")
.with_system_prompt_extra("override extra")
.with_user_message_extra("override user")
.with_extraction("extract fields");
url_map.insert("https://example.com".to_string(), Box::new(override_cfg));
let gate = crate::automation::PromptUrlGate::with_map(url_map);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.with_prompt_url_gate(Some(gate));
let resolved = engine
.resolve_runtime_for_url("https://example.com")
.expect("url should be allowed");
let (cfg, system_prompt, system_prompt_extra, user_message_extra) = resolved;
assert_eq!(cfg.max_rounds, 2);
assert_eq!(cfg.retry.max_attempts, 5);
assert!(cfg.extra_ai_data);
assert_eq!(cfg.extraction_prompt.as_deref(), Some("extract fields"));
assert_eq!(system_prompt.as_deref(), Some("override system"));
assert_eq!(system_prompt_extra.as_deref(), Some("override extra"));
assert_eq!(user_message_extra.as_deref(), Some("override user"));
}
#[test]
fn test_engine_resolve_runtime_for_url_blocked() {
let mut url_map = std::collections::HashMap::new();
url_map.insert(
"https://allowed.com".to_string(),
Box::new(crate::automation::AutomationConfig::new("allowed")),
);
let gate = crate::automation::PromptUrlGate::with_map(url_map);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.with_prompt_url_gate(Some(gate));
assert!(engine
.resolve_runtime_for_url("https://blocked.com")
.is_none());
}
#[test]
fn test_pool_routing_no_router_delegates() {
let engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None)
.with_api_key(Some("sk-test"));
assert!(engine.model_router.is_none());
let (url, model, key) =
engine.resolve_model_for_round_with_complexity(true, "click button", 500, 3, false);
assert_eq!(url, "https://api.example.com");
assert_eq!(model, "gpt-4o");
assert_eq!(key, Some("sk-test"));
}
#[test]
fn test_pool_routing_picks_cheap_for_simple() {
use crate::automation::router::auto_policy;
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "deepseek-chat"]);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.model_router = Some(crate::automation::router::ModelRouter::with_policy(
policy.clone(),
));
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("gpt-4o"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
crate::automation::ModelEndpoint::new("deepseek-chat"),
];
let (_, model, _) =
engine.resolve_model_for_round_with_complexity(false, "click button", 500, 3, false);
assert_eq!(model, policy.small, "simple round should use cheap model");
}
#[test]
fn test_pool_routing_picks_expensive_for_complex() {
use crate::automation::router::auto_policy;
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "deepseek-chat"]);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.model_router = Some(crate::automation::router::ModelRouter::with_policy(
policy.clone(),
));
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("gpt-4o"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
crate::automation::ModelEndpoint::new("deepseek-chat"),
];
let long_prompt = "a]".repeat(9000); let (_, model, _) = engine.resolve_model_for_round_with_complexity(
false,
&format!("analyze and implement code to fix: {long_prompt}"),
60_000,
0,
true, );
assert_eq!(
model, policy.large,
"complex round should use powerful model"
);
}
#[test]
fn test_pool_routing_stagnated_upgrades() {
use crate::automation::router::auto_policy;
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "deepseek-chat"]);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.model_router = Some(crate::automation::router::ModelRouter::with_policy(
policy.clone(),
));
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("gpt-4o"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
crate::automation::ModelEndpoint::new("deepseek-chat"),
];
let (_, model, _) =
engine.resolve_model_for_round_with_complexity(false, "click button", 500, 5, true);
assert_ne!(
model, policy.small,
"stagnated round should not use cheapest model"
);
}
#[test]
fn test_pool_routing_vision_fallback() {
use crate::automation::router::auto_policy;
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "deepseek-chat"]);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.model_router = Some(crate::automation::router::ModelRouter::with_policy(policy));
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("gpt-4o"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
crate::automation::ModelEndpoint::new("deepseek-chat"),
];
let (_, model, _) =
engine.resolve_model_for_round_with_complexity(true, "click button", 500, 3, false);
assert!(
llm_models_spider::supports_vision(model),
"vision round should resolve to a vision-capable model, got {model}"
);
}
#[test]
fn test_pool_routing_inherits_endpoint_keys() {
use crate::automation::router::auto_policy;
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "deepseek-chat"]);
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.api_key = Some("sk-parent".to_string());
engine.model_router = Some(crate::automation::router::ModelRouter::with_policy(policy));
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("gpt-4o"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
crate::automation::ModelEndpoint::new("deepseek-chat")
.with_api_url("https://api.deepseek.com/v1/chat/completions")
.with_api_key("sk-ds"),
];
let (url, model, key) = engine.resolve_model_for_round_with_complexity(
false,
"analyze complex page with code",
60_000,
0,
false,
);
if model == "deepseek-chat" {
assert_eq!(url, "https://api.deepseek.com/v1/chat/completions");
assert_eq!(key, Some("sk-ds"));
} else {
assert_eq!(url, "https://api.example.com");
assert_eq!(key, Some("sk-parent"));
}
}
#[test]
fn test_pick_fallback_model_skips_tried() {
let mut engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
engine.api_key = Some("sk-test".to_string());
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("gpt-4o"),
crate::automation::ModelEndpoint::new("claude-sonnet-4-20250514")
.with_api_url("https://api.anthropic.com/v1/messages")
.with_api_key("sk-ant"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
];
let tried = vec!["gpt-4o".to_string()];
let fallback = engine.pick_fallback_model(&tried, false);
assert!(fallback.is_some());
let (_, model, _) = fallback.unwrap();
assert_ne!(model, "gpt-4o", "should not re-pick already-tried model");
let tried_all = vec![
"gpt-4o".to_string(),
"claude-sonnet-4-20250514".to_string(),
"gpt-4o-mini".to_string(),
];
assert!(engine.pick_fallback_model(&tried_all, false).is_none());
}
#[test]
fn test_pick_fallback_model_prefers_vision_when_needed() {
let mut engine =
RemoteMultimodalEngine::new("https://api.example.com", "deepseek-chat", None);
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("deepseek-chat"),
crate::automation::ModelEndpoint::new("gpt-4o-mini"),
crate::automation::ModelEndpoint::new("gpt-4o"),
];
let tried = vec!["deepseek-chat".to_string()];
let fallback = engine.pick_fallback_model(&tried, true);
assert!(fallback.is_some());
let (_, model, _) = fallback.unwrap();
assert!(
llm_models_spider::supports_vision(&model),
"should pick vision-capable model for vision round, got {model}"
);
}
#[test]
fn test_pick_fallback_model_inherits_endpoint_config() {
let mut engine = RemoteMultimodalEngine::new("https://api.default.com", "model-a", None);
engine.api_key = Some("sk-default".to_string());
engine.model_pool = vec![
crate::automation::ModelEndpoint::new("model-a"),
crate::automation::ModelEndpoint::new("model-b")
.with_api_url("https://api.custom.com")
.with_api_key("sk-custom"),
];
let tried = vec!["model-a".to_string()];
let fallback = engine.pick_fallback_model(&tried, false);
let (url, model, key) = fallback.unwrap();
assert_eq!(model, "model-b");
assert_eq!(url, "https://api.custom.com");
assert_eq!(key, Some("sk-custom".to_string()));
}
#[test]
fn test_pick_fallback_model_empty_pool() {
let engine = RemoteMultimodalEngine::new("https://api.example.com", "gpt-4o", None);
assert!(engine.pick_fallback_model(&[], false).is_none());
}
#[test]
fn test_engine_error_retryable_status_codes() {
use crate::automation::EngineError;
assert!(
EngineError::RemoteStatus(502, "bad gateway".into()).is_retryable_on_different_model()
);
assert!(
EngineError::RemoteStatus(503, "unavailable".into()).is_retryable_on_different_model()
);
assert!(
EngineError::RemoteStatus(429, "rate limit".into()).is_retryable_on_different_model()
);
assert!(EngineError::RemoteStatus(500, "internal".into()).is_retryable_on_different_model());
assert!(EngineError::RemoteStatus(504, "timeout".into()).is_retryable_on_different_model());
assert!(
!EngineError::RemoteStatus(400, "bad request".into()).is_retryable_on_different_model()
);
assert!(!EngineError::RemoteStatus(401, "unauthorized".into())
.is_retryable_on_different_model());
assert!(
!EngineError::RemoteStatus(403, "forbidden".into()).is_retryable_on_different_model()
);
assert!(
!EngineError::RemoteStatus(404, "not found".into()).is_retryable_on_different_model()
);
assert!(!EngineError::MissingField("test").is_retryable_on_different_model());
assert!(!EngineError::InvalidField("test").is_retryable_on_different_model());
}
}