Skip to main content

agent_diva_manager/
runtime.rs

1mod bootstrap;
2mod shutdown;
3mod task_runtime;
4
5use crate::state::ManagerCommand;
6use agent_diva_agent::{
7    agent_loop::SoulGovernanceSettings, context::SoulContextSettings,
8    runtime_control::RuntimeControlCommand, tool_config::network::NetworkToolConfig,
9    tool_config::network::WebFetchRuntimeConfig, tool_config::network::WebRuntimeConfig,
10    tool_config::network::WebSearchRuntimeConfig, AgentLoop, ToolConfig,
11};
12use agent_diva_channels::ChannelManager;
13use agent_diva_core::bus::{InboundMessage, MessageBus};
14use agent_diva_core::config::{Config, ConfigLoader};
15use agent_diva_core::cron::service::JobCallback;
16use agent_diva_core::cron::CronService;
17use agent_diva_files::{default_data_dir_or_fallback, FileConfig, FileManager};
18use agent_diva_providers::{
19    DynamicProvider, LLMProvider, LiteLLMClient, ProviderAccess, ProviderCatalogService,
20    ProviderRegistry,
21};
22use anyhow::Result;
23use std::path::PathBuf;
24use std::sync::Arc;
25use tokio::sync::{broadcast, mpsc, watch};
26use tokio::task::JoinHandle;
27use tracing::error;
28
29pub const DEFAULT_GATEWAY_PORT: u16 = 3000;
30
31#[derive(Clone)]
32pub struct GatewayRuntimeConfig {
33    pub config: Config,
34    pub loader: ConfigLoader,
35    pub workspace: PathBuf,
36    pub cron_store: PathBuf,
37    pub port: u16,
38}
39
40pub struct EmbeddedGatewayRuntime {
41    tasks: Option<GatewayTasks>,
42}
43
44impl EmbeddedGatewayRuntime {
45    pub async fn shutdown(mut self) {
46        if let Some(tasks) = self.tasks.take() {
47            shutdown::shutdown_runtime(tasks, false).await;
48        }
49    }
50}
51
52struct GatewayBootstrap {
53    config: Config,
54    loader: ConfigLoader,
55    port: u16,
56    bus: MessageBus,
57    cron_service: Arc<CronService>,
58    dynamic_provider: Arc<DynamicProvider>,
59    runtime_control_tx: mpsc::UnboundedSender<RuntimeControlCommand>,
60    provider_api_key: Option<String>,
61    provider_api_base: Option<String>,
62    agent: AgentLoop,
63    file_manager: Arc<FileManager>,
64}
65
66struct ChannelBootstrap {
67    channel_manager: Arc<ChannelManager>,
68    inbound_bridge_handle: JoinHandle<()>,
69}
70
71struct GatewayTasks {
72    bus: MessageBus,
73    cron_service: Arc<CronService>,
74    channel_manager: Arc<ChannelManager>,
75    server_shutdown_tx: broadcast::Sender<()>,
76    inbound_bridge_handle: JoinHandle<()>,
77    neuro_link_bridge_handle: Option<JoinHandle<()>>,
78    outbound_dispatch_handle: JoinHandle<()>,
79    channel_handle: JoinHandle<()>,
80    agent_handle: JoinHandle<()>,
81    manager_handle: JoinHandle<Result<()>>,
82    server_handle: JoinHandle<()>,
83    _api_tx_keepalive: mpsc::Sender<ManagerCommand>,
84}
85
86fn provider_registry() -> ProviderRegistry {
87    ProviderRegistry::new()
88}
89
90fn infer_provider_name_from_model(model: &str) -> Option<String> {
91    let registry = provider_registry();
92    model
93        .split('/')
94        .next()
95        .and_then(|prefix| registry.find_by_name(prefix))
96        .or_else(|| registry.find_by_model(model))
97        .map(|spec| spec.name.clone())
98}
99
100fn current_provider_name(config: &Config) -> Option<String> {
101    let preferred_provider = config
102        .agents
103        .defaults
104        .provider
105        .as_deref()
106        .map(str::trim)
107        .filter(|value| !value.is_empty())
108        .map(ToOwned::to_owned);
109    preferred_provider.or_else(|| infer_provider_name_from_model(&config.agents.defaults.model))
110}
111
112fn resolve_provider_name_for_model(
113    config: &Config,
114    model: &str,
115    preferred_provider: Option<&str>,
116) -> Option<String> {
117    let inferred_provider = infer_provider_name_from_model(model);
118    if let Some(provider_name) = preferred_provider
119        .map(str::trim)
120        .filter(|value| !value.is_empty())
121    {
122        let registry = provider_registry();
123        if registry.find_by_name(provider_name).is_some() {
124            return Some(provider_name.to_string());
125        }
126        if inferred_provider.as_deref() == Some(provider_name) {
127            return inferred_provider;
128        }
129    }
130
131    inferred_provider.or_else(|| {
132        (model == config.agents.defaults.model)
133            .then(|| current_provider_name(config))
134            .flatten()
135    })
136}
137
138fn build_provider(config: &Config, model: &str) -> Result<LiteLLMClient> {
139    let catalog = ProviderCatalogService::new();
140    let provider_name = resolve_provider_name_for_model(
141        config,
142        model,
143        (model == config.agents.defaults.model)
144            .then_some(config.agents.defaults.provider.as_deref())
145            .flatten(),
146    )
147    .ok_or_else(|| anyhow::anyhow!("No provider found for model: {}", model))?;
148    let access = catalog
149        .get_provider_access(config, &provider_name)
150        .unwrap_or_else(|| ProviderAccess::from_config(None));
151    let extra_headers = (!access.extra_headers.is_empty()).then(|| {
152        access
153            .extra_headers
154            .into_iter()
155            .collect::<std::collections::HashMap<String, String>>()
156    });
157
158    Ok(LiteLLMClient::new(
159        access.api_key,
160        access.api_base,
161        model.to_string(),
162        extra_headers,
163        Some(provider_name),
164        config.agents.defaults.reasoning_effort.clone(),
165    ))
166}
167
168fn build_network_tool_config(config: &Config) -> NetworkToolConfig {
169    let api_key = config.tools.web.search.api_key.trim().to_string();
170    NetworkToolConfig {
171        web: WebRuntimeConfig {
172            search: WebSearchRuntimeConfig {
173                provider: config.tools.web.search.provider.clone(),
174                enabled: config.tools.web.search.enabled,
175                api_key: if api_key.is_empty() {
176                    None
177                } else {
178                    Some(api_key)
179                },
180                max_results: config.tools.web.search.max_results,
181            },
182            fetch: WebFetchRuntimeConfig {
183                enabled: config.tools.web.fetch.enabled,
184            },
185        },
186    }
187}
188
189pub async fn run_local_gateway(runtime: GatewayRuntimeConfig) -> Result<()> {
190    let port = runtime.port;
191    let bootstrap = bootstrap::bootstrap_runtime(runtime).await?;
192    let channel_bootstrap =
193        bootstrap::bootstrap_channel_runtime(&bootstrap.config, bootstrap.bus.clone()).await;
194    let mut tasks = task_runtime::start_runtime_tasks(bootstrap, channel_bootstrap).await;
195    tracing::info!(
196        "Gateway ready; HTTP API at http://127.0.0.1:{} (Ctrl+C to stop)",
197        port
198    );
199    let manager_handle_completed = shutdown::wait_for_shutdown(&mut tasks).await;
200    shutdown::shutdown_runtime(tasks, manager_handle_completed).await;
201    Ok(())
202}
203
204pub async fn start_embedded_gateway_runtime(
205    runtime: GatewayRuntimeConfig,
206    listener: tokio::net::TcpListener,
207    shutdown_rx: watch::Receiver<bool>,
208) -> Result<EmbeddedGatewayRuntime> {
209    let bootstrap = bootstrap::bootstrap_runtime(runtime).await?;
210    let channel_bootstrap =
211        bootstrap::bootstrap_channel_runtime(&bootstrap.config, bootstrap.bus.clone()).await;
212    let tasks = task_runtime::start_embedded_runtime_tasks(
213        bootstrap,
214        channel_bootstrap,
215        listener,
216        shutdown_rx,
217    )
218    .await;
219    Ok(EmbeddedGatewayRuntime { tasks: Some(tasks) })
220}
221
222async fn start_cron_service(cron_store: PathBuf, bus: MessageBus) -> Arc<CronService> {
223    let cron_service = Arc::new(CronService::new(cron_store, Some(build_cron_callback(bus))));
224    cron_service.start().await;
225    cron_service
226}
227
228fn build_cron_callback(bus: MessageBus) -> JobCallback {
229    Arc::new(
230        move |job: agent_diva_core::cron::CronJob,
231              cancel_token|
232              -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send>> {
233            let bus = bus.clone();
234            Box::pin(async move {
235                if cancel_token.is_cancelled() {
236                    return Some("Error: cancelled".to_string());
237                }
238                let deliver = job.payload.deliver;
239                if !deliver {
240                    return Some("skipped (deliver=false)".to_string());
241                }
242
243                let target_channel = job
244                    .payload
245                    .channel
246                    .clone()
247                    .unwrap_or_else(|| "cli".to_string());
248                let target_chat_id = job
249                    .payload
250                    .to
251                    .clone()
252                    .unwrap_or_else(|| "direct".to_string());
253                let (conversation_channel, conversation_chat_id) = if target_channel == "gui" {
254                    let chat_id = if target_chat_id.starts_with("cron:") {
255                        target_chat_id
256                    } else {
257                        format!("cron:{}", target_chat_id)
258                    };
259                    ("api".to_string(), chat_id)
260                } else {
261                    (target_channel.clone(), target_chat_id)
262                };
263
264                let inbound = InboundMessage::new(
265                    conversation_channel,
266                    "cron",
267                    conversation_chat_id,
268                    job.payload.message,
269                )
270                .with_metadata("cron_job_id", job.id.clone())
271                .with_metadata("cron_trigger", "scheduled")
272                .with_metadata("cron_delivery_channel", target_channel);
273
274                if let Err(e) = bus.publish_inbound(inbound) {
275                    error!("Failed to publish cron inbound job {}: {}", job.id, e);
276                    return Some(format!(
277                        "failed to publish cron inbound job {}: {}",
278                        job.id, e
279                    ));
280                }
281
282                Some("triggered agent turn".to_string())
283            })
284        },
285    )
286}
287
288async fn build_agent_loop(
289    config: &Config,
290    bus: MessageBus,
291    dynamic_provider: Arc<DynamicProvider>,
292    workspace: PathBuf,
293    runtime_control_rx: mpsc::UnboundedReceiver<RuntimeControlCommand>,
294    cron_service: Arc<CronService>,
295    file_manager: Arc<FileManager>,
296) -> Result<AgentLoop> {
297    let agent_provider: Arc<dyn LLMProvider> = dynamic_provider;
298    let tool_config = ToolConfig {
299        network: build_network_tool_config(config),
300        exec_timeout: config.tools.exec.timeout,
301        restrict_to_workspace: config.tools.restrict_to_workspace,
302        mcp_servers: config.tools.active_mcp_servers(),
303        cron_service: Some(cron_service),
304        soul_context: SoulContextSettings {
305            enabled: config.agents.soul.enabled,
306            max_chars: config.agents.soul.max_chars,
307            bootstrap_once: config.agents.soul.bootstrap_once,
308        },
309        notify_on_soul_change: config.agents.soul.notify_on_change,
310        soul_governance: SoulGovernanceSettings {
311            frequent_change_window_secs: config.agents.soul.frequent_change_window_secs,
312            frequent_change_threshold: config.agents.soul.frequent_change_threshold,
313            boundary_confirmation_hint: config.agents.soul.boundary_confirmation_hint,
314        },
315    };
316
317    AgentLoop::with_tools(
318        bus,
319        agent_provider,
320        workspace,
321        Some(config.agents.defaults.model.clone()),
322        Some(config.agents.defaults.max_tool_iterations as usize),
323        tool_config,
324        Some(runtime_control_rx),
325        file_manager,
326    )
327    .await
328    .map_err(|e| anyhow::anyhow!("Failed to create agent loop: {}", e))
329}
330
331fn resolve_provider_credentials(config: &Config) -> Result<(Option<String>, Option<String>)> {
332    let provider_name = current_provider_name(config)
333        .ok_or_else(|| anyhow::anyhow!("No provider found for model"))?;
334    let catalog = ProviderCatalogService::new();
335    let access = catalog
336        .get_provider_access(config, &provider_name)
337        .unwrap_or_else(|| ProviderAccess::from_config(None));
338    let resolved_api_base = access.api_base.clone().or_else(|| {
339        catalog
340            .get_provider_view(config, &provider_name)
341            .and_then(|view| view.api_base)
342    });
343    Ok((access.api_key, resolved_api_base))
344}