mod bootstrap;
mod shutdown;
mod task_runtime;
use crate::state::ManagerCommand;
use agent_diva_agent::{
agent_loop::SoulGovernanceSettings, context::SoulContextSettings,
runtime_control::RuntimeControlCommand, tool_config::network::NetworkToolConfig,
tool_config::network::WebFetchRuntimeConfig, tool_config::network::WebRuntimeConfig,
tool_config::network::WebSearchRuntimeConfig, AgentLoop, ToolConfig,
};
use agent_diva_channels::ChannelManager;
use agent_diva_core::bus::{InboundMessage, MessageBus};
use agent_diva_core::config::{Config, ConfigLoader};
use agent_diva_core::cron::service::JobCallback;
use agent_diva_core::cron::CronService;
use agent_diva_files::{default_data_dir_or_fallback, FileConfig, FileManager};
use agent_diva_providers::{
DynamicProvider, LLMProvider, LiteLLMClient, ProviderAccess, ProviderCatalogService,
ProviderRegistry,
};
use anyhow::Result;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, watch};
use tokio::task::JoinHandle;
use tracing::error;
pub const DEFAULT_GATEWAY_PORT: u16 = 3000;
#[derive(Clone)]
pub struct GatewayRuntimeConfig {
pub config: Config,
pub loader: ConfigLoader,
pub workspace: PathBuf,
pub cron_store: PathBuf,
pub port: u16,
}
pub struct EmbeddedGatewayRuntime {
tasks: Option<GatewayTasks>,
}
impl EmbeddedGatewayRuntime {
pub async fn shutdown(mut self) {
if let Some(tasks) = self.tasks.take() {
shutdown::shutdown_runtime(tasks, false).await;
}
}
}
struct GatewayBootstrap {
config: Config,
loader: ConfigLoader,
port: u16,
bus: MessageBus,
cron_service: Arc<CronService>,
dynamic_provider: Arc<DynamicProvider>,
runtime_control_tx: mpsc::UnboundedSender<RuntimeControlCommand>,
provider_api_key: Option<String>,
provider_api_base: Option<String>,
agent: AgentLoop,
file_manager: Arc<FileManager>,
}
struct ChannelBootstrap {
channel_manager: Arc<ChannelManager>,
inbound_bridge_handle: JoinHandle<()>,
}
struct GatewayTasks {
bus: MessageBus,
cron_service: Arc<CronService>,
channel_manager: Arc<ChannelManager>,
server_shutdown_tx: broadcast::Sender<()>,
inbound_bridge_handle: JoinHandle<()>,
neuro_link_bridge_handle: Option<JoinHandle<()>>,
outbound_dispatch_handle: JoinHandle<()>,
channel_handle: JoinHandle<()>,
agent_handle: JoinHandle<()>,
manager_handle: JoinHandle<Result<()>>,
server_handle: JoinHandle<()>,
_api_tx_keepalive: mpsc::Sender<ManagerCommand>,
}
fn provider_registry() -> ProviderRegistry {
ProviderRegistry::new()
}
fn infer_provider_name_from_model(model: &str) -> Option<String> {
let registry = provider_registry();
model
.split('/')
.next()
.and_then(|prefix| registry.find_by_name(prefix))
.or_else(|| registry.find_by_model(model))
.map(|spec| spec.name.clone())
}
fn current_provider_name(config: &Config) -> Option<String> {
let preferred_provider = config
.agents
.defaults
.provider
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned);
preferred_provider.or_else(|| infer_provider_name_from_model(&config.agents.defaults.model))
}
fn resolve_provider_name_for_model(
config: &Config,
model: &str,
preferred_provider: Option<&str>,
) -> Option<String> {
let inferred_provider = infer_provider_name_from_model(model);
if let Some(provider_name) = preferred_provider
.map(str::trim)
.filter(|value| !value.is_empty())
{
let registry = provider_registry();
if registry.find_by_name(provider_name).is_some() {
return Some(provider_name.to_string());
}
if inferred_provider.as_deref() == Some(provider_name) {
return inferred_provider;
}
}
inferred_provider.or_else(|| {
(model == config.agents.defaults.model)
.then(|| current_provider_name(config))
.flatten()
})
}
fn build_provider(config: &Config, model: &str) -> Result<LiteLLMClient> {
let catalog = ProviderCatalogService::new();
let provider_name = resolve_provider_name_for_model(
config,
model,
(model == config.agents.defaults.model)
.then_some(config.agents.defaults.provider.as_deref())
.flatten(),
)
.ok_or_else(|| anyhow::anyhow!("No provider found for model: {}", model))?;
let access = catalog
.get_provider_access(config, &provider_name)
.unwrap_or_else(|| ProviderAccess::from_config(None));
let extra_headers = (!access.extra_headers.is_empty()).then(|| {
access
.extra_headers
.into_iter()
.collect::<std::collections::HashMap<String, String>>()
});
Ok(LiteLLMClient::new(
access.api_key,
access.api_base,
model.to_string(),
extra_headers,
Some(provider_name),
config.agents.defaults.reasoning_effort.clone(),
))
}
fn build_network_tool_config(config: &Config) -> NetworkToolConfig {
let api_key = config.tools.web.search.api_key.trim().to_string();
NetworkToolConfig {
web: WebRuntimeConfig {
search: WebSearchRuntimeConfig {
provider: config.tools.web.search.provider.clone(),
enabled: config.tools.web.search.enabled,
api_key: if api_key.is_empty() {
None
} else {
Some(api_key)
},
max_results: config.tools.web.search.max_results,
},
fetch: WebFetchRuntimeConfig {
enabled: config.tools.web.fetch.enabled,
},
},
}
}
pub async fn run_local_gateway(runtime: GatewayRuntimeConfig) -> Result<()> {
let port = runtime.port;
let bootstrap = bootstrap::bootstrap_runtime(runtime).await?;
let channel_bootstrap =
bootstrap::bootstrap_channel_runtime(&bootstrap.config, bootstrap.bus.clone()).await;
let mut tasks = task_runtime::start_runtime_tasks(bootstrap, channel_bootstrap).await;
tracing::info!(
"Gateway ready; HTTP API at http://127.0.0.1:{} (Ctrl+C to stop)",
port
);
let manager_handle_completed = shutdown::wait_for_shutdown(&mut tasks).await;
shutdown::shutdown_runtime(tasks, manager_handle_completed).await;
Ok(())
}
pub async fn start_embedded_gateway_runtime(
runtime: GatewayRuntimeConfig,
listener: tokio::net::TcpListener,
shutdown_rx: watch::Receiver<bool>,
) -> Result<EmbeddedGatewayRuntime> {
let bootstrap = bootstrap::bootstrap_runtime(runtime).await?;
let channel_bootstrap =
bootstrap::bootstrap_channel_runtime(&bootstrap.config, bootstrap.bus.clone()).await;
let tasks = task_runtime::start_embedded_runtime_tasks(
bootstrap,
channel_bootstrap,
listener,
shutdown_rx,
)
.await;
Ok(EmbeddedGatewayRuntime { tasks: Some(tasks) })
}
async fn start_cron_service(cron_store: PathBuf, bus: MessageBus) -> Arc<CronService> {
let cron_service = Arc::new(CronService::new(cron_store, Some(build_cron_callback(bus))));
cron_service.start().await;
cron_service
}
fn build_cron_callback(bus: MessageBus) -> JobCallback {
Arc::new(
move |job: agent_diva_core::cron::CronJob,
cancel_token|
-> std::pin::Pin<Box<dyn std::future::Future<Output = Option<String>> + Send>> {
let bus = bus.clone();
Box::pin(async move {
if cancel_token.is_cancelled() {
return Some("Error: cancelled".to_string());
}
let deliver = job.payload.deliver;
if !deliver {
return Some("skipped (deliver=false)".to_string());
}
let target_channel = job
.payload
.channel
.clone()
.unwrap_or_else(|| "cli".to_string());
let target_chat_id = job
.payload
.to
.clone()
.unwrap_or_else(|| "direct".to_string());
let (conversation_channel, conversation_chat_id) = if target_channel == "gui" {
let chat_id = if target_chat_id.starts_with("cron:") {
target_chat_id
} else {
format!("cron:{}", target_chat_id)
};
("api".to_string(), chat_id)
} else {
(target_channel.clone(), target_chat_id)
};
let inbound = InboundMessage::new(
conversation_channel,
"cron",
conversation_chat_id,
job.payload.message,
)
.with_metadata("cron_job_id", job.id.clone())
.with_metadata("cron_trigger", "scheduled")
.with_metadata("cron_delivery_channel", target_channel);
if let Err(e) = bus.publish_inbound(inbound) {
error!("Failed to publish cron inbound job {}: {}", job.id, e);
return Some(format!(
"failed to publish cron inbound job {}: {}",
job.id, e
));
}
Some("triggered agent turn".to_string())
})
},
)
}
async fn build_agent_loop(
config: &Config,
bus: MessageBus,
dynamic_provider: Arc<DynamicProvider>,
workspace: PathBuf,
runtime_control_rx: mpsc::UnboundedReceiver<RuntimeControlCommand>,
cron_service: Arc<CronService>,
file_manager: Arc<FileManager>,
) -> Result<AgentLoop> {
let agent_provider: Arc<dyn LLMProvider> = dynamic_provider;
let tool_config = ToolConfig {
network: build_network_tool_config(config),
exec_timeout: config.tools.exec.timeout,
restrict_to_workspace: config.tools.restrict_to_workspace,
mcp_servers: config.tools.active_mcp_servers(),
cron_service: Some(cron_service),
soul_context: SoulContextSettings {
enabled: config.agents.soul.enabled,
max_chars: config.agents.soul.max_chars,
bootstrap_once: config.agents.soul.bootstrap_once,
},
notify_on_soul_change: config.agents.soul.notify_on_change,
soul_governance: SoulGovernanceSettings {
frequent_change_window_secs: config.agents.soul.frequent_change_window_secs,
frequent_change_threshold: config.agents.soul.frequent_change_threshold,
boundary_confirmation_hint: config.agents.soul.boundary_confirmation_hint,
},
};
AgentLoop::with_tools(
bus,
agent_provider,
workspace,
Some(config.agents.defaults.model.clone()),
Some(config.agents.defaults.max_tool_iterations as usize),
tool_config,
Some(runtime_control_rx),
file_manager,
)
.await
.map_err(|e| anyhow::anyhow!("Failed to create agent loop: {}", e))
}
fn resolve_provider_credentials(config: &Config) -> Result<(Option<String>, Option<String>)> {
let provider_name = current_provider_name(config)
.ok_or_else(|| anyhow::anyhow!("No provider found for model"))?;
let catalog = ProviderCatalogService::new();
let access = catalog
.get_provider_access(config, &provider_name)
.unwrap_or_else(|| ProviderAccess::from_config(None));
let resolved_api_base = access.api_base.clone().or_else(|| {
catalog
.get_provider_view(config, &provider_name)
.and_then(|view| view.api_base)
});
Ok((access.api_key, resolved_api_base))
}