1pub mod background;
2pub mod cache;
3pub mod infer_sources;
4pub mod sources;
5
6use std::path::PathBuf;
7
8use crate::constants::DEFAULT_CHECK_INTERVAL_MS;
9use crate::errors::UpdateKitError;
10use crate::types::{CheckMode, UpdateStatus};
11
12use self::cache::{create_cache_entry, is_cache_stale, read_cache, write_cache};
13use self::sources::{FetchOptions, VersionSource, VersionSourceResult};
14
15pub struct CheckUpdateOptions {
17 pub app_name: String,
18 pub current_version: String,
19 pub sources: Vec<Box<dyn VersionSource>>,
20 pub cache_dir: PathBuf,
21 pub check_interval: Option<u64>,
22}
23
24pub async fn check_update(
32 options: &CheckUpdateOptions,
33 mode: CheckMode,
34) -> Result<UpdateStatus, UpdateKitError> {
35 let interval = options
36 .check_interval
37 .unwrap_or(DEFAULT_CHECK_INTERVAL_MS);
38
39 match mode {
40 CheckMode::Blocking => check_blocking(options, interval).await,
41 CheckMode::NonBlocking => check_non_blocking(options, interval),
42 }
43}
44
45async fn check_blocking(
46 options: &CheckUpdateOptions,
47 interval: u64,
48) -> Result<UpdateStatus, UpdateKitError> {
49 for source in &options.sources {
51 let etag = read_cache(&options.cache_dir, &options.app_name)
52 .and_then(|e| e.etag.clone());
53
54 let result = source
55 .fetch_latest(FetchOptions {
56 etag: etag.clone(),
57 })
58 .await;
59
60 match result {
61 VersionSourceResult::Found { info, etag } => {
62 let entry = create_cache_entry(
63 &info.version,
64 &options.current_version,
65 source.name(),
66 etag,
67 info.release_url.clone(),
68 info.release_notes.clone(),
69 );
70 let _ = write_cache(&options.cache_dir, &options.app_name, &entry);
71
72 return Ok(compare_versions(
73 &options.current_version,
74 &info.version,
75 info.release_url,
76 info.release_notes,
77 info.assets,
78 ));
79 }
80 VersionSourceResult::NotModified { etag: _ } => {
81 if let Some(cached) = read_cache(&options.cache_dir, &options.app_name) {
83 if !is_cache_stale(&cached, interval) {
84 return Ok(compare_versions(
85 &options.current_version,
86 &cached.latest_version,
87 cached.release_url,
88 cached.release_notes,
89 None,
90 ));
91 }
92 }
93 continue;
95 }
96 VersionSourceResult::Error { .. } => {
97 continue;
99 }
100 }
101 }
102
103 Ok(UpdateStatus::Unknown {
104 reason: "All sources failed".into(),
105 cached_latest: None,
106 })
107}
108
109fn check_non_blocking(
110 options: &CheckUpdateOptions,
111 interval: u64,
112) -> Result<UpdateStatus, UpdateKitError> {
113 match read_cache(&options.cache_dir, &options.app_name) {
114 Some(cached) => {
115 if is_cache_stale(&cached, interval) {
116 let _ = try_spawn_background(&options.app_name);
118 Ok(compare_versions(
119 &options.current_version,
120 &cached.latest_version,
121 cached.release_url,
122 cached.release_notes,
123 None,
124 ))
125 } else {
126 Ok(compare_versions(
128 &options.current_version,
129 &cached.latest_version,
130 cached.release_url,
131 cached.release_notes,
132 None,
133 ))
134 }
135 }
136 None => {
137 let _ = try_spawn_background(&options.app_name);
139 Ok(UpdateStatus::Unknown {
140 reason: "No cached data available, background check spawned".into(),
141 cached_latest: None,
142 })
143 }
144 }
145}
146
147fn try_spawn_background(_app_name: &str) -> Result<(), UpdateKitError> {
150 let exe = std::env::current_exe().map_err(|e| {
151 UpdateKitError::CommandSpawnFailed(format!("Cannot determine current exe: {}", e))
152 })?;
153 background::spawn_background_check(&exe, "{}")
155}
156
157fn compare_versions(
159 current: &str,
160 latest: &str,
161 release_url: Option<String>,
162 release_notes: Option<String>,
163 assets: Option<Vec<crate::types::AssetInfo>>,
164) -> UpdateStatus {
165 let current_normalized = normalize_version(current);
166 let latest_normalized = normalize_version(latest);
167
168 match (current_normalized, latest_normalized) {
169 (Some(cur), Some(lat)) => {
170 if lat > cur {
171 UpdateStatus::Available {
172 current: cur.to_string(),
173 latest: lat.to_string(),
174 release_url,
175 release_notes,
176 assets,
177 }
178 } else {
179 UpdateStatus::UpToDate {
180 current: cur.to_string(),
181 }
182 }
183 }
184 _ => {
185 if current == latest {
187 UpdateStatus::UpToDate {
188 current: current.to_string(),
189 }
190 } else {
191 UpdateStatus::Available {
192 current: current.to_string(),
193 latest: latest.to_string(),
194 release_url,
195 release_notes,
196 assets,
197 }
198 }
199 }
200 }
201}
202
203pub fn normalize_version(version: &str) -> Option<semver::Version> {
206 let stripped = version.strip_prefix('v').unwrap_or(version);
207 semver::Version::parse(stripped).ok()
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use self::sources::VersionInfo;
214 use std::sync::Mutex;
215
216 struct MockSource {
217 name_val: &'static str,
218 result: Mutex<Option<VersionSourceResult>>,
219 }
220
221 impl MockSource {
222 fn new(name: &'static str, result: VersionSourceResult) -> Self {
223 Self {
224 name_val: name,
225 result: Mutex::new(Some(result)),
226 }
227 }
228 }
229
230 #[async_trait::async_trait]
231 impl VersionSource for MockSource {
232 fn name(&self) -> &str {
233 self.name_val
234 }
235 async fn fetch_latest(&self, _options: FetchOptions) -> VersionSourceResult {
236 self.result
237 .lock()
238 .unwrap()
239 .take()
240 .unwrap_or(VersionSourceResult::Error {
241 reason: "Already consumed".into(),
242 status: None,
243 })
244 }
245 }
246
247 #[test]
248 fn test_normalize_version_strips_v() {
249 let v = normalize_version("v1.2.3").unwrap();
250 assert_eq!(v, semver::Version::new(1, 2, 3));
251 }
252
253 #[test]
254 fn test_normalize_version_without_v() {
255 let v = normalize_version("1.2.3").unwrap();
256 assert_eq!(v, semver::Version::new(1, 2, 3));
257 }
258
259 #[test]
260 fn test_normalize_version_invalid() {
261 assert!(normalize_version("not-a-version").is_none());
262 }
263
264 #[test]
265 fn test_normalize_version_empty() {
266 assert!(normalize_version("").is_none());
267 }
268
269 #[test]
270 fn test_compare_versions_available() {
271 let status = compare_versions("1.0.0", "2.0.0", None, None, None);
272 assert!(matches!(status, UpdateStatus::Available { .. }));
273 }
274
275 #[test]
276 fn test_compare_versions_up_to_date() {
277 let status = compare_versions("2.0.0", "2.0.0", None, None, None);
278 assert!(matches!(status, UpdateStatus::UpToDate { .. }));
279 }
280
281 #[test]
282 fn test_compare_versions_newer_current() {
283 let status = compare_versions("3.0.0", "2.0.0", None, None, None);
284 assert!(matches!(status, UpdateStatus::UpToDate { .. }));
285 }
286
287 #[tokio::test]
288 async fn test_check_update_blocking_no_sources() {
289 let options = CheckUpdateOptions {
290 app_name: "test-app".into(),
291 current_version: "1.0.0".into(),
292 sources: vec![],
293 cache_dir: std::env::temp_dir().join("update-kit-test-no-sources"),
294 check_interval: Some(3600000),
295 };
296
297 let result = check_update(&options, CheckMode::Blocking).await.unwrap();
298 assert!(matches!(result, UpdateStatus::Unknown { .. }));
299 }
300
301 #[test]
302 fn test_check_update_non_blocking_no_cache() {
303 let tmp = tempfile::TempDir::new().unwrap();
304 let options = CheckUpdateOptions {
305 app_name: "test-app".into(),
306 current_version: "1.0.0".into(),
307 sources: vec![],
308 cache_dir: tmp.path().to_path_buf(),
309 check_interval: Some(3600000),
310 };
311
312 let result = check_non_blocking(&options, 3600000).unwrap();
313 assert!(matches!(result, UpdateStatus::Unknown { .. }));
314 }
315
316 #[test]
317 fn test_check_update_non_blocking_fresh_cache() {
318 let tmp = tempfile::TempDir::new().unwrap();
319 let entry = cache::create_cache_entry(
320 "2.0.0",
321 "1.0.0",
322 "github",
323 None,
324 None,
325 None,
326 );
327 cache::write_cache(tmp.path(), "test-app", &entry).unwrap();
328
329 let options = CheckUpdateOptions {
330 app_name: "test-app".into(),
331 current_version: "1.0.0".into(),
332 sources: vec![],
333 cache_dir: tmp.path().to_path_buf(),
334 check_interval: Some(3600000),
335 };
336
337 let result = check_non_blocking(&options, 3600000).unwrap();
338 assert!(matches!(result, UpdateStatus::Available { .. }));
339 if let UpdateStatus::Available { latest, .. } = &result {
340 assert_eq!(latest, "2.0.0");
341 }
342 }
343
344 #[tokio::test]
345 async fn blocking_first_source_succeeds() {
346 let tmp = tempfile::TempDir::new().unwrap();
347 let source = MockSource::new(
348 "github",
349 VersionSourceResult::Found {
350 info: VersionInfo {
351 version: "2.0.0".into(),
352 release_url: Some("https://example.com".into()),
353 release_notes: Some("notes".into()),
354 assets: None,
355 published_at: None,
356 },
357 etag: Some("etag1".into()),
358 },
359 );
360
361 let options = CheckUpdateOptions {
362 app_name: "test-app".into(),
363 current_version: "1.0.0".into(),
364 sources: vec![Box::new(source)],
365 cache_dir: tmp.path().to_path_buf(),
366 check_interval: Some(3600000),
367 };
368
369 let result = check_update(&options, CheckMode::Blocking).await.unwrap();
370 assert!(matches!(result, UpdateStatus::Available { .. }));
371 if let UpdateStatus::Available { latest, .. } = &result {
372 assert_eq!(latest, "2.0.0");
373 }
374 }
375
376 #[tokio::test]
377 async fn blocking_first_fails_second_succeeds() {
378 let tmp = tempfile::TempDir::new().unwrap();
379 let source1 = MockSource::new(
380 "github",
381 VersionSourceResult::Error {
382 reason: "Rate limited".into(),
383 status: Some(429),
384 },
385 );
386 let source2 = MockSource::new(
387 "npm",
388 VersionSourceResult::Found {
389 info: VersionInfo {
390 version: "2.0.0".into(),
391 release_url: None,
392 release_notes: None,
393 assets: None,
394 published_at: None,
395 },
396 etag: None,
397 },
398 );
399
400 let options = CheckUpdateOptions {
401 app_name: "test-app".into(),
402 current_version: "1.0.0".into(),
403 sources: vec![Box::new(source1), Box::new(source2)],
404 cache_dir: tmp.path().to_path_buf(),
405 check_interval: Some(3600000),
406 };
407
408 let result = check_update(&options, CheckMode::Blocking).await.unwrap();
409 assert!(matches!(result, UpdateStatus::Available { .. }));
410 }
411
412 #[tokio::test]
413 async fn blocking_all_fail_returns_unknown() {
414 let tmp = tempfile::TempDir::new().unwrap();
415 let source = MockSource::new(
416 "github",
417 VersionSourceResult::Error {
418 reason: "Network error".into(),
419 status: None,
420 },
421 );
422
423 let options = CheckUpdateOptions {
424 app_name: "test-app".into(),
425 current_version: "1.0.0".into(),
426 sources: vec![Box::new(source)],
427 cache_dir: tmp.path().to_path_buf(),
428 check_interval: Some(3600000),
429 };
430
431 let result = check_update(&options, CheckMode::Blocking).await.unwrap();
432 assert!(matches!(result, UpdateStatus::Unknown { .. }));
433 }
434
435 #[tokio::test]
436 async fn blocking_writes_cache_on_success() {
437 let tmp = tempfile::TempDir::new().unwrap();
438 let source = MockSource::new(
439 "github",
440 VersionSourceResult::Found {
441 info: VersionInfo {
442 version: "3.0.0".into(),
443 release_url: None,
444 release_notes: None,
445 assets: None,
446 published_at: None,
447 },
448 etag: None,
449 },
450 );
451
452 let options = CheckUpdateOptions {
453 app_name: "test-app".into(),
454 current_version: "1.0.0".into(),
455 sources: vec![Box::new(source)],
456 cache_dir: tmp.path().to_path_buf(),
457 check_interval: Some(3600000),
458 };
459
460 check_update(&options, CheckMode::Blocking).await.unwrap();
461 let cached = cache::read_cache(tmp.path(), "test-app");
463 assert!(cached.is_some());
464 assert_eq!(cached.unwrap().latest_version, "3.0.0");
465 }
466
467 #[tokio::test]
468 async fn blocking_up_to_date_when_same_version() {
469 let tmp = tempfile::TempDir::new().unwrap();
470 let source = MockSource::new(
471 "npm",
472 VersionSourceResult::Found {
473 info: VersionInfo {
474 version: "1.0.0".into(),
475 release_url: None,
476 release_notes: None,
477 assets: None,
478 published_at: None,
479 },
480 etag: None,
481 },
482 );
483
484 let options = CheckUpdateOptions {
485 app_name: "test-app".into(),
486 current_version: "1.0.0".into(),
487 sources: vec![Box::new(source)],
488 cache_dir: tmp.path().to_path_buf(),
489 check_interval: Some(3600000),
490 };
491
492 let result = check_update(&options, CheckMode::Blocking).await.unwrap();
493 assert!(matches!(result, UpdateStatus::UpToDate { .. }));
494 }
495
496 #[test]
497 fn compare_with_v_prefix() {
498 let status = compare_versions("v1.0.0", "v2.0.0", None, None, None);
499 assert!(matches!(status, UpdateStatus::Available { .. }));
500 }
501
502 #[test]
503 fn compare_prerelease_less_than_release() {
504 let status = compare_versions("1.0.0-beta.1", "1.0.0", None, None, None);
505 assert!(matches!(status, UpdateStatus::Available { .. }));
506 }
507
508 #[test]
509 fn compare_same_invalid_semver() {
510 let status = compare_versions("abc", "abc", None, None, None);
511 assert!(matches!(status, UpdateStatus::UpToDate { .. }));
512 }
513
514 #[test]
515 fn compare_different_invalid_semver() {
516 let status = compare_versions("abc", "def", None, None, None);
517 assert!(matches!(status, UpdateStatus::Available { .. }));
518 }
519
520 #[test]
521 fn compare_includes_release_url_and_notes() {
522 let status = compare_versions(
523 "1.0.0",
524 "2.0.0",
525 Some("https://url".into()),
526 Some("notes".into()),
527 None,
528 );
529 if let UpdateStatus::Available {
530 release_url,
531 release_notes,
532 ..
533 } = &status
534 {
535 assert_eq!(release_url.as_deref(), Some("https://url"));
536 assert_eq!(release_notes.as_deref(), Some("notes"));
537 } else {
538 panic!("Expected Available");
539 }
540 }
541}