Skip to main content

agent_diva_manager/
manager.rs

1mod companion_admin;
2mod provider_admin;
3mod runtime_control;
4
5use agent_diva_agent::runtime_control::RuntimeControlCommand;
6use agent_diva_agent::tool_config::network::{
7    NetworkToolConfig, WebFetchRuntimeConfig, WebRuntimeConfig, WebSearchRuntimeConfig,
8};
9use agent_diva_channels::ChannelManager;
10use agent_diva_core::bus::MessageBus;
11use agent_diva_core::config::{ConfigLoader, CustomProviderConfig};
12use agent_diva_core::cron::CronService;
13use agent_diva_files::FileManager;
14use agent_diva_providers::{DynamicProvider, ProviderCatalogService, ProviderRegistry};
15use std::sync::Arc;
16use tokio::sync::mpsc;
17use tracing::{debug, error, info};
18
19use crate::state::{ManagerCommand, ProviderCommand};
20
21pub struct Manager {
22    api_rx: mpsc::Receiver<ManagerCommand>,
23    bus: MessageBus,
24    provider: Arc<DynamicProvider>,
25    loader: ConfigLoader,
26    // Current config state
27    current_provider: Option<String>,
28    current_model: String,
29    current_api_base: Option<String>,
30    current_api_key: Option<String>,
31    channel_manager: Option<Arc<ChannelManager>>,
32    runtime_control_tx: Option<mpsc::UnboundedSender<RuntimeControlCommand>>,
33    cron_service: Arc<CronService>,
34    file_manager: Arc<FileManager>,
35}
36
37enum ProviderConfigTarget<'a> {
38    Builtin(&'a mut agent_diva_core::config::schema::ProviderConfig),
39    Shadow(&'a mut CustomProviderConfig),
40}
41
42impl ProviderConfigTarget<'_> {
43    fn set_api_key(&mut self, api_key: String) {
44        match self {
45            Self::Builtin(config) => config.api_key = api_key,
46            Self::Shadow(config) => config.api_key = api_key,
47        }
48    }
49
50    fn set_api_base(&mut self, api_base: Option<String>) {
51        match self {
52            Self::Builtin(config) => config.api_base = api_base,
53            Self::Shadow(config) => config.api_base = api_base,
54        }
55    }
56}
57
58impl Manager {
59    #[allow(clippy::too_many_arguments)]
60    pub fn new(
61        api_rx: mpsc::Receiver<ManagerCommand>,
62        bus: MessageBus,
63        provider: Arc<DynamicProvider>,
64        loader: ConfigLoader,
65        initial_provider: Option<String>,
66        initial_model: String,
67        api_key: Option<String>,
68        api_base: Option<String>,
69        channel_manager: Option<Arc<ChannelManager>>,
70        runtime_control_tx: Option<mpsc::UnboundedSender<RuntimeControlCommand>>,
71        cron_service: Arc<CronService>,
72        file_manager: Arc<FileManager>,
73    ) -> Self {
74        Self {
75            api_rx,
76            bus,
77            provider,
78            loader,
79            current_provider: initial_provider
80                .or_else(|| Self::provider_name_for_model(None, &initial_model)),
81            current_model: initial_model,
82            current_api_base: api_base,
83            current_api_key: api_key,
84            channel_manager,
85            runtime_control_tx,
86            cron_service,
87            file_manager,
88        }
89    }
90
91    fn provider_name_for_model(preferred_provider: Option<&str>, model: &str) -> Option<String> {
92        let registry = ProviderRegistry::new();
93        preferred_provider
94            .map(str::trim)
95            .filter(|value| !value.is_empty())
96            .and_then(|name| registry.find_by_name(name))
97            .map(|spec| spec.name.clone())
98            .or_else(|| {
99                model
100                    .split('/')
101                    .next()
102                    .and_then(|prefix| registry.find_by_name(prefix))
103                    .map(|spec| spec.name.clone())
104            })
105            .or_else(|| registry.find_by_model(model).map(|spec| spec.name.clone()))
106    }
107
108    fn map_network_config(config: &agent_diva_core::config::schema::Config) -> NetworkToolConfig {
109        let api_key = config.tools.web.search.api_key.trim().to_string();
110        NetworkToolConfig {
111            web: WebRuntimeConfig {
112                search: WebSearchRuntimeConfig {
113                    provider: config.tools.web.search.provider.clone(),
114                    enabled: config.tools.web.search.enabled,
115                    api_key: if api_key.is_empty() {
116                        None
117                    } else {
118                        Some(api_key)
119                    },
120                    max_results: config.tools.web.search.max_results,
121                },
122                fetch: WebFetchRuntimeConfig {
123                    enabled: config.tools.web.fetch.enabled,
124                },
125            },
126        }
127    }
128
129    fn reload_runtime_mcp(&self) {
130        let Some(tx) = &self.runtime_control_tx else {
131            return;
132        };
133        let Ok(config) = self.loader.load() else {
134            error!("Failed to load config for MCP runtime update");
135            return;
136        };
137        if let Err(e) = tx.send(RuntimeControlCommand::UpdateMcp {
138            servers: config.tools.active_mcp_servers(),
139        }) {
140            error!("Failed to send runtime MCP update: {}", e);
141        }
142    }
143
144    fn model_matches_provider(provider_id: &str, model: &str) -> bool {
145        let trimmed_provider = provider_id.trim();
146        let trimmed_model = model.trim();
147        if trimmed_provider.is_empty() || trimmed_model.is_empty() {
148            return false;
149        }
150
151        if trimmed_model
152            .split('/')
153            .next()
154            .is_some_and(|prefix| prefix == trimmed_provider)
155        {
156            return true;
157        }
158
159        ProviderRegistry::new()
160            .find_by_model(trimmed_model)
161            .is_some_and(|spec| spec.name == trimmed_provider)
162    }
163
164    async fn normalize_model_for_provider(
165        config: &agent_diva_core::config::schema::Config,
166        catalog: &ProviderCatalogService,
167        provider_id: &str,
168        requested_model: &str,
169        provider_explicit: bool,
170        model_explicit: bool,
171    ) -> String {
172        let requested_model = requested_model.trim();
173        let provider_models = catalog
174            .list_provider_models(config, provider_id, false, None)
175            .await
176            .ok();
177
178        if !requested_model.is_empty() {
179            if provider_explicit && model_explicit {
180                return requested_model.to_string();
181            }
182
183            let in_catalog = provider_models.as_ref().is_some_and(|catalog| {
184                catalog
185                    .models
186                    .iter()
187                    .any(|entry| entry.id == requested_model)
188            });
189            if in_catalog || Self::model_matches_provider(provider_id, requested_model) {
190                return requested_model.to_string();
191            }
192        }
193
194        if let Some(default_model) = catalog
195            .get_provider_view(config, provider_id)
196            .and_then(|view| view.default_model)
197            .map(|value| value.trim().to_string())
198            .filter(|value| !value.is_empty())
199        {
200            return default_model;
201        }
202
203        if let Some(first_model) = provider_models
204            .and_then(|catalog| catalog.models.into_iter().next().map(|entry| entry.id))
205            .map(|value| value.trim().to_string())
206            .filter(|value| !value.is_empty())
207        {
208            return first_model;
209        }
210
211        requested_model.to_string()
212    }
213
214    fn ensure_provider_credentials_slot<'a>(
215        config: &'a mut agent_diva_core::config::schema::Config,
216        provider_id: &str,
217    ) -> ProviderConfigTarget<'a> {
218        if agent_diva_core::config::schema::ProvidersConfig::is_builtin_provider(provider_id) {
219            let provider = config
220                .providers
221                .get_mut(provider_id)
222                .expect("builtin provider slot must exist");
223            return ProviderConfigTarget::Builtin(provider);
224        }
225
226        let provider = config
227            .providers
228            .custom_providers
229            .entry(provider_id.to_string())
230            .or_default();
231        ProviderConfigTarget::Shadow(provider)
232    }
233
234    pub async fn run(mut self) -> anyhow::Result<()> {
235        info!("Manager loop started");
236
237        loop {
238            debug!("Waiting for command...");
239            tokio::select! {
240                msg = self.api_rx.recv() => {
241                    let cmd = match msg {
242                        Some(cmd) => {
243                            debug!("Received command");
244                            cmd
245                        },
246                        None => {
247                            info!("Manager channel closed, stopping loop");
248                            break Ok(());
249                        }
250                    };
251                    match cmd {
252                        ManagerCommand::Chat(req) => self.handle_chat(req),
253                        ManagerCommand::StopChat(req, reply) => {
254                            self.handle_stop_chat(req, reply);
255                        }
256                        ManagerCommand::ResetSession(req, reply) => {
257                            self.handle_reset_session(req, reply);
258                        }
259                        ManagerCommand::GetSessions(reply) => {
260                            self.handle_get_sessions(reply).await;
261                        }
262                        ManagerCommand::GetSessionHistory(session_key, reply) => {
263                            self.handle_get_session_history(session_key, reply).await;
264                        }
265                        ManagerCommand::DeleteSession(session_key, reply) => {
266                            self.handle_delete_session(session_key, reply).await;
267                        }
268                        ManagerCommand::ListCronJobs(reply) => {
269                            self.handle_list_cron_jobs(reply).await;
270                        }
271                        ManagerCommand::GetCronJob(job_id, reply) => {
272                            self.handle_get_cron_job(job_id, reply).await;
273                        }
274                        ManagerCommand::CreateCronJob(request, reply) => {
275                            self.handle_create_cron_job(request, reply).await;
276                        }
277                        ManagerCommand::UpdateCronJob(job_id, request, reply) => {
278                            self.handle_update_cron_job(job_id, request, reply).await;
279                        }
280                        ManagerCommand::DeleteCronJob(job_id, reply) => {
281                            self.handle_delete_cron_job(job_id, reply).await;
282                        }
283                        ManagerCommand::SetCronJobEnabled(job_id, enabled, reply) => {
284                            self.handle_set_cron_job_enabled(job_id, enabled, reply)
285                                .await;
286                        }
287                        ManagerCommand::RunCronJobNow(job_id, force, reply) => {
288                            self.handle_run_cron_job_now(job_id, force, reply).await;
289                        }
290                        ManagerCommand::StopCronJobRun(job_id, reply) => {
291                            self.handle_stop_cron_job_run(job_id, reply).await;
292                        }
293                        ManagerCommand::UpdateConfig(update) => {
294                            self.handle_update_config(update).await?;
295                        }
296                        ManagerCommand::GetConfig(reply) => {
297                            self.handle_get_config(reply);
298                        }
299                        ManagerCommand::GetChannels(reply) => {
300                            self.handle_get_channels(reply);
301                        }
302                        ManagerCommand::GetTools(reply) => {
303                            self.handle_get_tools(reply);
304                        }
305                        ManagerCommand::GetSkills(reply) => self.handle_get_skills(reply),
306                        ManagerCommand::GetMcps(reply) => self.handle_get_mcps(reply),
307                        ManagerCommand::CreateMcp(payload, reply) => {
308                            self.handle_create_mcp(payload, reply);
309                        }
310                        ManagerCommand::UpdateMcp(name, payload, reply) => {
311                            self.handle_update_mcp(name, payload, reply);
312                        }
313                        ManagerCommand::DeleteMcp(name, reply) => {
314                            self.handle_delete_mcp(name, reply);
315                        }
316                        ManagerCommand::SetMcpEnabled(name, enabled, reply) => {
317                            self.handle_set_mcp_enabled(name, enabled, reply);
318                        }
319                        ManagerCommand::RefreshMcpStatus(name, reply) => {
320                            self.handle_refresh_mcp_status(name, reply);
321                        }
322                        ManagerCommand::UploadSkill(request, reply) => {
323                            self.handle_upload_skill(request, reply);
324                        }
325                        ManagerCommand::DeleteSkill(name, reply) => {
326                            self.handle_delete_skill(name, reply);
327                        }
328                        ManagerCommand::UploadFile(request, reply) => {
329                            self.handle_upload_file(request, reply).await;
330                        }
331                        ManagerCommand::Provider(command) => {
332                            self.handle_provider_command(command).await;
333                        }
334                        ManagerCommand::UpdateTools(update) => {
335                            self.handle_update_tools(update);
336                        }
337                        ManagerCommand::UpdateChannel(update) => {
338                            self.handle_update_channel(update).await;
339                        }
340                    }
341                }
342            }
343        }
344    }
345}
346
347impl Manager {
348    async fn handle_provider_command(&self, command: ProviderCommand) {
349        match command {
350            ProviderCommand::GetProviders(reply) => self.handle_get_providers(reply),
351            ProviderCommand::GetProvider(name, reply) => self.handle_get_provider(name, reply),
352            ProviderCommand::GetProviderModels(name, runtime, reply) => {
353                self.handle_get_provider_models(name, runtime, reply).await;
354            }
355            ProviderCommand::ResolveProvider(model, preferred_provider, reply) => {
356                self.handle_resolve_provider(model, preferred_provider, reply);
357            }
358            ProviderCommand::AddProviderModel(name, model, reply) => {
359                self.handle_add_provider_model(name, model, reply).await;
360            }
361            ProviderCommand::DeleteProviderModel(name, model_id, reply) => {
362                self.handle_delete_provider_model(name, model_id, reply)
363                    .await;
364            }
365            ProviderCommand::CreateProvider(payload, reply) => {
366                self.handle_create_provider(payload, reply).await;
367            }
368            ProviderCommand::UpdateProvider(name, payload, reply) => {
369                self.handle_update_provider(name, payload, reply).await;
370            }
371            ProviderCommand::DeleteProvider(name, reply) => {
372                self.handle_delete_provider(name, reply).await;
373            }
374        }
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    #[test]
383    fn model_matches_provider_accepts_registry_resolved_models() {
384        assert!(Manager::model_matches_provider("openai", "gpt-4o"));
385        assert!(Manager::model_matches_provider("deepseek", "deepseek-chat"));
386        assert!(Manager::model_matches_provider(
387            "openrouter",
388            "openrouter/anthropic/claude-sonnet-4"
389        ));
390        assert!(!Manager::model_matches_provider("openai", "deepseek-chat"));
391    }
392
393    #[tokio::test]
394    async fn normalize_model_for_provider_replaces_cross_provider_model() {
395        let catalog = ProviderCatalogService::new();
396        let config = agent_diva_core::config::schema::Config::default();
397
398        let model = Manager::normalize_model_for_provider(
399            &config,
400            &catalog,
401            "openai",
402            "deepseek-chat",
403            true,
404            false,
405        )
406        .await;
407
408        assert_eq!(model, "openai/gpt-4o");
409    }
410
411    #[tokio::test]
412    async fn normalize_model_for_provider_keeps_explicit_model_for_explicit_provider() {
413        let catalog = ProviderCatalogService::new();
414        let config = agent_diva_core::config::schema::Config::default();
415
416        let model = Manager::normalize_model_for_provider(
417            &config,
418            &catalog,
419            "silicon",
420            "ByteDance-Seed/Seed-OSS-36B-Instruct",
421            true,
422            true,
423        )
424        .await;
425
426        assert_eq!(model, "ByteDance-Seed/Seed-OSS-36B-Instruct");
427    }
428}