Skip to main content

roboticus_cli/cli/update/
update_providers.rs

1//! Provider/model configuration update logic.
2
3use std::path::PathBuf;
4
5use super::{
6    ContentRecord, OverwriteChoice, UpdateState, bytes_sha256, colors, confirm_action,
7    confirm_overwrite, fetch_file, fetch_manifest, file_sha256, icons, now_iso, print_diff,
8    registry_base_url, resolve_registry_url,
9};
10use crate::cli::{CRT_DRAW_MS, heading, theme};
11
12pub(super) fn providers_local_path(config_path: &str) -> PathBuf {
13    if let Ok(content) = std::fs::read_to_string(config_path)
14        && let Ok(config) = content.parse::<toml::Value>()
15        && let Some(path) = config.get("providers_file").and_then(|v| v.as_str())
16    {
17        return PathBuf::from(path);
18    }
19    super::roboticus_home().join("providers.toml")
20}
21
22pub(super) async fn apply_providers_update(
23    yes: bool,
24    registry_url: &str,
25    config_path: &str,
26) -> Result<bool, Box<dyn std::error::Error>> {
27    let (DIM, BOLD, _, GREEN, YELLOW, _, _, RESET, MONO) = colors();
28    let (OK, _, WARN, DETAIL, _) = icons();
29    let client = super::http_client()?;
30
31    println!("\n  {BOLD}Provider Configs{RESET}\n");
32
33    let manifest = match fetch_manifest(&client, registry_url).await {
34        Ok(m) => m,
35        Err(e) => {
36            println!("    {WARN} Could not fetch registry manifest: {e}");
37            return Ok(false);
38        }
39    };
40
41    let base_url = registry_base_url(registry_url);
42    let remote_content = match fetch_file(&client, &base_url, &manifest.packs.providers.path).await
43    {
44        Ok(c) => c,
45        Err(e) => {
46            println!("    {WARN} Could not fetch providers.toml: {e}");
47            return Ok(false);
48        }
49    };
50
51    let remote_hash = bytes_sha256(remote_content.as_bytes());
52    let state = UpdateState::load();
53
54    let local_path = providers_local_path(config_path);
55    let local_exists = local_path.exists();
56    let local_content = if local_exists {
57        std::fs::read_to_string(&local_path).unwrap_or_default()
58    } else {
59        String::new()
60    };
61
62    if local_exists {
63        let local_hash = bytes_sha256(local_content.as_bytes());
64        if local_hash == remote_hash {
65            println!("    {OK} Provider configs are up to date");
66            return Ok(false);
67        }
68    }
69
70    let user_modified = if let Some(ref record) = state.installed_content.providers {
71        if local_exists {
72            let current_hash = file_sha256(&local_path).unwrap_or_default();
73            current_hash != record.sha256
74        } else {
75            false
76        }
77    } else {
78        local_exists
79    };
80
81    if !local_exists {
82        println!("    {GREEN}+ New provider configuration available{RESET}");
83        print_diff("", &remote_content);
84    } else if user_modified {
85        println!("    {YELLOW}Provider config has been modified locally{RESET}");
86        println!("    Changes from registry:");
87        print_diff(&local_content, &remote_content);
88    } else {
89        println!("    Updated provider configuration available");
90        print_diff(&local_content, &remote_content);
91    }
92
93    println!();
94
95    if user_modified {
96        match confirm_overwrite("providers config") {
97            OverwriteChoice::Overwrite => {}
98            OverwriteChoice::Backup => {
99                let backup = local_path.with_extension("toml.bak");
100                std::fs::copy(&local_path, &backup)?;
101                println!("    {DETAIL} Backed up to {}", backup.display());
102            }
103            OverwriteChoice::Skip => {
104                println!("    Skipped.");
105                return Ok(false);
106            }
107        }
108    } else if !yes && !confirm_action("Apply provider updates?", true) {
109        println!("    Skipped.");
110        return Ok(false);
111    }
112
113    if let Some(parent) = local_path.parent() {
114        std::fs::create_dir_all(parent)?;
115    }
116    std::fs::write(&local_path, &remote_content)?;
117
118    let mut state = UpdateState::load();
119    state.installed_content.providers = Some(ContentRecord {
120        version: manifest.version.clone(),
121        sha256: remote_hash,
122        installed_at: now_iso(),
123    });
124    state.last_check = now_iso();
125    state
126        .save()
127        .inspect_err(
128            |e| tracing::warn!(error = %e, "failed to save update state after provider install"),
129        )
130        .ok();
131
132    println!("    {OK} Provider configs updated to v{}", manifest.version);
133    Ok(true)
134}
135
136// ── CLI entry point ──────────────────────────────────────────
137
138pub async fn cmd_update_providers(
139    yes: bool,
140    registry_url_override: Option<&str>,
141    config_path: &str,
142    hygiene_fn: Option<&super::HygieneFn>,
143) -> Result<(), Box<dyn std::error::Error>> {
144    heading("Provider Config Update");
145    let registry_url = resolve_registry_url(registry_url_override, config_path);
146    apply_providers_update(yes, &registry_url, config_path).await?;
147    super::run_oauth_storage_maintenance();
148    super::run_mechanic_checks_maintenance(config_path, hygiene_fn);
149    println!();
150    Ok(())
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::test_support::EnvGuard;
157
158    #[test]
159    fn local_path_helpers_fallback_when_config_missing() {
160        let p = providers_local_path("/no/such/file.toml");
161        assert!(p.ends_with("providers.toml"));
162    }
163
164    #[serial_test::serial]
165    #[tokio::test]
166    async fn apply_providers_update_fetches_and_writes_local_file() {
167        let temp = tempfile::tempdir().unwrap();
168        let _home_guard = EnvGuard::set("HOME", temp.path().to_str().unwrap());
169        let config_path = temp.path().join("roboticus.toml");
170        let providers_path = temp.path().join("providers.toml");
171        std::fs::write(
172            &config_path,
173            format!(
174                "providers_file = \"{}\"\n",
175                providers_path.display().to_string().replace('\\', "/")
176            ),
177        )
178        .unwrap();
179
180        let providers = "[providers.openai]\nurl = \"https://api.openai.com\"\n".to_string();
181        let (registry_url, handle) = crate::cli::update::tests_support::start_mock_registry(
182            providers.clone(),
183            "# hello\nbody\n".to_string(),
184        )
185        .await;
186
187        let changed = apply_providers_update(true, &registry_url, config_path.to_str().unwrap())
188            .await
189            .unwrap();
190        assert!(changed);
191        assert_eq!(std::fs::read_to_string(&providers_path).unwrap(), providers);
192
193        let changed_second =
194            apply_providers_update(true, &registry_url, config_path.to_str().unwrap())
195                .await
196                .unwrap();
197        assert!(!changed_second);
198        handle.abort();
199    }
200}