Skip to main content

ncu/
npm.rs

1use anyhow::{Context, Result};
2use check_updates_core::{PackageInfo, Version};
3use serde::Deserialize;
4use std::collections::HashMap;
5use std::str::FromStr;
6use std::sync::Arc;
7use tokio::sync::Semaphore;
8
9const NPM_REGISTRY: &str = "https://registry.npmjs.org";
10
11#[derive(Debug, Deserialize)]
12struct NpmPackageResponse {
13    name: String,
14    #[serde(rename = "dist-tags")]
15    dist_tags: HashMap<String, String>,
16    versions: HashMap<String, serde_json::Value>,
17}
18
19#[derive(Clone)]
20pub struct NpmClient {
21    client: reqwest::Client,
22    include_prerelease: bool,
23}
24
25impl NpmClient {
26    pub fn new(include_prerelease: bool) -> Self {
27        Self {
28            client: reqwest::Client::new(),
29            include_prerelease,
30        }
31    }
32
33    /// Get package info from npm registry
34    pub async fn get_package(&self, name: &str) -> Result<PackageInfo> {
35        let url = format!("{NPM_REGISTRY}/{name}");
36
37        let response = self
38            .client
39            .get(&url)
40            .header("Accept", "application/json")
41            .send()
42            .await
43            .with_context(|| format!("Failed to fetch package: {name}"))?;
44
45        if response.status() == reqwest::StatusCode::NOT_FOUND {
46            anyhow::bail!("Package '{name}' not found on npm");
47        }
48
49        let data: NpmPackageResponse = response
50            .json()
51            .await
52            .with_context(|| format!("Failed to parse npm response for: {name}"))?;
53
54        let mut versions: Vec<Version> = data
55            .versions
56            .keys()
57            .filter_map(|v| Version::from_str(v).ok())
58            .filter(|v| self.include_prerelease || !v.is_prerelease())
59            .collect();
60
61        versions.sort();
62
63        let latest = data
64            .dist_tags
65            .get("latest")
66            .and_then(|v| Version::from_str(v).ok())
67            .unwrap_or_else(|| versions.last().cloned().unwrap_or_else(|| Version::new(0, 0, 0)));
68
69        let latest_stable = versions.iter().rfind(|v| !v.is_prerelease()).cloned();
70
71        Ok(PackageInfo {
72            name: data.name,
73            versions,
74            latest,
75            latest_stable,
76        })
77    }
78
79    /// Get multiple packages concurrently with progress callback and rate limiting
80    pub async fn get_packages(
81        &self,
82        names: &[String],
83        progress_callback: impl Fn(usize, usize) + Send + Sync + 'static,
84    ) -> Vec<(String, Result<PackageInfo>)> {
85        let total = names.len();
86        let progress_callback = Arc::new(progress_callback);
87        let semaphore = Arc::new(Semaphore::new(10));
88
89        let mut tasks = Vec::new();
90
91        for name in names {
92            let client = self.clone();
93            let name = name.clone();
94            let semaphore = Arc::clone(&semaphore);
95
96            let task = tokio::spawn(async move {
97                let _permit = semaphore.acquire().await.expect("semaphore closed");
98                let result = client.get_package(&name).await;
99                (name, result)
100            });
101
102            tasks.push(task);
103        }
104
105        let mut results = Vec::new();
106        for (i, task) in tasks.into_iter().enumerate() {
107            match task.await {
108                Ok(result) => results.push(result),
109                Err(e) => results.push(("unknown".to_string(), Err(anyhow::anyhow!("Task failed: {e}")))),
110            }
111            progress_callback(i + 1, total);
112        }
113
114        results
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[tokio::test]
123    async fn test_get_package_express() {
124        let client = NpmClient::new(false);
125        let result = client.get_package("express").await;
126        assert!(result.is_ok());
127        let info = result.expect("should succeed");
128        assert_eq!(info.name, "express");
129        assert!(!info.versions.is_empty());
130    }
131
132    #[tokio::test]
133    async fn test_get_package_not_found() {
134        let client = NpmClient::new(false);
135        let result = client
136            .get_package("this-package-definitely-does-not-exist-12345")
137            .await;
138        assert!(result.is_err());
139    }
140}