use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use log::warn;
use tempfile::TempDir;
use crate::core::archive;
use crate::error::SubXError;
#[derive(Debug, Clone)]
pub struct InputPathHandler {
pub paths: Vec<PathBuf>,
pub recursive: bool,
pub file_extensions: Vec<String>,
pub no_extract: bool,
}
impl InputPathHandler {
pub fn merge_paths_from_multiple_sources(
optional_paths: &[Option<PathBuf>],
multiple_paths: &[PathBuf],
string_paths: &[String],
) -> Result<Vec<PathBuf>, SubXError> {
let mut all_paths = Vec::new();
for p in optional_paths.iter().flatten() {
all_paths.push(p.clone());
}
all_paths.extend(multiple_paths.iter().cloned());
for path_str in string_paths {
all_paths.push(PathBuf::from(path_str));
}
if all_paths.is_empty() {
return Err(SubXError::NoInputSpecified);
}
Ok(all_paths)
}
pub fn from_args(input_args: &[PathBuf], recursive: bool) -> Result<Self, SubXError> {
let handler = Self {
paths: input_args.to_vec(),
recursive,
file_extensions: Vec::new(),
no_extract: false,
};
handler.validate()?;
Ok(handler)
}
pub fn with_extensions(mut self, extensions: &[&str]) -> Self {
self.file_extensions = extensions.iter().map(|s| s.to_lowercase()).collect();
self
}
pub fn with_no_extract(mut self, no_extract: bool) -> Self {
self.no_extract = no_extract;
self
}
pub fn validate(&self) -> Result<(), SubXError> {
for path in &self.paths {
if !path.exists() {
return Err(SubXError::PathNotFound(path.clone()));
}
}
Ok(())
}
pub fn get_directories(&self) -> Vec<PathBuf> {
let mut directories = std::collections::HashSet::new();
for path in &self.paths {
if path.is_dir() {
directories.insert(path.clone());
} else if path.is_file() {
if let Some(parent) = path.parent() {
directories.insert(parent.to_path_buf());
}
}
}
directories.into_iter().collect()
}
pub fn collect_files(&self) -> Result<CollectedFiles, SubXError> {
let mut files = Vec::new();
let mut temp_dirs = Vec::new();
let mut archive_origins: HashMap<PathBuf, PathBuf> = HashMap::new();
for base in &self.paths {
if base.is_file() {
if !self.no_extract {
if let Some(_format) = archive::detect_format(base) {
match self.extract_and_collect(base) {
Ok((extracted, temp_dir)) => {
let temp_root = temp_dir.path().to_path_buf();
archive_origins.insert(temp_root, base.clone());
files.extend(extracted);
temp_dirs.push(temp_dir);
continue;
}
Err(e) => {
warn!(
"Failed to extract archive {}, skipping: {e}",
base.display()
);
continue;
}
}
}
}
if self.matches_extension(base) {
files.push(base.clone());
}
} else if base.is_dir() {
if self.recursive {
files.extend(self.scan_directory_recursive(base)?);
} else {
files.extend(self.scan_directory_flat(base)?);
}
} else {
return Err(SubXError::InvalidPath(base.clone()));
}
}
if temp_dirs.is_empty() {
Ok(CollectedFiles::new(files))
} else {
Ok(CollectedFiles::with_archives(
files,
temp_dirs,
archive_origins,
))
}
}
fn extract_and_collect(
&self,
archive_path: &Path,
) -> Result<(Vec<PathBuf>, TempDir), SubXError> {
let temp_dir = TempDir::new().map_err(|e| {
SubXError::CommandExecution(format!("Failed to create temp directory: {e}"))
})?;
let extracted = archive::extract_archive(archive_path, temp_dir.path()).map_err(|e| {
SubXError::CommandExecution(format!(
"Failed to extract {}: {e}",
archive_path.display()
))
})?;
let filtered: Vec<PathBuf> = extracted
.into_iter()
.filter(|p| self.matches_extension(p))
.collect();
Ok((filtered, temp_dir))
}
fn matches_extension(&self, path: &Path) -> bool {
if self.file_extensions.is_empty() {
return true;
}
path.extension()
.and_then(|e| e.to_str())
.map(|s| {
self.file_extensions
.iter()
.any(|ext| ext.eq_ignore_ascii_case(s))
})
.unwrap_or(false)
}
fn scan_directory_flat(&self, dir: &Path) -> Result<Vec<PathBuf>, SubXError> {
let mut result = Vec::new();
let rd = fs::read_dir(dir).map_err(|e| SubXError::DirectoryReadError {
path: dir.to_path_buf(),
source: e,
})?;
for entry in rd {
let entry = entry.map_err(|e| SubXError::DirectoryReadError {
path: dir.to_path_buf(),
source: e,
})?;
let ft = entry
.file_type()
.map_err(|e| SubXError::DirectoryReadError {
path: dir.to_path_buf(),
source: e,
})?;
if ft.is_symlink() {
log::debug!("Skipping symlink: {}", entry.path().display());
continue;
}
let p = entry.path();
if ft.is_file() && self.matches_extension(&p) {
result.push(p);
}
}
Ok(result)
}
fn scan_directory_recursive(&self, dir: &Path) -> Result<Vec<PathBuf>, SubXError> {
let mut result = Vec::new();
let rd = fs::read_dir(dir).map_err(|e| SubXError::DirectoryReadError {
path: dir.to_path_buf(),
source: e,
})?;
for entry in rd {
let entry = entry.map_err(|e| SubXError::DirectoryReadError {
path: dir.to_path_buf(),
source: e,
})?;
let ft = entry
.file_type()
.map_err(|e| SubXError::DirectoryReadError {
path: dir.to_path_buf(),
source: e,
})?;
if ft.is_symlink() {
log::debug!("Skipping symlink: {}", entry.path().display());
continue;
}
let p = entry.path();
if ft.is_file() {
if self.matches_extension(&p) {
result.push(p.clone());
}
} else if ft.is_dir() {
result.extend(self.scan_directory_recursive(&p)?);
}
}
Ok(result)
}
}
#[derive(Debug)]
pub struct CollectedFiles {
paths: Vec<PathBuf>,
_temp_dirs: Vec<TempDir>,
archive_origins: HashMap<PathBuf, PathBuf>,
}
impl CollectedFiles {
pub fn new(paths: Vec<PathBuf>) -> Self {
Self {
paths,
_temp_dirs: Vec::new(),
archive_origins: HashMap::new(),
}
}
pub fn with_archives(
paths: Vec<PathBuf>,
temp_dirs: Vec<TempDir>,
archive_origins: HashMap<PathBuf, PathBuf>,
) -> Self {
Self {
paths,
_temp_dirs: temp_dirs,
archive_origins,
}
}
pub fn archive_origin(&self, path: &Path) -> Option<&Path> {
for (temp_root, archive_path) in &self.archive_origins {
if path.starts_with(temp_root) {
return Some(archive_path.as_path());
}
}
None
}
pub fn into_paths(self) -> Vec<PathBuf> {
self.paths
}
}
impl std::ops::Deref for CollectedFiles {
type Target = Vec<PathBuf>;
fn deref(&self) -> &Self::Target {
&self.paths
}
}
impl AsRef<[PathBuf]> for CollectedFiles {
fn as_ref(&self) -> &[PathBuf] {
&self.paths
}
}
#[cfg(test)]
mod symlink_tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[cfg(unix)]
#[test]
fn test_scan_directory_recursive_skips_symlinks() {
let tmp = TempDir::new().unwrap();
let real = tmp.path().join("real.txt");
fs::write(&real, b"x").unwrap();
let link = tmp.path().join("link.txt");
std::os::unix::fs::symlink(&real, &link).unwrap();
let handler = InputPathHandler::from_args(&[tmp.path().to_path_buf()], true).unwrap();
let results = handler.scan_directory_recursive(tmp.path()).unwrap();
assert!(results.iter().any(|p| p == &real));
assert!(
!results.iter().any(|p| p == &link),
"symlinked file should have been skipped"
);
}
#[cfg(unix)]
#[test]
fn test_scan_directory_flat_skips_symlinks() {
let tmp = TempDir::new().unwrap();
let real = tmp.path().join("real.txt");
fs::write(&real, b"x").unwrap();
let link = tmp.path().join("link.txt");
std::os::unix::fs::symlink(&real, &link).unwrap();
let handler = InputPathHandler::from_args(&[tmp.path().to_path_buf()], false).unwrap();
let results = handler.scan_directory_flat(tmp.path()).unwrap();
assert!(results.iter().any(|p| p == &real));
assert!(!results.iter().any(|p| p == &link));
}
}