#[cfg(feature = "audio-loader")]
pub mod audio;
#[cfg(feature = "csv-loader")]
pub mod csv_loader;
#[cfg(feature = "html-loader")]
pub mod html;
#[cfg(feature = "image-loader")]
pub mod image;
#[cfg(any(feature = "pdf-pdfium", feature = "pdf-pure-rust"))]
pub mod pdf;
pub mod text;
#[cfg(feature = "unstructured")]
pub mod unstructured;
use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use cognee_models::Document;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LoaderError {
#[error("Invalid UTF-8 in document content: {0}")]
InvalidUtf8(String),
#[error("Unsupported document format: {0}")]
UnsupportedFormat(String),
#[error("IO error during extraction: {0}")]
IoError(String),
#[error("Extraction failed: {0}")]
ExtractionFailed(String),
}
#[async_trait]
pub trait DocumentLoader: Send + Sync {
async fn extract(&self, bytes: &[u8], doc: &Document) -> Result<LoaderOutput, LoaderError>;
fn engine_name(&self) -> &'static str;
}
#[derive(Debug)]
pub enum LoaderOutput {
Text(String),
Rows(Vec<String>),
SingleChunk {
text: String,
cut_type: &'static str,
},
}
pub struct LoaderRegistry {
loaders: HashMap<String, Arc<dyn DocumentLoader>>,
}
impl LoaderRegistry {
pub fn new() -> Self {
Self {
loaders: HashMap::new(),
}
}
pub fn register(&mut self, document_type: &str, loader: Arc<dyn DocumentLoader>) {
self.loaders.insert(document_type.to_string(), loader);
}
pub fn get(&self, document_type: &str) -> Option<&Arc<dyn DocumentLoader>> {
self.loaders.get(document_type)
}
pub fn default_registry() -> Self {
let mut registry = Self::new();
registry.register("text", Arc::new(text::TextLoader));
#[cfg(feature = "html-loader")]
registry.register("html", Arc::new(html::HtmlLoader));
#[cfg(any(feature = "pdf-pdfium", feature = "pdf-pure-rust"))]
registry.register("pdf", Arc::new(pdf::PdfLoader));
#[cfg(feature = "csv-loader")]
registry.register("csv", Arc::new(csv_loader::CsvLoader));
#[cfg(feature = "unstructured")]
registry.register("unstructured", Arc::new(unstructured::UnstructuredLoader));
registry
}
}
impl Default for LoaderRegistry {
fn default() -> Self {
Self::default_registry()
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
#[test]
fn default_registry_has_text_loader() {
let registry = LoaderRegistry::default();
let loader = registry.get("text");
assert!(loader.is_some());
assert_eq!(
loader.expect("just checked is_some").engine_name(),
"text_loader"
);
}
#[test]
fn get_nonexistent_returns_none() {
let registry = LoaderRegistry::default();
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn custom_registration_works() {
let mut registry = LoaderRegistry::new();
registry.register("custom", Arc::new(text::TextLoader));
let loader = registry.get("custom");
assert!(loader.is_some());
assert_eq!(
loader.expect("just checked is_some").engine_name(),
"text_loader"
);
}
#[test]
fn register_replaces_existing() {
let mut registry = LoaderRegistry::default();
registry.register("text", Arc::new(text::TextLoader));
let loader = registry.get("text");
assert!(loader.is_some());
}
#[cfg(feature = "html-loader")]
#[test]
fn default_registry_has_html_loader() {
let registry = LoaderRegistry::default_registry();
let loader = registry.get("html");
assert!(loader.is_some(), "registry must contain an \"html\" loader");
assert_eq!(
loader.expect("just checked is_some").engine_name(),
"beautiful_soup_loader",
"HtmlLoader engine_name must be \"beautiful_soup_loader\""
);
}
#[cfg(feature = "image-loader")]
#[test]
fn test_registry_has_image_loader() {
use std::sync::Arc;
use cognee_test_utils::MockLlm;
use super::image::ImageLoader;
let mock_llm = Arc::new(MockLlm::empty());
let image_loader = ImageLoader::new(mock_llm);
let mut registry = LoaderRegistry::default_registry();
registry.register("image", Arc::new(image_loader));
let loader = registry.get("image");
assert!(
loader.is_some(),
"registry must contain an \"image\" loader"
);
assert_eq!(
loader.expect("just checked is_some").engine_name(),
"image_loader",
"ImageLoader engine_name must be \"image_loader\""
);
}
#[cfg(feature = "audio-loader")]
#[test]
fn test_registry_has_audio_loader() {
use std::sync::Arc;
use cognee_llm::TranscriptionOutput;
use cognee_test_utils::MockTranscriber;
use super::audio::AudioLoader;
let mock_transcriber = Arc::new(MockTranscriber::new(
"mock-whisper",
vec![TranscriptionOutput {
text: "hello".to_string(),
language: None,
duration: None,
}],
));
let audio_loader = AudioLoader::new(mock_transcriber);
let mut registry = LoaderRegistry::default_registry();
registry.register("audio", Arc::new(audio_loader));
let loader = registry.get("audio");
assert!(
loader.is_some(),
"registry must contain an \"audio\" loader"
);
assert_eq!(
loader.expect("just checked is_some").engine_name(),
"audio_loader",
"AudioLoader engine_name must be \"audio_loader\""
);
}
}