use std::sync::{Arc, OnceLock};
use super::{
CaptureProfile, ConfidenceRetryStrategy, ExtractionSchema, HtmlDiffMode, ModelEndpoint,
ModelPolicy, PlanningModeConfig, PromptUrlGate, ReasoningEffort, RemoteMultimodalConfig,
RetryPolicy, SelfHealingConfig, SynthesisConfig, ToolCallingMode, VisionRouteMode,
};
fn default_chrome_ai_max_user_chars() -> usize {
6000
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct RemoteMultimodalConfigs {
pub api_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
pub model_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt_extra: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_message_extra: Option<String>,
pub cfg: RemoteMultimodalConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_url_gate: Option<PromptUrlGate>,
#[serde(skip_serializing_if = "Option::is_none")]
pub concurrency_limit: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub vision_model: Option<ModelEndpoint>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text_model: Option<ModelEndpoint>,
#[serde(default)]
pub vision_route_mode: VisionRouteMode,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub model_pool: Vec<ModelEndpoint>,
#[serde(default)]
pub use_chrome_ai: bool,
#[serde(default = "default_chrome_ai_max_user_chars")]
pub chrome_ai_max_user_chars: usize,
#[cfg(feature = "skills")]
#[serde(skip)]
pub skill_registry: Option<super::skills::SkillRegistry>,
#[cfg(feature = "skills_s3")]
#[serde(skip_serializing_if = "Option::is_none")]
pub s3_skill_source: Option<super::skills::S3SkillSource>,
#[serde(skip, default = "RemoteMultimodalConfigs::default_semaphore")]
pub semaphore: OnceLock<Arc<tokio::sync::Semaphore>>,
#[serde(skip)]
pub relevance_credits: Arc<std::sync::atomic::AtomicU32>,
#[serde(skip)]
pub url_prefilter_cache: Arc<dashmap::DashMap<String, bool>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub proxies: Option<Vec<String>>,
}
impl PartialEq for RemoteMultimodalConfigs {
fn eq(&self, other: &Self) -> bool {
self.api_url == other.api_url
&& self.api_key == other.api_key
&& self.model_name == other.model_name
&& self.system_prompt == other.system_prompt
&& self.system_prompt_extra == other.system_prompt_extra
&& self.user_message_extra == other.user_message_extra
&& self.cfg == other.cfg
&& self.prompt_url_gate == other.prompt_url_gate
&& self.concurrency_limit == other.concurrency_limit
&& self.vision_model == other.vision_model
&& self.text_model == other.text_model
&& self.vision_route_mode == other.vision_route_mode
&& self.model_pool == other.model_pool
&& self.use_chrome_ai == other.use_chrome_ai
&& self.chrome_ai_max_user_chars == other.chrome_ai_max_user_chars
&& self.proxies == other.proxies
}
}
impl Eq for RemoteMultimodalConfigs {}
impl Default for RemoteMultimodalConfigs {
fn default() -> Self {
Self {
api_url: String::new(),
api_key: None,
model_name: String::new(),
system_prompt: None,
system_prompt_extra: None,
user_message_extra: None,
cfg: RemoteMultimodalConfig::default(),
prompt_url_gate: None,
concurrency_limit: None,
vision_model: None,
text_model: None,
vision_route_mode: VisionRouteMode::default(),
model_pool: Vec::new(),
use_chrome_ai: false,
chrome_ai_max_user_chars: default_chrome_ai_max_user_chars(),
#[cfg(feature = "skills")]
skill_registry: Some(super::skills::builtin_web_challenges()),
#[cfg(feature = "skills_s3")]
s3_skill_source: None,
semaphore: Self::default_semaphore(),
relevance_credits: Arc::new(std::sync::atomic::AtomicU32::new(0)),
url_prefilter_cache: Arc::new(dashmap::DashMap::new()),
proxies: None,
}
}
}
impl RemoteMultimodalConfigs {
pub fn new(api_url: impl Into<String>, model_name: impl Into<String>) -> Self {
Self {
api_url: api_url.into(),
model_name: model_name.into(),
..Default::default()
}
}
fn default_semaphore() -> OnceLock<Arc<tokio::sync::Semaphore>> {
OnceLock::new()
}
pub fn get_or_init_semaphore(&self) -> Option<Arc<tokio::sync::Semaphore>> {
let n = self.concurrency_limit?;
if n == 0 {
return None;
}
Some(
self.semaphore
.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(n)))
.clone(),
)
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = Some(prompt.into());
self
}
pub fn with_system_prompt_extra(mut self, extra: impl Into<String>) -> Self {
self.system_prompt_extra = Some(extra.into());
self
}
pub fn with_user_message_extra(mut self, extra: impl Into<String>) -> Self {
self.user_message_extra = Some(extra.into());
self
}
pub fn with_cfg(mut self, cfg: RemoteMultimodalConfig) -> Self {
self.cfg = cfg;
self
}
pub fn with_prompt_url_gate(mut self, gate: PromptUrlGate) -> Self {
self.prompt_url_gate = Some(gate);
self
}
pub fn with_concurrency_limit(mut self, limit: usize) -> Self {
self.concurrency_limit = Some(limit);
self
}
pub fn with_proxies(mut self, proxies: Option<Vec<String>>) -> Self {
self.proxies = proxies;
self
}
pub fn with_extra_ai_data(mut self, enabled: bool) -> Self {
self.cfg.extra_ai_data = enabled;
self
}
pub fn with_extraction_prompt(mut self, prompt: impl Into<String>) -> Self {
self.cfg.extraction_prompt = Some(prompt.into());
self
}
pub fn with_screenshot(mut self, enabled: bool) -> Self {
self.cfg.screenshot = enabled;
self
}
pub fn with_extraction_schema(mut self, schema: ExtractionSchema) -> Self {
self.cfg.extraction_schema = Some(schema);
self
}
pub fn model_supports_vision(&self) -> bool {
super::supports_vision(&self.model_name)
}
pub fn should_include_screenshot(&self) -> bool {
match self.cfg.include_screenshot {
Some(explicit) => explicit,
None => self.model_supports_vision(),
}
}
pub fn filter_screenshot<'a>(&self, screenshot: Option<&'a str>) -> Option<&'a str> {
if self.should_include_screenshot() {
screenshot
} else {
None
}
}
pub fn with_vision_model(mut self, endpoint: ModelEndpoint) -> Self {
self.vision_model = Some(endpoint);
self
}
pub fn with_text_model(mut self, endpoint: ModelEndpoint) -> Self {
self.text_model = Some(endpoint);
self
}
pub fn with_vision_route_mode(mut self, mode: VisionRouteMode) -> Self {
self.vision_route_mode = mode;
self
}
pub fn with_dual_models(mut self, vision: ModelEndpoint, text: ModelEndpoint) -> Self {
self.vision_model = Some(vision);
self.text_model = Some(text);
self
}
pub fn with_model_pool(mut self, pool: Vec<ModelEndpoint>) -> Self {
self.model_pool = pool;
self
}
#[cfg(feature = "skills_s3")]
pub fn with_s3_skill_source(mut self, source: super::skills::S3SkillSource) -> Self {
self.s3_skill_source = Some(source);
self
}
pub fn with_relevance_gate(mut self, prompt: Option<String>) -> Self {
self.cfg.relevance_gate = true;
self.cfg.relevance_prompt = prompt;
self
}
pub fn with_url_prefilter(mut self, batch_size: Option<usize>) -> Self {
self.cfg.url_prefilter = true;
if let Some(bs) = batch_size {
self.cfg.url_prefilter_batch_size = bs;
}
self
}
pub fn with_chrome_ai(mut self, enabled: bool) -> Self {
self.use_chrome_ai = enabled;
self
}
pub fn with_chrome_ai_max_user_chars(mut self, chars: usize) -> Self {
self.chrome_ai_max_user_chars = chars;
self
}
pub fn should_use_chrome_ai(&self) -> bool {
self.use_chrome_ai || (self.api_url.is_empty() && self.api_key.is_none())
}
pub fn with_automation_timeout_ms(mut self, ms: u64) -> Self {
self.cfg.automation_timeout_ms = Some(ms);
self
}
pub fn with_api_url(mut self, url: impl Into<String>) -> Self {
self.api_url = url.into();
self
}
pub fn with_model_name(mut self, name: impl Into<String>) -> Self {
self.model_name = name.into();
self
}
#[cfg(feature = "skills")]
pub fn with_skill_registry(mut self, registry: super::skills::SkillRegistry) -> Self {
self.skill_registry = Some(registry);
self
}
pub fn with_include_html(mut self, include: bool) -> Self {
self.cfg.include_html = include;
self
}
pub fn with_html_max_bytes(mut self, bytes: usize) -> Self {
self.cfg.html_max_bytes = bytes;
self
}
pub fn with_include_url(mut self, include: bool) -> Self {
self.cfg.include_url = include;
self
}
pub fn with_include_title(mut self, include: bool) -> Self {
self.cfg.include_title = include;
self
}
pub fn with_include_screenshot(mut self, include: Option<bool>) -> Self {
self.cfg.include_screenshot = include;
self
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.cfg.temperature = temp;
self
}
pub fn with_max_tokens(mut self, tokens: u16) -> Self {
self.cfg.max_tokens = tokens;
self
}
pub fn with_request_json_object(mut self, enabled: bool) -> Self {
self.cfg.request_json_object = enabled;
self
}
pub fn with_best_effort_json_extract(mut self, enabled: bool) -> Self {
self.cfg.best_effort_json_extract = enabled;
self
}
pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self {
self.cfg.reasoning_effort = Some(effort);
self
}
pub fn with_thinking_budget(mut self, budget: u32) -> Self {
self.cfg.thinking_budget = Some(budget);
self
}
pub fn with_max_rounds(mut self, rounds: usize) -> Self {
self.cfg.max_rounds = rounds;
self
}
pub fn with_retry(mut self, retry: RetryPolicy) -> Self {
self.cfg.retry = retry;
self
}
pub fn with_capture_profile(mut self, profile: CaptureProfile) -> Self {
self.cfg.capture_profiles.push(profile);
self
}
pub fn with_model_policy(mut self, policy: ModelPolicy) -> Self {
self.cfg.model_policy = policy;
self
}
pub fn with_post_plan_wait_ms(mut self, ms: u64) -> Self {
self.cfg.post_plan_wait_ms = ms;
self
}
pub fn with_max_inflight_requests(mut self, max: usize) -> Self {
self.cfg.max_inflight_requests = Some(max);
self
}
pub fn with_tool_calling_mode(mut self, mode: ToolCallingMode) -> Self {
self.cfg.tool_calling_mode = mode;
self
}
pub fn with_html_diff_mode(mut self, mode: HtmlDiffMode) -> Self {
self.cfg.html_diff_mode = mode;
self
}
pub fn with_planning_mode(mut self, config: PlanningModeConfig) -> Self {
self.cfg.planning_mode = Some(config);
self
}
pub fn with_synthesis_config(mut self, config: SynthesisConfig) -> Self {
self.cfg.synthesis_config = Some(config);
self
}
pub fn with_confidence_strategy(mut self, strategy: ConfidenceRetryStrategy) -> Self {
self.cfg.confidence_strategy = Some(strategy);
self
}
pub fn with_self_healing(mut self, config: SelfHealingConfig) -> Self {
self.cfg.self_healing = Some(config);
self
}
pub fn with_concurrent_execution(mut self, enabled: bool) -> Self {
self.cfg.concurrent_execution = enabled;
self
}
pub fn with_max_skills_per_round(mut self, max: usize) -> Self {
self.cfg.max_skills_per_round = max;
self
}
pub fn with_max_skill_context_chars(mut self, max: usize) -> Self {
self.cfg.max_skill_context_chars = max;
self
}
pub fn automation_timeout(&self) -> Option<std::time::Duration> {
self.cfg
.automation_timeout_ms
.map(std::time::Duration::from_millis)
}
pub fn has_dual_model_routing(&self) -> bool {
self.vision_model.is_some() || self.text_model.is_some()
}
pub fn resolve_model_for_round(&self, use_vision: bool) -> (&str, &str, Option<&str>) {
let endpoint = if use_vision {
self.vision_model.as_ref()
} else {
self.text_model.as_ref()
};
match endpoint {
Some(ep) => {
let url = ep.api_url.as_deref().unwrap_or(&self.api_url);
let key = ep.api_key.as_deref().or(self.api_key.as_deref());
(url, &ep.model_name, key)
}
None => (&self.api_url, &self.model_name, self.api_key.as_deref()),
}
}
pub fn should_use_vision_this_round(
&self,
round_idx: usize,
stagnated: bool,
action_stuck_rounds: usize,
force_vision: bool,
) -> bool {
if !self.has_dual_model_routing() {
return true; }
if force_vision {
return true;
}
match self.vision_route_mode {
VisionRouteMode::AlwaysPrimary => true,
VisionRouteMode::TextFirst => round_idx == 0 || stagnated || action_stuck_rounds >= 3,
VisionRouteMode::VisionFirst => round_idx < 2 || stagnated || action_stuck_rounds >= 3,
VisionRouteMode::AgentDriven => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remote_multimodal_configs_new() {
let configs = RemoteMultimodalConfigs::new(
"http://localhost:11434/v1/chat/completions",
"qwen2.5-vl",
);
assert_eq!(
configs.api_url,
"http://localhost:11434/v1/chat/completions"
);
assert_eq!(configs.model_name, "qwen2.5-vl");
assert!(configs.api_key.is_none());
assert!(configs.system_prompt.is_none());
}
#[test]
fn test_remote_multimodal_configs_builder() {
let configs =
RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o")
.with_api_key("sk-test")
.with_system_prompt("You are a helpful assistant.")
.with_concurrency_limit(5)
.with_screenshot(true);
assert_eq!(configs.api_key, Some("sk-test".to_string()));
assert_eq!(
configs.system_prompt,
Some("You are a helpful assistant.".to_string())
);
assert_eq!(configs.concurrency_limit, Some(5));
assert!(configs.cfg.screenshot);
}
#[test]
fn test_remote_multimodal_configs_vision_detection() {
let cfg =
RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
assert!(cfg.model_supports_vision());
assert!(cfg.should_include_screenshot());
let cfg = RemoteMultimodalConfigs::new(
"https://api.openai.com/v1/chat/completions",
"gpt-3.5-turbo",
);
assert!(!cfg.model_supports_vision());
assert!(!cfg.should_include_screenshot());
let mut cfg = RemoteMultimodalConfigs::new(
"https://api.openai.com/v1/chat/completions",
"gpt-3.5-turbo",
);
cfg.cfg.include_screenshot = Some(true);
assert!(cfg.should_include_screenshot());
let mut cfg =
RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
cfg.cfg.include_screenshot = Some(false);
assert!(!cfg.should_include_screenshot());
}
#[test]
fn test_filter_screenshot() {
let screenshot = "base64data...";
let cfg =
RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
assert_eq!(cfg.filter_screenshot(Some(screenshot)), Some(screenshot));
let cfg = RemoteMultimodalConfigs::new(
"https://api.openai.com/v1/chat/completions",
"gpt-3.5-turbo",
);
assert_eq!(cfg.filter_screenshot(Some(screenshot)), None);
let cfg =
RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o");
assert_eq!(cfg.filter_screenshot(None), None);
}
#[test]
fn test_has_dual_model_routing() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o");
assert!(!cfg.has_dual_model_routing());
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_vision_model(ModelEndpoint::new("gpt-4o"));
assert!(cfg.has_dual_model_routing());
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_text_model(ModelEndpoint::new("gpt-4o-mini"));
assert!(cfg.has_dual_model_routing());
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
);
assert!(cfg.has_dual_model_routing());
}
#[test]
fn test_resolve_model_for_round_no_routing() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_api_key("sk-parent");
let (url, model, key) = cfg.resolve_model_for_round(true);
assert_eq!(url, "https://api.example.com");
assert_eq!(model, "gpt-4o");
assert_eq!(key, Some("sk-parent"));
let (url, model, key) = cfg.resolve_model_for_round(false);
assert_eq!(url, "https://api.example.com");
assert_eq!(model, "gpt-4o");
assert_eq!(key, Some("sk-parent"));
}
#[test]
fn test_resolve_model_for_round_dual() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_api_key("sk-parent")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
);
let (url, model, key) = cfg.resolve_model_for_round(true);
assert_eq!(model, "gpt-4o");
assert_eq!(url, "https://api.example.com");
assert_eq!(key, Some("sk-parent"));
let (url, model, key) = cfg.resolve_model_for_round(false);
assert_eq!(model, "gpt-4o-mini");
assert_eq!(url, "https://api.example.com");
assert_eq!(key, Some("sk-parent"));
}
#[test]
fn test_resolve_model_cross_provider() {
let cfg =
RemoteMultimodalConfigs::new("https://api.openai.com/v1/chat/completions", "gpt-4o")
.with_api_key("sk-openai")
.with_vision_model(ModelEndpoint::new("gpt-4o"))
.with_text_model(
ModelEndpoint::new("llama-3.3-70b-versatile")
.with_api_url("https://api.groq.com/openai/v1/chat/completions")
.with_api_key("gsk-groq"),
);
let (url, model, key) = cfg.resolve_model_for_round(true);
assert_eq!(url, "https://api.openai.com/v1/chat/completions");
assert_eq!(model, "gpt-4o");
assert_eq!(key, Some("sk-openai"));
let (url, model, key) = cfg.resolve_model_for_round(false);
assert_eq!(url, "https://api.groq.com/openai/v1/chat/completions");
assert_eq!(model, "llama-3.3-70b-versatile");
assert_eq!(key, Some("gsk-groq"));
}
#[test]
fn test_vision_route_mode_always_primary() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
)
.with_vision_route_mode(VisionRouteMode::AlwaysPrimary);
assert!(cfg.should_use_vision_this_round(0, false, 0, false));
assert!(cfg.should_use_vision_this_round(5, false, 0, false));
assert!(cfg.should_use_vision_this_round(10, false, 0, false));
}
#[test]
fn test_vision_route_mode_text_first() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
)
.with_vision_route_mode(VisionRouteMode::TextFirst);
assert!(cfg.should_use_vision_this_round(0, false, 0, false));
assert!(!cfg.should_use_vision_this_round(1, false, 0, false));
assert!(!cfg.should_use_vision_this_round(5, false, 0, false));
assert!(cfg.should_use_vision_this_round(3, true, 0, false));
assert!(cfg.should_use_vision_this_round(5, false, 3, false));
assert!(cfg.should_use_vision_this_round(5, false, 0, true));
}
#[test]
fn test_vision_route_mode_vision_first() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
)
.with_vision_route_mode(VisionRouteMode::VisionFirst);
assert!(cfg.should_use_vision_this_round(0, false, 0, false));
assert!(cfg.should_use_vision_this_round(1, false, 0, false));
assert!(!cfg.should_use_vision_this_round(2, false, 0, false));
assert!(!cfg.should_use_vision_this_round(5, false, 0, false));
assert!(cfg.should_use_vision_this_round(5, true, 0, false));
assert!(cfg.should_use_vision_this_round(5, false, 3, false));
}
#[test]
fn test_no_dual_routing_always_returns_true() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o");
assert!(!cfg.has_dual_model_routing());
assert!(cfg.should_use_vision_this_round(0, false, 0, false));
assert!(cfg.should_use_vision_this_round(5, false, 0, false));
assert!(cfg.should_use_vision_this_round(99, false, 0, false));
}
#[test]
fn test_with_dual_models_builder() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "primary")
.with_dual_models(
ModelEndpoint::new("vision-model"),
ModelEndpoint::new("text-model"),
)
.with_vision_route_mode(VisionRouteMode::TextFirst);
assert!(cfg.has_dual_model_routing());
assert_eq!(
cfg.vision_model.as_ref().unwrap().model_name,
"vision-model"
);
assert_eq!(cfg.text_model.as_ref().unwrap().model_name, "text-model");
assert_eq!(cfg.vision_route_mode, VisionRouteMode::TextFirst);
}
#[test]
fn test_configs_serde_with_dual_models() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_api_key("sk-test")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini")
.with_api_url("https://other.api.com")
.with_api_key("sk-other"),
)
.with_vision_route_mode(VisionRouteMode::TextFirst);
let json = serde_json::to_string(&cfg).unwrap();
let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.model_name, "gpt-4o");
assert!(deserialized.has_dual_model_routing());
assert_eq!(
deserialized.vision_model.as_ref().unwrap().model_name,
"gpt-4o"
);
assert_eq!(
deserialized.text_model.as_ref().unwrap().model_name,
"gpt-4o-mini"
);
assert_eq!(
deserialized.text_model.as_ref().unwrap().api_url.as_deref(),
Some("https://other.api.com")
);
assert_eq!(deserialized.vision_route_mode, VisionRouteMode::TextFirst);
}
#[cfg(feature = "skills")]
#[test]
fn test_default_configs_auto_load_builtin_skills() {
let cfg = RemoteMultimodalConfigs::default();
let registry = cfg
.skill_registry
.as_ref()
.expect("default config should auto-load built-in skills");
assert!(
registry.get("image-grid-selection").is_some(),
"expected image-grid-selection built-in skill"
);
assert!(
registry.get("tic-tac-toe").is_some(),
"expected tic-tac-toe built-in skill"
);
}
#[cfg(feature = "skills")]
#[test]
fn test_new_configs_auto_load_builtin_skills() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "model");
let registry = cfg
.skill_registry
.as_ref()
.expect("new config should auto-load built-in skills");
assert!(
registry.get("word-search").is_some(),
"expected word-search built-in skill"
);
}
#[test]
fn test_selector_to_dual_model_config() {
use super::super::router::{ModelRequirements, ModelSelector, SelectionStrategy};
let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let vision_reqs = ModelRequirements::default().with_vision();
let vision_pick = selector
.select(&vision_reqs)
.expect("should find a vision model");
selector.set_strategy(SelectionStrategy::CheapestFirst);
let text_reqs = ModelRequirements::default();
let text_pick = selector
.select(&text_reqs)
.expect("should find a text model");
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", &vision_pick.name)
.with_dual_models(
ModelEndpoint::new(&vision_pick.name),
ModelEndpoint::new(&text_pick.name),
)
.with_vision_route_mode(VisionRouteMode::TextFirst);
let (_, model, _) = cfg.resolve_model_for_round(true);
assert_eq!(
model, vision_pick.name,
"vision round should use vision pick"
);
let (_, model, _) = cfg.resolve_model_for_round(false);
assert_eq!(model, text_pick.name, "text round should use text pick");
}
#[test]
fn test_auto_policy_to_configs_round_trip() {
use super::super::router::auto_policy;
let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", &policy.large)
.with_dual_models(
ModelEndpoint::new(&policy.large),
ModelEndpoint::new(&policy.small),
);
let json = serde_json::to_string(&cfg).unwrap();
let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
let (_, vision_model, _) = deserialized.resolve_model_for_round(true);
let (_, text_model, _) = deserialized.resolve_model_for_round(false);
assert_eq!(vision_model, policy.large);
assert_eq!(text_model, policy.small);
}
#[test]
fn test_vision_routing_with_real_capabilities() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_dual_models(
ModelEndpoint::new("gpt-4o"), ModelEndpoint::new("gpt-3.5-turbo"), )
.with_vision_route_mode(VisionRouteMode::TextFirst);
assert!(cfg.should_use_vision_this_round(0, false, 0, false));
let (_, model, _) = cfg.resolve_model_for_round(true);
assert_eq!(model, "gpt-4o");
assert!(
llm_models_spider::supports_vision(model),
"vision-round model should support vision"
);
assert!(!cfg.should_use_vision_this_round(3, false, 0, false));
let (_, model, _) = cfg.resolve_model_for_round(false);
assert_eq!(model, "gpt-3.5-turbo");
assert!(
!llm_models_spider::supports_vision(model),
"text-round model should NOT support vision"
);
}
#[test]
fn test_single_model_config_e2e() {
use super::super::router::auto_policy;
let policy = auto_policy(&["gpt-4o"]);
let cfg = RemoteMultimodalConfigs::new(
"https://api.openai.com/v1/chat/completions",
&policy.large,
)
.with_api_key("sk-test");
assert!(!cfg.has_dual_model_routing());
let (url, model, key) = cfg.resolve_model_for_round(true);
assert_eq!(model, "gpt-4o");
assert_eq!(key, Some("sk-test"));
let (url2, model2, key2) = cfg.resolve_model_for_round(false);
assert_eq!(url, url2, "single model: same URL for both modes");
assert_eq!(model, model2, "single model: same model for both modes");
assert_eq!(key, key2, "single model: same key for both modes");
let json = serde_json::to_string(&cfg).unwrap();
let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
assert!(!deserialized.has_dual_model_routing());
let (_, m, _) = deserialized.resolve_model_for_round(true);
assert_eq!(m, "gpt-4o");
}
#[test]
fn test_model_resolution_consistency() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_api_key("sk-test")
.with_dual_models(
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini")
.with_api_url("https://other.api.com")
.with_api_key("sk-other"),
);
for _ in 0..100 {
let (url, model, key) = cfg.resolve_model_for_round(true);
assert_eq!(url, "https://api.example.com");
assert_eq!(model, "gpt-4o");
assert_eq!(key, Some("sk-test"));
let (url, model, key) = cfg.resolve_model_for_round(false);
assert_eq!(url, "https://other.api.com");
assert_eq!(model, "gpt-4o-mini");
assert_eq!(key, Some("sk-other"));
}
}
#[test]
fn test_model_pool_default_empty() {
let cfg = RemoteMultimodalConfigs::default();
assert!(cfg.model_pool.is_empty());
}
#[test]
fn test_model_pool_builder() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_model_pool(vec![
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
ModelEndpoint::new("deepseek-chat")
.with_api_url("https://api.deepseek.com/v1/chat/completions")
.with_api_key("sk-ds"),
]);
assert_eq!(cfg.model_pool.len(), 3);
assert_eq!(cfg.model_pool[2].model_name, "deepseek-chat");
assert_eq!(
cfg.model_pool[2].api_url.as_deref(),
Some("https://api.deepseek.com/v1/chat/completions")
);
}
#[test]
fn test_model_pool_serde_round_trip() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_model_pool(vec![
ModelEndpoint::new("gpt-4o"),
ModelEndpoint::new("gpt-4o-mini"),
ModelEndpoint::new("deepseek-chat"),
]);
let json = serde_json::to_string(&cfg).unwrap();
assert!(json.contains("model_pool"));
let deserialized: RemoteMultimodalConfigs = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.model_pool.len(), 3);
assert_eq!(deserialized.model_pool[0].model_name, "gpt-4o");
assert_eq!(deserialized.model_pool[1].model_name, "gpt-4o-mini");
assert_eq!(deserialized.model_pool[2].model_name, "deepseek-chat");
}
#[test]
fn test_model_pool_empty_omitted_from_json() {
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o");
let json = serde_json::to_string(&cfg).unwrap();
assert!(
!json.contains("model_pool"),
"empty model_pool should be omitted from JSON"
);
}
#[test]
fn test_cfg_convenience_builders() {
use super::super::{ReasoningEffort, ToolCallingMode};
let cfg = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_api_url("https://other.api.com")
.with_model_name("gpt-4o-mini")
.with_include_html(true)
.with_html_max_bytes(10_000)
.with_include_url(true)
.with_include_title(true)
.with_include_screenshot(Some(false))
.with_temperature(0.5)
.with_max_tokens(2048)
.with_request_json_object(true)
.with_best_effort_json_extract(true)
.with_reasoning_effort(ReasoningEffort::High)
.with_thinking_budget(4096)
.with_max_rounds(10)
.with_post_plan_wait_ms(500)
.with_max_inflight_requests(8)
.with_tool_calling_mode(ToolCallingMode::Auto)
.with_concurrent_execution(true)
.with_max_skills_per_round(5)
.with_max_skill_context_chars(8000);
assert_eq!(cfg.api_url, "https://other.api.com");
assert_eq!(cfg.model_name, "gpt-4o-mini");
assert!(cfg.cfg.include_html);
assert_eq!(cfg.cfg.html_max_bytes, 10_000);
assert!(cfg.cfg.include_url);
assert!(cfg.cfg.include_title);
assert_eq!(cfg.cfg.include_screenshot, Some(false));
assert!((cfg.cfg.temperature - 0.5).abs() < f32::EPSILON);
assert_eq!(cfg.cfg.max_tokens, 2048);
assert!(cfg.cfg.request_json_object);
assert!(cfg.cfg.best_effort_json_extract);
assert_eq!(cfg.cfg.reasoning_effort, Some(ReasoningEffort::High));
assert_eq!(cfg.cfg.thinking_budget, Some(4096));
assert_eq!(cfg.cfg.max_rounds, 10);
assert_eq!(cfg.cfg.post_plan_wait_ms, 500);
assert_eq!(cfg.cfg.max_inflight_requests, Some(8));
assert_eq!(cfg.cfg.tool_calling_mode, ToolCallingMode::Auto);
assert!(cfg.cfg.concurrent_execution);
assert_eq!(cfg.cfg.max_skills_per_round, 5);
assert_eq!(cfg.cfg.max_skill_context_chars, 8000);
}
#[test]
fn test_model_pool_equality() {
let a = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_model_pool(vec![ModelEndpoint::new("gpt-4o")]);
let b = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_model_pool(vec![ModelEndpoint::new("gpt-4o")]);
let c = RemoteMultimodalConfigs::new("https://api.example.com", "gpt-4o")
.with_model_pool(vec![ModelEndpoint::new("gpt-4o-mini")]);
assert_eq!(a, b);
assert_ne!(a, c);
}
}