use crate::Result;
use crate::core::config::ExtractionConfig;
use crate::plugins::{DocumentExtractor, Plugin};
use crate::text::utf8_validation;
use crate::types::{ExtractionResult, Metadata, Table};
use async_trait::async_trait;
pub struct CsvExtractor;
impl CsvExtractor {
pub fn new() -> Self {
Self
}
}
impl Default for CsvExtractor {
fn default() -> Self {
Self::new()
}
}
impl Plugin for CsvExtractor {
fn name(&self) -> &str {
"csv-extractor"
}
fn version(&self) -> String {
env!("CARGO_PKG_VERSION").to_string()
}
fn initialize(&self) -> Result<()> {
Ok(())
}
fn shutdown(&self) -> Result<()> {
Ok(())
}
fn description(&self) -> &str {
"CSV/TSV text extraction with table structure"
}
fn author(&self) -> &str {
"Kreuzberg Team"
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl DocumentExtractor for CsvExtractor {
async fn extract_bytes(
&self,
content: &[u8],
mime_type: &str,
config: &ExtractionConfig,
) -> Result<ExtractionResult> {
let text = decode_csv_bytes(content);
let delimiter = if mime_type == "text/tab-separated-values" {
'\t'
} else {
detect_delimiter(&text)
};
let rows = parse_csv(&text, delimiter);
let row_count = rows.len();
let col_count = rows.iter().map(|r| r.len()).max().unwrap_or(0);
let has_header = detect_header(&rows);
let content_text = build_content_text(&rows, has_header);
let column_types = infer_column_types(&rows, has_header);
let markdown = build_markdown_table(&rows);
let table = Table {
cells: rows,
markdown,
page_number: 1,
bounding_box: None,
};
let mut additional = ahash::AHashMap::new();
additional.insert(
std::borrow::Cow::Borrowed("row_count"),
serde_json::Value::Number(row_count.into()),
);
additional.insert(
std::borrow::Cow::Borrowed("column_count"),
serde_json::Value::Number(col_count.into()),
);
additional.insert(
std::borrow::Cow::Borrowed("extraction_method"),
serde_json::Value::String("native_csv".to_string()),
);
additional.insert(
std::borrow::Cow::Borrowed("has_header"),
serde_json::Value::Bool(has_header),
);
if !column_types.is_empty() {
additional.insert(
std::borrow::Cow::Borrowed("column_types"),
serde_json::json!(column_types),
);
}
let document = if config.include_document_structure && !table.cells.is_empty() {
use crate::types::builder::DocumentStructureBuilder;
let mut builder = DocumentStructureBuilder::new().source_format("csv");
builder.push_table_from_cells(&table.cells, None);
Some(builder.build())
} else {
None
};
Ok(ExtractionResult {
content: content_text,
mime_type: mime_type.to_string().into(),
metadata: Metadata {
additional,
..Default::default()
},
pages: None,
tables: vec![table],
detected_languages: None,
chunks: None,
images: None,
djot_content: None,
elements: None,
ocr_elements: None,
document,
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
extracted_keywords: None,
quality_score: None,
processing_warnings: Vec::new(),
annotations: None,
children: None,
})
}
fn supported_mime_types(&self) -> &[&str] {
&["text/csv", "text/tab-separated-values"]
}
fn priority(&self) -> i32 {
60 }
}
fn detect_delimiter(text: &str) -> char {
const CANDIDATES: &[char] = &[',', '\t', '|', ';'];
let mut best_delimiter = ',';
let mut best_score = 0usize;
for &candidate in CANDIDATES {
let sample: String = text.lines().take(10).collect::<Vec<_>>().join("\n");
let rows = parse_csv(&sample, candidate);
if rows.len() < 2 {
continue;
}
let col_counts: Vec<usize> = rows.iter().map(|r| r.len()).collect();
let first_count = col_counts[0];
if first_count <= 1 {
continue;
}
let consistent_rows = col_counts.iter().filter(|&&c| c == first_count).count();
let score = consistent_rows * first_count;
if score > best_score {
best_score = score;
best_delimiter = candidate;
}
}
best_delimiter
}
fn parse_csv(text: &str, delimiter: char) -> Vec<Vec<String>> {
let mut rows: Vec<Vec<String>> = Vec::new();
let mut current_row: Vec<String> = Vec::new();
let mut current_field = String::new();
let mut in_quotes = false;
let mut chars = text.chars().peekable();
while let Some(c) = chars.next() {
if in_quotes {
if c == '"' {
if chars.peek() == Some(&'"') {
current_field.push('"');
chars.next();
} else {
in_quotes = false;
}
} else {
current_field.push(c);
}
} else {
match c {
'"' => {
in_quotes = true;
}
c if c == delimiter => {
current_row.push(current_field.clone());
current_field.clear();
}
'\r' => {
if chars.peek() == Some(&'\n') {
chars.next();
}
current_row.push(current_field.clone());
current_field.clear();
if !current_row.iter().all(|f| f.is_empty()) {
rows.push(current_row);
}
current_row = Vec::new();
}
'\n' => {
current_row.push(current_field.clone());
current_field.clear();
if !current_row.iter().all(|f| f.is_empty()) {
rows.push(current_row);
}
current_row = Vec::new();
}
_ => {
current_field.push(c);
}
}
}
}
if !current_field.is_empty() || !current_row.is_empty() {
current_row.push(current_field);
if !current_row.iter().all(|f| f.is_empty()) {
rows.push(current_row);
}
}
rows
}
fn decode_csv_bytes(content: &[u8]) -> String {
if let Ok(s) = utf8_validation::from_utf8(content) {
return s.to_string();
}
#[cfg(feature = "quality")]
{
crate::utils::safe_decode(content, None)
}
#[cfg(not(feature = "quality"))]
{
decode_csv_bytes_fallback(content)
}
}
#[cfg(not(feature = "quality"))]
fn decode_csv_bytes_fallback(content: &[u8]) -> String {
let encoding_labels = [
"shift_jis", "windows-31j", "windows-1252", "iso-8859-1", "gb18030", "big5", ];
for label in &encoding_labels {
if let Some(encoding) = encoding_rs::Encoding::for_label(label.as_bytes()) {
let (decoded, _, had_errors) = encoding.decode(content);
if !had_errors {
return decoded.into_owned();
}
}
}
if let Some(shift_jis) = encoding_rs::Encoding::for_label(b"shift_jis") {
let (decoded, _, _) = shift_jis.decode(content);
return decoded.into_owned();
}
String::from_utf8_lossy(content).into_owned()
}
fn detect_header(rows: &[Vec<String>]) -> bool {
if rows.len() < 2 {
return false;
}
let first_row = &rows[0];
if first_row.len() < 2 {
return false;
}
let first_row_has_number = first_row.iter().any(|cell| {
let trimmed = cell.trim();
!trimmed.is_empty() && trimmed.parse::<f64>().is_ok()
});
if first_row_has_number {
return false;
}
let data_rows = &rows[1..rows.len().min(6)];
data_rows.iter().any(|row| {
row.iter().any(|cell| {
let trimmed = cell.trim();
!trimmed.is_empty() && trimmed.parse::<f64>().is_ok()
})
})
}
fn infer_column_types(rows: &[Vec<String>], has_header: bool) -> Vec<String> {
if rows.is_empty() {
return Vec::new();
}
let col_count = rows.iter().map(|r| r.len()).max().unwrap_or(0);
if col_count == 0 {
return Vec::new();
}
let data_start = if has_header { 1 } else { 0 };
let scan_end = rows.len().min(data_start + 20);
if data_start >= scan_end {
return vec!["text".to_string(); col_count];
}
let data_rows = &rows[data_start..scan_end];
let date_patterns = [
regex::Regex::new(r"^\d{4}-\d{2}-\d{2}").ok(),
regex::Regex::new(r"^\d{1,2}/\d{1,2}/\d{2,4}").ok(),
regex::Regex::new(r"^\d{1,2}\.\d{1,2}\.\d{2,4}").ok(),
];
(0..col_count)
.map(|col_idx| {
let mut numeric_count = 0usize;
let mut date_count = 0usize;
let mut non_empty_count = 0usize;
for row in data_rows {
let cell = row.get(col_idx).map(|s| s.trim()).unwrap_or("");
if cell.is_empty() {
continue;
}
non_empty_count += 1;
if cell.parse::<f64>().is_ok() {
numeric_count += 1;
} else {
for pat in &date_patterns {
if let Some(re) = pat
&& re.is_match(cell)
{
date_count += 1;
break;
}
}
}
}
if non_empty_count == 0 {
"text".to_string()
} else if numeric_count * 2 >= non_empty_count {
"numeric".to_string()
} else if date_count * 2 >= non_empty_count {
"date".to_string()
} else {
"text".to_string()
}
})
.collect()
}
fn build_content_text(rows: &[Vec<String>], has_header: bool) -> String {
if rows.is_empty() {
return String::new();
}
if !has_header || rows.len() < 2 {
return rows
.iter()
.map(|row| {
row.iter()
.map(|cell| cell.trim())
.filter(|cell| !cell.is_empty())
.collect::<Vec<_>>()
.join(" ")
})
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join("\n");
}
let headers = &rows[0];
let mut sections = Vec::with_capacity(rows.len() - 1);
for (i, row) in rows[1..].iter().enumerate() {
let mut lines = vec![format!("Row {}:", i + 1)];
for (header, value) in headers.iter().zip(row.iter()) {
let h = header.trim();
let v = value.trim();
if !h.is_empty() && !v.is_empty() {
lines.push(format!(" {}: {}", h, v));
}
}
if lines.len() > 1 {
sections.push(lines.join("\n"));
}
}
sections.join("\n\n")
}
fn build_markdown_table(rows: &[Vec<String>]) -> String {
if rows.is_empty() {
return String::new();
}
let col_count = rows.iter().map(|r| r.len()).max().unwrap_or(0);
if col_count == 0 {
return String::new();
}
let mut markdown = String::new();
for (i, row) in rows.iter().enumerate() {
markdown.push('|');
for j in 0..col_count {
let cell = row.get(j).map(|s| s.trim()).unwrap_or("");
markdown.push(' ');
markdown.push_str(cell);
markdown.push_str(" |");
}
markdown.push('\n');
if i == 0 {
markdown.push('|');
for _ in 0..col_count {
markdown.push_str(" --- |");
}
markdown.push('\n');
}
}
markdown
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_csv_simple() {
let rows = parse_csv("a,b,c\n1,2,3\n", ',');
assert_eq!(rows, vec![vec!["a", "b", "c"], vec!["1", "2", "3"]]);
}
#[test]
fn test_parse_csv_quoted() {
let rows = parse_csv("\"hello, world\",b,c\n", ',');
assert_eq!(rows, vec![vec!["hello, world", "b", "c"]]);
}
#[test]
fn test_parse_csv_escaped_quotes() {
let rows = parse_csv("\"say \"\"hello\"\"\",b\n", ',');
assert_eq!(rows, vec![vec!["say \"hello\"", "b"]]);
}
#[test]
fn test_parse_tsv() {
let rows = parse_csv("a\tb\tc\n1\t2\t3\n", '\t');
assert_eq!(rows, vec![vec!["a", "b", "c"], vec!["1", "2", "3"]]);
}
#[test]
fn test_parse_csv_crlf() {
let rows = parse_csv("a,b\r\n1,2\r\n", ',');
assert_eq!(rows, vec![vec!["a", "b"], vec!["1", "2"]]);
}
#[test]
fn test_parse_csv_empty_fields() {
let rows = parse_csv("a,,c\n", ',');
assert_eq!(rows, vec![vec!["a", "", "c"]]);
}
#[test]
fn test_build_markdown_table() {
let rows = vec![
vec!["Name".to_string(), "Age".to_string()],
vec!["Alice".to_string(), "30".to_string()],
];
let md = build_markdown_table(&rows);
assert!(md.contains("| Name | Age |"));
assert!(md.contains("| --- | --- |"));
assert!(md.contains("| Alice | 30 |"));
}
#[tokio::test]
async fn test_csv_extractor_plugin_interface() {
let extractor = CsvExtractor::new();
assert_eq!(extractor.name(), "csv-extractor");
assert_eq!(extractor.version(), env!("CARGO_PKG_VERSION"));
assert_eq!(extractor.priority(), 60);
assert_eq!(
extractor.supported_mime_types(),
&["text/csv", "text/tab-separated-values"]
);
}
#[tokio::test]
async fn test_csv_extractor_output() {
let extractor = CsvExtractor::new();
let config = ExtractionConfig::default();
let csv_data = b"Name,Age,City\nAlice,30,NYC\nBob,25,LA\n";
let result = extractor
.extract_bytes(csv_data, "text/csv", &config)
.await
.expect("CSV extraction should succeed");
assert!(result.content.contains("Name: Alice"));
assert!(result.content.contains("Age: 30"));
assert!(result.content.contains("City: NYC"));
assert!(result.content.contains("Name: Bob"));
assert!(result.content.contains("Row 1:"));
assert!(result.content.contains("Row 2:"));
assert_eq!(result.tables.len(), 1);
assert_eq!(result.tables[0].cells.len(), 3);
assert_eq!(result.tables[0].cells[0], vec!["Name", "Age", "City"]);
}
#[tokio::test]
async fn test_csv_extractor_quoted_fields() {
let extractor = CsvExtractor::new();
let config = ExtractionConfig::default();
let csv_data = b"Name,Description\n\"Smith, John\",\"Has a comma, inside\"\n";
let result = extractor
.extract_bytes(csv_data, "text/csv", &config)
.await
.expect("CSV extraction with quoted fields should succeed");
assert!(result.content.contains("Smith, John"));
assert_eq!(result.tables[0].cells[1][0], "Smith, John");
}
#[test]
fn test_detect_delimiter_comma() {
assert_eq!(detect_delimiter("a,b,c\n1,2,3\n4,5,6"), ',');
}
#[test]
fn test_detect_delimiter_semicolon() {
assert_eq!(detect_delimiter("a;b;c\n1;2;3\n4;5;6"), ';');
}
#[test]
fn test_detect_delimiter_pipe() {
assert_eq!(detect_delimiter("a|b|c\n1|2|3\n4|5|6"), '|');
}
#[test]
fn test_detect_delimiter_tab() {
assert_eq!(detect_delimiter("a\tb\tc\n1\t2\t3\n4\t5\t6"), '\t');
}
#[test]
fn test_detect_delimiter_semicolons_with_commas_in_values() {
assert_eq!(
detect_delimiter("\"last, first\";age;city\n\"doe, john\";30;NYC\n\"smith, jane\";25;LA"),
';'
);
}
#[test]
fn test_decode_csv_bytes_shift_jis() {
let shift_jis_data = vec![
0x96u8, 0xbc, 0x91, 0x4f, 0x2c, 0x94, 0x4e, 0x97, 0xee, 0x2c, 0x8f, 0x5a, 0x8f, 0x8a,
];
let decoded = decode_csv_bytes(&shift_jis_data);
assert!(decoded.contains("名前"), "Should contain '名前' (Name)");
assert!(decoded.contains("年齢"), "Should contain '年齢' (Age)");
assert!(decoded.contains("住所"), "Should contain '住所' (Address)");
assert!(
!decoded.contains("□"),
"Should not contain mojibake replacement characters"
);
assert!(
!decoded.contains("\u{FFFD}"),
"Should not contain Unicode replacement characters"
);
}
#[test]
fn test_decode_csv_bytes_utf8() {
let utf8_data = "名前,年齢,住所".as_bytes();
let decoded = decode_csv_bytes(utf8_data);
assert_eq!(decoded, "名前,年齢,住所");
}
#[test]
fn test_detect_header_with_numeric_data() {
let rows = vec![
vec!["Name".to_string(), "Age".to_string(), "Score".to_string()],
vec!["Alice".to_string(), "30".to_string(), "95.5".to_string()],
vec!["Bob".to_string(), "25".to_string(), "88.0".to_string()],
];
assert!(detect_header(&rows), "Should detect header when data rows have numbers");
}
#[test]
fn test_detect_header_all_text() {
let rows = vec![
vec!["Name".to_string(), "City".to_string()],
vec!["Alice".to_string(), "NYC".to_string()],
vec!["Bob".to_string(), "LA".to_string()],
];
assert!(!detect_header(&rows), "Should not detect header when all data is text");
}
#[test]
fn test_detect_header_numeric_first_row() {
let rows = vec![
vec!["1".to_string(), "2".to_string(), "3".to_string()],
vec!["4".to_string(), "5".to_string(), "6".to_string()],
];
assert!(
!detect_header(&rows),
"Should not detect header when first row has numbers"
);
}
#[test]
fn test_infer_column_types_basic() {
let rows = vec![
vec!["Name".to_string(), "Age".to_string(), "Date".to_string()],
vec!["Alice".to_string(), "30".to_string(), "2024-01-15".to_string()],
vec!["Bob".to_string(), "25".to_string(), "2024-02-20".to_string()],
];
let types = infer_column_types(&rows, true);
assert_eq!(types.len(), 3);
assert_eq!(types[0], "text");
assert_eq!(types[1], "numeric");
assert_eq!(types[2], "date");
}
#[tokio::test]
async fn test_csv_extractor_header_detection_metadata() {
let extractor = CsvExtractor::new();
let config = ExtractionConfig::default();
let csv_data = b"Name,Age,City\nAlice,30,NYC\nBob,25,LA\n";
let result = extractor.extract_bytes(csv_data, "text/csv", &config).await.unwrap();
let has_header = result.metadata.additional.get("has_header");
assert_eq!(has_header, Some(&serde_json::Value::Bool(true)));
let col_types = result.metadata.additional.get("column_types");
assert!(col_types.is_some(), "Should have column_types metadata");
}
#[tokio::test]
async fn test_csv_extractor_real_file() {
let test_file =
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test_documents/csv/data_table.csv");
if !test_file.exists() {
return;
}
let content = std::fs::read(&test_file).expect("Failed to read test CSV");
let extractor = CsvExtractor::new();
let config = ExtractionConfig::default();
let result = extractor.extract_bytes(&content, "text/csv", &config).await.unwrap();
assert!(!result.content.is_empty());
assert!(result.content.contains("Name: Alice Johnson"));
assert!(result.content.contains("Department: Engineering"));
assert_eq!(result.tables.len(), 1);
assert_eq!(result.tables[0].cells.len(), 11); }
}