use std::collections::HashMap;
use std::sync::Arc;
use axum::middleware;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::routing::{get, post};
use axum::Router;
use futures::stream;
use std::convert::Infallible;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::a2a::auth::middleware::{auth_middleware, AuthMiddlewareState};
use crate::a2a::types::{
AgentCapabilities, AgentCard, AgentSkill, JsonRpcRequest, Message,
TaskArtifactUpdateEvent, TaskStatus, TaskStatusUpdateEvent,
};
use crate::agents::adk::builder::AdkAgentBuilder;
use crate::agents::strategy::AgentExecutor;
use crate::config::agent::{AgentConfig, AgentType};
use crate::config::auth::AuthConfig;
use crate::error::{AgentKitError, Result};
use crate::tools::factory::ToolFactory;
use crate::utils::get_protocol;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemoteAgentAddress {
pub host: String,
pub port: Option<u16>,
pub path: Option<String>,
pub auth: Option<AuthConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillConfig {
pub id: String,
pub name: String,
pub description: String,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub examples: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilitiesConfig {
#[serde(default = "bool_true")]
pub streaming: bool,
#[serde(default)]
pub push_notifications: bool,
#[serde(default)]
pub state_transition_history: bool,
}
fn bool_true() -> bool { true }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
pub phoenix: Option<PhoenixConfig>,
pub arize: Option<ArizeConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PhoenixConfig {
pub host: String,
pub port: u16,
pub project_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArizeConfig {
pub arize_space_id: String,
pub arize_project_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentFactoryConfig {
pub name: String,
pub description: String,
pub url: String,
#[serde(default = "default_version")]
pub version: String,
pub instruction: Option<String>,
pub agent: AgentConfig,
#[serde(default)]
pub tools: Vec<serde_json::Value>,
#[serde(default)]
pub skills: Vec<SkillConfig>,
pub capabilities: Option<CapabilitiesConfig>,
pub auth: Option<AuthConfig>,
pub monitoring: Option<MonitoringConfig>,
#[serde(default)]
pub remote_agents_addresses: Vec<RemoteAgentAddress>,
#[serde(rename = "defaultInputModes", default = "default_io_modes")]
pub default_input_modes: Vec<String>,
#[serde(rename = "defaultOutputModes", default = "default_io_modes")]
pub default_output_modes: Vec<String>,
}
fn default_version() -> String { "1.0.0".into() }
fn default_io_modes() -> Vec<String> { vec!["text".into(), "text/plain".into()] }
async fn build_sse_events(
id: Option<serde_json::Value>,
msg_val: serde_json::Value,
exec: Arc<dyn AgentExecutor>,
) -> Vec<serde_json::Value> {
let msg = match serde_json::from_value::<Message>(msg_val.clone()) {
Ok(m) => m,
Err(e) => {
tracing::error!(error = %e, raw = %msg_val, "Message deserialisation failed");
return vec![serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": {"code": -32600, "message": format!("Invalid message params: {e}")}
})];
}
};
tracing::info!("Message parsed OK, calling executor");
let task = match exec.execute(msg).await {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "Executor failed");
return vec![serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": {"code": -32603, "message": e.to_string()}
})];
}
};
tracing::info!(task_id = %task.id, artifacts = task.artifacts.len(), "Executor succeeded");
let mut evs = Vec::new();
for (i, artifact) in task.artifacts.iter().enumerate() {
let is_last = i == task.artifacts.len() - 1;
let ev = TaskArtifactUpdateEvent {
kind: "artifact-update".to_string(),
task_id: task.id.clone(),
context_id: task.context_id.clone(),
artifact: artifact.clone(),
append: Some(false),
last_chunk: Some(is_last),
metadata: None,
};
evs.push(serde_json::json!({"jsonrpc": "2.0", "id": id, "result": ev}));
}
let status_ev = TaskStatusUpdateEvent {
kind: "status-update".to_string(),
task_id: task.id.clone(),
context_id: task.context_id.clone(),
status: TaskStatus {
state: task.state.clone(),
message: task.history.last().cloned(),
timestamp: None,
},
r#final: true,
metadata: None,
};
evs.push(serde_json::json!({"jsonrpc": "2.0", "id": id, "result": status_ev}));
evs
}
pub struct AgentFactory {
cfg: AgentFactoryConfig,
tool_registry: HashMap<String, crate::tools::loader::DynTool>,
}
impl AgentFactory {
pub fn from_config(path: &str) -> Result<Self> {
let content = std::fs::read_to_string(path).map_err(|e| AgentKitError::ConfigIo {
path: path.to_string(),
source: e,
})?;
let cfg: AgentFactoryConfig = serde_yaml::from_str(&content)?;
Ok(Self { cfg, tool_registry: HashMap::new() })
}
pub fn register_tool(&mut self, name: impl Into<String>, tool: crate::tools::loader::DynTool) {
self.tool_registry.insert(name.into(), tool);
}
fn parse_auth(&self, auth: Option<&AuthConfig>) -> AuthConfig {
auth.cloned().unwrap_or_default()
}
fn generate_remote_configs(&self) -> Vec<(String, AuthConfig)> {
self.cfg
.remote_agents_addresses
.iter()
.map(|ra| {
let url = build_url(&ra.host, ra.port, ra.path.as_deref());
let auth = self.parse_auth(ra.auth.as_ref());
(url, auth)
})
.collect()
}
fn agent_skills(&self) -> Vec<AgentSkill> {
self.cfg
.skills
.iter()
.map(|s| AgentSkill {
id: s.id.clone(),
name: s.name.clone(),
description: s.description.clone(),
tags: s.tags.clone(),
examples: s.examples.clone(),
})
.collect()
}
async fn load_tools(&self) -> Result<Vec<crate::tools::loader::DynTool>> {
let mut factory = ToolFactory::new(&self.cfg.agent, self.tool_registry.clone());
let mut tools = Vec::new();
for raw in &self.cfg.tools {
match factory.create_tool(raw.clone()).await {
Ok(Some(t)) => tools.push(t),
Ok(None) => {}
Err(e) => tracing::error!("Tool load error: {e}"),
}
}
Ok(tools)
}
pub async fn build_executor(&self) -> Result<Arc<dyn AgentExecutor>> {
let tools = self.load_tools().await?;
let card = self.build_agent_card();
let instruction = self.cfg.instruction.clone().unwrap_or_default();
let remote_configs = self.generate_remote_configs();
let mut remote_manager = crate::a2a::client::connector::RemoteAgentManager::new();
if !remote_configs.is_empty() {
if let Err(e) = remote_manager.connect(remote_configs).await {
tracing::warn!("Some remote agents could not be connected: {e}");
}
}
let remote_connections = remote_manager.take_connections();
match self.cfg.agent.agent_type {
AgentType::Adk | AgentType::Langchain => {
let builder = AdkAgentBuilder::new(self.cfg.agent.model.clone());
let agent = builder.build_agent(&self.cfg.name, &instruction, tools, remote_connections)?;
Ok(builder.build_executor(agent, card))
}
}
}
pub fn build_agent_card(&self) -> AgentCard {
let auth = self.cfg.auth.clone().unwrap_or_default();
let caps = self
.cfg
.capabilities
.clone()
.unwrap_or(CapabilitiesConfig { streaming: true, push_notifications: false, state_transition_history: false });
AgentCard {
name: self.cfg.name.clone(),
description: self.cfg.description.clone(),
url: self.cfg.url.clone(),
version: self.cfg.version.clone(),
default_input_modes: self.cfg.default_input_modes.clone(),
default_output_modes: self.cfg.default_output_modes.clone(),
capabilities: AgentCapabilities {
streaming: caps.streaming,
push_notifications: caps.push_notifications,
state_transition_history: caps.state_transition_history,
},
skills: self.agent_skills(),
security: auth.security,
security_schemes: auth.security_schemes.map(|ss| {
ss.into_iter()
.map(|(k, v)| (k, serde_json::to_value(v).unwrap_or_default()))
.collect()
}),
}
}
pub async fn run_server(self, port: u16) -> Result<()> {
let executor = self.build_executor().await?;
let card = self.build_agent_card();
let auth_cfg = self.cfg.auth.clone().unwrap_or_default();
let state = Arc::new(AuthMiddlewareState::new(
card.clone(),
auth_cfg,
vec!["/.well-known/agent.json".into(), "/.well-known/agent-card.json".into(), "/health".into()],
));
let executor = Arc::clone(&executor);
let card_clone = card.clone();
let app = Router::new()
.route("/.well-known/agent.json", get({
let card = card_clone.clone();
move || async move { axum::Json(card.clone()) }
}))
.route("/.well-known/agent-card.json", get({
let card = card_clone.clone();
move || async move { axum::Json(card.clone()) }
}))
.route("/health", get(|| async { axum::Json(serde_json::json!({"status": "ok"})) }))
.route("/", post({
let exec = Arc::clone(&executor);
move |axum::Json(rpc): axum::Json<JsonRpcRequest>| {
let exec = Arc::clone(&exec);
async move {
let id = rpc.id.clone();
tracing::info!(method = %rpc.method, id = ?id, "A2A request received");
let msg_val = rpc.params.get("message")
.cloned()
.unwrap_or_else(|| rpc.params.clone());
tracing::debug!(raw = %msg_val, "Extracted message value");
let events: Vec<serde_json::Value> = build_sse_events(id, msg_val, exec).await;
let s = stream::iter(
events.into_iter().map(|e| Ok::<Event, Infallible>(Event::default().data(e.to_string())))
);
Sse::new(s).keep_alive(KeepAlive::default())
}
}
}))
.layer(middleware::from_fn_with_state(state, auth_middleware));
let addr = format!("0.0.0.0:{port}");
info!("Velocia AgentKit server starting on http://{addr}");
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| AgentKitError::A2aServer(e.to_string()))?;
axum::serve(listener, app)
.await
.map_err(|e| AgentKitError::A2aServer(e.to_string()))
}
}
#[allow(dead_code)]
fn build_url(host: &str, port: Option<u16>, path: Option<&str>) -> String {
let protocol = get_protocol();
let mut url = format!("{protocol}://{host}");
if let Some(p) = port {
url = format!("{url}:{p}");
}
if let Some(p) = path {
let p = p.trim();
if p.starts_with('/') {
url.push_str(p);
} else {
url.push('/');
url.push_str(p);
}
}
url
}