use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context as _, Result, anyhow, bail};
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64;
use ed25519_dalek::{Signature, VerifyingKey};
use futures::AsyncReadExt as _;
use serde::{Deserialize, Serialize};
use sha2::{Digest as _, Sha256};
use kael_release::update::{UpdateChannel, UpdateManifest, verify_manifest};
use semantic_version::SemanticVersion;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutoUpdaterConfig {
pub feed_url: String,
#[serde(with = "duration_secs")]
pub check_interval: Duration,
pub allow_prerelease: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateInfo {
pub version: SemanticVersion,
pub release_notes: Option<String>,
pub download_url: String,
pub signature: Option<String>,
#[serde(default)]
pub sha256: Option<String>,
#[serde(default)]
pub size_bytes: Option<u64>,
}
#[derive(Debug, Clone, Copy)]
pub struct DownloadProgress {
pub bytes_downloaded: u64,
pub total_bytes: Option<u64>,
}
impl DownloadProgress {
pub fn fraction(&self) -> Option<f64> {
self.total_bytes
.map(|total| self.bytes_downloaded as f64 / total as f64)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UpdateStatus {
Idle,
Checking,
UpdateAvailable(SemanticVersion),
Downloading,
ReadyToInstall,
Error(String),
}
pub trait PlatformInstaller: Send + Sync {
fn install_and_restart(&self, package_path: &std::path::Path) -> Result<()>;
}
pub struct AutoUpdater {
config: AutoUpdaterConfig,
current_version: SemanticVersion,
http_client: Arc<dyn http_client::HttpClient>,
installer: Option<Arc<dyn PlatformInstaller>>,
status: UpdateStatus,
latest_update: Option<UpdateInfo>,
downloaded_path: Option<std::path::PathBuf>,
verifying_key: Option<VerifyingKey>,
update_channel: UpdateChannel,
require_signature: bool,
}
impl AutoUpdater {
pub fn new(
config: AutoUpdaterConfig,
current_version: SemanticVersion,
http_client: Arc<dyn http_client::HttpClient>,
) -> Self {
Self {
config,
current_version,
http_client,
installer: None,
status: UpdateStatus::Idle,
latest_update: None,
downloaded_path: None,
verifying_key: None,
update_channel: UpdateChannel::Stable,
require_signature: true,
}
}
pub fn set_installer(&mut self, installer: Arc<dyn PlatformInstaller>) {
self.installer = Some(installer);
}
pub fn set_public_key(&mut self, public_key: &[u8]) -> Result<()> {
let key_array: [u8; 32] = public_key
.try_into()
.map_err(|_| anyhow!("ed25519 public key must be exactly 32 bytes"))?;
let key = VerifyingKey::from_bytes(&key_array)
.map_err(|_| anyhow!("invalid ed25519 public key"))?;
self.verifying_key = Some(key);
Ok(())
}
pub fn set_public_key_hex(&mut self, hex_key: &str) -> Result<()> {
let bytes = hex::decode(hex_key.trim()).context("update public key is not valid hex")?;
self.set_public_key(&bytes)
}
pub fn set_update_channel(&mut self, channel: impl AsRef<str>) {
self.update_channel = channel_from_str(channel.as_ref());
}
pub fn set_require_signature(&mut self, require: bool) {
self.require_signature = require;
}
pub fn status(&self) -> &UpdateStatus {
&self.status
}
pub fn latest_update(&self) -> Option<&UpdateInfo> {
self.latest_update.as_ref()
}
pub fn config(&self) -> &AutoUpdaterConfig {
&self.config
}
pub async fn check_for_updates(&mut self) -> Result<Option<UpdateInfo>> {
self.status = UpdateStatus::Checking;
let mut response = self
.http_client
.get(&self.config.feed_url, Default::default(), false)
.await
.context("failed to fetch update feed")?;
let status = response.status();
if !status.is_success() {
let msg = format!("update feed returned HTTP {}", status.as_u16());
self.status = UpdateStatus::Error(msg.clone());
bail!("{}", msg);
}
let mut body = Vec::new();
response
.body_mut()
.read_to_end(&mut body)
.await
.context("failed to read update feed body")?;
let body_str = String::from_utf8_lossy(&body);
let updates = parse_update_feed(&body_str)?;
let latest = updates
.into_iter()
.filter(|u| u.version > self.current_version)
.max_by_key(|u| u.version);
if let Some(ref update) = latest {
self.status = UpdateStatus::UpdateAvailable(update.version);
self.latest_update = Some(update.clone());
} else {
self.status = UpdateStatus::Idle;
self.latest_update = None;
}
Ok(latest)
}
pub async fn download_update(
&mut self,
on_progress: impl Fn(DownloadProgress) + Send + 'static,
) -> Result<std::path::PathBuf> {
let update = self
.latest_update
.as_ref()
.ok_or_else(|| anyhow!("no update available to download"))?
.clone();
self.status = UpdateStatus::Downloading;
let mut response = self
.http_client
.get(&update.download_url, Default::default(), false)
.await
.context("failed to start update download")?;
let status = response.status();
if !status.is_success() {
let msg = format!("update download returned HTTP {}", status.as_u16());
self.status = UpdateStatus::Error(msg.clone());
bail!("{}", msg);
}
let total_bytes = response
.headers()
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
let mut bytes = Vec::new();
response
.body_mut()
.read_to_end(&mut bytes)
.await
.context("failed to read update package")?;
on_progress(DownloadProgress {
bytes_downloaded: bytes.len() as u64,
total_bytes,
});
if let Err(err) = self.verify_package(&update, &bytes) {
self.downloaded_path = None;
self.status = UpdateStatus::Error(err.to_string());
return Err(err).context("update package failed verification; refusing to install");
}
let staging_dir =
std::env::temp_dir().join(format!("kael_update_{}", uuid::Uuid::new_v4()));
std::fs::create_dir_all(&staging_dir)
.context("failed to create update staging directory")?;
restrict_dir_permissions(&staging_dir);
let download_path = staging_dir.join(sanitize_package_filename(&update.download_url));
std::fs::write(&download_path, &bytes).context("failed to write update package to disk")?;
self.downloaded_path = Some(download_path.clone());
self.status = UpdateStatus::ReadyToInstall;
Ok(download_path)
}
fn verify_package(&self, update: &UpdateInfo, bytes: &[u8]) -> Result<()> {
match self.verifying_key.as_ref() {
Some(key) => {
let signature_b64 = update.signature.as_deref().ok_or_else(|| {
anyhow!("update is unsigned but signature verification is required")
})?;
let signature_bytes = BASE64
.decode(signature_b64)
.context("update signature is not valid base64")?;
let signature_array: [u8; 64] = signature_bytes
.as_slice()
.try_into()
.map_err(|_| anyhow!("update signature must be 64 bytes"))?;
let signature = Signature::from_bytes(&signature_array);
let sha256 = update
.sha256
.as_deref()
.ok_or_else(|| anyhow!("signed update is missing its sha256 hash"))?;
let size_bytes = update
.size_bytes
.ok_or_else(|| anyhow!("signed update is missing its size"))?;
let manifest = UpdateManifest {
version: update.version.to_string(),
channel: self.update_channel.clone(),
url: update.download_url.clone(),
sha256: sha256.to_string(),
size_bytes,
release_notes: None,
min_version: None,
};
if !verify_manifest(&manifest, &signature, key) {
bail!("update signature verification failed");
}
}
None => {
if self.require_signature {
bail!(
"auto-update signature verification is required but no public key is configured"
);
}
}
}
match update.sha256.as_deref() {
Some(expected) => {
if let Some(expected_size) = update.size_bytes {
if bytes.len() as u64 != expected_size {
bail!(
"update size mismatch: expected {expected_size} bytes, downloaded {}",
bytes.len()
);
}
}
let actual = sha256_hex(bytes);
if actual.len() != expected.len() || !actual.eq_ignore_ascii_case(expected) {
bail!("update hash mismatch: expected {expected}, downloaded {actual}");
}
}
None => {
if self.require_signature {
bail!("update is missing a sha256 hash; cannot verify integrity");
}
}
}
Ok(())
}
pub fn install_and_restart(&self) -> Result<()> {
let installer = self
.installer
.as_ref()
.ok_or_else(|| anyhow!("no platform installer configured"))?;
let path = self
.downloaded_path
.as_ref()
.ok_or_else(|| anyhow!("no update has been downloaded"))?;
installer.install_and_restart(path)
}
}
fn sha256_hex(bytes: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(bytes);
hex::encode(hasher.finalize())
}
fn channel_from_str(channel: &str) -> UpdateChannel {
let trimmed = channel.trim();
if trimmed.eq_ignore_ascii_case("stable") {
UpdateChannel::Stable
} else if trimmed.eq_ignore_ascii_case("beta") {
UpdateChannel::Beta
} else if trimmed.eq_ignore_ascii_case("nightly") {
UpdateChannel::Nightly
} else {
UpdateChannel::Custom(trimmed.to_string())
}
}
fn sanitize_package_filename(download_url: &str) -> String {
let candidate = download_url
.rsplit(['/', '\\'])
.next()
.unwrap_or("")
.split(['?', '#'])
.next()
.unwrap_or("");
let cleaned: String = candidate
.chars()
.filter(|c| c.is_ascii_alphanumeric() || matches!(c, '.' | '-' | '_'))
.collect();
if cleaned.is_empty() || cleaned == "." || cleaned == ".." {
"update_package".to_string()
} else {
cleaned
}
}
fn restrict_dir_permissions(dir: &std::path::Path) {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt as _;
let _ = std::fs::set_permissions(dir, std::fs::Permissions::from_mode(0o700));
}
#[cfg(not(unix))]
{
let _ = dir;
}
}
pub fn parse_update_feed(body: &str) -> Result<Vec<UpdateInfo>> {
let trimmed = body.trim();
if trimmed.starts_with('<') {
parse_appcast_xml(trimmed)
} else if trimmed.starts_with('[') || trimmed.starts_with('{') {
parse_json_feed(trimmed)
} else {
bail!("unrecognized update feed format");
}
}
#[derive(Debug, Deserialize)]
struct JsonFeedItem {
version: String,
#[serde(default)]
release_notes: Option<String>,
download_url: String,
#[serde(default)]
signature: Option<String>,
#[serde(default)]
sha256: Option<String>,
#[serde(default)]
size_bytes: Option<u64>,
}
fn parse_json_feed(body: &str) -> Result<Vec<UpdateInfo>> {
let items: Vec<JsonFeedItem> = if body.trim().starts_with('[') {
serde_json::from_str(body).context("failed to parse JSON update feed as array")?
} else {
#[derive(Deserialize)]
struct Wrapper {
items: Vec<JsonFeedItem>,
}
let wrapper: Wrapper =
serde_json::from_str(body).context("failed to parse JSON update feed as object")?;
wrapper.items
};
items
.into_iter()
.map(|item| {
let version = item
.version
.parse::<SemanticVersion>()
.context(format!("invalid version string: {}", item.version))?;
Ok(UpdateInfo {
version,
release_notes: item.release_notes,
download_url: item.download_url,
signature: item.signature,
sha256: item.sha256,
size_bytes: item.size_bytes,
})
})
.collect()
}
fn parse_appcast_xml(body: &str) -> Result<Vec<UpdateInfo>> {
let mut updates = Vec::new();
for item_block in split_xml_items(body) {
let version_str = extract_xml_attr(&item_block, "sparkle:version")
.or_else(|| extract_xml_attr(&item_block, "sparkle:shortVersionString"))
.or_else(|| extract_xml_tag_content(&item_block, "sparkle:version"));
let download_url = extract_xml_attr(&item_block, "url");
let signature = extract_xml_attr(&item_block, "sparkle:edSignature")
.or_else(|| extract_xml_attr(&item_block, "sparkle:dsaSignature"));
let sha256 = extract_xml_attr(&item_block, "sparkle:sha256")
.or_else(|| extract_xml_attr(&item_block, "sha256"));
let size_bytes =
extract_xml_attr(&item_block, "length").and_then(|len| len.parse::<u64>().ok());
let release_notes = extract_xml_tag_content(&item_block, "description");
if let (Some(version_str), Some(download_url)) = (version_str, download_url) {
if let Ok(version) = version_str.parse::<SemanticVersion>() {
updates.push(UpdateInfo {
version,
release_notes,
download_url,
signature,
sha256,
size_bytes,
});
}
}
}
Ok(updates)
}
fn split_xml_items(body: &str) -> Vec<String> {
let mut items = Vec::new();
let lower = body.to_lowercase();
let mut search_from = 0;
while let Some(pos) = lower[search_from..]
.find("<item>")
.or_else(|| lower[search_from..].find("<item "))
{
let start = search_from + pos;
let end = match lower[start..].find("</item>") {
Some(pos) => start + pos + "</item>".len(),
None => break,
};
items.push(body[start..end].to_string());
search_from = end;
}
items
}
fn extract_xml_attr(block: &str, attr_name: &str) -> Option<String> {
let search = format!("{}=\"", attr_name);
let start = block.find(&search)?;
let value_start = start + search.len();
let value_end = block[value_start..].find('"')? + value_start;
Some(block[value_start..value_end].to_string())
}
fn extract_xml_tag_content(block: &str, tag_name: &str) -> Option<String> {
let open = format!("<{}", tag_name);
let close = format!("</{}>", tag_name);
let start = block.find(&open)?;
let after_open = block[start..].find('>')? + start + 1;
let end = block[after_open..].find(&close)? + after_open;
let content = block[after_open..end].trim().to_string();
if content.is_empty() {
None
} else {
Some(content)
}
}
#[cfg(target_os = "macos")]
pub struct MacInstaller;
#[cfg(target_os = "macos")]
impl PlatformInstaller for MacInstaller {
fn install_and_restart(&self, package_path: &std::path::Path) -> Result<()> {
let ext = package_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
let app_bundle = resolve_running_app_bundle()?;
match ext {
"zip" => {
let temp_dir = std::env::temp_dir().join("gpui_update_extract");
if temp_dir.exists() {
std::fs::remove_dir_all(&temp_dir)?;
}
std::fs::create_dir_all(&temp_dir)?;
let status = std::process::Command::new("ditto")
.args([
"-xk",
&package_path.to_string_lossy(),
&temp_dir.to_string_lossy(),
])
.status()
.context("failed to run ditto to extract zip")?;
if !status.success() {
bail!("ditto extraction failed with status {}", status);
}
let new_app = find_app_bundle_in(&temp_dir)?;
replace_app_bundle(&new_app, &app_bundle)?;
}
"dmg" => {
let mount_point = std::env::temp_dir().join("gpui_update_dmg");
if mount_point.exists() {
let _ = std::process::Command::new("hdiutil")
.args(["detach", &mount_point.to_string_lossy(), "-quiet"])
.status();
let _ = std::fs::remove_dir_all(&mount_point);
}
std::fs::create_dir_all(&mount_point)?;
let status = std::process::Command::new("hdiutil")
.args([
"attach",
&package_path.to_string_lossy(),
"-mountpoint",
&mount_point.to_string_lossy(),
"-nobrowse",
"-quiet",
])
.status()
.context("failed to run hdiutil attach")?;
if !status.success() {
bail!("hdiutil attach failed with status {}", status);
}
let result = (|| -> Result<()> {
let new_app = find_app_bundle_in(&mount_point)?;
replace_app_bundle(&new_app, &app_bundle)
})();
let _ = std::process::Command::new("hdiutil")
.args(["detach", &mount_point.to_string_lossy(), "-quiet"])
.status();
result?;
}
other => bail!("unsupported macOS package format: .{}", other),
}
let status = std::process::Command::new("open")
.args(["-n", &app_bundle.to_string_lossy()])
.status()
.context("failed to restart application")?;
if !status.success() {
bail!("failed to restart application, open returned {}", status);
}
std::process::exit(0);
}
}
#[cfg(target_os = "macos")]
fn resolve_running_app_bundle() -> Result<std::path::PathBuf> {
let exe = std::env::current_exe().context("failed to get current executable path")?;
let app_bundle = exe
.parent() .and_then(|p| p.parent()) .and_then(|p| p.parent()) .ok_or_else(|| anyhow!("could not determine .app bundle path from executable"))?;
Ok(app_bundle.to_path_buf())
}
#[cfg(target_os = "macos")]
fn find_app_bundle_in(dir: &std::path::Path) -> Result<std::path::PathBuf> {
for entry in std::fs::read_dir(dir).context("failed to read extraction directory")? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("app") {
return Ok(path);
}
}
bail!("no .app bundle found in {}", dir.display())
}
#[cfg(target_os = "macos")]
fn replace_app_bundle(new_app: &std::path::Path, existing_app: &std::path::Path) -> Result<()> {
let backup = existing_app.with_extension("app.bak");
if backup.exists() {
std::fs::remove_dir_all(&backup)?;
}
std::fs::rename(existing_app, &backup)
.context("failed to move existing app bundle to backup")?;
let status = std::process::Command::new("cp")
.args([
"-R",
&new_app.to_string_lossy(),
&existing_app.to_string_lossy(),
])
.status()
.context("failed to copy new app bundle")?;
if !status.success() {
let _ = std::fs::rename(&backup, existing_app);
bail!("failed to copy new app bundle into place");
}
let _ = std::fs::remove_dir_all(&backup);
Ok(())
}
#[cfg(target_os = "windows")]
pub struct WindowsInstaller;
#[cfg(target_os = "windows")]
impl PlatformInstaller for WindowsInstaller {
fn install_and_restart(&self, package_path: &std::path::Path) -> Result<()> {
let ext = package_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
match ext {
"msi" => {
let status = std::process::Command::new("msiexec")
.args([
"/i",
&package_path.to_string_lossy(),
"/quiet",
"/norestart",
])
.status()
.context("failed to run msiexec")?;
if !status.success() {
bail!("msiexec failed with status {}", status);
}
}
"exe" => {
let status = std::process::Command::new(package_path)
.args(["/S"])
.status()
.context("failed to run NSIS installer")?;
if !status.success() {
bail!("NSIS installer failed with status {}", status);
}
}
other => bail!("unsupported Windows package format: .{}", other),
}
let exe = std::env::current_exe().context("failed to get current executable path")?;
let _ = std::process::Command::new(exe)
.spawn()
.context("failed to restart application")?;
std::process::exit(0);
}
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
pub struct LinuxInstaller {
pub format_hint: Option<LinuxPackageFormat>,
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinuxPackageFormat {
AppImage,
Flatpak,
Snap,
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
impl LinuxInstaller {
pub fn new() -> Self {
Self { format_hint: None }
}
pub fn with_format(format: LinuxPackageFormat) -> Self {
Self {
format_hint: Some(format),
}
}
fn detect_format(&self, package_path: &std::path::Path) -> LinuxPackageFormat {
if let Some(hint) = self.format_hint {
return hint;
}
let ext = package_path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("");
if ext.eq_ignore_ascii_case("appimage") {
return LinuxPackageFormat::AppImage;
}
if std::env::var("FLATPAK_ID").is_ok() {
return LinuxPackageFormat::Flatpak;
}
if std::env::var("SNAP").is_ok() {
return LinuxPackageFormat::Snap;
}
LinuxPackageFormat::AppImage
}
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
impl PlatformInstaller for LinuxInstaller {
fn install_and_restart(&self, package_path: &std::path::Path) -> Result<()> {
let format = self.detect_format(package_path);
match format {
LinuxPackageFormat::AppImage => {
let exe =
std::env::current_exe().context("failed to get current executable path")?;
let backup = exe.with_extension("bak");
if backup.exists() {
std::fs::remove_file(&backup)?;
}
std::fs::rename(&exe, &backup)
.context("failed to move current AppImage to backup")?;
if let Err(e) = std::fs::copy(package_path, &exe) {
let _ = std::fs::rename(&backup, &exe);
return Err(e).context("failed to copy new AppImage into place");
}
let status = std::process::Command::new("chmod")
.args(["+x", &exe.to_string_lossy()])
.status()
.context("failed to chmod new AppImage")?;
if !status.success() {
let _ = std::fs::rename(&backup, &exe);
bail!("chmod failed with status {}", status);
}
let _ = std::fs::remove_file(&backup);
let _ = std::process::Command::new(&exe)
.spawn()
.context("failed to restart AppImage")?;
std::process::exit(0);
}
LinuxPackageFormat::Flatpak => {
let app_id =
std::env::var("FLATPAK_ID").unwrap_or_else(|_| "current-app".to_string());
let status = std::process::Command::new("flatpak")
.args(["update", "-y", &app_id])
.status()
.context("failed to run flatpak update")?;
if !status.success() {
bail!("flatpak update failed with status {}", status);
}
let _ = std::process::Command::new("flatpak")
.args(["run", &app_id])
.spawn()
.context("failed to restart Flatpak application")?;
std::process::exit(0);
}
LinuxPackageFormat::Snap => {
let snap_name =
std::env::var("SNAP_NAME").unwrap_or_else(|_| "current-app".to_string());
let status = std::process::Command::new("snap")
.args(["refresh", &snap_name])
.status()
.context("failed to run snap refresh")?;
if !status.success() {
bail!("snap refresh failed with status {}", status);
}
let _ = std::process::Command::new("snap")
.args(["run", &snap_name])
.spawn()
.context("failed to restart Snap application")?;
std::process::exit(0);
}
}
}
}
mod duration_secs {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_secs())
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let secs = u64::deserialize(deserializer)?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_json_feed_array() {
let json = r#"[
{
"version": "1.2.3",
"release_notes": "Bug fixes",
"download_url": "https://example.com/update-1.2.3.zip",
"signature": "abc123"
},
{
"version": "1.1.0",
"download_url": "https://example.com/update-1.1.0.zip"
}
]"#;
let updates = parse_update_feed(json).unwrap();
assert_eq!(updates.len(), 2);
assert_eq!(updates[0].version, SemanticVersion::new(1, 2, 3));
assert_eq!(updates[0].release_notes.as_deref(), Some("Bug fixes"));
assert_eq!(
updates[0].download_url,
"https://example.com/update-1.2.3.zip"
);
assert_eq!(updates[0].signature.as_deref(), Some("abc123"));
assert_eq!(updates[1].version, SemanticVersion::new(1, 1, 0));
assert!(updates[1].release_notes.is_none());
assert!(updates[1].signature.is_none());
}
#[test]
fn test_parse_json_feed_object_wrapper() {
let json = r#"{
"items": [
{
"version": "2.0.0",
"download_url": "https://example.com/v2.zip"
}
]
}"#;
let updates = parse_update_feed(json).unwrap();
assert_eq!(updates.len(), 1);
assert_eq!(updates[0].version, SemanticVersion::new(2, 0, 0));
}
#[test]
fn test_parse_appcast_xml() {
let xml = r#"<?xml version="1.0" encoding="utf-8"?>
<rss version="2.0" xmlns:sparkle="http://www.andymatuschak.org/xml-namespaces/sparkle">
<channel>
<title>My App Updates</title>
<item>
<title>Version 3.1.0</title>
<description>New features and improvements</description>
<enclosure url="https://example.com/MyApp-3.1.0.dmg"
sparkle:version="3.1.0"
sparkle:dsaSignature="sig123"
length="12345678"
type="application/octet-stream" />
</item>
<item>
<title>Version 3.0.0</title>
<enclosure url="https://example.com/MyApp-3.0.0.dmg"
sparkle:version="3.0.0"
length="11111111"
type="application/octet-stream" />
</item>
</channel>
</rss>"#;
let updates = parse_update_feed(xml).unwrap();
assert_eq!(updates.len(), 2);
assert_eq!(updates[0].version, SemanticVersion::new(3, 1, 0));
assert_eq!(
updates[0].download_url,
"https://example.com/MyApp-3.1.0.dmg"
);
assert_eq!(updates[0].signature.as_deref(), Some("sig123"));
assert_eq!(
updates[0].release_notes.as_deref(),
Some("New features and improvements")
);
assert_eq!(updates[1].version, SemanticVersion::new(3, 0, 0));
assert!(updates[1].signature.is_none());
}
#[test]
fn test_parse_appcast_xml_with_ed_signature() {
let xml = r#"<rss><channel>
<item>
<enclosure url="https://example.com/app.zip"
sparkle:version="1.0.0"
sparkle:edSignature="ed_sig_value" />
</item>
</channel></rss>"#;
let updates = parse_update_feed(xml).unwrap();
assert_eq!(updates.len(), 1);
assert_eq!(updates[0].signature.as_deref(), Some("ed_sig_value"));
}
#[test]
fn test_parse_empty_json_array() {
let updates = parse_update_feed("[]").unwrap();
assert!(updates.is_empty());
}
#[test]
fn test_parse_unrecognized_format() {
let result = parse_update_feed("this is not valid");
assert!(result.is_err());
}
#[test]
fn test_download_progress_fraction() {
let progress = DownloadProgress {
bytes_downloaded: 50,
total_bytes: Some(100),
};
assert_eq!(progress.fraction(), Some(0.5));
let unknown = DownloadProgress {
bytes_downloaded: 50,
total_bytes: None,
};
assert_eq!(unknown.fraction(), None);
}
#[test]
fn test_config_serialization_roundtrip() {
let config = AutoUpdaterConfig {
feed_url: "https://example.com/appcast.xml".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: AutoUpdaterConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.feed_url, config.feed_url);
assert_eq!(deserialized.check_interval, config.check_interval);
assert_eq!(deserialized.allow_prerelease, config.allow_prerelease);
}
#[test]
fn test_update_info_serialization_roundtrip() {
let info = UpdateInfo {
version: SemanticVersion::new(2, 5, 1),
release_notes: Some("Fixed a bug".to_string()),
download_url: "https://example.com/v2.5.1.zip".to_string(),
signature: Some("sig_value".to_string()),
sha256: Some("a".repeat(64)),
size_bytes: Some(4096),
};
let json = serde_json::to_string(&info).unwrap();
let deserialized: UpdateInfo = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.version, info.version);
assert_eq!(deserialized.release_notes, info.release_notes);
assert_eq!(deserialized.download_url, info.download_url);
assert_eq!(deserialized.signature, info.signature);
assert_eq!(deserialized.sha256, info.sha256);
assert_eq!(deserialized.size_bytes, info.size_bytes);
}
#[test]
fn test_auto_updater_initial_state() {
let config = AutoUpdaterConfig {
feed_url: "https://example.com/feed".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let client = http_client::FakeHttpClient::with_200_response();
let updater = AutoUpdater::new(config, SemanticVersion::new(1, 0, 0), client);
assert_eq!(*updater.status(), UpdateStatus::Idle);
assert!(updater.latest_update().is_none());
}
#[test]
fn test_install_without_installer_errors() {
let config = AutoUpdaterConfig {
feed_url: "https://example.com/feed".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let client = http_client::FakeHttpClient::with_200_response();
let updater = AutoUpdater::new(config, SemanticVersion::new(1, 0, 0), client);
let result = updater.install_and_restart();
assert!(result.is_err());
}
fn signed_update_fixture(bytes: &[u8], channel: UpdateChannel) -> (VerifyingKey, UpdateInfo) {
use ed25519_dalek::SigningKey;
let signing_key = SigningKey::from_bytes(&[7u8; 32]);
let verifying_key = signing_key.verifying_key();
let sha256 = sha256_hex(bytes);
let size_bytes = bytes.len() as u64;
let version = SemanticVersion::new(1, 2, 0);
let download_url = "https://example.com/MyApp-1.2.0.zip".to_string();
let manifest = UpdateManifest {
version: version.to_string(),
channel,
url: download_url.clone(),
sha256: sha256.clone(),
size_bytes,
release_notes: None,
min_version: None,
};
let signature = kael_release::update::sign_manifest(&manifest, &signing_key);
let signature_b64 = BASE64.encode(signature.to_bytes());
(
verifying_key,
UpdateInfo {
version,
release_notes: None,
download_url,
signature: Some(signature_b64),
sha256: Some(sha256),
size_bytes: Some(size_bytes),
},
)
}
fn updater_with_key(key: &VerifyingKey) -> AutoUpdater {
let config = AutoUpdaterConfig {
feed_url: "https://example.com/feed".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let client = http_client::FakeHttpClient::with_200_response();
let mut updater = AutoUpdater::new(config, SemanticVersion::new(1, 0, 0), client);
updater.set_public_key(key.as_bytes()).unwrap();
updater
}
#[test]
fn test_verify_package_accepts_genuine_payload() {
let bytes = b"genuine update payload".to_vec();
let (key, update) = signed_update_fixture(&bytes, UpdateChannel::Stable);
let updater = updater_with_key(&key);
assert!(updater.verify_package(&update, &bytes).is_ok());
}
#[test]
fn test_verify_package_rejects_tampered_bytes() {
let bytes = b"genuine update payload".to_vec();
let (key, update) = signed_update_fixture(&bytes, UpdateChannel::Stable);
let updater = updater_with_key(&key);
let tampered = b"malware payload xxxxxx".to_vec();
assert_eq!(tampered.len(), bytes.len());
let err = updater.verify_package(&update, &tampered).unwrap_err();
assert!(err.to_string().contains("hash mismatch"), "{err}");
}
#[test]
fn test_verify_package_rejects_unsigned_when_key_configured() {
let bytes = b"genuine update payload".to_vec();
let (key, mut update) = signed_update_fixture(&bytes, UpdateChannel::Stable);
update.signature = None;
let updater = updater_with_key(&key);
let err = updater.verify_package(&update, &bytes).unwrap_err();
assert!(err.to_string().contains("unsigned"), "{err}");
}
#[test]
fn test_verify_package_rejects_wrong_key() {
let bytes = b"genuine update payload".to_vec();
let (_real_key, update) = signed_update_fixture(&bytes, UpdateChannel::Stable);
let other = ed25519_dalek::SigningKey::from_bytes(&[9u8; 32]).verifying_key();
let updater = updater_with_key(&other);
let err = updater.verify_package(&update, &bytes).unwrap_err();
assert!(
err.to_string().contains("signature verification failed"),
"{err}"
);
}
#[test]
fn test_verify_package_rejects_channel_mismatch() {
let bytes = b"genuine update payload".to_vec();
let (key, update) = signed_update_fixture(&bytes, UpdateChannel::Beta);
let mut updater = updater_with_key(&key);
updater.set_update_channel("stable");
assert!(updater.verify_package(&update, &bytes).is_err());
}
#[test]
fn test_verify_fails_closed_without_public_key() {
let bytes = b"genuine update payload".to_vec();
let (_key, update) = signed_update_fixture(&bytes, UpdateChannel::Stable);
let config = AutoUpdaterConfig {
feed_url: "https://example.com/feed".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let client = http_client::FakeHttpClient::with_200_response();
let updater = AutoUpdater::new(config, SemanticVersion::new(1, 0, 0), client);
let err = updater.verify_package(&update, &bytes).unwrap_err();
assert!(
err.to_string().contains("no public key is configured"),
"{err}"
);
}
#[test]
fn test_sanitize_package_filename_stays_a_single_path_component() {
use std::path::{Component, Path};
let adversarial = [
"https://example.com/releases/kael-1.2.3.dmg",
"https://example.com/kael.dmg?token=secret#frag",
"https://example.com/../../etc/passwd",
"https://example.com/foo/..",
"https://example.com/a\\b\\evil.exe",
"https://example.com/",
"https://example.com/???",
"file:///etc/shadow",
"../../../../root/.ssh/authorized_keys",
"",
".",
"..",
"/absolute/evil",
];
for url in adversarial {
let name = sanitize_package_filename(url);
assert!(!name.is_empty(), "empty name for {url:?}");
assert!(
!name.contains('/') && !name.contains('\\'),
"separator survived for {url:?}: {name:?}"
);
assert_ne!(name, "..", "traversal token survived for {url:?}");
let components: Vec<_> = Path::new(&name).components().collect();
assert_eq!(
components.len(),
1,
"{url:?} -> {name:?} is not exactly one path component"
);
assert!(
matches!(components[0], Component::Normal(_)),
"{url:?} -> {name:?} is not a normal path component"
);
}
}
#[test]
fn test_download_update_rejects_tampered_before_ready() {
use http_client::{AsyncBody, FakeHttpClient, Response};
let genuine = b"genuine update payload".to_vec();
let (key, update) = signed_update_fixture(&genuine, UpdateChannel::Stable);
let served = b"malware payload xxxxxx".to_vec();
let client = FakeHttpClient::create(move |_req| {
let body = served.clone();
async move {
Ok(Response::builder()
.status(200)
.body(AsyncBody::from(body))
.unwrap())
}
});
let config = AutoUpdaterConfig {
feed_url: "https://example.com/feed".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let mut updater = AutoUpdater::new(config, SemanticVersion::new(1, 0, 0), client);
updater.set_public_key(key.as_bytes()).unwrap();
updater.latest_update = Some(update);
let result = smol::block_on(updater.download_update(|_| {}));
assert!(result.is_err());
assert!(matches!(updater.status(), UpdateStatus::Error(_)));
assert_ne!(*updater.status(), UpdateStatus::ReadyToInstall);
assert!(updater.downloaded_path.is_none());
assert!(updater.install_and_restart().is_err());
}
#[test]
fn test_download_update_accepts_genuine_and_marks_ready() {
use http_client::{AsyncBody, FakeHttpClient, Response};
let genuine = b"genuine update payload".to_vec();
let (key, update) = signed_update_fixture(&genuine, UpdateChannel::Stable);
let served = genuine.clone();
let client = FakeHttpClient::create(move |_req| {
let body = served.clone();
async move {
Ok(Response::builder()
.status(200)
.body(AsyncBody::from(body))
.unwrap())
}
});
let config = AutoUpdaterConfig {
feed_url: "https://example.com/feed".to_string(),
check_interval: Duration::from_secs(3600),
allow_prerelease: false,
};
let mut updater = AutoUpdater::new(config, SemanticVersion::new(1, 0, 0), client);
updater.set_public_key(key.as_bytes()).unwrap();
updater.latest_update = Some(update);
let path = smol::block_on(updater.download_update(|_| {})).unwrap();
assert_eq!(*updater.status(), UpdateStatus::ReadyToInstall);
assert!(path.exists());
let path_str = path.to_string_lossy();
assert!(path_str.contains("kael_update_"), "{path_str}");
assert!(!path_str.contains("gpui_update_"), "{path_str}");
let on_disk = std::fs::read(&path).unwrap();
assert_eq!(on_disk, genuine);
if let Some(dir) = path.parent() {
let _ = std::fs::remove_dir_all(dir);
}
}
#[test]
fn test_platform_installer_trait_is_object_safe() {
fn _assert_object_safe(_: &dyn PlatformInstaller) {}
}
#[cfg(target_os = "macos")]
#[test]
fn test_mac_installer_rejects_unsupported_format() {
let installer = MacInstaller;
let path = std::path::Path::new("/tmp/update.tar.gz");
let result = installer.install_and_restart(path);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("unsupported macOS package format")
);
}
#[cfg(target_os = "windows")]
#[test]
fn test_windows_installer_rejects_unsupported_format() {
let installer = WindowsInstaller;
let path = std::path::Path::new("C:\\temp\\update.tar.gz");
let result = installer.install_and_restart(path);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("unsupported Windows package format")
);
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
#[test]
fn test_linux_installer_default_format_detection() {
let installer = LinuxInstaller::new();
let appimage_path = std::path::Path::new("/tmp/MyApp.AppImage");
assert_eq!(
installer.detect_format(appimage_path),
LinuxPackageFormat::AppImage
);
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
#[test]
fn test_linux_installer_explicit_format_hint() {
let installer = LinuxInstaller::with_format(LinuxPackageFormat::Flatpak);
let path = std::path::Path::new("/tmp/MyApp.AppImage");
assert_eq!(installer.detect_format(path), LinuxPackageFormat::Flatpak);
}
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
#[test]
fn test_linux_installer_unknown_extension_defaults_to_appimage() {
let installer = LinuxInstaller::new();
let path = std::path::Path::new("/tmp/update.bin");
assert_eq!(installer.detect_format(path), LinuxPackageFormat::AppImage);
}
#[test]
fn test_appcast_skips_invalid_versions() {
let xml = r#"<rss><channel>
<item>
<enclosure url="https://example.com/app.zip"
sparkle:version="not-a-version" />
</item>
<item>
<enclosure url="https://example.com/app2.zip"
sparkle:version="1.0.0" />
</item>
</channel></rss>"#;
let updates = parse_update_feed(xml).unwrap();
assert_eq!(updates.len(), 1);
assert_eq!(updates[0].version, SemanticVersion::new(1, 0, 0));
}
}