rustvani 0.3.0

Voice AI framework for Rust — real-time speech pipelines with STT, LLM, TTS, and Dhara conversation flows
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use async_trait::async_trait;
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use tokio::time::{timeout, Duration};

use crate::clock::BaseClock;
use crate::error::{PipecatError, Result};
use crate::observer::BaseObserver;

use super::base::Agent;
use super::bus::{AgentBus, AgentRegistryEntry, BusMessage, BusPayload, BusSubscriber};
use super::registry::{AgentInfo, AgentRegistry};

// ---------------------------------------------------------------------------
// AgentRunner
// ---------------------------------------------------------------------------

pub struct AgentRunner {
    name: String,
    bus: Arc<dyn AgentBus>,
    registry: Arc<AgentRegistry>,
    agents: Mutex<HashMap<String, Arc<dyn Agent>>>,
    clock: Arc<dyn BaseClock>,
    observer: Option<Arc<dyn BaseObserver>>,
    shutdown: Arc<Notify>,
    shutdown_requested: Arc<AtomicBool>,
    started: Arc<AtomicBool>,
    agent_tasks: Mutex<Vec<JoinHandle<()>>>,
}

impl AgentRunner {
    pub fn new(name: impl Into<String>, bus: Arc<dyn AgentBus>, clock: Arc<dyn BaseClock>) -> Self {
        let name = name.into();
        let registry = AgentRegistry::new(&name);
        Self {
            name: name.clone(),
            bus,
            registry,
            agents: Mutex::new(HashMap::new()),
            clock,
            observer: None,
            shutdown: Arc::new(Notify::new()),
            shutdown_requested: Arc::new(AtomicBool::new(false)),
            started: Arc::new(AtomicBool::new(false)),
            agent_tasks: Mutex::new(Vec::new()),
        }
    }

    pub fn with_observer(mut self, observer: Arc<dyn BaseObserver>) -> Self {
        self.observer = Some(observer);
        self
    }

    pub fn bus(&self) -> &Arc<dyn AgentBus> {
        &self.bus
    }

    pub fn registry(&self) -> &Arc<AgentRegistry> {
        &self.registry
    }

    pub async fn add_agent(&self, agent: Arc<dyn Agent>) -> Result<()> {
        let name = agent.name().to_string();
        let mut agents = self.agents.lock().await;
        if agents.contains_key(&name) {
            log::error!("Agent '{}' already exists, skipping", name);
            return Ok(());
        }
        agents.insert(name, agent);
        Ok(())
    }

    pub async fn run(&self) -> Result<()> {
        if self.started.swap(true, Ordering::Relaxed) {
            return Err(PipecatError::pipeline(
                "AgentRunner::run() called more than once",
            ));
        }

        log::debug!("AgentRunner '{}' starting", self.name);

        // Subscribe runner to bus
        let runner_sub = Arc::new(RunnerSubscriber {
            runner_name: self.name.clone(),
            shutdown: self.shutdown.clone(),
            shutdown_requested: self.shutdown_requested.clone(),
            registry: self.registry.clone(),
            bus: self.bus.clone(),
        });
        self.bus.subscribe(runner_sub).await?;
        self.bus.start().await;

        // Subscribe all agents to bus
        let agents = self.agents.lock().await.clone();
        for agent in agents.values() {
            let wrapper = Arc::new(AgentSubscriberWrapper(agent.clone()));
            self.bus.subscribe(wrapper).await?;
        }

        // Build the parent → children index. The End/Cancel cascade itself
        // runs through the registry (BaseAgent::end), so this is primarily
        // validation: a dangling parent reference means that agent will
        // never receive a cascaded End.
        {
            let mut children_index: HashMap<String, Vec<String>> = HashMap::new();
            for (name, agent) in &agents {
                if let Some(parent) = agent.parent() {
                    if !agents.contains_key(parent) {
                        log::warn!(
                            "Agent '{}' declares unknown parent '{}' — it will not \
                             receive cascaded End/Cancel",
                            name,
                            parent
                        );
                    }
                    children_index
                        .entry(parent.to_string())
                        .or_default()
                        .push(name.clone());
                }
            }
            if !children_index.is_empty() {
                log::debug!(
                    "AgentRunner '{}' agent tree: {:?}",
                    self.name,
                    children_index
                );
            }
        }

        // Setup all agents
        for agent in agents.values() {
            agent.setup(self.bus.clone(), self.registry.clone()).await?;
        }

        // Start all agents
        let mut tasks = self.agent_tasks.lock().await;
        for agent in agents.values() {
            let agent = agent.clone();
            let clock = self.clock.clone();
            let observer = self.observer.clone();
            let handle = tokio::spawn(async move {
                if let Err(e) = agent.run(clock, observer).await {
                    log::error!("Agent '{}' run error: {}", agent.name(), e);
                }
            });
            tasks.push(handle);
        }
        drop(tasks);

        // Wait for shutdown
        self.shutdown.notified().await;

        // End root agents
        let agents = self.agents.lock().await;
        for agent in agents.values() {
            if agent.parent().is_none() {
                agent.end(None).await.ok();
            }
        }

        // Wait for agent tasks to finish (with timeout)
        let mut tasks = self.agent_tasks.lock().await;
        for handle in tasks.drain(..) {
            let _ = timeout(Duration::from_secs(10), handle).await;
        }

        self.bus.stop().await;

        log::debug!("AgentRunner '{}' finished", self.name);
        Ok(())
    }

    pub async fn end(&self, reason: Option<String>) {
        if self.shutdown_requested.swap(true, Ordering::Relaxed) {
            return;
        }
        log::debug!("AgentRunner '{}' ending gracefully", self.name);

        let agents = self.agents.lock().await;
        for agent in agents.values() {
            if agent.parent().is_none() {
                let msg = BusMessage::new(
                    self.name.clone(),
                    Some(agent.name().to_string()),
                    BusPayload::End {
                        reason: reason.clone(),
                    },
                );
                self.bus.send(msg).await;
            }
        }
        self.shutdown.notify_one();
    }

    pub async fn cancel(&self, reason: Option<String>) {
        if self.shutdown_requested.swap(true, Ordering::Relaxed) {
            return;
        }
        log::debug!("AgentRunner '{}' cancelling", self.name);

        let agents = self.agents.lock().await;
        for agent in agents.values() {
            if agent.parent().is_none() {
                let msg = BusMessage::new(
                    self.name.clone(),
                    Some(agent.name().to_string()),
                    BusPayload::Cancel {
                        reason: reason.clone(),
                    },
                );
                self.bus.send(msg).await;
            }
        }
        self.shutdown.notify_one();
    }
}

// ---------------------------------------------------------------------------
// AgentSubscriberWrapper — bridges Arc<dyn Agent> into Arc<dyn BusSubscriber>
// ---------------------------------------------------------------------------

struct AgentSubscriberWrapper(Arc<dyn Agent>);

#[async_trait]
impl BusSubscriber for AgentSubscriberWrapper {
    fn name(&self) -> &str {
        self.0.name()
    }

    async fn on_bus_message(&self, message: Arc<BusMessage>) {
        self.0.on_bus_message(message).await;
    }
}

// ---------------------------------------------------------------------------
// RunnerSubscriber — handles runner-level bus messages
// ---------------------------------------------------------------------------

struct RunnerSubscriber {
    runner_name: String,
    shutdown: Arc<Notify>,
    shutdown_requested: Arc<AtomicBool>,
    registry: Arc<AgentRegistry>,
    bus: Arc<dyn AgentBus>,
}

#[async_trait]
impl BusSubscriber for RunnerSubscriber {
    fn name(&self) -> &str {
        &self.runner_name
    }

    async fn on_bus_message(&self, message: Arc<BusMessage>) {
        if message.source == self.runner_name {
            return;
        }

        match &message.payload {
            BusPayload::End { .. } => {
                if self.shutdown_requested.swap(true, Ordering::Relaxed) {
                    return;
                }
                self.shutdown.notify_one();
            }
            BusPayload::Cancel { .. } => {
                if self.shutdown_requested.swap(true, Ordering::Relaxed) {
                    return;
                }
                self.shutdown.notify_one();
            }
            BusPayload::AgentReady {
                runner,
                parent,
                active,
                bridged,
                started_at,
            } => {
                let is_local = *runner == self.runner_name;

                if !is_local {
                    let info = AgentInfo {
                        name: message.source.clone(),
                        runner: runner.clone(),
                        parent: parent.clone(),
                        active: *active,
                        bridged: *bridged,
                        started_at: *started_at,
                    };
                    self.registry.register(info).await;
                }

                // When a local agent becomes ready, broadcast our registry
                // so remote runners can discover us.
                if is_local {
                    let entries: Vec<AgentRegistryEntry> = {
                        let local_names = self.registry.local_agents().await;
                        let mut entries = Vec::new();
                        for name in local_names {
                            if let Some(info) = self.registry.get(&name).await {
                                entries.push(AgentRegistryEntry {
                                    name: info.name,
                                    parent: info.parent,
                                    active: info.active,
                                    bridged: info.bridged,
                                    started_at: info.started_at,
                                });
                            }
                        }
                        entries
                    };

                    if !entries.is_empty() {
                        let msg = BusMessage::new(
                            self.runner_name.clone(),
                            None,
                            BusPayload::AgentRegistry {
                                runner: self.runner_name.clone(),
                                agents: entries,
                            },
                        );
                        self.bus.send(msg).await;
                    }
                }
            }
            BusPayload::AgentRegistry { runner, agents } if *runner != self.runner_name => {
                for entry in agents {
                    let info = AgentInfo {
                        name: entry.name.clone(),
                        runner: runner.clone(),
                        parent: entry.parent.clone(),
                        active: entry.active,
                        bridged: entry.bridged,
                        started_at: entry.started_at,
                    };
                    self.registry.register(info).await;
                }
            }
            _ => {}
        }
    }
}