1use crate::config::{get_cache_dir, Config, OFFICIAL_PUBLIC_KEY};
2use crate::db::resolve_db_path;
3use anyhow::{Context, Result};
4use cmdhub_shared::{CmdHubError, UpdateManifest};
5use ed25519_dalek::{Signature, Verifier, VerifyingKey};
6use fs2::FileExt;
7use reqwest::Client;
8use sha2::{Digest, Sha256};
9use std::fs;
10
11pub async fn update_database(config: &Config, force: bool) -> Result<()> {
12 let client = Client::builder()
13 .timeout(std::time::Duration::from_secs(config.timeout_seconds))
14 .build()?;
15
16 let update_url = format!("{}/db/update", config.api_url);
17
18 eprintln!("Checking for updates at {}...", update_url);
19
20 let manifest_resp = client.get(&update_url).send().await;
22 let manifest: UpdateManifest = match manifest_resp {
23 Ok(resp) => {
24 if resp.status().is_success() {
25 resp.json()
26 .await
27 .context("Failed to parse UpdateManifest JSON")?
28 } else {
29 return Err(anyhow::anyhow!(CmdHubError::UpdateFailed(format!(
30 "Cloud returned status code: {}",
31 resp.status()
32 ))));
33 }
34 }
35 Err(e) => {
36 return Err(anyhow::anyhow!(CmdHubError::UpdateFailed(format!(
37 "Failed to fetch database update manifest: {}",
38 e
39 ))));
40 }
41 };
42
43 let cache_dir = get_cache_dir();
44 let downloads_dir = cache_dir.join("downloads");
45 fs::create_dir_all(&downloads_dir).context("Failed to create downloads cache directory")?;
46
47 let db_zst_path = downloads_dir.join("latest.db.zst");
48 let sig_path = downloads_dir.join("latest.db.sig");
49
50 eprintln!(
51 "Downloading database update (version: {})...",
52 manifest.version
53 );
54
55 let db_resp = client
57 .get(&manifest.db_url)
58 .send()
59 .await
60 .context("Failed to download database file")?;
61 let db_bytes = db_resp
62 .bytes()
63 .await
64 .context("Failed to read database bytes")?;
65 fs::write(&db_zst_path, &db_bytes).context("Failed to write downloaded database payload")?;
66
67 let sig_resp = client
69 .get(&manifest.sig_url)
70 .send()
71 .await
72 .context("Failed to download database signature file")?;
73 let sig_bytes = sig_resp
74 .bytes()
75 .await
76 .context("Failed to read database signature bytes")?;
77 fs::write(&sig_path, &sig_bytes).context("Failed to write downloaded signature payload")?;
78
79 eprintln!("Verifying database integrity and signature...");
81 let mut hasher = Sha256::new();
82 hasher.update(&db_bytes);
83 let hash_result: [u8; 32] = hasher.finalize().into();
84
85 let computed_hex = hash_result
87 .iter()
88 .map(|b| format!("{:02x}", b))
89 .collect::<String>();
90 if !force && computed_hex != manifest.sha256 {
91 return Err(anyhow::anyhow!(CmdHubError::Validation(format!(
92 "SHA-256 mismatch: computed {}, manifest {}",
93 computed_hex, manifest.sha256
94 ))));
95 }
96
97 let pub_key_bytes = match hex_decode(&config.public_key) {
99 Ok(bytes) => {
100 let mut arr = [0u8; 32];
101 if bytes.len() == 32 {
102 arr.copy_from_slice(&bytes);
103 arr
104 } else {
105 OFFICIAL_PUBLIC_KEY
106 }
107 }
108 Err(_) => OFFICIAL_PUBLIC_KEY,
109 };
110
111 let verifying_key = VerifyingKey::from_bytes(&pub_key_bytes).map_err(|e| {
112 anyhow::anyhow!(CmdHubError::SignatureVerification(format!(
113 "Invalid public key: {}",
114 e
115 )))
116 })?;
117
118 let signature = Signature::from_slice(&sig_bytes).map_err(|e| {
119 anyhow::anyhow!(CmdHubError::SignatureVerification(format!(
120 "Invalid signature format: {}",
121 e
122 )))
123 })?;
124
125 verifying_key
126 .verify(&hash_result, &signature)
127 .map_err(|e| {
128 anyhow::anyhow!(CmdHubError::SignatureVerification(format!(
129 "Ed25519 signature verification failed: {}",
130 e
131 )))
132 })?;
133
134 eprintln!("Decompressing database...");
136 let decompressed =
137 zstd::decode_all(&db_bytes[..]).context("Failed to decompress zstd payload")?;
138
139 let tmp_dir = cache_dir.join("tmp");
140 fs::create_dir_all(&tmp_dir).context("Failed to create temporary staging directory")?;
141 let staging_path = tmp_dir.join("latest.db");
142 fs::write(&staging_path, &decompressed)
143 .context("Failed to write decompressed staging database")?;
144
145 eprintln!("Applying atomic database replacement...");
147 let lock_path = cache_dir.join("update.lock");
148 let lock_file = fs::OpenOptions::new()
149 .read(true)
150 .write(true)
151 .create(true)
152 .truncate(true)
153 .open(&lock_path)
154 .context("Failed to open update.lock file")?;
155
156 lock_file
157 .lock_exclusive()
158 .context("Failed to acquire exclusive lock on update.lock")?;
159
160 let live_db_path = resolve_db_path();
161 if let Some(parent) = live_db_path.parent() {
162 fs::create_dir_all(parent).context("Failed to create live database directory")?;
163 }
164
165 eprintln!("Safely applying database changes...");
167 let src_conn =
168 rusqlite::Connection::open(&staging_path).context("Failed to open staging database")?;
169 let mut dst_conn =
170 rusqlite::Connection::open(&live_db_path).context("Failed to open live database")?;
171
172 let _ = dst_conn.execute("PRAGMA journal_mode = WAL;", []);
174 let _ = dst_conn.execute("PRAGMA synchronous = NORMAL;", []);
175
176 let backup = rusqlite::backup::Backup::new(&src_conn, &mut dst_conn)
177 .context("Failed to initialize SQLite backup")?;
178
179 backup
180 .run_to_completion(100, std::time::Duration::from_millis(10), None)
181 .context("SQLite backup to live database failed")?;
182
183 let _ = fs::remove_file(&staging_path);
185
186 eprintln!(
187 "Database successfully updated to version {}!",
188 manifest.version
189 );
190 Ok(())
191}
192
193fn hex_decode(s: &str) -> Result<Vec<u8>> {
194 let mut bytes = Vec::new();
195 let mut chars = s.chars().peekable();
196 while let Some(c1) = chars.next() {
197 if let Some(c2) = chars.next() {
198 let hex = format!("{}{}", c1, c2);
199 let b = u8::from_str_radix(&hex, 16)?;
200 bytes.push(b);
201 }
202 }
203 Ok(bytes)
204}