use std::collections::HashMap;
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::error::Error;
const GITHUB_RELEASE_BASE: &str = "https://github.com/kreuzberg-dev/tree-sitter-language-pack/releases/download";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParserManifest {
pub version: String,
pub platforms: HashMap<String, PlatformBundle>,
pub languages: HashMap<String, LanguageInfo>,
pub groups: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlatformBundle {
pub url: String,
pub sha256: String,
pub size: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageInfo {
pub group: String,
pub size: u64,
}
pub struct DownloadManager {
version: String,
cache_dir: PathBuf,
manifest: Option<ParserManifest>,
}
impl DownloadManager {
pub fn new(version: &str) -> Result<Self, Error> {
let cache_dir = Self::default_cache_dir(version)?;
Ok(Self {
version: version.to_string(),
cache_dir,
manifest: None,
})
}
pub fn with_cache_dir(version: &str, cache_dir: PathBuf) -> Self {
Self {
version: version.to_string(),
cache_dir,
manifest: None,
}
}
pub fn default_cache_dir(version: &str) -> Result<PathBuf, Error> {
let base = dirs::cache_dir()
.ok_or_else(|| Error::Download("Could not determine system cache directory".to_string()))?;
Ok(base
.join("tree-sitter-language-pack")
.join(format!("v{version}"))
.join("libs"))
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
pub fn installed_languages(&self) -> Vec<String> {
let mut langs = Vec::new();
if let Ok(entries) = fs::read_dir(&self.cache_dir) {
for entry in entries.flatten() {
let name = entry.file_name().to_string_lossy().to_string();
if let Some(lang) = Self::lang_from_lib_filename(&name) {
langs.push(lang);
}
}
}
langs.sort();
langs
}
fn lang_from_lib_filename(filename: &str) -> Option<String> {
let name = filename.strip_prefix("lib").unwrap_or(filename);
let name = name
.strip_prefix("tree_sitter_")
.or_else(|| name.strip_prefix("tree-sitter-"))?;
let name = name
.strip_suffix(".so")
.or_else(|| name.strip_suffix(".dylib"))
.or_else(|| name.strip_suffix(".dll"))?;
Some(name.to_string())
}
pub fn ensure_languages(&mut self, names: &[&str]) -> Result<(), Error> {
let missing: Vec<&str> = names.iter().filter(|name| !self.is_cached(name)).copied().collect();
if missing.is_empty() {
return Ok(());
}
if self.manifest.is_none() {
self.manifest = Some(self.fetch_manifest()?);
}
let manifest = self.manifest.as_ref().expect("manifest loaded above");
for name in &missing {
if !manifest.languages.contains_key(*name) {
return Err(Error::Download(format!(
"Language '{}' not available for download. Available groups: {:?}",
name,
manifest.groups.keys().collect::<Vec<_>>()
)));
}
}
let platform_key = Self::platform_key();
let bundle = manifest.platforms.get(&platform_key).ok_or_else(|| {
Error::Download(format!(
"No pre-built parsers available for platform '{}'. Available: {:?}",
platform_key,
manifest.platforms.keys().collect::<Vec<_>>()
))
})?;
let archive_data = self.download_bundle(&bundle.url)?;
let actual_hash = Self::sha256_hex(&archive_data);
if actual_hash != bundle.sha256 {
return Err(Error::ChecksumMismatch {
file: bundle.url.clone(),
expected: bundle.sha256.clone(),
actual: actual_hash,
});
}
self.extract_languages(&archive_data, &missing)?;
Ok(())
}
pub fn ensure_group(&mut self, group: &str) -> Result<(), Error> {
if self.manifest.is_none() {
self.manifest = Some(self.fetch_manifest()?);
}
let manifest = self.manifest.as_ref().expect("manifest loaded above");
let langs = manifest.groups.get(group).ok_or_else(|| {
Error::Download(format!(
"Group '{}' not found. Available: {:?}",
group,
manifest.groups.keys().collect::<Vec<_>>()
))
})?;
let lang_names: Vec<String> = langs.clone();
let names: Vec<&str> = lang_names.iter().map(String::as_str).collect();
self.ensure_languages(&names)
}
fn is_cached(&self, name: &str) -> bool {
self.lib_path(name).exists()
}
pub fn lib_path(&self, name: &str) -> PathBuf {
let lib_name = format!("tree_sitter_{name}");
let (prefix, ext) = if cfg!(target_os = "macos") {
("lib", "dylib")
} else if cfg!(target_os = "windows") {
("", "dll")
} else {
("lib", "so")
};
self.cache_dir.join(format!("{prefix}{lib_name}.{ext}"))
}
pub fn fetch_manifest(&mut self) -> Result<ParserManifest, Error> {
let manifest_path = self.cache_dir.parent().map(|p| p.join("manifest.json"));
if let Some(ref path) = manifest_path
&& path.exists()
{
let data = fs::read_to_string(path)?;
let manifest: ParserManifest = serde_json::from_str(&data)?;
if manifest.version == self.version {
return Ok(manifest);
}
}
let url = format!("{}/v{}/parsers.json", GITHUB_RELEASE_BASE, self.version);
let body = ureq::get(&url)
.call()
.map_err(|e| Error::Download(format!("Failed to fetch manifest from {}: {}", url, e)))?
.into_body()
.read_to_string()
.map_err(|e| Error::Download(format!("Failed to read manifest body: {}", e)))?;
let manifest: ParserManifest = serde_json::from_str(&body)?;
if let Some(ref path) = manifest_path {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, &body)?;
}
Ok(manifest)
}
fn download_bundle(&self, url: &str) -> Result<Vec<u8>, Error> {
let response = ureq::get(url)
.call()
.map_err(|e| Error::Download(format!("Failed to download {}: {}", url, e)))?;
let mut data = Vec::new();
response
.into_body()
.into_reader()
.read_to_end(&mut data)
.map_err(|e| Error::Download(format!("Failed to read download body: {}", e)))?;
Ok(data)
}
fn extract_languages(&self, archive_data: &[u8], names: &[&str]) -> Result<(), Error> {
fs::create_dir_all(&self.cache_dir)?;
let decoder = zstd::Decoder::new(archive_data)
.map_err(|e| Error::Download(format!("Failed to decompress archive: {}", e)))?;
let mut archive = tar::Archive::new(decoder);
let expected_files: HashMap<String, &str> = names
.iter()
.map(|name| {
let filename = self
.lib_path(name)
.file_name()
.expect("lib_path always has a filename")
.to_string_lossy()
.to_string();
(filename, *name)
})
.collect();
for entry in archive
.entries()
.map_err(|e| Error::Download(format!("Failed to read archive entries: {}", e)))?
{
let mut entry = entry.map_err(|e| Error::Download(format!("Failed to read archive entry: {}", e)))?;
let path = entry
.path()
.map_err(|e| Error::Download(format!("Failed to read entry path: {}", e)))?;
let filename = path
.file_name()
.map(|f| f.to_string_lossy().to_string())
.unwrap_or_default();
if expected_files.contains_key(&filename) {
let dest = self.cache_dir.join(&filename);
entry
.unpack(&dest)
.map_err(|e| Error::Download(format!("Failed to extract {}: {}", filename, e)))?;
}
}
Ok(())
}
pub fn clean_cache(&self) -> Result<(), Error> {
if self.cache_dir.exists() {
fs::remove_dir_all(&self.cache_dir)?;
}
Ok(())
}
fn sha256_hex(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
fn platform_key() -> String {
let os = if cfg!(target_os = "macos") {
"macos"
} else if cfg!(target_os = "windows") {
"windows"
} else {
"linux"
};
let arch = if cfg!(target_arch = "aarch64") {
if cfg!(target_os = "macos") { "arm64" } else { "aarch64" }
} else if cfg!(target_arch = "x86_64") {
"x86_64"
} else {
std::env::consts::ARCH
};
format!("{os}-{arch}")
}
}