use std::net::ToSocketAddrs;
use std::path::Path;
use roboticus_core::config::RoboticusConfig;
pub(crate) fn provider_requires_internal_proxy(
name: &str,
cfg: &roboticus_core::config::ProviderConfig,
) -> bool {
if cfg.is_local.unwrap_or(false) {
return false;
}
let lowered = name.to_ascii_lowercase();
if lowered.contains("ollama") {
return false;
}
let parsed = match reqwest::Url::parse(cfg.url.trim()) {
Ok(u) => u,
Err(_) => return false,
};
matches!(
parsed
.host_str()
.unwrap_or_default()
.to_ascii_lowercase()
.as_str(),
"127.0.0.1" | "localhost" | "::1"
)
}
fn tcp_endpoint_reachable(host: &str, port: u16) -> bool {
tcp_endpoint_reachable_with(host, port, |sock, timeout| {
std::net::TcpStream::connect_timeout(&sock, timeout).map(|_| ())
})
}
pub(crate) fn tcp_endpoint_reachable_with<F>(host: &str, port: u16, connect: F) -> bool
where
F: FnOnce(std::net::SocketAddr, std::time::Duration) -> std::io::Result<()>,
{
let addr = format!("{host}:{port}");
let resolved = match addr.to_socket_addrs() {
Ok(mut addrs) => addrs.next(),
Err(_) => None,
};
let Some(sock) = resolved else {
return false;
};
connect(sock, std::time::Duration::from_millis(400)).is_ok()
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ProviderUrlMigration {
pub provider: String,
pub from_url: String,
pub to_url: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum LegacyLoopbackMode {
MigrateDeprecated,
Unsupported,
}
pub(crate) fn legacy_loopback_mode_for_version(version: &str) -> LegacyLoopbackMode {
let mut it = version.split('.');
let major = it.next().and_then(|v| v.parse::<u64>().ok()).unwrap_or(0);
let minor = it.next().and_then(|v| v.parse::<u64>().ok()).unwrap_or(0);
if major > 0 || minor >= 8 {
LegacyLoopbackMode::Unsupported
} else {
LegacyLoopbackMode::MigrateDeprecated
}
}
pub(crate) fn legacy_loopback_mode() -> LegacyLoopbackMode {
legacy_loopback_mode_for_version(env!("CARGO_PKG_VERSION"))
}
pub(crate) fn canonical_provider_base_url(provider_name: &str) -> Option<&'static str> {
match provider_name.to_ascii_lowercase().as_str() {
"anthropic" => Some("https://api.anthropic.com"),
"google" => Some("https://generativelanguage.googleapis.com"),
"openai" => Some("https://api.openai.com"),
"openrouter" => Some("https://openrouter.ai/api"),
"moonshot" => Some("https://api.moonshot.ai"),
_ => None,
}
}
pub(crate) fn parse_legacy_proxy_url(provider_name: &str, url: &str) -> bool {
let parsed = match reqwest::Url::parse(url.trim()) {
Ok(v) => v,
Err(_) => return false,
};
let host = parsed.host_str().unwrap_or_default().to_ascii_lowercase();
if !matches!(host.as_str(), "127.0.0.1" | "localhost" | "::1") {
return false;
}
if parsed.port_or_known_default().unwrap_or(80) != 8788 {
return false;
}
let mut segs = match parsed.path_segments() {
Some(v) => v,
None => return false,
};
let Some(first) = segs.next() else {
return false;
};
first.eq_ignore_ascii_case(provider_name)
}
pub(crate) fn rewrite_provider_urls_in_toml(
original: &str,
migrations: &[ProviderUrlMigration],
) -> (String, bool) {
let mut migration_map = std::collections::HashMap::<String, String>::new();
for m in migrations {
migration_map.insert(m.provider.to_ascii_lowercase(), m.to_url.clone());
}
if migration_map.is_empty() {
return (original.to_string(), false);
}
let mut current_provider: Option<String> = None;
let mut changed = false;
let mut out = Vec::<String>::new();
for line in original.lines() {
let trimmed = line.trim();
if trimmed.starts_with('[') && trimmed.ends_with(']') {
let section = &trimmed[1..trimmed.len() - 1];
if let Some(rest) = section.strip_prefix("providers.") {
if !rest.contains('.') && !rest.is_empty() {
current_provider = Some(rest.to_ascii_lowercase());
} else {
current_provider = None;
}
} else {
current_provider = None;
}
out.push(line.to_string());
continue;
}
if let Some(provider) = current_provider.as_deref()
&& trimmed.starts_with("url")
&& trimmed.contains('=')
&& let Some(new_url) = migration_map.get(provider)
{
let indent: String = line.chars().take_while(|c| c.is_whitespace()).collect();
out.push(format!("{indent}url = \"{new_url}\""));
changed = true;
continue;
}
out.push(line.to_string());
}
let mut rewritten = out.join("\n");
if original.ends_with('\n') {
rewritten.push('\n');
}
(rewritten, changed)
}
fn persist_provider_url_migrations(
config_path: &Path,
migrations: &[ProviderUrlMigration],
) -> Result<(), Box<dyn std::error::Error>> {
if migrations.is_empty() || !config_path.exists() {
return Ok(());
}
let original = std::fs::read_to_string(config_path)?;
let (rewritten, changed) = rewrite_provider_urls_in_toml(&original, migrations);
if !changed {
return Ok(());
}
let backup = config_path.with_extension("toml.bak");
if !backup.exists() {
std::fs::copy(config_path, &backup)?;
}
let tmp = config_path.with_extension("toml.tmp");
std::fs::write(&tmp, rewritten)?;
std::fs::rename(&tmp, config_path)?;
Ok(())
}
pub(crate) fn migrate_legacy_proxy_urls(
config: &mut RoboticusConfig,
config_path: Option<&Path>,
) -> Result<Vec<ProviderUrlMigration>, Box<dyn std::error::Error>> {
let mut migrations = Vec::new();
for (name, provider) in &mut config.providers {
if provider.is_local.unwrap_or(false) {
continue;
}
if !parse_legacy_proxy_url(name, &provider.url) {
continue;
}
let Some(canonical) = canonical_provider_base_url(name) else {
continue;
};
if provider.url.trim().eq_ignore_ascii_case(canonical) {
continue;
}
let from = provider.url.clone();
provider.url = canonical.to_string();
migrations.push(ProviderUrlMigration {
provider: name.clone(),
from_url: from,
to_url: canonical.to_string(),
});
}
if let Some(path) = config_path {
persist_provider_url_migrations(path, &migrations)?;
}
Ok(migrations)
}
pub(crate) fn collect_legacy_loopback_providers(config: &RoboticusConfig) -> Vec<String> {
let mut providers = Vec::new();
for (name, provider) in &config.providers {
if provider.is_local.unwrap_or(false) {
continue;
}
if parse_legacy_proxy_url(name, &provider.url) {
providers.push(format!("providers.{name}.url={}", provider.url));
}
}
providers
}
pub(crate) fn validate_legacy_loopback_urls_for_mode(
config: &RoboticusConfig,
mode: LegacyLoopbackMode,
) -> Result<(), String> {
if mode != LegacyLoopbackMode::Unsupported {
return Ok(());
}
let legacy = collect_legacy_loopback_providers(config);
if legacy.is_empty() {
Ok(())
} else {
Err(format!(
"unsupported legacy provider URLs detected (replace with direct provider bases): {}",
legacy.join(", ")
))
}
}
pub(crate) fn check_internal_proxy_reachability(config: &RoboticusConfig) -> Vec<String> {
let mut candidates = Vec::<(String, String, u16)>::new();
for (name, provider) in &config.providers {
if !provider_requires_internal_proxy(name, provider) {
continue;
}
let Ok(parsed) = reqwest::Url::parse(provider.url.trim()) else {
continue;
};
let Some(host) = parsed.host_str() else {
continue;
};
let port = parsed.port_or_known_default().unwrap_or(80);
candidates.push((name.clone(), host.to_string(), port));
}
let mut unreachable = Vec::new();
for (name, host, port) in candidates {
if !tcp_endpoint_reachable(&host, port) {
unreachable.push(format!("{name} ({host}:{port})"));
}
}
unreachable
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_cfg_with_providers(providers_block: &str) -> RoboticusConfig {
let cfg = format!(
r#"
[agent]
name = "T"
id = "t"
[server]
bind = "127.0.0.1"
port = 18789
[database]
path = ":memory:"
[models]
primary = "moonshot/kimi-k2-turbo-preview"
{providers}
"#,
providers = providers_block
);
RoboticusConfig::from_str(&cfg).expect("config parses")
}
#[test]
fn provider_requires_internal_proxy_true_for_non_local_loopback() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
"#,
);
let p = cfg.providers.get("anthropic").unwrap();
assert!(provider_requires_internal_proxy("anthropic", p));
}
#[test]
fn provider_requires_internal_proxy_false_for_ollama_and_local() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.ollama]
url = "http://127.0.0.1:11434"
tier = "T1"
is_local = true
"#,
);
let p = cfg.providers.get("ollama").unwrap();
assert!(!provider_requires_internal_proxy("ollama", p));
}
#[test]
fn tcp_endpoint_reachable_detects_open_listener() {
let mut observed = None;
let reachable = tcp_endpoint_reachable_with("127.0.0.1", 8788, |sock, timeout| {
observed = Some((sock, timeout));
Ok(())
});
assert!(reachable);
let (sock, timeout) = observed.expect("connect callback should receive resolved socket");
assert_eq!(sock.ip().to_string(), "127.0.0.1");
assert_eq!(sock.port(), 8788);
assert_eq!(timeout, std::time::Duration::from_millis(400));
}
#[test]
fn tcp_endpoint_reachable_detects_closed_port() {
assert!(!tcp_endpoint_reachable("192.0.2.1", 1));
}
#[test]
fn parse_legacy_proxy_url_requires_loopback_8788_and_provider_prefix() {
assert!(parse_legacy_proxy_url(
"anthropic",
"http://127.0.0.1:8788/anthropic"
));
assert!(!parse_legacy_proxy_url(
"anthropic",
"http://127.0.0.1:8789/anthropic"
));
assert!(!parse_legacy_proxy_url(
"anthropic",
"https://api.anthropic.com"
));
}
#[test]
fn rewrite_provider_urls_in_toml_updates_only_targeted_provider_blocks() {
let source = r#"[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
[providers.google]
url = "http://127.0.0.1:8788/google"
tier = "T2"
"#;
let migrations = vec![ProviderUrlMigration {
provider: "anthropic".into(),
from_url: "http://127.0.0.1:8788/anthropic".into(),
to_url: "https://api.anthropic.com".into(),
}];
let (rewritten, changed) = rewrite_provider_urls_in_toml(source, &migrations);
assert!(changed);
assert!(rewritten.contains("url = \"https://api.anthropic.com\""));
assert!(rewritten.contains("url = \"http://127.0.0.1:8788/google\""));
}
#[test]
fn migrate_legacy_proxy_urls_rewrites_config_and_persists_file() {
let cfg = r#"
[agent]
name = "T"
id = "t"
[server]
bind = "127.0.0.1"
port = 18789
[database]
path = ":memory:"
[models]
primary = "anthropic/x"
[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
"#;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("roboticus.toml");
std::fs::write(&path, cfg).unwrap();
let mut parsed = RoboticusConfig::from_str(cfg).unwrap();
let migrations = migrate_legacy_proxy_urls(&mut parsed, Some(&path)).unwrap();
assert_eq!(migrations.len(), 1);
assert_eq!(
parsed.providers.get("anthropic").unwrap().url,
"https://api.anthropic.com"
);
let persisted = std::fs::read_to_string(&path).unwrap();
assert!(persisted.contains("url = \"https://api.anthropic.com\""));
}
#[test]
fn legacy_loopback_mode_for_version_changes_at_0_8() {
assert_eq!(
legacy_loopback_mode_for_version("0.7.1"),
LegacyLoopbackMode::MigrateDeprecated
);
assert_eq!(
legacy_loopback_mode_for_version("0.8.0"),
LegacyLoopbackMode::Unsupported
);
}
#[test]
fn collect_legacy_loopback_providers_finds_legacy_urls() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
[providers.google]
url = "https://generativelanguage.googleapis.com"
tier = "T2"
"#,
);
let legacy = collect_legacy_loopback_providers(&cfg);
assert_eq!(legacy.len(), 1);
assert!(legacy[0].contains("providers.anthropic.url"));
}
#[test]
fn validate_legacy_loopback_urls_for_mode_rejects_in_0_8_mode() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
"#,
);
let err = validate_legacy_loopback_urls_for_mode(&cfg, LegacyLoopbackMode::Unsupported)
.expect_err("v0.8 mode must reject legacy loopback");
assert!(err.contains("providers.anthropic.url"));
}
#[test]
fn canonical_provider_base_url_is_case_insensitive() {
assert_eq!(
canonical_provider_base_url("AnThRoPiC"),
Some("https://api.anthropic.com")
);
assert_eq!(
canonical_provider_base_url("GOOGLE"),
Some("https://generativelanguage.googleapis.com")
);
assert_eq!(canonical_provider_base_url("unknown"), None);
}
#[test]
fn parse_legacy_proxy_url_rejects_wrong_path_prefix() {
assert!(!parse_legacy_proxy_url(
"anthropic",
"http://127.0.0.1:8788/google"
));
}
#[test]
fn rewrite_provider_urls_in_toml_noop_without_migrations() {
let source = r#"[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
"#;
let (rewritten, changed) = rewrite_provider_urls_in_toml(source, &[]);
assert!(!changed);
assert_eq!(rewritten, source);
}
#[test]
fn persist_provider_url_migrations_writes_backup_and_new_url() {
let dir = tempfile::tempdir().unwrap();
let config_path = dir.path().join("roboticus.toml");
std::fs::write(
&config_path,
r#"[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
"#,
)
.unwrap();
let migrations = vec![ProviderUrlMigration {
provider: "anthropic".into(),
from_url: "http://127.0.0.1:8788/anthropic".into(),
to_url: "https://api.anthropic.com".into(),
}];
persist_provider_url_migrations(&config_path, &migrations).unwrap();
let updated = std::fs::read_to_string(&config_path).unwrap();
assert!(updated.contains("https://api.anthropic.com"));
assert!(config_path.with_extension("toml.bak").exists());
}
#[test]
fn check_internal_proxy_reachability_skips_non_loopback_providers() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "https://api.anthropic.com"
tier = "T3"
"#,
);
let unreachable = check_internal_proxy_reachability(&cfg);
assert!(unreachable.is_empty());
}
#[test]
fn provider_requires_internal_proxy_false_for_invalid_url() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "not-a-url"
tier = "T3"
"#,
);
let p = cfg.providers.get("anthropic").unwrap();
assert!(!provider_requires_internal_proxy("anthropic", p));
}
#[test]
fn parse_legacy_proxy_url_rejects_non_loopback_hosts() {
assert!(!parse_legacy_proxy_url(
"anthropic",
"http://10.0.0.1:8788/anthropic"
));
}
#[test]
fn persist_provider_url_migrations_is_noop_for_missing_config() {
let dir = tempfile::tempdir().unwrap();
let missing = dir.path().join("missing.toml");
let migrations = vec![ProviderUrlMigration {
provider: "anthropic".into(),
from_url: "http://127.0.0.1:8788/anthropic".into(),
to_url: "https://api.anthropic.com".into(),
}];
persist_provider_url_migrations(&missing, &migrations).unwrap();
assert!(!missing.exists());
}
#[test]
fn collect_legacy_loopback_providers_ignores_local_flagged_provider() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.ollama]
url = "http://127.0.0.1:8788/ollama"
tier = "T1"
is_local = true
"#,
);
let legacy = collect_legacy_loopback_providers(&cfg);
assert!(legacy.is_empty());
}
#[test]
fn validate_legacy_loopback_urls_for_mode_allows_pre_0_8_mode() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
"#,
);
let result =
validate_legacy_loopback_urls_for_mode(&cfg, LegacyLoopbackMode::MigrateDeprecated);
assert!(result.is_ok());
}
#[test]
fn check_internal_proxy_reachability_reports_unreachable_loopback_proxy() {
let cfg = minimal_cfg_with_providers(
r#"
[providers.anthropic]
url = "http://127.0.0.1:9/anthropic"
tier = "T3"
"#,
);
let unreachable = check_internal_proxy_reachability(&cfg);
assert!(!unreachable.is_empty(), "should report unreachable proxy");
assert!(unreachable[0].contains("anthropic"));
}
#[test]
fn migrate_legacy_proxy_urls_noop_for_already_canonical_urls() {
let cfg_text = r#"
[agent]
name = "T"
id = "t"
[server]
bind = "127.0.0.1"
port = 18789
[database]
path = ":memory:"
[models]
primary = "anthropic/x"
[providers.anthropic]
url = "https://api.anthropic.com"
tier = "T3"
"#;
let mut cfg = RoboticusConfig::from_str(cfg_text).unwrap();
let migrations = migrate_legacy_proxy_urls(&mut cfg, None).unwrap();
assert!(migrations.is_empty());
assert_eq!(
cfg.providers.get("anthropic").unwrap().url,
"https://api.anthropic.com"
);
}
#[test]
fn rewrite_provider_urls_in_toml_only_changes_top_level_provider_section_url_field() {
let source = r#"[providers.anthropic]
url = "http://127.0.0.1:8788/anthropic"
tier = "T3"
[providers.anthropic.extra]
url = "http://127.0.0.1:8788/should-not-change"
"#;
let migrations = vec![ProviderUrlMigration {
provider: "anthropic".into(),
from_url: "http://127.0.0.1:8788/anthropic".into(),
to_url: "https://api.anthropic.com".into(),
}];
let (rewritten, changed) = rewrite_provider_urls_in_toml(source, &migrations);
assert!(changed);
assert!(rewritten.contains("url = \"https://api.anthropic.com\""));
assert!(rewritten.contains("url = \"http://127.0.0.1:8788/should-not-change\""));
}
#[test]
fn legacy_loopback_mode_matches_current_package_version_rule() {
let mode = legacy_loopback_mode();
assert_eq!(mode, LegacyLoopbackMode::Unsupported);
}
#[test]
fn canonical_provider_base_url_covers_known_providers() {
assert_eq!(
canonical_provider_base_url("openrouter"),
Some("https://openrouter.ai/api")
);
assert_eq!(
canonical_provider_base_url("moonshot"),
Some("https://api.moonshot.ai")
);
assert_eq!(canonical_provider_base_url("unknown-provider"), None);
}
}