use std::io::Read;
use std::path::{Path, PathBuf};
use futures_util::StreamExt;
use tracing::{debug, info, instrument, warn};
use crate::types::CompressionFormat;
use crate::{Entry, Error, LangCode, Result, Store};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DownloadPhase {
Downloading,
Importing,
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub phase: DownloadPhase,
pub total_bytes: Option<u64>,
pub downloaded_bytes: u64,
pub current_file: String,
pub total_files: usize,
pub current_file_index: usize,
}
const UNIMORPH_RAW_URL: &str = "https://raw.githubusercontent.com/unimorph";
const UNIMORPH_LFS_URL: &str = "https://media.githubusercontent.com/media/unimorph";
const GIT_LFS_PREFIX: &[u8] = b"version https://git-lfs.github.com/spec/v1";
fn is_lfs_pointer(bytes: &[u8]) -> bool {
bytes.starts_with(GIT_LFS_PREFIX)
}
#[derive(Debug)]
struct DownloadResult {
content: String,
filenames: Vec<String>,
compression: CompressionFormat,
from_lfs: bool,
}
fn detect_compression(filename: &str) -> CompressionFormat {
if filename.ends_with(".xz") {
CompressionFormat::Xz
} else if filename.ends_with(".gz") {
CompressionFormat::Gzip
} else if filename.ends_with(".zip") {
CompressionFormat::Zip
} else {
CompressionFormat::None
}
}
fn decompress_content(filename: &str, bytes: &[u8]) -> Result<String> {
if filename.ends_with(".xz") {
debug!(filename, "decompressing XZ/LZMA content");
let mut decoder = xz2::read::XzDecoder::new(bytes);
let mut content = String::new();
decoder
.read_to_string(&mut content)
.map_err(|e| Error::DecompressionFailed(format!("XZ decompression failed: {}", e)))?;
Ok(content)
} else if filename.ends_with(".gz") {
debug!(filename, "decompressing gzip content");
let mut decoder = flate2::read::GzDecoder::new(bytes);
let mut content = String::new();
decoder
.read_to_string(&mut content)
.map_err(|e| Error::DecompressionFailed(format!("gzip decompression failed: {}", e)))?;
Ok(content)
} else if filename.ends_with(".zip") {
debug!(filename, "extracting ZIP content");
let cursor = std::io::Cursor::new(bytes);
let mut archive = zip::ZipArchive::new(cursor)
.map_err(|e| Error::DecompressionFailed(format!("ZIP archive error: {}", e)))?;
if archive.is_empty() {
return Err(Error::DecompressionFailed(
"ZIP archive is empty".to_string(),
));
}
let mut file = archive
.by_index(0)
.map_err(|e| Error::DecompressionFailed(format!("ZIP extraction error: {}", e)))?;
let mut content = String::new();
file.read_to_string(&mut content)
.map_err(|e| Error::DecompressionFailed(format!("ZIP read error: {}", e)))?;
Ok(content)
} else {
String::from_utf8(bytes.to_vec())
.map_err(|e| Error::DecompressionFailed(format!("UTF-8 conversion failed: {}", e)))
}
}
pub struct Repository {
cache_dir: PathBuf,
store: Store,
}
impl Repository {
#[instrument(level = "debug")]
pub fn new() -> Result<Self> {
let cache_dir = dirs::cache_dir()
.ok_or_else(|| Error::CacheDir {
path: PathBuf::from("~/.cache"),
reason: "could not determine cache directory".to_string(),
})?
.join("unimorph");
debug!(cache_dir = %cache_dir.display(), "using default cache directory");
Self::with_cache_dir(cache_dir)
}
pub fn with_cache_dir<P: AsRef<Path>>(cache_dir: P) -> Result<Self> {
let cache_dir = cache_dir.as_ref().to_path_buf();
std::fs::create_dir_all(&cache_dir).map_err(|e| Error::CacheDir {
path: cache_dir.clone(),
reason: e.to_string(),
})?;
let db_path = cache_dir.join("datasets.db");
let store = Store::open(&db_path)?;
Ok(Self { cache_dir, store })
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
pub fn store(&self) -> &Store {
&self.store
}
pub fn store_mut(&mut self) -> &mut Store {
&mut self.store
}
#[instrument(level = "info", skip(self))]
pub async fn ensure(&mut self, lang: &str) -> Result<bool> {
let lang_code = LangCode::new(lang)?;
if self.store.has_language(lang)? {
debug!(lang, "language already cached");
return Ok(false);
}
info!(lang, "downloading language dataset");
self.download_and_import(&lang_code).await?;
Ok(true)
}
#[instrument(level = "info", skip(self))]
pub async fn refresh(&mut self, lang: &str) -> Result<()> {
let lang_code = LangCode::new(lang)?;
info!(lang, "refreshing language dataset");
self.download_and_import(&lang_code).await
}
#[instrument(level = "info", skip(self, on_progress))]
pub async fn refresh_with_progress<F>(&mut self, lang: &str, on_progress: F) -> Result<()>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let lang_code = LangCode::new(lang)?;
info!(lang, "refreshing language dataset with progress");
self.download_and_import_with_progress(&lang_code, on_progress)
.await
}
#[instrument(level = "info", skip(self, on_progress))]
pub async fn ensure_with_progress<F>(&mut self, lang: &str, on_progress: F) -> Result<bool>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let lang_code = LangCode::new(lang)?;
if self.store.has_language(lang)? {
debug!(lang, "language already cached");
return Ok(false);
}
info!(lang, "downloading language dataset with progress");
self.download_and_import_with_progress(&lang_code, on_progress)
.await?;
Ok(true)
}
#[instrument(level = "debug", skip(self))]
async fn download_and_import(&mut self, lang: &LangCode) -> Result<()> {
let commit_sha = fetch_commit_sha(lang).await.ok();
debug!(lang = %lang, commit_sha = ?commit_sha, "fetched commit SHA");
let download = download_language(lang).await?;
let (entries, mut report) = Entry::parse_tsv_with_report(&download.content);
report.compression = download.compression;
report.from_lfs = download.from_lfs;
report.filename = Some(download.filenames.join(", "));
info!(
lang = %lang,
filename = ?download.filenames,
compression = %download.compression,
from_lfs = download.from_lfs,
valid_entries = report.valid_entries,
blank_lines = report.blank_lines,
malformed = report.malformed_count,
"parsed downloaded data"
);
if report.malformed_count > 0 {
warn!(
lang = %lang,
malformed = report.malformed_count,
"skipped malformed entries during import"
);
for entry in &report.malformed {
warn!(
lang = %lang,
line = entry.line_num,
reason = %entry.reason,
"malformed entry"
);
}
if report.malformed_count > report.malformed.len() {
warn!(
lang = %lang,
additional = report.malformed_count - report.malformed.len(),
"additional malformed entries not shown"
);
}
}
let source_url = format!("https://github.com/unimorph/{}", lang.as_str());
self.store
.import(lang, &entries, Some(&source_url), commit_sha.as_deref())?;
info!(
lang = %lang,
entries = entries.len(),
commit_sha = ?commit_sha,
"imported language dataset"
);
Ok(())
}
#[instrument(level = "debug", skip(self, on_progress))]
async fn download_and_import_with_progress<F>(
&mut self,
lang: &LangCode,
on_progress: F,
) -> Result<()>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let commit_sha = fetch_commit_sha(lang).await.ok();
debug!(lang = %lang, commit_sha = ?commit_sha, "fetched commit SHA");
let download = download_language_with_progress(lang, &on_progress).await?;
on_progress(DownloadProgress {
phase: DownloadPhase::Importing,
total_bytes: None,
downloaded_bytes: 0,
current_file: String::new(),
total_files: 0,
current_file_index: 0,
});
let (entries, mut report) = Entry::parse_tsv_with_report(&download.content);
report.compression = download.compression;
report.from_lfs = download.from_lfs;
report.filename = Some(download.filenames.join(", "));
info!(
lang = %lang,
filename = ?download.filenames,
compression = %download.compression,
from_lfs = download.from_lfs,
valid_entries = report.valid_entries,
blank_lines = report.blank_lines,
malformed = report.malformed_count,
"parsed downloaded data"
);
if report.malformed_count > 0 {
warn!(
lang = %lang,
malformed = report.malformed_count,
"skipped malformed entries during import"
);
for entry in &report.malformed {
warn!(
lang = %lang,
line = entry.line_num,
reason = %entry.reason,
"malformed entry"
);
}
if report.malformed_count > report.malformed.len() {
warn!(
lang = %lang,
additional = report.malformed_count - report.malformed.len(),
"additional malformed entries not shown"
);
}
}
let source_url = format!("https://github.com/unimorph/{}", lang.as_str());
self.store
.import(lang, &entries, Some(&source_url), commit_sha.as_deref())?;
info!(
lang = %lang,
entries = entries.len(),
commit_sha = ?commit_sha,
"imported language dataset"
);
Ok(())
}
pub fn cached_languages(&self) -> Result<Vec<LangCode>> {
self.store.languages()
}
pub fn delete(&mut self, lang: &str) -> Result<()> {
self.store.delete_language(lang)
}
}
fn get_file_alternatives(lang: &LangCode) -> Vec<Vec<String>> {
match lang.as_str() {
"fin" => vec![vec!["fin.1".to_string(), "fin.2".to_string()]],
_ => vec![
vec![format!("{}.xz", lang.as_str())],
vec![format!("{}.gz", lang.as_str())],
vec![lang.as_str().to_string()],
],
}
}
#[instrument(level = "debug")]
async fn download_language(lang: &LangCode) -> Result<DownloadResult> {
let client = reqwest::Client::new();
let alternatives = get_file_alternatives(lang);
debug!(lang = %lang, alternatives = ?alternatives, "downloading from GitHub");
for files in &alternatives {
match try_download_files(&client, lang, files).await {
Ok(result) => return Ok(result),
Err(e) => {
debug!(files = ?files, error = %e, "alternative failed, trying next");
continue;
}
}
}
Err(Error::DownloadFailed(format!(
"No data files found for language: {}",
lang.as_str()
)))
}
async fn try_download_files(
client: &reqwest::Client,
lang: &LangCode,
files: &[String],
) -> Result<DownloadResult> {
let mut all_content = String::new();
let mut from_lfs = false;
let mut compression = CompressionFormat::None;
for filename in files {
let url = format!("{}/{}/master/{}", UNIMORPH_RAW_URL, lang.as_str(), filename);
debug!(url = %url, "fetching file");
let response = client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::FORBIDDEN {
warn!(lang = %lang, "GitHub rate limit exceeded");
return Err(Error::RateLimited);
}
if response.status() == reqwest::StatusCode::NOT_FOUND {
debug!(url = %url, "file not found");
return Err(Error::DownloadFailed(format!("File not found: {}", url)));
}
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let mut bytes = response.bytes().await?;
debug!(url = %url, bytes = bytes.len(), "downloaded file");
if is_lfs_pointer(&bytes) {
debug!(url = %url, "detected Git LFS pointer, fetching from media endpoint");
let lfs_url = format!("{}/{}/master/{}", UNIMORPH_LFS_URL, lang.as_str(), filename);
let lfs_response = client.get(&lfs_url).send().await?;
if !lfs_response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"LFS fetch failed HTTP {}: {}",
lfs_response.status(),
lfs_url
)));
}
bytes = lfs_response.bytes().await?;
debug!(url = %lfs_url, bytes = bytes.len(), "downloaded LFS file");
from_lfs = true;
}
compression = detect_compression(filename);
let content = decompress_content(filename, &bytes)?;
all_content.push_str(&content);
if !content.ends_with('\n') {
all_content.push('\n');
}
}
Ok(DownloadResult {
content: all_content,
filenames: files.to_vec(),
compression,
from_lfs,
})
}
#[instrument(level = "debug", skip(on_progress))]
async fn download_language_with_progress<F>(
lang: &LangCode,
on_progress: &F,
) -> Result<DownloadResult>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let client = reqwest::Client::new();
let alternatives = get_file_alternatives(lang);
debug!(lang = %lang, alternatives = ?alternatives, "downloading from GitHub with progress");
for files in &alternatives {
match try_download_files_with_progress(&client, lang, files, on_progress).await {
Ok(result) => return Ok(result),
Err(e) => {
debug!(files = ?files, error = %e, "alternative failed, trying next");
continue;
}
}
}
Err(Error::DownloadFailed(format!(
"No data files found for language: {}",
lang.as_str()
)))
}
async fn try_download_files_with_progress<F>(
client: &reqwest::Client,
lang: &LangCode,
files: &[String],
on_progress: &F,
) -> Result<DownloadResult>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let total_files = files.len();
let mut all_content = String::new();
let mut from_lfs = false;
let mut compression = CompressionFormat::None;
for (file_index, filename) in files.iter().enumerate() {
let url = format!("{}/{}/master/{}", UNIMORPH_RAW_URL, lang.as_str(), filename);
debug!(url = %url, "fetching file");
let response = client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::FORBIDDEN {
warn!(lang = %lang, "GitHub rate limit exceeded");
return Err(Error::RateLimited);
}
if response.status() == reqwest::StatusCode::NOT_FOUND {
debug!(url = %url, "file not found");
return Err(Error::DownloadFailed(format!("File not found: {}", url)));
}
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let total_bytes = response.content_length();
let mut downloaded_bytes: u64 = 0;
let mut bytes = Vec::new();
on_progress(DownloadProgress {
phase: DownloadPhase::Downloading,
total_bytes,
downloaded_bytes,
current_file: filename.clone(),
total_files,
current_file_index: file_index + 1,
});
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
downloaded_bytes += chunk.len() as u64;
bytes.extend_from_slice(&chunk);
on_progress(DownloadProgress {
phase: DownloadPhase::Downloading,
total_bytes,
downloaded_bytes,
current_file: filename.clone(),
total_files,
current_file_index: file_index + 1,
});
}
debug!(url = %url, bytes = bytes.len(), "downloaded file");
if is_lfs_pointer(&bytes) {
debug!(url = %url, "detected Git LFS pointer, fetching from media endpoint");
let lfs_url = format!("{}/{}/master/{}", UNIMORPH_LFS_URL, lang.as_str(), filename);
let lfs_response = client.get(&lfs_url).send().await?;
if !lfs_response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"LFS fetch failed HTTP {}: {}",
lfs_response.status(),
lfs_url
)));
}
let lfs_total_bytes = lfs_response.content_length();
downloaded_bytes = 0;
bytes.clear();
on_progress(DownloadProgress {
phase: DownloadPhase::Downloading,
total_bytes: lfs_total_bytes,
downloaded_bytes,
current_file: format!("{} (LFS)", filename),
total_files,
current_file_index: file_index + 1,
});
let mut lfs_stream = lfs_response.bytes_stream();
while let Some(chunk) = lfs_stream.next().await {
let chunk = chunk?;
downloaded_bytes += chunk.len() as u64;
bytes.extend_from_slice(&chunk);
on_progress(DownloadProgress {
phase: DownloadPhase::Downloading,
total_bytes: lfs_total_bytes,
downloaded_bytes,
current_file: format!("{} (LFS)", filename),
total_files,
current_file_index: file_index + 1,
});
}
debug!(url = %lfs_url, bytes = bytes.len(), "downloaded LFS file");
from_lfs = true;
}
compression = detect_compression(filename);
let content = decompress_content(filename, &bytes)?;
all_content.push_str(&content);
if !content.ends_with('\n') {
all_content.push('\n');
}
}
Ok(DownloadResult {
content: all_content,
filenames: files.to_vec(),
compression,
from_lfs,
})
}
#[instrument(level = "debug")]
async fn fetch_commit_sha(lang: &LangCode) -> Result<String> {
let client = reqwest::Client::new();
let url = format!(
"https://api.github.com/repos/unimorph/{}/commits/master",
lang.as_str()
);
debug!(url = %url, "fetching commit SHA");
let response = client
.get(&url)
.header("User-Agent", "unimorph-rs")
.header("Accept", "application/vnd.github.v3+json")
.send()
.await?;
if response.status() == reqwest::StatusCode::FORBIDDEN {
return Err(Error::RateLimited);
}
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"Failed to fetch commit info: HTTP {}",
response.status()
)));
}
let json: serde_json::Value = response.json().await?;
let sha = json["sha"]
.as_str()
.ok_or_else(|| Error::DownloadFailed("No SHA in commit response".to_string()))?
.to_string();
debug!(sha = %sha, "fetched commit SHA");
Ok(sha)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn repository_with_custom_dir() {
let temp_dir = TempDir::new().unwrap();
let repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
assert!(repo.cache_dir().exists());
assert!(repo.cache_dir().join("datasets.db").exists());
}
#[test]
fn cached_languages_empty() {
let temp_dir = TempDir::new().unwrap();
let repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
let langs = repo.cached_languages().unwrap();
assert!(langs.is_empty());
}
#[test]
fn file_alternatives() {
let ita: LangCode = "ita".parse().unwrap();
let fin: LangCode = "fin".parse().unwrap();
let ita_alts = get_file_alternatives(&ita);
assert_eq!(ita_alts.len(), 3);
assert_eq!(ita_alts[0], vec!["ita.xz"]);
assert_eq!(ita_alts[1], vec!["ita.gz"]);
assert_eq!(ita_alts[2], vec!["ita"]);
let fin_alts = get_file_alternatives(&fin);
assert_eq!(fin_alts.len(), 1);
assert_eq!(fin_alts[0], vec!["fin.1", "fin.2"]);
}
#[test]
fn decompress_plain_text() {
let content = b"lemma\tform\tfeatures\n";
let result = decompress_content("test.txt", content).unwrap();
assert_eq!(result, "lemma\tform\tfeatures\n");
}
#[test]
fn detect_lfs_pointer() {
let lfs_content =
b"version https://git-lfs.github.com/spec/v1\noid sha256:abc123\nsize 12345\n";
assert!(is_lfs_pointer(lfs_content));
let normal_content = b"lemma\tform\tfeatures\n";
assert!(!is_lfs_pointer(normal_content));
let xz_magic = b"\xfd7zXZ\x00";
assert!(!is_lfs_pointer(xz_magic));
}
#[test]
fn decompress_gzip() {
use flate2::Compression;
use flate2::write::GzEncoder;
use std::io::Write;
let original = "lemma\tform\tV;IND;PRS\n";
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(original.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
let result = decompress_content("test.gz", &compressed).unwrap();
assert_eq!(result, original);
}
#[test]
fn decompress_xz() {
use std::io::Write;
use xz2::write::XzEncoder;
let original = "lemma\tform\tV;IND;PRS\n";
let mut encoder = XzEncoder::new(Vec::new(), 6);
encoder.write_all(original.as_bytes()).unwrap();
let compressed = encoder.finish().unwrap();
let result = decompress_content("test.xz", &compressed).unwrap();
assert_eq!(result, original);
}
#[tokio::test]
#[ignore = "requires network access"]
async fn download_italian_uncompressed() {
let temp_dir = TempDir::new().unwrap();
let mut repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
let downloaded = repo.ensure("ita").await.unwrap();
assert!(downloaded);
let stats = repo.store().stats("ita").unwrap().unwrap();
assert!(stats.total_entries > 0);
let downloaded_again = repo.ensure("ita").await.unwrap();
assert!(!downloaded_again);
}
#[tokio::test]
#[ignore = "requires network access"]
async fn download_polish_compressed_xz() {
let temp_dir = TempDir::new().unwrap();
let mut repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
let downloaded = repo.ensure("pol").await.unwrap();
assert!(downloaded);
let stats = repo.store().stats("pol").unwrap().unwrap();
assert!(stats.total_entries > 0);
assert!(stats.total_entries > 100_000);
}
#[tokio::test]
#[ignore = "requires network access"]
async fn download_finnish_split_files() {
let temp_dir = TempDir::new().unwrap();
let mut repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
let downloaded = repo.ensure("fin").await.unwrap();
assert!(downloaded);
let stats = repo.store().stats("fin").unwrap().unwrap();
assert!(stats.total_entries > 0);
}
#[tokio::test]
#[ignore = "requires network access"]
async fn download_czech_lfs() {
let temp_dir = TempDir::new().unwrap();
let mut repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
let downloaded = repo.ensure("ces").await.unwrap();
assert!(downloaded);
let stats = repo.store().stats("ces").unwrap().unwrap();
assert!(stats.total_entries > 0);
assert!(stats.total_entries > 1_000_000); }
}