use serde_json::Value;
use std::collections::HashMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct Dependency {
pub name: String,
pub version: String,
pub is_git: bool,
}
pub fn parse_dependencies(
package_json_path: &Path,
) -> Result<HashMap<String, Dependency>, Box<dyn std::error::Error>> {
let content = fs::read_to_string(package_json_path)?;
let json: Value = serde_json::from_str(&content)?;
let deps = json
.get("dependencies")
.and_then(|d| d.as_object())
.ok_or("no dependencies section found in package.json")?;
let mut dependencies = HashMap::new();
for (name, value) in deps {
if let Some(version_str) = value.as_str() {
let is_git = version_str.contains("github.com") || version_str.starts_with("git");
let version = extract_version(version_str);
validate_package_name(name)?;
validate_version(&version)?;
dependencies.insert(
name.clone(),
Dependency {
name: name.clone(),
version,
is_git,
},
);
}
}
Ok(dependencies)
}
fn validate_package_name(name: &str) -> Result<(), Box<dyn std::error::Error>> {
if name.is_empty() || name.len() > 200 {
return Err(format!("package name {name:?} has invalid length").into());
}
if name.contains("..") {
return Err(format!("package name {name:?} contains '..'").into());
}
if !name
.bytes()
.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b'_' | b'@' | b'/'))
{
return Err(format!("package name {name:?} contains disallowed characters").into());
}
Ok(())
}
fn validate_version(version: &str) -> Result<(), Box<dyn std::error::Error>> {
if version.is_empty() || version.len() > 100 {
return Err(format!("version {version:?} has invalid length").into());
}
if version.contains("..") {
return Err(format!("version {version:?} contains '..'").into());
}
if !version
.bytes()
.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b'+' | b'_'))
{
return Err(format!("version {version:?} contains disallowed characters").into());
}
Ok(())
}
fn extract_version(value: &str) -> String {
if value.contains("github.com") || value.starts_with("git") {
if let Some(hash_pos) = value.rfind('#') {
return value[hash_pos + 1..].to_string();
}
}
value
.trim_start_matches('^')
.trim_start_matches('~')
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn parses_pinned_caret_and_git_specs() {
let tmp = tempdir().unwrap();
let p = tmp.path().join("package.json");
fs::write(
&p,
r#"{ "dependencies": {
"lit": "3.3.3",
"bootstrap": "^5.3.8",
"forked": "github:owner/repo#abc123"
} }"#,
)
.unwrap();
let deps = parse_dependencies(&p).unwrap();
assert_eq!(deps["lit"].version, "3.3.3");
assert!(!deps["lit"].is_git);
assert_eq!(deps["bootstrap"].version, "5.3.8");
assert_eq!(deps["forked"].version, "abc123");
assert!(deps["forked"].is_git);
}
}