use std::path::PathBuf;
use async_trait::async_trait;
use cognis_core::document_loaders::BaseLoader;
use cognis_core::document_loaders::DocumentStream;
use cognis_core::documents::Document;
use cognis_core::error::{CognisError, Result};
use futures::stream;
use super::csv::CsvLoader;
use super::html::HTMLLoader;
use super::json::JsonLoader;
use super::text::TextLoader;
pub struct DirectoryLoader {
path: PathBuf,
glob_pattern: Option<String>,
recursive: bool,
}
impl DirectoryLoader {
pub fn new(path: impl Into<PathBuf>) -> Self {
Self {
path: path.into(),
glob_pattern: None,
recursive: false,
}
}
pub fn with_glob(mut self, pattern: impl Into<String>) -> Self {
self.glob_pattern = Some(pattern.into());
self
}
pub fn recursive(mut self, recursive: bool) -> Self {
self.recursive = recursive;
self
}
fn collect_files(&self) -> Result<Vec<PathBuf>> {
let pattern = self.glob_pattern.as_deref().unwrap_or("*");
let glob_str = if self.recursive {
format!("{}/**/{}", self.path.display(), pattern)
} else {
format!("{}/{}", self.path.display(), pattern)
};
let mut files: Vec<PathBuf> = Vec::new();
let entries = glob::glob(&glob_str)
.map_err(|e| CognisError::Other(format!("Invalid glob pattern: {}", e)))?;
for entry in entries {
let path = entry.map_err(|e| CognisError::Other(format!("Glob entry error: {}", e)))?;
if path.is_file() {
files.push(path);
}
}
files.sort();
Ok(files)
}
async fn load_file(&self, path: &PathBuf) -> Result<Vec<Document>> {
let ext = path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_lowercase();
match ext.as_str() {
"json" | "jsonl" => {
let loader = JsonLoader::new(path);
loader.load().await
}
"csv" => {
let loader = CsvLoader::new(path);
loader.load().await
}
"html" | "htm" => {
let loader = HTMLLoader::new(path);
loader.load().await
}
_ => {
let loader = TextLoader::new(path);
loader.load().await
}
}
}
}
#[async_trait]
impl BaseLoader for DirectoryLoader {
async fn lazy_load(&self) -> Result<DocumentStream> {
let files = self.collect_files()?;
let mut all_docs: Vec<Result<Document>> = Vec::new();
for file in &files {
match self.load_file(file).await {
Ok(docs) => {
for doc in docs {
all_docs.push(Ok(doc));
}
}
Err(e) => {
all_docs.push(Err(e));
}
}
}
Ok(Box::pin(stream::iter(all_docs)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn test_directory_loader() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("notes.txt"), "Some notes here").unwrap();
fs::write(
dir.path().join("data.json"),
r#"[{"text":"item1"},{"text":"item2"}]"#,
)
.unwrap();
fs::write(dir.path().join("people.csv"), "name,age\nAlice,30\n").unwrap();
let loader = DirectoryLoader::new(dir.path());
let docs = loader.load().await.unwrap();
assert_eq!(docs.len(), 4);
}
}