1mod bootstrap;
2mod shutdown;
3mod task_runtime;
4
5use crate::state::ManagerCommand;
6use agent_diva_agent::{
7 agent_loop::SoulGovernanceSettings, context::SoulContextSettings,
8 runtime_control::RuntimeControlCommand, tool_config::network::NetworkToolConfig,
9 tool_config::network::WebFetchRuntimeConfig, tool_config::network::WebRuntimeConfig,
10 tool_config::network::WebSearchRuntimeConfig, AgentLoop, ToolConfig,
11};
12use agent_diva_channels::ChannelManager;
13use agent_diva_core::bus::{InboundMessage, MessageBus};
14use agent_diva_core::config::{Config, ConfigLoader};
15use agent_diva_core::cron::service::JobCallback;
16use agent_diva_core::cron::CronService;
17use agent_diva_files::{default_data_dir_or_fallback, FileConfig, FileManager};
18use agent_diva_providers::{
19 DynamicProvider, LLMProvider, LiteLLMClient, ProviderAccess, ProviderCatalogService,
20 ProviderRegistry,
21};
22use anyhow::Result;
23use std::path::PathBuf;
24use std::sync::Arc;
25use tokio::sync::{broadcast, mpsc, watch};
26use tokio::task::JoinHandle;
27use tracing::error;
28
29pub const DEFAULT_GATEWAY_PORT: u16 = 3000;
30
31#[derive(Clone)]
32pub struct GatewayRuntimeConfig {
33 pub config: Config,
34 pub loader: ConfigLoader,
35 pub workspace: PathBuf,
36 pub cron_store: PathBuf,
37 pub port: u16,
38}
39
40pub struct EmbeddedGatewayRuntime {
41 tasks: Option<GatewayTasks>,
42}
43
44impl EmbeddedGatewayRuntime {
45 pub async fn shutdown(mut self) {
46 if let Some(tasks) = self.tasks.take() {
47 shutdown::shutdown_runtime(tasks, false).await;
48 }
49 }
50}
51
52struct GatewayBootstrap {
53 config: Config,
54 loader: ConfigLoader,
55 port: u16,
56 bus: MessageBus,
57 cron_service: Arc<CronService>,
58 dynamic_provider: Arc<DynamicProvider>,
59 runtime_control_tx: mpsc::UnboundedSender<RuntimeControlCommand>,
60 provider_api_key: Option<String>,
61 provider_api_base: Option<String>,
62 agent: AgentLoop,
63 file_manager: Arc<FileManager>,
64}
65
66struct ChannelBootstrap {
67 channel_manager: Arc<ChannelManager>,
68 inbound_bridge_handle: JoinHandle<()>,
69}
70
71struct GatewayTasks {
72 bus: MessageBus,
73 cron_service: Arc<CronService>,
74 channel_manager: Arc<ChannelManager>,
75 server_shutdown_tx: broadcast::Sender<()>,
76 inbound_bridge_handle: JoinHandle<()>,
77 neuro_link_bridge_handle: Option<JoinHandle<()>>,
78 outbound_dispatch_handle: JoinHandle<()>,
79 channel_handle: JoinHandle<()>,
80 agent_handle: JoinHandle<()>,
81 manager_handle: JoinHandle<Result<()>>,
82 server_handle: JoinHandle<()>,
83 _api_tx_keepalive: mpsc::Sender<ManagerCommand>,
84}
85
86fn provider_registry() -> ProviderRegistry {
87 ProviderRegistry::new()
88}
89
90fn infer_provider_name_from_model(model: &str) -> Option<String> {
91 let registry = provider_registry();
92 model
93 .split('/')
94 .next()
95 .and_then(|prefix| registry.find_by_name(prefix))
96 .or_else(|| registry.find_by_model(model))
97 .map(|spec| spec.name.clone())
98}
99
100fn current_provider_name(config: &Config) -> Option<String> {
101 let preferred_provider = config
102 .agents
103 .defaults
104 .provider
105 .as_deref()
106 .map(str::trim)
107 .filter(|value| !value.is_empty())
108 .map(ToOwned::to_owned);
109 preferred_provider.or_else(|| infer_provider_name_from_model(&config.agents.defaults.model))
110}
111
112fn resolve_provider_name_for_model(
113 config: &Config,
114 model: &str,
115 preferred_provider: Option<&str>,
116) -> Option<String> {
117 let inferred_provider = infer_provider_name_from_model(model);
118 if let Some(provider_name) = preferred_provider
119 .map(str::trim)
120 .filter(|value| !value.is_empty())
121 {
122 let registry = provider_registry();
123 if registry.find_by_name(provider_name).is_some() {
124 return Some(provider_name.to_string());
125 }
126 if inferred_provider.as_deref() == Some(provider_name) {
127 return inferred_provider;
128 }
129 }
130
131 inferred_provider.or_else(|| {
132 (model == config.agents.defaults.model)
133 .then(|| current_provider_name(config))
134 .flatten()
135 })
136}
137
138fn build_provider(config: &Config, model: &str) -> Result<LiteLLMClient> {
139 let catalog = ProviderCatalogService::new();
140 let provider_name = resolve_provider_name_for_model(
141 config,
142 model,
143 (model == config.agents.defaults.model)
144 .then_some(config.agents.defaults.provider.as_deref())
145 .flatten(),
146 )
147 .ok_or_else(|| anyhow::anyhow!("No provider found for model: {}", model))?;
148 let access = catalog
149 .get_provider_access(config, &provider_name)
150 .unwrap_or_else(|| ProviderAccess::from_config(None));
151 let extra_headers = (!access.extra_headers.is_empty()).then(|| {
152 access
153 .extra_headers
154 .into_iter()
155 .collect::<std::collections::HashMap<String, String>>()
156 });
157
158 Ok(LiteLLMClient::new(
159 access.api_key,
160 access.api_base,
161 model.to_string(),
162 extra_headers,
163 Some(provider_name),
164 config.agents.defaults.reasoning_effort.clone(),
165 ))
166}
167
168fn build_network_tool_config(config: &Config) -> NetworkToolConfig {
169 let api_key = config.tools.web.search.api_key.trim().to_string();
170 NetworkToolConfig {
171 web: WebRuntimeConfig {
172 search: WebSearchRuntimeConfig {
173 provider: config.tools.web.search.provider.clone(),
174 enabled: config.tools.web.search.enabled,
175 api_key: if api_key.is_empty() {
176 None
177 } else {
178 Some(api_key)
179 },
180 max_results: config.tools.web.search.max_results,
181 },
182 fetch: WebFetchRuntimeConfig {
183 enabled: config.tools.web.fetch.enabled,
184 },
185 },
186 }
187}
188
189pub async fn run_local_gateway(runtime: GatewayRuntimeConfig) -> Result<()> {
190 let port = runtime.port;
191 let bootstrap = bootstrap::bootstrap_runtime(runtime).await?;
192 let channel_bootstrap =
193 bootstrap::bootstrap_channel_runtime(&bootstrap.config, bootstrap.bus.clone()).await;
194 let mut tasks = task_runtime::start_runtime_tasks(bootstrap, channel_bootstrap).await;
195 tracing::info!(
196 "Gateway ready; HTTP API at http://127.0.0.1:{} (Ctrl+C to stop)",
197 port
198 );
199 let manager_handle_completed = shutdown::wait_for_shutdown(&mut tasks).await;
200 shutdown::shutdown_runtime(tasks, manager_handle_completed).await;
201 Ok(())
202}
203
204pub async fn start_embedded_gateway_runtime(
205 runtime: GatewayRuntimeConfig,
206 listener: tokio::net::TcpListener,
207 shutdown_rx: watch::Receiver<bool>,
208) -> Result<EmbeddedGatewayRuntime> {
209 let bootstrap = bootstrap::bootstrap_runtime(runtime).await?;
210 let channel_bootstrap =
211 bootstrap::bootstrap_channel_runtime(&bootstrap.config, bootstrap.bus.clone()).await;
212 let tasks = task_runtime::start_embedded_runtime_tasks(
213 bootstrap,
214 channel_bootstrap,
215 listener,
216 shutdown_rx,
217 )
218 .await;
219 Ok(EmbeddedGatewayRuntime { tasks: Some(tasks) })
220}
221
222async fn start_cron_service(cron_store: PathBuf, bus: MessageBus) -> Arc<CronService> {
223 let cron_service = Arc::new(CronService::new(cron_store, Some(build_cron_callback(bus))));
224 cron_service.start().await;
225 cron_service
226}
227
228fn build_cron_callback(bus: MessageBus) -> JobCallback {
229 Arc::new(
230 move |job: agent_diva_core::cron::CronJob,
231 cancel_token|
232 -> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send>> {
233 let bus = bus.clone();
234 Box::pin(async move {
235 if cancel_token.is_cancelled() {
236 return Some("Error: cancelled".to_string());
237 }
238 let deliver = job.payload.deliver;
239 if !deliver {
240 return Some("skipped (deliver=false)".to_string());
241 }
242
243 let target_channel = job
244 .payload
245 .channel
246 .clone()
247 .unwrap_or_else(|| "cli".to_string());
248 let target_chat_id = job
249 .payload
250 .to
251 .clone()
252 .unwrap_or_else(|| "direct".to_string());
253 let (conversation_channel, conversation_chat_id) = if target_channel == "gui" {
254 let chat_id = if target_chat_id.starts_with("cron:") {
255 target_chat_id
256 } else {
257 format!("cron:{}", target_chat_id)
258 };
259 ("api".to_string(), chat_id)
260 } else {
261 (target_channel.clone(), target_chat_id)
262 };
263
264 let inbound = InboundMessage::new(
265 conversation_channel,
266 "cron",
267 conversation_chat_id,
268 job.payload.message,
269 )
270 .with_metadata("cron_job_id", job.id.clone())
271 .with_metadata("cron_trigger", "scheduled")
272 .with_metadata("cron_delivery_channel", target_channel);
273
274 if let Err(e) = bus.publish_inbound(inbound) {
275 error!("Failed to publish cron inbound job {}: {}", job.id, e);
276 return Some(format!(
277 "failed to publish cron inbound job {}: {}",
278 job.id, e
279 ));
280 }
281
282 Some("triggered agent turn".to_string())
283 })
284 },
285 )
286}
287
288async fn build_agent_loop(
289 config: &Config,
290 bus: MessageBus,
291 dynamic_provider: Arc<DynamicProvider>,
292 workspace: PathBuf,
293 runtime_control_rx: mpsc::UnboundedReceiver<RuntimeControlCommand>,
294 cron_service: Arc<CronService>,
295 file_manager: Arc<FileManager>,
296) -> Result<AgentLoop> {
297 let agent_provider: Arc<dyn LLMProvider> = dynamic_provider;
298 let tool_config = ToolConfig {
299 network: build_network_tool_config(config),
300 exec_timeout: config.tools.exec.timeout,
301 restrict_to_workspace: config.tools.restrict_to_workspace,
302 mcp_servers: config.tools.active_mcp_servers(),
303 cron_service: Some(cron_service),
304 soul_context: SoulContextSettings {
305 enabled: config.agents.soul.enabled,
306 max_chars: config.agents.soul.max_chars,
307 bootstrap_once: config.agents.soul.bootstrap_once,
308 },
309 notify_on_soul_change: config.agents.soul.notify_on_change,
310 soul_governance: SoulGovernanceSettings {
311 frequent_change_window_secs: config.agents.soul.frequent_change_window_secs,
312 frequent_change_threshold: config.agents.soul.frequent_change_threshold,
313 boundary_confirmation_hint: config.agents.soul.boundary_confirmation_hint,
314 },
315 };
316
317 AgentLoop::with_tools(
318 bus,
319 agent_provider,
320 workspace,
321 Some(config.agents.defaults.model.clone()),
322 Some(config.agents.defaults.max_tool_iterations as usize),
323 tool_config,
324 Some(runtime_control_rx),
325 file_manager,
326 )
327 .await
328 .map_err(|e| anyhow::anyhow!("Failed to create agent loop: {}", e))
329}
330
331fn resolve_provider_credentials(config: &Config) -> Result<(Option<String>, Option<String>)> {
332 let provider_name = current_provider_name(config)
333 .ok_or_else(|| anyhow::anyhow!("No provider found for model"))?;
334 let catalog = ProviderCatalogService::new();
335 let access = catalog
336 .get_provider_access(config, &provider_name)
337 .unwrap_or_else(|| ProviderAccess::from_config(None));
338 let resolved_api_base = access.api_base.clone().or_else(|| {
339 catalog
340 .get_provider_view(config, &provider_name)
341 .and_then(|view| view.api_base)
342 });
343 Ok((access.api_key, resolved_api_base))
344}