use std::path::{Path, PathBuf};
use futures_util::StreamExt;
use tracing::{debug, info, instrument, warn};
use crate::{Entry, Error, LangCode, Result, Store};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DownloadPhase {
Downloading,
Importing,
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub phase: DownloadPhase,
pub total_bytes: Option<u64>,
pub downloaded_bytes: u64,
pub current_file: String,
pub total_files: usize,
pub current_file_index: usize,
}
const UNIMORPH_RAW_URL: &str = "https://raw.githubusercontent.com/unimorph";
pub struct Repository {
cache_dir: PathBuf,
store: Store,
}
impl Repository {
#[instrument(level = "debug")]
pub fn new() -> Result<Self> {
let cache_dir = dirs::cache_dir()
.ok_or_else(|| Error::CacheDir {
path: PathBuf::from("~/.cache"),
reason: "could not determine cache directory".to_string(),
})?
.join("unimorph");
debug!(cache_dir = %cache_dir.display(), "using default cache directory");
Self::with_cache_dir(cache_dir)
}
pub fn with_cache_dir<P: AsRef<Path>>(cache_dir: P) -> Result<Self> {
let cache_dir = cache_dir.as_ref().to_path_buf();
std::fs::create_dir_all(&cache_dir).map_err(|e| Error::CacheDir {
path: cache_dir.clone(),
reason: e.to_string(),
})?;
let db_path = cache_dir.join("datasets.db");
let store = Store::open(&db_path)?;
Ok(Self { cache_dir, store })
}
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
pub fn store(&self) -> &Store {
&self.store
}
pub fn store_mut(&mut self) -> &mut Store {
&mut self.store
}
#[instrument(level = "info", skip(self))]
pub async fn ensure(&mut self, lang: &str) -> Result<bool> {
let lang_code = LangCode::new(lang)?;
if self.store.has_language(lang)? {
debug!(lang, "language already cached");
return Ok(false);
}
info!(lang, "downloading language dataset");
self.download_and_import(&lang_code).await?;
Ok(true)
}
#[instrument(level = "info", skip(self))]
pub async fn refresh(&mut self, lang: &str) -> Result<()> {
let lang_code = LangCode::new(lang)?;
info!(lang, "refreshing language dataset");
self.download_and_import(&lang_code).await
}
#[instrument(level = "info", skip(self, on_progress))]
pub async fn refresh_with_progress<F>(&mut self, lang: &str, on_progress: F) -> Result<()>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let lang_code = LangCode::new(lang)?;
info!(lang, "refreshing language dataset with progress");
self.download_and_import_with_progress(&lang_code, on_progress)
.await
}
#[instrument(level = "info", skip(self, on_progress))]
pub async fn ensure_with_progress<F>(&mut self, lang: &str, on_progress: F) -> Result<bool>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let lang_code = LangCode::new(lang)?;
if self.store.has_language(lang)? {
debug!(lang, "language already cached");
return Ok(false);
}
info!(lang, "downloading language dataset with progress");
self.download_and_import_with_progress(&lang_code, on_progress)
.await?;
Ok(true)
}
#[instrument(level = "debug", skip(self))]
async fn download_and_import(&mut self, lang: &LangCode) -> Result<()> {
let commit_sha = fetch_commit_sha(lang).await.ok();
debug!(lang = %lang, commit_sha = ?commit_sha, "fetched commit SHA");
let content = download_language(lang).await?;
let (entries, skipped) = Entry::parse_tsv_lenient(&content);
if skipped > 0 {
warn!(
lang = %lang,
skipped,
"skipped malformed entries during import"
);
}
debug!(
lang = %lang,
entries = entries.len(),
"parsed entries from downloaded data"
);
let source_url = format!("https://github.com/unimorph/{}", lang.as_str());
self.store
.import(lang, &entries, Some(&source_url), commit_sha.as_deref())?;
info!(
lang = %lang,
entries = entries.len(),
commit_sha = ?commit_sha,
"imported language dataset"
);
Ok(())
}
#[instrument(level = "debug", skip(self, on_progress))]
async fn download_and_import_with_progress<F>(
&mut self,
lang: &LangCode,
on_progress: F,
) -> Result<()>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let commit_sha = fetch_commit_sha(lang).await.ok();
debug!(lang = %lang, commit_sha = ?commit_sha, "fetched commit SHA");
let content = download_language_with_progress(lang, &on_progress).await?;
on_progress(DownloadProgress {
phase: DownloadPhase::Importing,
total_bytes: None,
downloaded_bytes: 0,
current_file: String::new(),
total_files: 0,
current_file_index: 0,
});
let (entries, skipped) = Entry::parse_tsv_lenient(&content);
if skipped > 0 {
warn!(
lang = %lang,
skipped,
"skipped malformed entries during import"
);
}
debug!(
lang = %lang,
entries = entries.len(),
"parsed entries from downloaded data"
);
let source_url = format!("https://github.com/unimorph/{}", lang.as_str());
self.store
.import(lang, &entries, Some(&source_url), commit_sha.as_deref())?;
info!(
lang = %lang,
entries = entries.len(),
commit_sha = ?commit_sha,
"imported language dataset"
);
Ok(())
}
pub fn cached_languages(&self) -> Result<Vec<LangCode>> {
self.store.languages()
}
pub fn delete(&mut self, lang: &str) -> Result<()> {
self.store.delete_language(lang)
}
}
fn get_file_patterns(lang: &LangCode) -> Vec<String> {
match lang.as_str() {
"fin" => vec!["fin.1".to_string(), "fin.2".to_string()],
_ => vec![lang.as_str().to_string()],
}
}
#[instrument(level = "debug")]
async fn download_language(lang: &LangCode) -> Result<String> {
let client = reqwest::Client::new();
let patterns = get_file_patterns(lang);
let mut all_content = String::new();
let mut found_any = false;
debug!(lang = %lang, patterns = ?patterns, "downloading from GitHub");
for pattern in &patterns {
let url = format!("{}/{}/master/{}", UNIMORPH_RAW_URL, lang.as_str(), pattern);
debug!(url = %url, "fetching file");
let response = client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::FORBIDDEN {
warn!(lang = %lang, "GitHub rate limit exceeded");
return Err(Error::RateLimited);
}
if response.status() == reqwest::StatusCode::NOT_FOUND {
debug!(url = %url, "file not found, trying next pattern");
continue;
}
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let content = response.text().await?;
let bytes = content.len();
debug!(url = %url, bytes, "downloaded file");
all_content.push_str(&content);
if !content.ends_with('\n') {
all_content.push('\n');
}
found_any = true;
}
if !found_any {
return Err(Error::DownloadFailed(format!(
"No data files found for language: {}",
lang.as_str()
)));
}
Ok(all_content)
}
#[instrument(level = "debug", skip(on_progress))]
async fn download_language_with_progress<F>(lang: &LangCode, on_progress: &F) -> Result<String>
where
F: Fn(DownloadProgress) + Send + Sync,
{
let client = reqwest::Client::new();
let patterns = get_file_patterns(lang);
let total_files = patterns.len();
let mut all_content = String::new();
let mut found_any = false;
debug!(lang = %lang, patterns = ?patterns, "downloading from GitHub with progress");
for (file_index, pattern) in patterns.iter().enumerate() {
let url = format!("{}/{}/master/{}", UNIMORPH_RAW_URL, lang.as_str(), pattern);
debug!(url = %url, "fetching file");
let response = client.get(&url).send().await?;
if response.status() == reqwest::StatusCode::FORBIDDEN {
warn!(lang = %lang, "GitHub rate limit exceeded");
return Err(Error::RateLimited);
}
if response.status() == reqwest::StatusCode::NOT_FOUND {
debug!(url = %url, "file not found, trying next pattern");
continue;
}
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"HTTP {}: {}",
response.status(),
url
)));
}
let total_bytes = response.content_length();
let mut downloaded_bytes: u64 = 0;
let mut content = Vec::new();
on_progress(DownloadProgress {
phase: DownloadPhase::Downloading,
total_bytes,
downloaded_bytes,
current_file: pattern.clone(),
total_files,
current_file_index: file_index + 1,
});
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
downloaded_bytes += chunk.len() as u64;
content.extend_from_slice(&chunk);
on_progress(DownloadProgress {
phase: DownloadPhase::Downloading,
total_bytes,
downloaded_bytes,
current_file: pattern.clone(),
total_files,
current_file_index: file_index + 1,
});
}
let text = String::from_utf8_lossy(&content);
debug!(url = %url, bytes = content.len(), "downloaded file");
all_content.push_str(&text);
if !text.ends_with('\n') {
all_content.push('\n');
}
found_any = true;
}
if !found_any {
return Err(Error::DownloadFailed(format!(
"No data files found for language: {}",
lang.as_str()
)));
}
Ok(all_content)
}
#[instrument(level = "debug")]
async fn fetch_commit_sha(lang: &LangCode) -> Result<String> {
let client = reqwest::Client::new();
let url = format!(
"https://api.github.com/repos/unimorph/{}/commits/master",
lang.as_str()
);
debug!(url = %url, "fetching commit SHA");
let response = client
.get(&url)
.header("User-Agent", "unimorph-rs")
.header("Accept", "application/vnd.github.v3+json")
.send()
.await?;
if response.status() == reqwest::StatusCode::FORBIDDEN {
return Err(Error::RateLimited);
}
if !response.status().is_success() {
return Err(Error::DownloadFailed(format!(
"Failed to fetch commit info: HTTP {}",
response.status()
)));
}
let json: serde_json::Value = response.json().await?;
let sha = json["sha"]
.as_str()
.ok_or_else(|| Error::DownloadFailed("No SHA in commit response".to_string()))?
.to_string();
debug!(sha = %sha, "fetched commit SHA");
Ok(sha)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn repository_with_custom_dir() {
let temp_dir = TempDir::new().unwrap();
let repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
assert!(repo.cache_dir().exists());
assert!(repo.cache_dir().join("datasets.db").exists());
}
#[test]
fn cached_languages_empty() {
let temp_dir = TempDir::new().unwrap();
let repo = Repository::with_cache_dir(temp_dir.path()).unwrap();
let langs = repo.cached_languages().unwrap();
assert!(langs.is_empty());
}
#[test]
fn file_patterns() {
let ita: LangCode = "ita".parse().unwrap();
let fin: LangCode = "fin".parse().unwrap();
assert_eq!(get_file_patterns(&ita), vec!["ita"]);
assert_eq!(get_file_patterns(&fin), vec!["fin.1", "fin.2"]);
}
}