velocia 0.3.1

velocia – production-ready AI agent framework using ADK-Rust, A2A protocol, and AWS DynamoDB
//! `AgentFactory` – the primary entry point for the Velocia AgentKit.
//!
//! Mirrors Python's `AgentA2AFactory`.  Load a YAML config, call
//! `build_agent()`, wire up the A2A server and go.

use std::collections::HashMap;
use std::sync::Arc;

use axum::middleware;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::routing::{get, post};
use axum::Router;
use futures::stream;
use std::convert::Infallible;
use serde::{Deserialize, Serialize};
use tracing::info;

use crate::a2a::auth::middleware::{auth_middleware, AuthMiddlewareState};
use crate::a2a::types::{
    AgentCapabilities, AgentCard, AgentSkill, JsonRpcRequest, Message,
    TaskArtifactUpdateEvent, TaskStatus, TaskStatusUpdateEvent,
};
use crate::agents::adk::builder::AdkAgentBuilder;
use crate::agents::strategy::AgentExecutor;
use crate::config::agent::{AgentConfig, AgentType};
use crate::config::auth::AuthConfig;
use crate::error::{AgentKitError, Result};
use crate::tools::factory::ToolFactory;
use crate::utils::get_protocol;

// ── Top-level config schema ───────────────────────────────────────────────────

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RemoteAgentAddress {
    pub host: String,
    pub port: Option<u16>,
    pub path: Option<String>,
    pub auth: Option<AuthConfig>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillConfig {
    pub id: String,
    pub name: String,
    pub description: String,
    #[serde(default)]
    pub tags: Vec<String>,
    #[serde(default)]
    pub examples: Vec<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CapabilitiesConfig {
    #[serde(default = "bool_true")]
    pub streaming: bool,
    #[serde(default)]
    pub push_notifications: bool,
    #[serde(default)]
    pub state_transition_history: bool,
}

fn bool_true() -> bool { true }

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
    pub phoenix: Option<PhoenixConfig>,
    pub arize: Option<ArizeConfig>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PhoenixConfig {
    pub host: String,
    pub port: u16,
    pub project_name: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArizeConfig {
    pub arize_space_id: String,
    pub arize_project_name: Option<String>,
}

/// The complete structure of `agent_config.yaml`.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentFactoryConfig {
    pub name: String,
    pub description: String,
    pub url: String,
    #[serde(default = "default_version")]
    pub version: String,
    pub instruction: Option<String>,
    pub agent: AgentConfig,
    #[serde(default)]
    pub tools: Vec<serde_json::Value>,
    #[serde(default)]
    pub skills: Vec<SkillConfig>,
    pub capabilities: Option<CapabilitiesConfig>,
    pub auth: Option<AuthConfig>,
    pub monitoring: Option<MonitoringConfig>,
    #[serde(default)]
    pub remote_agents_addresses: Vec<RemoteAgentAddress>,
    #[serde(rename = "defaultInputModes", default = "default_io_modes")]
    pub default_input_modes: Vec<String>,
    #[serde(rename = "defaultOutputModes", default = "default_io_modes")]
    pub default_output_modes: Vec<String>,
}

fn default_version() -> String { "1.0.0".into() }
fn default_io_modes() -> Vec<String> { vec!["text".into(), "text/plain".into()] }

// ── SSE event builder (extracted to avoid brace nesting issues) ───────────────

async fn build_sse_events(
    id: Option<serde_json::Value>,
    msg_val: serde_json::Value,
    exec: Arc<dyn AgentExecutor>,
) -> Vec<serde_json::Value> {
    let msg = match serde_json::from_value::<Message>(msg_val.clone()) {
        Ok(m) => m,
        Err(e) => {
            tracing::error!(error = %e, raw = %msg_val, "Message deserialisation failed");
            return vec![serde_json::json!({
                "jsonrpc": "2.0",
                "id": id,
                "error": {"code": -32600, "message": format!("Invalid message params: {e}")}
            })];
        }
    };

    tracing::info!("Message parsed OK, calling executor");

    let task = match exec.execute(msg).await {
        Ok(t) => t,
        Err(e) => {
            tracing::error!(error = %e, "Executor failed");
            return vec![serde_json::json!({
                "jsonrpc": "2.0",
                "id": id,
                "error": {"code": -32603, "message": e.to_string()}
            })];
        }
    };

    tracing::info!(task_id = %task.id, artifacts = task.artifacts.len(), "Executor succeeded");

    let mut evs = Vec::new();

    for (i, artifact) in task.artifacts.iter().enumerate() {
        let is_last = i == task.artifacts.len() - 1;
        let ev = TaskArtifactUpdateEvent {
            kind: "artifact-update".to_string(),
            task_id: task.id.clone(),
            context_id: task.context_id.clone(),
            artifact: artifact.clone(),
            append: Some(false),
            last_chunk: Some(is_last),
            metadata: None,
        };
        evs.push(serde_json::json!({"jsonrpc": "2.0", "id": id, "result": ev}));
    }

    let status_ev = TaskStatusUpdateEvent {
        kind: "status-update".to_string(),
        task_id: task.id.clone(),
        context_id: task.context_id.clone(),
        status: TaskStatus {
            state: task.state.clone(),
            message: task.history.last().cloned(),
            timestamp: None,
        },
        r#final: true,
        metadata: None,
    };
    evs.push(serde_json::json!({"jsonrpc": "2.0", "id": id, "result": status_ev}));

    evs
}

// ── Factory ───────────────────────────────────────────────────────────────────

/// Primary factory: reads `agent_config.yaml` and produces agents and servers.
pub struct AgentFactory {
    cfg: AgentFactoryConfig,
    /// Tool function registry (function_name → async handler).
    /// Populated by the user before calling `build_agent()`.
    tool_registry: HashMap<String, crate::tools::loader::DynTool>,
}

impl AgentFactory {
    /// Load configuration from the given YAML file path.
    pub fn from_config(path: &str) -> Result<Self> {
        let content = std::fs::read_to_string(path).map_err(|e| AgentKitError::ConfigIo {
            path: path.to_string(),
            source: e,
        })?;
        let cfg: AgentFactoryConfig = serde_yaml::from_str(&content)?;
        Ok(Self { cfg, tool_registry: HashMap::new() })
    }

    /// Register a named tool implementation for `function` and `class_method`
    /// tool types.  Call this before `build_agent()`.
    pub fn register_tool(&mut self, name: impl Into<String>, tool: crate::tools::loader::DynTool) {
        self.tool_registry.insert(name.into(), tool);
    }

    // ── Internal helpers ──────────────────────────────────────────────────────

    fn parse_auth(&self, auth: Option<&AuthConfig>) -> AuthConfig {
        auth.cloned().unwrap_or_default()
    }

    fn generate_remote_configs(&self) -> Vec<(String, AuthConfig)> {
        self.cfg
            .remote_agents_addresses
            .iter()
            .map(|ra| {
                let url = build_url(&ra.host, ra.port, ra.path.as_deref());
                let auth = self.parse_auth(ra.auth.as_ref());
                (url, auth)
            })
            .collect()
    }

    fn agent_skills(&self) -> Vec<AgentSkill> {
        self.cfg
            .skills
            .iter()
            .map(|s| AgentSkill {
                id: s.id.clone(),
                name: s.name.clone(),
                description: s.description.clone(),
                tags: s.tags.clone(),
                examples: s.examples.clone(),
            })
            .collect()
    }

    async fn load_tools(&self) -> Result<Vec<crate::tools::loader::DynTool>> {
        let mut factory = ToolFactory::new(&self.cfg.agent, self.tool_registry.clone());
        let mut tools = Vec::new();

        for raw in &self.cfg.tools {
            match factory.create_tool(raw.clone()).await {
                Ok(Some(t)) => tools.push(t),
                Ok(None) => {}
                Err(e) => tracing::error!("Tool load error: {e}"),
            }
        }
        Ok(tools)
    }

    // ── Public API ────────────────────────────────────────────────────────────

    /// Build and return the agent executor configured by the YAML file.
    pub async fn build_executor(&self) -> Result<Arc<dyn AgentExecutor>> {
        let tools = self.load_tools().await?;
        let card = self.build_agent_card();
        let instruction = self.cfg.instruction.clone().unwrap_or_default();

        // Connect remote agents declared in `remote_agents_addresses`.
        let remote_configs = self.generate_remote_configs();
        let mut remote_manager = crate::a2a::client::connector::RemoteAgentManager::new();
        if !remote_configs.is_empty() {
            if let Err(e) = remote_manager.connect(remote_configs).await {
                tracing::warn!("Some remote agents could not be connected: {e}");
            }
        }
        let remote_connections = remote_manager.take_connections();

        match self.cfg.agent.agent_type {
            AgentType::Adk | AgentType::Langchain => {
                let builder = AdkAgentBuilder::new(self.cfg.agent.model.clone());
                let agent = builder.build_agent(&self.cfg.name, &instruction, tools, remote_connections)?;
                Ok(builder.build_executor(agent, card))
            }
        }
    }

    /// Build the `AgentCard` from configuration.
    pub fn build_agent_card(&self) -> AgentCard {
        let auth = self.cfg.auth.clone().unwrap_or_default();
        let caps = self
            .cfg
            .capabilities
            .clone()
            .unwrap_or(CapabilitiesConfig { streaming: true, push_notifications: false, state_transition_history: false });

        AgentCard {
            name: self.cfg.name.clone(),
            description: self.cfg.description.clone(),
            url: self.cfg.url.clone(),
            version: self.cfg.version.clone(),
            default_input_modes: self.cfg.default_input_modes.clone(),
            default_output_modes: self.cfg.default_output_modes.clone(),
            capabilities: AgentCapabilities {
                streaming: caps.streaming,
                push_notifications: caps.push_notifications,
                state_transition_history: caps.state_transition_history,
            },
            skills: self.agent_skills(),
            security: auth.security,
            security_schemes: auth.security_schemes.map(|ss| {
                ss.into_iter()
                    .map(|(k, v)| (k, serde_json::to_value(v).unwrap_or_default()))
                    .collect()
            }),
        }
    }

    /// Start an HTTP server on `port` exposing the A2A endpoints.
    pub async fn run_server(self, port: u16) -> Result<()> {
        let executor = self.build_executor().await?;
        let card = self.build_agent_card();
        let auth_cfg = self.cfg.auth.clone().unwrap_or_default();

        let state = Arc::new(AuthMiddlewareState::new(
            card.clone(),
            auth_cfg,
            vec!["/.well-known/agent.json".into(), "/.well-known/agent-card.json".into(), "/health".into()],
        ));

        let executor = Arc::clone(&executor);
        let card_clone = card.clone();

        let app = Router::new()
            .route("/.well-known/agent.json", get({
                let card = card_clone.clone();
                move || async move { axum::Json(card.clone()) }
            }))
            // Alias used by the A2A inspector and newer A2A spec revisions
            .route("/.well-known/agent-card.json", get({
                let card = card_clone.clone();
                move || async move { axum::Json(card.clone()) }
            }))
            .route("/health", get(|| async { axum::Json(serde_json::json!({"status": "ok"})) }))
            .route("/", post({
                let exec = Arc::clone(&executor);
                move |axum::Json(rpc): axum::Json<JsonRpcRequest>| {
                    let exec = Arc::clone(&exec);
                    async move {
                        let id = rpc.id.clone();
                        tracing::info!(method = %rpc.method, id = ?id, "A2A request received");

                        // a2a-sdk wraps the message under params.message
                        let msg_val = rpc.params.get("message")
                            .cloned()
                            .unwrap_or_else(|| rpc.params.clone());
                        tracing::debug!(raw = %msg_val, "Extracted message value");

                        let events: Vec<serde_json::Value> = build_sse_events(id, msg_val, exec).await;

                        let s = stream::iter(
                            events.into_iter().map(|e| Ok::<Event, Infallible>(Event::default().data(e.to_string())))
                        );
                        Sse::new(s).keep_alive(KeepAlive::default())
                    }
                }
            }))
            .layer(middleware::from_fn_with_state(state, auth_middleware));

        let addr = format!("0.0.0.0:{port}");
        info!("Velocia AgentKit server starting on http://{addr}");

        let listener = tokio::net::TcpListener::bind(&addr)
            .await
            .map_err(|e| AgentKitError::A2aServer(e.to_string()))?;

        axum::serve(listener, app)
            .await
            .map_err(|e| AgentKitError::A2aServer(e.to_string()))
    }
}

// ── URL helpers ───────────────────────────────────────────────────────────────

#[allow(dead_code)]
fn build_url(host: &str, port: Option<u16>, path: Option<&str>) -> String {
    let protocol = get_protocol();
    let mut url = format!("{protocol}://{host}");
    if let Some(p) = port {
        url = format!("{url}:{p}");
    }
    if let Some(p) = path {
        let p = p.trim();
        if p.starts_with('/') {
            url.push_str(p);
        } else {
            url.push('/');
            url.push_str(p);
        }
    }
    url
}