1use anyhow::{Context, Result};
2use check_updates_core::{PackageInfo, Version};
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::str::FromStr;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9const NPM_REGISTRY: &str = "https://registry.npmjs.org";
10
11#[derive(Debug, Deserialize)]
12struct NpmPackageResponse {
13 name: String,
14 #[serde(rename = "dist-tags")]
15 dist_tags: HashMap<String, String>,
16 versions: HashMap<String, serde_json::Value>,
17}
18
19#[derive(Clone)]
20pub struct NpmClient {
21 client: reqwest::Client,
22 include_prerelease: bool,
23}
24
25impl NpmClient {
26 pub fn new(include_prerelease: bool) -> Self {
27 Self {
28 client: reqwest::Client::new(),
29 include_prerelease,
30 }
31 }
32
33 pub async fn get_package(&self, name: &str) -> Result<PackageInfo> {
35 let url = format!("{NPM_REGISTRY}/{name}");
36
37 let response = self
38 .client
39 .get(&url)
40 .header("Accept", "application/json")
41 .send()
42 .await
43 .with_context(|| format!("Failed to fetch package: {name}"))?;
44
45 if response.status() == reqwest::StatusCode::NOT_FOUND {
46 anyhow::bail!("Package '{name}' not found on npm");
47 }
48
49 let data: NpmPackageResponse = response
50 .json()
51 .await
52 .with_context(|| format!("Failed to parse npm response for: {name}"))?;
53
54 let mut versions: Vec<Version> = data
55 .versions
56 .keys()
57 .filter_map(|v| Version::from_str(v).ok())
58 .filter(|v| self.include_prerelease || !v.is_prerelease())
59 .collect();
60
61 versions.sort();
62
63 let latest = data
64 .dist_tags
65 .get("latest")
66 .and_then(|v| Version::from_str(v).ok())
67 .unwrap_or_else(|| versions.last().cloned().unwrap_or_else(|| Version::new(0, 0, 0)));
68
69 let latest_stable = versions.iter().rfind(|v| !v.is_prerelease()).cloned();
70
71 Ok(PackageInfo {
72 name: data.name,
73 versions,
74 latest,
75 latest_stable,
76 })
77 }
78
79 pub async fn get_packages(
81 &self,
82 names: &[String],
83 progress_callback: impl Fn(usize, usize) + Send + Sync + 'static,
84 ) -> Vec<(String, Result<PackageInfo>)> {
85 let total = names.len();
86 let progress_callback = Arc::new(progress_callback);
87 let semaphore = Arc::new(Semaphore::new(10));
88
89 let mut tasks = Vec::new();
90
91 for name in names {
92 let client = self.clone();
93 let name = name.clone();
94 let semaphore = Arc::clone(&semaphore);
95
96 let task = tokio::spawn(async move {
97 let _permit = semaphore.acquire().await.expect("semaphore closed");
98 let result = client.get_package(&name).await;
99 (name, result)
100 });
101
102 tasks.push(task);
103 }
104
105 let mut results = Vec::new();
106 for (i, task) in tasks.into_iter().enumerate() {
107 match task.await {
108 Ok(result) => results.push(result),
109 Err(e) => results.push(("unknown".to_string(), Err(anyhow::anyhow!("Task failed: {e}")))),
110 }
111 progress_callback(i + 1, total);
112 }
113
114 results
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[tokio::test]
123 async fn test_get_package_express() {
124 let client = NpmClient::new(false);
125 let result = client.get_package("express").await;
126 assert!(result.is_ok());
127 let info = result.expect("should succeed");
128 assert_eq!(info.name, "express");
129 assert!(!info.versions.is_empty());
130 }
131
132 #[tokio::test]
133 async fn test_get_package_not_found() {
134 let client = NpmClient::new(false);
135 let result = client
136 .get_package("this-package-definitely-does-not-exist-12345")
137 .await;
138 assert!(result.is_err());
139 }
140}