1mod companion_admin;
2mod provider_admin;
3mod runtime_control;
4
5use agent_diva_agent::runtime_control::RuntimeControlCommand;
6use agent_diva_agent::tool_config::network::{
7 NetworkToolConfig, WebFetchRuntimeConfig, WebRuntimeConfig, WebSearchRuntimeConfig,
8};
9use agent_diva_channels::ChannelManager;
10use agent_diva_core::bus::MessageBus;
11use agent_diva_core::config::{ConfigLoader, CustomProviderConfig};
12use agent_diva_core::cron::CronService;
13use agent_diva_files::FileManager;
14use agent_diva_providers::{DynamicProvider, ProviderCatalogService, ProviderRegistry};
15use std::sync::Arc;
16use tokio::sync::mpsc;
17use tracing::{debug, error, info};
18
19use crate::state::{ManagerCommand, ProviderCommand};
20
21pub struct Manager {
22 api_rx: mpsc::Receiver<ManagerCommand>,
23 bus: MessageBus,
24 provider: Arc<DynamicProvider>,
25 loader: ConfigLoader,
26 current_provider: Option<String>,
28 current_model: String,
29 current_api_base: Option<String>,
30 current_api_key: Option<String>,
31 channel_manager: Option<Arc<ChannelManager>>,
32 runtime_control_tx: Option<mpsc::UnboundedSender<RuntimeControlCommand>>,
33 cron_service: Arc<CronService>,
34 file_manager: Arc<FileManager>,
35}
36
37enum ProviderConfigTarget<'a> {
38 Builtin(&'a mut agent_diva_core::config::schema::ProviderConfig),
39 Shadow(&'a mut CustomProviderConfig),
40}
41
42impl ProviderConfigTarget<'_> {
43 fn set_api_key(&mut self, api_key: String) {
44 match self {
45 Self::Builtin(config) => config.api_key = api_key,
46 Self::Shadow(config) => config.api_key = api_key,
47 }
48 }
49
50 fn set_api_base(&mut self, api_base: Option<String>) {
51 match self {
52 Self::Builtin(config) => config.api_base = api_base,
53 Self::Shadow(config) => config.api_base = api_base,
54 }
55 }
56}
57
58impl Manager {
59 #[allow(clippy::too_many_arguments)]
60 pub fn new(
61 api_rx: mpsc::Receiver<ManagerCommand>,
62 bus: MessageBus,
63 provider: Arc<DynamicProvider>,
64 loader: ConfigLoader,
65 initial_provider: Option<String>,
66 initial_model: String,
67 api_key: Option<String>,
68 api_base: Option<String>,
69 channel_manager: Option<Arc<ChannelManager>>,
70 runtime_control_tx: Option<mpsc::UnboundedSender<RuntimeControlCommand>>,
71 cron_service: Arc<CronService>,
72 file_manager: Arc<FileManager>,
73 ) -> Self {
74 Self {
75 api_rx,
76 bus,
77 provider,
78 loader,
79 current_provider: initial_provider
80 .or_else(|| Self::provider_name_for_model(None, &initial_model)),
81 current_model: initial_model,
82 current_api_base: api_base,
83 current_api_key: api_key,
84 channel_manager,
85 runtime_control_tx,
86 cron_service,
87 file_manager,
88 }
89 }
90
91 fn provider_name_for_model(preferred_provider: Option<&str>, model: &str) -> Option<String> {
92 let registry = ProviderRegistry::new();
93 preferred_provider
94 .map(str::trim)
95 .filter(|value| !value.is_empty())
96 .and_then(|name| registry.find_by_name(name))
97 .map(|spec| spec.name.clone())
98 .or_else(|| {
99 model
100 .split('/')
101 .next()
102 .and_then(|prefix| registry.find_by_name(prefix))
103 .map(|spec| spec.name.clone())
104 })
105 .or_else(|| registry.find_by_model(model).map(|spec| spec.name.clone()))
106 }
107
108 fn map_network_config(config: &agent_diva_core::config::schema::Config) -> NetworkToolConfig {
109 let api_key = config.tools.web.search.api_key.trim().to_string();
110 NetworkToolConfig {
111 web: WebRuntimeConfig {
112 search: WebSearchRuntimeConfig {
113 provider: config.tools.web.search.provider.clone(),
114 enabled: config.tools.web.search.enabled,
115 api_key: if api_key.is_empty() {
116 None
117 } else {
118 Some(api_key)
119 },
120 max_results: config.tools.web.search.max_results,
121 },
122 fetch: WebFetchRuntimeConfig {
123 enabled: config.tools.web.fetch.enabled,
124 },
125 },
126 }
127 }
128
129 fn reload_runtime_mcp(&self) {
130 let Some(tx) = &self.runtime_control_tx else {
131 return;
132 };
133 let Ok(config) = self.loader.load() else {
134 error!("Failed to load config for MCP runtime update");
135 return;
136 };
137 if let Err(e) = tx.send(RuntimeControlCommand::UpdateMcp {
138 servers: config.tools.active_mcp_servers(),
139 }) {
140 error!("Failed to send runtime MCP update: {}", e);
141 }
142 }
143
144 fn model_matches_provider(provider_id: &str, model: &str) -> bool {
145 let trimmed_provider = provider_id.trim();
146 let trimmed_model = model.trim();
147 if trimmed_provider.is_empty() || trimmed_model.is_empty() {
148 return false;
149 }
150
151 if trimmed_model
152 .split('/')
153 .next()
154 .is_some_and(|prefix| prefix == trimmed_provider)
155 {
156 return true;
157 }
158
159 ProviderRegistry::new()
160 .find_by_model(trimmed_model)
161 .is_some_and(|spec| spec.name == trimmed_provider)
162 }
163
164 async fn normalize_model_for_provider(
165 config: &agent_diva_core::config::schema::Config,
166 catalog: &ProviderCatalogService,
167 provider_id: &str,
168 requested_model: &str,
169 provider_explicit: bool,
170 model_explicit: bool,
171 ) -> String {
172 let requested_model = requested_model.trim();
173 let provider_models = catalog
174 .list_provider_models(config, provider_id, false, None)
175 .await
176 .ok();
177
178 if !requested_model.is_empty() {
179 if provider_explicit && model_explicit {
180 return requested_model.to_string();
181 }
182
183 let in_catalog = provider_models.as_ref().is_some_and(|catalog| {
184 catalog
185 .models
186 .iter()
187 .any(|entry| entry.id == requested_model)
188 });
189 if in_catalog || Self::model_matches_provider(provider_id, requested_model) {
190 return requested_model.to_string();
191 }
192 }
193
194 if let Some(default_model) = catalog
195 .get_provider_view(config, provider_id)
196 .and_then(|view| view.default_model)
197 .map(|value| value.trim().to_string())
198 .filter(|value| !value.is_empty())
199 {
200 return default_model;
201 }
202
203 if let Some(first_model) = provider_models
204 .and_then(|catalog| catalog.models.into_iter().next().map(|entry| entry.id))
205 .map(|value| value.trim().to_string())
206 .filter(|value| !value.is_empty())
207 {
208 return first_model;
209 }
210
211 requested_model.to_string()
212 }
213
214 fn ensure_provider_credentials_slot<'a>(
215 config: &'a mut agent_diva_core::config::schema::Config,
216 provider_id: &str,
217 ) -> ProviderConfigTarget<'a> {
218 if agent_diva_core::config::schema::ProvidersConfig::is_builtin_provider(provider_id) {
219 let provider = config
220 .providers
221 .get_mut(provider_id)
222 .expect("builtin provider slot must exist");
223 return ProviderConfigTarget::Builtin(provider);
224 }
225
226 let provider = config
227 .providers
228 .custom_providers
229 .entry(provider_id.to_string())
230 .or_default();
231 ProviderConfigTarget::Shadow(provider)
232 }
233
234 pub async fn run(mut self) -> anyhow::Result<()> {
235 info!("Manager loop started");
236
237 loop {
238 debug!("Waiting for command...");
239 tokio::select! {
240 msg = self.api_rx.recv() => {
241 let cmd = match msg {
242 Some(cmd) => {
243 debug!("Received command");
244 cmd
245 },
246 None => {
247 info!("Manager channel closed, stopping loop");
248 break Ok(());
249 }
250 };
251 match cmd {
252 ManagerCommand::Chat(req) => self.handle_chat(req),
253 ManagerCommand::StopChat(req, reply) => {
254 self.handle_stop_chat(req, reply);
255 }
256 ManagerCommand::ResetSession(req, reply) => {
257 self.handle_reset_session(req, reply);
258 }
259 ManagerCommand::GetSessions(reply) => {
260 self.handle_get_sessions(reply).await;
261 }
262 ManagerCommand::GetSessionHistory(session_key, reply) => {
263 self.handle_get_session_history(session_key, reply).await;
264 }
265 ManagerCommand::DeleteSession(session_key, reply) => {
266 self.handle_delete_session(session_key, reply).await;
267 }
268 ManagerCommand::ListCronJobs(reply) => {
269 self.handle_list_cron_jobs(reply).await;
270 }
271 ManagerCommand::GetCronJob(job_id, reply) => {
272 self.handle_get_cron_job(job_id, reply).await;
273 }
274 ManagerCommand::CreateCronJob(request, reply) => {
275 self.handle_create_cron_job(request, reply).await;
276 }
277 ManagerCommand::UpdateCronJob(job_id, request, reply) => {
278 self.handle_update_cron_job(job_id, request, reply).await;
279 }
280 ManagerCommand::DeleteCronJob(job_id, reply) => {
281 self.handle_delete_cron_job(job_id, reply).await;
282 }
283 ManagerCommand::SetCronJobEnabled(job_id, enabled, reply) => {
284 self.handle_set_cron_job_enabled(job_id, enabled, reply)
285 .await;
286 }
287 ManagerCommand::RunCronJobNow(job_id, force, reply) => {
288 self.handle_run_cron_job_now(job_id, force, reply).await;
289 }
290 ManagerCommand::StopCronJobRun(job_id, reply) => {
291 self.handle_stop_cron_job_run(job_id, reply).await;
292 }
293 ManagerCommand::UpdateConfig(update) => {
294 self.handle_update_config(update).await?;
295 }
296 ManagerCommand::GetConfig(reply) => {
297 self.handle_get_config(reply);
298 }
299 ManagerCommand::GetChannels(reply) => {
300 self.handle_get_channels(reply);
301 }
302 ManagerCommand::GetTools(reply) => {
303 self.handle_get_tools(reply);
304 }
305 ManagerCommand::GetSkills(reply) => self.handle_get_skills(reply),
306 ManagerCommand::GetMcps(reply) => self.handle_get_mcps(reply),
307 ManagerCommand::CreateMcp(payload, reply) => {
308 self.handle_create_mcp(payload, reply);
309 }
310 ManagerCommand::UpdateMcp(name, payload, reply) => {
311 self.handle_update_mcp(name, payload, reply);
312 }
313 ManagerCommand::DeleteMcp(name, reply) => {
314 self.handle_delete_mcp(name, reply);
315 }
316 ManagerCommand::SetMcpEnabled(name, enabled, reply) => {
317 self.handle_set_mcp_enabled(name, enabled, reply);
318 }
319 ManagerCommand::RefreshMcpStatus(name, reply) => {
320 self.handle_refresh_mcp_status(name, reply);
321 }
322 ManagerCommand::UploadSkill(request, reply) => {
323 self.handle_upload_skill(request, reply);
324 }
325 ManagerCommand::DeleteSkill(name, reply) => {
326 self.handle_delete_skill(name, reply);
327 }
328 ManagerCommand::UploadFile(request, reply) => {
329 self.handle_upload_file(request, reply).await;
330 }
331 ManagerCommand::Provider(command) => {
332 self.handle_provider_command(command).await;
333 }
334 ManagerCommand::UpdateTools(update) => {
335 self.handle_update_tools(update);
336 }
337 ManagerCommand::UpdateChannel(update) => {
338 self.handle_update_channel(update).await;
339 }
340 }
341 }
342 }
343 }
344 }
345}
346
347impl Manager {
348 async fn handle_provider_command(&self, command: ProviderCommand) {
349 match command {
350 ProviderCommand::GetProviders(reply) => self.handle_get_providers(reply),
351 ProviderCommand::GetProvider(name, reply) => self.handle_get_provider(name, reply),
352 ProviderCommand::GetProviderModels(name, runtime, reply) => {
353 self.handle_get_provider_models(name, runtime, reply).await;
354 }
355 ProviderCommand::ResolveProvider(model, preferred_provider, reply) => {
356 self.handle_resolve_provider(model, preferred_provider, reply);
357 }
358 ProviderCommand::AddProviderModel(name, model, reply) => {
359 self.handle_add_provider_model(name, model, reply).await;
360 }
361 ProviderCommand::DeleteProviderModel(name, model_id, reply) => {
362 self.handle_delete_provider_model(name, model_id, reply)
363 .await;
364 }
365 ProviderCommand::CreateProvider(payload, reply) => {
366 self.handle_create_provider(payload, reply).await;
367 }
368 ProviderCommand::UpdateProvider(name, payload, reply) => {
369 self.handle_update_provider(name, payload, reply).await;
370 }
371 ProviderCommand::DeleteProvider(name, reply) => {
372 self.handle_delete_provider(name, reply).await;
373 }
374 }
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn model_matches_provider_accepts_registry_resolved_models() {
384 assert!(Manager::model_matches_provider("openai", "gpt-4o"));
385 assert!(Manager::model_matches_provider("deepseek", "deepseek-chat"));
386 assert!(Manager::model_matches_provider(
387 "openrouter",
388 "openrouter/anthropic/claude-sonnet-4"
389 ));
390 assert!(!Manager::model_matches_provider("openai", "deepseek-chat"));
391 }
392
393 #[tokio::test]
394 async fn normalize_model_for_provider_replaces_cross_provider_model() {
395 let catalog = ProviderCatalogService::new();
396 let config = agent_diva_core::config::schema::Config::default();
397
398 let model = Manager::normalize_model_for_provider(
399 &config,
400 &catalog,
401 "openai",
402 "deepseek-chat",
403 true,
404 false,
405 )
406 .await;
407
408 assert_eq!(model, "openai/gpt-4o");
409 }
410
411 #[tokio::test]
412 async fn normalize_model_for_provider_keeps_explicit_model_for_explicit_provider() {
413 let catalog = ProviderCatalogService::new();
414 let config = agent_diva_core::config::schema::Config::default();
415
416 let model = Manager::normalize_model_for_provider(
417 &config,
418 &catalog,
419 "silicon",
420 "ByteDance-Seed/Seed-OSS-36B-Instruct",
421 true,
422 true,
423 )
424 .await;
425
426 assert_eq!(model, "ByteDance-Seed/Seed-OSS-36B-Instruct");
427 }
428}