Skip to main content

voirs_cli/packaging/
update.rs

1use crate::error::VoirsCLIError;
2use anyhow::Result;
3use chrono::{DateTime, Utc};
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::fs;
8use std::path::PathBuf;
9use std::process::Command;
10use tokio::fs::File;
11use tokio::io::AsyncWriteExt;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct UpdateConfig {
16    pub check_interval_hours: u64,
17    pub auto_update: bool,
18    pub backup_count: u32,
19    pub update_channel: UpdateChannel,
20    pub update_server: String,
21    pub verify_signatures: bool,
22    pub signature_algorithm: String,
23    pub public_key_path: Option<PathBuf>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub enum UpdateChannel {
28    Stable,
29    Beta,
30    Nightly,
31}
32
33impl Default for UpdateConfig {
34    fn default() -> Self {
35        Self {
36            check_interval_hours: 24,
37            auto_update: false,
38            backup_count: 3,
39            update_channel: UpdateChannel::Stable,
40            update_server: "https://api.github.com/repos/voirs-org/voirs".to_string(),
41            verify_signatures: true,
42            signature_algorithm: "ed25519".to_string(),
43            public_key_path: None,
44        }
45    }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct VersionInfo {
50    pub version: String,
51    pub release_date: DateTime<Utc>,
52    pub download_url: String,
53    pub checksum: String,
54    pub signature: Option<String>,
55    pub changelog: String,
56    pub is_security_update: bool,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct UpdateState {
61    pub last_check: DateTime<Utc>,
62    pub current_version: String,
63    pub available_version: Option<String>,
64    pub update_available: bool,
65    pub last_update: Option<DateTime<Utc>>,
66    pub backup_paths: Vec<PathBuf>,
67}
68
69impl Default for UpdateState {
70    fn default() -> Self {
71        Self {
72            last_check: Utc::now(),
73            current_version: env!("CARGO_PKG_VERSION").to_string(),
74            available_version: None,
75            update_available: false,
76            last_update: None,
77            backup_paths: Vec::new(),
78        }
79    }
80}
81
82pub struct UpdateManager {
83    config: UpdateConfig,
84    state: UpdateState,
85    client: Client,
86    state_file: PathBuf,
87}
88
89impl UpdateManager {
90    pub fn new(config: UpdateConfig, state_file: PathBuf) -> Result<Self> {
91        let state = if state_file.exists() {
92            let content = fs::read_to_string(&state_file)?;
93            serde_json::from_str(&content).unwrap_or_default()
94        } else {
95            UpdateState::default()
96        };
97
98        let client = Client::builder()
99            .user_agent(format!("voirs-cli/{}", env!("CARGO_PKG_VERSION")))
100            .build()?;
101
102        Ok(Self {
103            config,
104            state,
105            client,
106            state_file,
107        })
108    }
109
110    pub async fn check_for_updates(&mut self) -> Result<Option<VersionInfo>> {
111        info!("Checking for updates");
112
113        let should_check = self.should_check_for_updates();
114        if !should_check {
115            debug!("Update check skipped - too soon since last check");
116            return Ok(None);
117        }
118
119        let latest_version = self.fetch_latest_version().await?;
120
121        self.state.last_check = Utc::now();
122        self.state.available_version = Some(latest_version.version.clone());
123        self.state.update_available = self.is_newer_version(&latest_version.version)?;
124
125        self.save_state()?;
126
127        if self.state.update_available {
128            info!(
129                "Update available: {} -> {}",
130                self.state.current_version, latest_version.version
131            );
132            Ok(Some(latest_version))
133        } else {
134            info!("No updates available");
135            Ok(None)
136        }
137    }
138
139    pub async fn perform_update(&mut self, version_info: &VersionInfo) -> Result<bool> {
140        info!(
141            "Starting update process to version {}",
142            version_info.version
143        );
144
145        // Create backup of current binary
146        let backup_path = self.create_backup().await?;
147
148        // Download new binary
149        let temp_binary = self.download_binary(version_info).await?;
150
151        // Verify integrity
152        if !self
153            .verify_binary_integrity(&temp_binary, &version_info.checksum)
154            .await?
155        {
156            error!("Binary integrity verification failed");
157            return Ok(false);
158        }
159
160        // Verify signature if enabled
161        if self.config.verify_signatures {
162            if let Some(signature) = &version_info.signature {
163                if !self.verify_signature(&temp_binary, signature).await? {
164                    error!("Binary signature verification failed");
165                    return Ok(false);
166                }
167            }
168        }
169
170        // Replace current binary
171        let current_binary = self.get_current_binary_path()?;
172        fs::rename(&temp_binary, &current_binary)?;
173
174        // Update permissions
175        #[cfg(unix)]
176        {
177            use std::os::unix::fs::PermissionsExt;
178            let mut perms = fs::metadata(&current_binary)?.permissions();
179            perms.set_mode(0o755);
180            fs::set_permissions(&current_binary, perms)?;
181        }
182
183        // Update state
184        self.state.current_version = version_info.version.clone();
185        self.state.last_update = Some(Utc::now());
186        self.state.update_available = false;
187        self.state.backup_paths.push(backup_path);
188
189        // Clean up old backups
190        self.cleanup_old_backups().await?;
191
192        self.save_state()?;
193
194        info!("Update completed successfully");
195        Ok(true)
196    }
197
198    pub async fn rollback_update(&mut self) -> Result<bool> {
199        info!("Rolling back update");
200
201        if let Some(backup_path) = self.state.backup_paths.last() {
202            if backup_path.exists() {
203                let current_binary = self.get_current_binary_path()?;
204                fs::rename(backup_path, &current_binary)?;
205
206                // Update permissions
207                #[cfg(unix)]
208                {
209                    use std::os::unix::fs::PermissionsExt;
210                    let mut perms = fs::metadata(&current_binary)?.permissions();
211                    perms.set_mode(0o755);
212                    fs::set_permissions(&current_binary, perms)?;
213                }
214
215                self.state.backup_paths.pop();
216                self.save_state()?;
217
218                info!("Rollback completed successfully");
219                Ok(true)
220            } else {
221                warn!("Backup file not found for rollback");
222                Ok(false)
223            }
224        } else {
225            warn!("No backup available for rollback");
226            Ok(false)
227        }
228    }
229
230    fn should_check_for_updates(&self) -> bool {
231        let hours_since_last_check = Utc::now()
232            .signed_duration_since(self.state.last_check)
233            .num_hours() as u64;
234
235        hours_since_last_check >= self.config.check_interval_hours
236    }
237
238    async fn fetch_latest_version(&self) -> Result<VersionInfo> {
239        let url = format!("{}/releases/latest", self.config.update_server);
240        let response = self.client.get(&url).send().await?;
241
242        if !response.status().is_success() {
243            return Err(VoirsCLIError::UpdateError(format!(
244                "Failed to fetch latest version: HTTP {}",
245                response.status()
246            ))
247            .into());
248        }
249
250        let release_info: serde_json::Value = response.json().await?;
251
252        let version = release_info["tag_name"]
253            .as_str()
254            .unwrap_or("")
255            .trim_start_matches('v')
256            .to_string();
257
258        let release_date =
259            DateTime::parse_from_rfc3339(release_info["published_at"].as_str().unwrap_or(""))?
260                .with_timezone(&Utc);
261
262        let download_url = self.get_download_url_for_platform(&release_info)?;
263
264        Ok(VersionInfo {
265            version,
266            release_date,
267            download_url,
268            checksum: String::new(), // Would be fetched from release assets
269            signature: None,
270            changelog: release_info["body"].as_str().unwrap_or("").to_string(),
271            is_security_update: release_info["body"]
272                .as_str()
273                .unwrap_or("")
274                .to_lowercase()
275                .contains("security"),
276        })
277    }
278
279    fn get_download_url_for_platform(&self, release_info: &serde_json::Value) -> Result<String> {
280        let assets = release_info["assets"]
281            .as_array()
282            .ok_or_else(|| VoirsCLIError::UpdateError("No assets found in release".to_string()))?;
283
284        let platform_suffix = if cfg!(target_os = "windows") {
285            "windows"
286        } else if cfg!(target_os = "macos") {
287            "macos"
288        } else {
289            "linux"
290        };
291
292        for asset in assets {
293            if let Some(name) = asset["name"].as_str() {
294                if name.contains(platform_suffix) {
295                    return Ok(asset["browser_download_url"]
296                        .as_str()
297                        .ok_or_else(|| {
298                            VoirsCLIError::UpdateError("Invalid download URL".to_string())
299                        })?
300                        .to_string());
301                }
302            }
303        }
304
305        Err(VoirsCLIError::UpdateError(format!(
306            "No binary found for platform: {}",
307            platform_suffix
308        ))
309        .into())
310    }
311
312    fn is_newer_version(&self, remote_version: &str) -> Result<bool> {
313        let current = semver::Version::parse(&self.state.current_version)?;
314        let remote = semver::Version::parse(remote_version)?;
315
316        Ok(remote > current)
317    }
318
319    async fn create_backup(&self) -> Result<PathBuf> {
320        let current_binary = self.get_current_binary_path()?;
321        let backup_name = format!("voirs-backup-{}.bak", Utc::now().timestamp());
322        let backup_path = current_binary
323            .parent()
324            .unwrap_or(&PathBuf::from("."))
325            .join(&backup_name);
326
327        fs::copy(&current_binary, &backup_path)?;
328
329        info!("Created backup at: {:?}", backup_path);
330        Ok(backup_path)
331    }
332
333    async fn download_binary(&self, version_info: &VersionInfo) -> Result<PathBuf> {
334        info!("Downloading binary from: {}", version_info.download_url);
335
336        let response = self.client.get(&version_info.download_url).send().await?;
337
338        if !response.status().is_success() {
339            return Err(VoirsCLIError::UpdateError(format!(
340                "Failed to download binary: HTTP {}",
341                response.status()
342            ))
343            .into());
344        }
345
346        let temp_path = std::env::temp_dir().join(format!("voirs-update-{}", version_info.version));
347        let mut file = File::create(&temp_path).await?;
348
349        let content = response.bytes().await?;
350        file.write_all(&content).await?;
351
352        info!("Binary downloaded to: {:?}", temp_path);
353        Ok(temp_path)
354    }
355
356    async fn verify_binary_integrity(
357        &self,
358        binary_path: &PathBuf,
359        expected_checksum: &str,
360    ) -> Result<bool> {
361        if expected_checksum.is_empty() {
362            warn!("No checksum provided for verification");
363            return Ok(true);
364        }
365
366        let content = fs::read(binary_path)?;
367        let mut hasher = Sha256::new();
368        hasher.update(&content);
369        let actual_checksum = format!("{:x}", hasher.finalize());
370
371        let matches = actual_checksum == expected_checksum;
372        if matches {
373            info!("Binary integrity verification passed");
374        } else {
375            error!(
376                "Binary integrity verification failed: expected {}, got {}",
377                expected_checksum, actual_checksum
378            );
379        }
380
381        Ok(matches)
382    }
383
384    async fn verify_signature(&self, binary_path: &PathBuf, signature: &str) -> Result<bool> {
385        info!("Verifying signature for binary: {:?}", binary_path);
386
387        // Read the binary file
388        let binary_content = fs::read(binary_path)?;
389
390        // Parse the signature (assuming it's hex-encoded)
391        let signature_bytes = self.parse_hex_signature(signature)?;
392
393        // Get the public key for verification
394        let public_key = self.get_verification_public_key()?;
395
396        // Verify the signature using Ed25519 (or RSA as fallback)
397        let is_valid = match self.config.signature_algorithm.as_str() {
398            "ed25519" => {
399                self.verify_ed25519_signature(&binary_content, &signature_bytes, &public_key)?
400            }
401            "rsa" => self.verify_rsa_signature(&binary_content, &signature_bytes, &public_key)?,
402            "ecdsa" => {
403                self.verify_ecdsa_signature(&binary_content, &signature_bytes, &public_key)?
404            }
405            _ => {
406                warn!(
407                    "Unknown signature algorithm: {}",
408                    self.config.signature_algorithm
409                );
410                return Ok(false);
411            }
412        };
413
414        if is_valid {
415            info!("Binary signature verification successful");
416        } else {
417            warn!("Binary signature verification failed");
418        }
419
420        Ok(is_valid)
421    }
422
423    /// Parse hex-encoded signature
424    fn parse_hex_signature(&self, signature: &str) -> Result<Vec<u8>> {
425        let signature_clean = signature.trim().replace(" ", "").replace("\n", "");
426
427        if signature_clean.len() % 2 != 0 {
428            return Err(anyhow::anyhow!("Invalid hex signature length"));
429        }
430
431        let mut signature_bytes = Vec::new();
432        for i in (0..signature_clean.len()).step_by(2) {
433            let hex_byte = &signature_clean[i..i + 2];
434            let byte = u8::from_str_radix(hex_byte, 16)
435                .map_err(|_| anyhow::anyhow!("Invalid hex character in signature"))?;
436            signature_bytes.push(byte);
437        }
438
439        Ok(signature_bytes)
440    }
441
442    /// Get the public key for signature verification
443    fn get_verification_public_key(&self) -> Result<Vec<u8>> {
444        // Try to get public key from multiple sources
445
446        // 1. Check environment variable
447        if let Ok(key_env) = std::env::var("VOIRS_PUBLIC_KEY") {
448            return self.parse_public_key(&key_env);
449        }
450
451        // 2. Check configuration file
452        if let Some(key_path) = &self.config.public_key_path {
453            if key_path.exists() {
454                let key_content = fs::read_to_string(key_path)?;
455                return self.parse_public_key(&key_content);
456            }
457        }
458
459        // 3. Use embedded public key (hardcoded for security)
460        let embedded_key = self.get_embedded_public_key();
461        Ok(embedded_key)
462    }
463
464    /// Parse public key from string (supports PEM and raw hex)
465    fn parse_public_key(&self, key_str: &str) -> Result<Vec<u8>> {
466        let key_clean = key_str.trim();
467
468        // Check if it's a PEM-formatted key
469        if key_clean.starts_with("-----BEGIN") && key_clean.ends_with("-----END") {
470            // Extract the base64 content between BEGIN and END markers
471            let lines: Vec<&str> = key_clean.lines().collect();
472            if lines.len() < 3 {
473                return Err(anyhow::anyhow!("Invalid PEM format"));
474            }
475
476            let b64_content = lines[1..lines.len() - 1].join("");
477            let key_bytes = base64::decode(&b64_content)
478                .map_err(|_| anyhow::anyhow!("Invalid base64 in PEM key"))?;
479
480            Ok(key_bytes)
481        } else if key_clean
482            .chars()
483            .all(|c| c.is_ascii_hexdigit() || c.is_whitespace())
484        {
485            // Treat as hex-encoded key
486            self.parse_hex_signature(key_clean)
487        } else {
488            Err(anyhow::anyhow!("Unsupported public key format"))
489        }
490    }
491
492    /// Get embedded public key (hardcoded for security)
493    fn get_embedded_public_key(&self) -> Vec<u8> {
494        // In a real implementation, this would be the actual public key
495        // For now, we'll use a placeholder key
496        match self.config.signature_algorithm.as_str() {
497            "ed25519" => {
498                // Ed25519 public key (32 bytes)
499                vec![
500                    0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC, 0xDD,
501                    0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA,
502                    0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
503                ]
504            }
505            "rsa" => {
506                // RSA public key (DER encoded, simplified)
507                vec![
508                    0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7,
509                    0x0d, 0x01, 0x01,
510                    // ... RSA public key continues (truncated for brevity)
511                ]
512            }
513            "ecdsa" => {
514                // ECDSA public key (33 bytes compressed)
515                vec![
516                    0x02, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA, 0xBB, 0xCC,
517                    0xDD, 0xEE, 0xFF, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99,
518                    0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00,
519                ]
520            }
521            _ => vec![],
522        }
523    }
524
525    /// Verify Ed25519 signature
526    fn verify_ed25519_signature(
527        &self,
528        data: &[u8],
529        signature: &[u8],
530        public_key: &[u8],
531    ) -> Result<bool> {
532        if signature.len() != 64 {
533            return Err(anyhow::anyhow!("Invalid Ed25519 signature length"));
534        }
535
536        if public_key.len() != 32 {
537            return Err(anyhow::anyhow!("Invalid Ed25519 public key length"));
538        }
539
540        // Calculate SHA-256 hash of the data
541        let hash = sha2::Sha256::digest(data);
542
543        // Simulate Ed25519 signature verification
544        // In a real implementation, this would use the `ed25519-dalek` crate
545        let is_valid = self.simulate_signature_verification(&hash, signature, public_key);
546
547        Ok(is_valid)
548    }
549
550    /// Verify RSA signature
551    fn verify_rsa_signature(
552        &self,
553        data: &[u8],
554        signature: &[u8],
555        public_key: &[u8],
556    ) -> Result<bool> {
557        // Calculate SHA-256 hash of the data
558        let hash = sha2::Sha256::digest(data);
559
560        // Simulate RSA signature verification
561        // In a real implementation, this would use the `rsa` crate
562        let is_valid = self.simulate_signature_verification(&hash, signature, public_key);
563
564        Ok(is_valid)
565    }
566
567    /// Verify ECDSA signature
568    fn verify_ecdsa_signature(
569        &self,
570        data: &[u8],
571        signature: &[u8],
572        public_key: &[u8],
573    ) -> Result<bool> {
574        // Calculate SHA-256 hash of the data
575        let hash = sha2::Sha256::digest(data);
576
577        // Simulate ECDSA signature verification
578        // In a real implementation, this would use the `p256` or `secp256k1` crate
579        let is_valid = self.simulate_signature_verification(&hash, signature, public_key);
580
581        Ok(is_valid)
582    }
583
584    /// Simulate signature verification (for demonstration purposes)
585    fn simulate_signature_verification(
586        &self,
587        hash: &[u8],
588        signature: &[u8],
589        public_key: &[u8],
590    ) -> bool {
591        // This is a simplified simulation for demonstration
592        // In a real implementation, this would use proper cryptographic verification
593
594        // Check basic length requirements
595        if signature.is_empty() || public_key.is_empty() || hash.is_empty() {
596            return false;
597        }
598
599        // Simulate verification by checking if signature matches a pattern
600        // This is NOT secure and is only for demonstration
601        let mut verification_hash = Vec::new();
602        verification_hash.extend_from_slice(hash);
603        verification_hash.extend_from_slice(public_key);
604
605        let computed_hash = sha2::Sha256::digest(&verification_hash);
606
607        // Check if first 16 bytes of signature match first 16 bytes of computed hash
608        if signature.len() >= 16 && computed_hash.len() >= 16 {
609            signature[0..16] == computed_hash[0..16]
610        } else {
611            false
612        }
613    }
614
615    fn get_current_binary_path(&self) -> Result<PathBuf> {
616        let current_exe = std::env::current_exe()?;
617        Ok(current_exe)
618    }
619
620    async fn cleanup_old_backups(&mut self) -> Result<()> {
621        while self.state.backup_paths.len() > self.config.backup_count as usize {
622            let old_backup = self.state.backup_paths.remove(0);
623            if old_backup.exists() {
624                fs::remove_file(&old_backup)?;
625                info!("Removed old backup: {:?}", old_backup);
626            }
627        }
628        Ok(())
629    }
630
631    fn save_state(&self) -> Result<()> {
632        let content = serde_json::to_string_pretty(&self.state)?;
633        fs::write(&self.state_file, content)?;
634        Ok(())
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641    use tempfile::TempDir;
642
643    #[test]
644    fn test_update_config_default() {
645        let config = UpdateConfig::default();
646        assert_eq!(config.check_interval_hours, 24);
647        assert!(!config.auto_update);
648        assert_eq!(config.backup_count, 3);
649        assert!(matches!(config.update_channel, UpdateChannel::Stable));
650    }
651
652    #[test]
653    fn test_update_state_default() {
654        let state = UpdateState::default();
655        assert!(!state.update_available);
656        assert!(state.backup_paths.is_empty());
657        assert_eq!(state.current_version, env!("CARGO_PKG_VERSION"));
658    }
659
660    #[test]
661    fn test_version_comparison() {
662        let state = UpdateState::default();
663        let manager = UpdateManager {
664            config: UpdateConfig::default(),
665            state,
666            client: Client::new(),
667            state_file: PathBuf::from("test.json"),
668        };
669
670        // This would normally test version comparison logic
671        // For now, we just verify the structure is correct
672        assert_eq!(manager.state.current_version, env!("CARGO_PKG_VERSION"));
673    }
674
675    #[test]
676    fn test_update_channel_serialization() {
677        let channel = UpdateChannel::Stable;
678        let serialized = serde_json::to_string(&channel).unwrap();
679        let deserialized: UpdateChannel = serde_json::from_str(&serialized).unwrap();
680        assert!(matches!(deserialized, UpdateChannel::Stable));
681    }
682}