use crate::config::models::defaults::default_true;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum AgentProvider {
#[default]
A2A,
LangGraph,
VertexAI,
AzureAIFoundry,
BedrockAgentCore,
PydanticAI,
Custom,
}
impl AgentProvider {
pub fn display_name(&self) -> &'static str {
match self {
AgentProvider::A2A => "A2A",
AgentProvider::LangGraph => "LangGraph",
AgentProvider::VertexAI => "Vertex AI Agent Engine",
AgentProvider::AzureAIFoundry => "Azure AI Foundry",
AgentProvider::BedrockAgentCore => "Bedrock AgentCore",
AgentProvider::PydanticAI => "Pydantic AI",
AgentProvider::Custom => "Custom",
}
}
pub fn supports_streaming(&self) -> bool {
matches!(
self,
AgentProvider::LangGraph
| AgentProvider::VertexAI
| AgentProvider::AzureAIFoundry
| AgentProvider::A2A
)
}
pub fn supports_async_tasks(&self) -> bool {
matches!(
self,
AgentProvider::LangGraph | AgentProvider::BedrockAgentCore | AgentProvider::A2A
)
}
}
impl std::fmt::Display for AgentProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.display_name())
}
}
impl std::str::FromStr for AgentProvider {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"a2a" | "generic" => Ok(AgentProvider::A2A),
"langgraph" | "langchain" => Ok(AgentProvider::LangGraph),
"vertex" | "vertexai" | "vertex_ai" | "google" => Ok(AgentProvider::VertexAI),
"azure" | "azureai" | "azure_ai_foundry" => Ok(AgentProvider::AzureAIFoundry),
"bedrock" | "aws" | "bedrock_agentcore" => Ok(AgentProvider::BedrockAgentCore),
"pydantic" | "pydanticai" | "pydantic_ai" => Ok(AgentProvider::PydanticAI),
"custom" => Ok(AgentProvider::Custom),
_ => Err(format!("Unknown agent provider: {}", s)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentConfig {
pub name: String,
#[serde(default)]
pub provider: AgentProvider,
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub headers: HashMap<String, String>,
#[serde(default = "default_timeout")]
pub timeout_ms: u64,
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(default)]
pub capabilities: AgentCapabilities,
#[serde(skip_serializing_if = "Option::is_none")]
pub rate_limit_rpm: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cost_per_request: Option<f64>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tags: Vec<String>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub provider_config: HashMap<String, serde_json::Value>,
}
fn default_timeout() -> u64 {
60000 }
fn default_enabled() -> bool {
true
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
name: String::new(),
provider: AgentProvider::default(),
url: String::new(),
api_key: None,
headers: HashMap::new(),
timeout_ms: default_timeout(),
enabled: true,
description: None,
capabilities: AgentCapabilities::default(),
rate_limit_rpm: None,
cost_per_request: None,
tags: Vec::new(),
provider_config: HashMap::new(),
}
}
}
impl AgentConfig {
pub fn new(name: impl Into<String>, url: impl Into<String>) -> Self {
Self {
name: name.into(),
url: url.into(),
..Default::default()
}
}
pub fn with_provider(mut self, provider: AgentProvider) -> Self {
self.provider = provider;
self
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn validate(&self) -> Result<(), String> {
if self.name.is_empty() {
return Err("Agent name cannot be empty".to_string());
}
if self.url.is_empty() {
return Err("Agent URL cannot be empty".to_string());
}
if !self.url.starts_with("http://") && !self.url.starts_with("https://") {
return Err(format!(
"Agent URL must start with http:// or https://, got: {}",
self.url
));
}
let host = extract_url_host(&self.url)
.ok_or_else(|| format!("Agent URL has an invalid or missing host: {}", self.url))?;
if is_private_or_reserved_host(&host) {
return Err(format!(
"Agent URL targets a private or reserved address '{}', which is not allowed (SSRF protection)",
host
));
}
Ok(())
}
}
fn extract_url_host(url: &str) -> Option<String> {
let after_scheme = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))?;
let authority = after_scheme.split(['/', '?', '#']).next()?;
if authority.is_empty() {
return None;
}
let host = if authority.starts_with('[') {
let end_bracket = authority.find(']')?;
&authority[1..end_bracket]
} else {
match authority.rfind(':') {
Some(pos) => &authority[..pos],
None => authority,
}
};
Some(host.to_lowercase())
}
fn is_private_or_reserved_host(host: &str) -> bool {
if host == "localhost"
|| host.ends_with(".localhost")
|| host == "metadata.google.internal"
|| host == "169.254.169.254"
{
return true;
}
if let Ok(ip) = host.parse::<IpAddr>() {
return is_private_or_reserved_ip(ip);
}
false
}
fn is_private_or_reserved_ip(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let octets = v4.octets();
if octets == [0, 0, 0, 0] {
return true;
}
if octets[0] == 127 {
return true;
}
if octets[0] == 10 {
return true;
}
if octets[0] == 172 && (16..=31).contains(&octets[1]) {
return true;
}
if octets[0] == 192 && octets[1] == 168 {
return true;
}
if octets[0] == 169 && octets[1] == 254 {
return true;
}
false
}
IpAddr::V6(v6) => {
if v6.is_loopback() {
return true;
}
let segments = v6.segments();
if (segments[0] & 0xfe00) == 0xfc00 {
return true;
}
if let Some(v4) = v6.to_ipv4_mapped() {
return is_private_or_reserved_ip(IpAddr::V4(v4));
}
false
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AgentCapabilities {
#[serde(default)]
pub streaming: bool,
#[serde(default)]
pub push_notifications: bool,
#[serde(default)]
pub task_cancellation: bool,
#[serde(default = "default_true")]
pub multi_turn: bool,
#[serde(default)]
pub file_attachments: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_input_length: Option<u32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input_types: Vec<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub output_types: Vec<String>,
}
impl AgentCapabilities {
pub fn full() -> Self {
Self {
streaming: true,
push_notifications: true,
task_cancellation: true,
multi_turn: true,
file_attachments: true,
max_input_length: None,
input_types: vec!["text".to_string(), "image".to_string()],
output_types: vec!["text".to_string(), "image".to_string()],
}
}
pub fn minimal() -> Self {
Self {
streaming: false,
push_notifications: false,
task_cancellation: false,
multi_turn: false,
file_attachments: false,
max_input_length: None,
input_types: vec!["text".to_string()],
output_types: vec!["text".to_string()],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct A2AGatewayConfig {
#[serde(default)]
pub agents: HashMap<String, AgentConfig>,
#[serde(default = "default_timeout")]
pub default_timeout_ms: u64,
#[serde(default = "default_true")]
pub enable_logging: bool,
#[serde(default)]
pub enable_cost_tracking: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub global_rate_limit: Option<u32>,
#[serde(default = "default_true")]
pub health_check_enabled: bool,
#[serde(default = "default_health_check_interval_secs")]
pub health_check_interval_secs: u64,
}
fn default_health_check_interval_secs() -> u64 {
30
}
impl A2AGatewayConfig {
pub fn add_agent(&mut self, config: AgentConfig) {
self.agents.insert(config.name.clone(), config);
}
pub fn get_agent(&self, name: &str) -> Option<&AgentConfig> {
self.agents.get(name)
}
pub fn validate(&self) -> Result<(), Vec<String>> {
let errors: Vec<String> = self
.agents
.values()
.filter_map(|a| a.validate().err())
.collect();
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_provider_display() {
assert_eq!(AgentProvider::LangGraph.display_name(), "LangGraph");
assert_eq!(
AgentProvider::VertexAI.display_name(),
"Vertex AI Agent Engine"
);
}
#[test]
fn test_agent_provider_from_str() {
assert_eq!(
"langgraph".parse::<AgentProvider>().unwrap(),
AgentProvider::LangGraph
);
assert_eq!(
"vertex".parse::<AgentProvider>().unwrap(),
AgentProvider::VertexAI
);
assert_eq!(
"azure".parse::<AgentProvider>().unwrap(),
AgentProvider::AzureAIFoundry
);
assert_eq!(
"bedrock".parse::<AgentProvider>().unwrap(),
AgentProvider::BedrockAgentCore
);
}
#[test]
fn test_agent_provider_streaming_support() {
assert!(AgentProvider::LangGraph.supports_streaming());
assert!(AgentProvider::VertexAI.supports_streaming());
assert!(!AgentProvider::BedrockAgentCore.supports_streaming());
}
#[test]
fn test_agent_config_new() {
let config = AgentConfig::new("my-agent", "https://api.example.com/agent");
assert_eq!(config.name, "my-agent");
assert_eq!(config.url, "https://api.example.com/agent");
assert!(config.enabled);
}
#[test]
fn test_agent_config_builder() {
let config = AgentConfig::new("my-agent", "https://api.example.com/agent")
.with_provider(AgentProvider::LangGraph)
.with_api_key("sk-test123")
.with_timeout(30000)
.with_description("Test agent");
assert_eq!(config.provider, AgentProvider::LangGraph);
assert_eq!(config.api_key.as_deref(), Some("sk-test123"));
assert_eq!(config.timeout_ms, 30000);
assert!(config.description.is_some());
}
#[test]
fn test_agent_config_validation() {
let config = AgentConfig::new("test", "https://example.com");
assert!(config.validate().is_ok());
let config = AgentConfig::new("", "https://example.com");
assert!(config.validate().is_err());
let config = AgentConfig::new("test", "");
assert!(config.validate().is_err());
let config = AgentConfig::new("test", "ftp://example.com");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_loopback_ipv4_rejected() {
let config = AgentConfig::new("test", "http://127.0.0.1/api");
let err = config.validate().unwrap_err();
assert!(err.contains("private or reserved"), "got: {}", err);
}
#[test]
fn test_ssrf_loopback_ipv4_any_port_rejected() {
let config = AgentConfig::new("test", "https://127.0.0.1:8080/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_localhost_hostname_rejected() {
let config = AgentConfig::new("test", "http://localhost/api");
let err = config.validate().unwrap_err();
assert!(err.contains("private or reserved"), "got: {}", err);
}
#[test]
fn test_ssrf_localhost_subdomain_rejected() {
let config = AgentConfig::new("test", "http://my.localhost/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_rfc1918_10_rejected() {
let config = AgentConfig::new("test", "http://10.0.0.1/api");
assert!(config.validate().is_err());
let config = AgentConfig::new("test", "http://10.255.255.255/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_rfc1918_172_rejected() {
let config = AgentConfig::new("test", "http://172.16.0.1/api");
assert!(config.validate().is_err());
let config = AgentConfig::new("test", "http://172.31.255.255/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_rfc1918_172_boundary_allowed() {
let config = AgentConfig::new("test", "https://172.15.0.1/api");
assert!(config.validate().is_ok());
let config = AgentConfig::new("test", "https://172.32.0.1/api");
assert!(config.validate().is_ok());
}
#[test]
fn test_ssrf_rfc1918_192_168_rejected() {
let config = AgentConfig::new("test", "http://192.168.1.1/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_link_local_metadata_rejected() {
let config = AgentConfig::new("test", "http://169.254.169.254/latest/meta-data/");
assert!(config.validate().is_err());
let config = AgentConfig::new("test", "http://169.254.0.1/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_unspecified_ipv4_rejected() {
let config = AgentConfig::new("test", "http://0.0.0.0/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_ipv6_loopback_rejected() {
let config = AgentConfig::new("test", "http://[::1]/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_ipv6_unique_local_rejected() {
let config = AgentConfig::new("test", "http://[fc00::1]/api");
assert!(config.validate().is_err());
let config = AgentConfig::new("test", "http://[fd00::1]/api");
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_google_metadata_hostname_rejected() {
let config = AgentConfig::new(
"test",
"http://metadata.google.internal/computeMetadata/v1/",
);
assert!(config.validate().is_err());
}
#[test]
fn test_ssrf_public_ip_allowed() {
let config = AgentConfig::new("test", "https://8.8.8.8/api");
assert!(config.validate().is_ok());
}
#[test]
fn test_ssrf_public_domain_allowed() {
let config = AgentConfig::new("test", "https://api.example.com/v1/agent");
assert!(config.validate().is_ok());
}
#[test]
fn test_extract_url_host_basic() {
assert_eq!(
extract_url_host("https://example.com/path"),
Some("example.com".to_string())
);
assert_eq!(
extract_url_host("http://10.0.0.1:8080/api"),
Some("10.0.0.1".to_string())
);
assert_eq!(
extract_url_host("http://[::1]:9000/api"),
Some("::1".to_string())
);
assert_eq!(extract_url_host("ftp://example.com"), None);
assert_eq!(extract_url_host("http://"), None);
}
#[test]
fn test_is_private_or_reserved_ip_public() {
use std::net::Ipv4Addr;
assert!(!is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
8, 8, 8, 8
))));
assert!(!is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
1, 1, 1, 1
))));
}
#[test]
fn test_is_private_or_reserved_ip_private() {
use std::net::Ipv4Addr;
assert!(is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
127, 0, 0, 1
))));
assert!(is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
10, 1, 2, 3
))));
assert!(is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
172, 20, 0, 1
))));
assert!(is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
192, 168, 0, 1
))));
assert!(is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
169, 254, 169, 254
))));
assert!(is_private_or_reserved_ip(IpAddr::V4(Ipv4Addr::new(
0, 0, 0, 0
))));
}
#[test]
fn test_agent_capabilities_full() {
let caps = AgentCapabilities::full();
assert!(caps.streaming);
assert!(caps.push_notifications);
assert!(caps.task_cancellation);
assert!(caps.multi_turn);
assert!(caps.file_attachments);
}
#[test]
fn test_agent_capabilities_minimal() {
let caps = AgentCapabilities::minimal();
assert!(!caps.streaming);
assert!(!caps.push_notifications);
}
#[test]
fn test_gateway_config() {
let mut config = A2AGatewayConfig::default();
config.add_agent(AgentConfig::new("agent1", "https://example.com/agent1"));
assert!(config.get_agent("agent1").is_some());
assert!(config.get_agent("nonexistent").is_none());
}
#[test]
fn test_config_serialization() {
let config =
AgentConfig::new("test", "https://example.com").with_provider(AgentProvider::LangGraph);
let json = serde_json::to_string(&config).unwrap();
let deserialized: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.name, "test");
assert_eq!(deserialized.provider, AgentProvider::LangGraph);
}
}