langchain_rust/document_loaders/csv_loader/
csv_loader.rs

1use crate::document_loaders::{process_doc_stream, LoaderError};
2use crate::{document_loaders::Loader, schemas::Document, text_splitter::TextSplitter};
3use async_stream::stream;
4use async_trait::async_trait;
5use csv;
6use futures::Stream;
7use serde_json::Value;
8
9use std::collections::HashMap;
10use std::fs::File;
11use std::io::{BufReader, Cursor, Read};
12use std::path::Path;
13use std::pin::Pin;
14
15#[derive(Debug, Clone)]
16pub struct CsvLoader<R> {
17    reader: R,
18    columns: Vec<String>,
19}
20
21impl<R: Read> CsvLoader<R> {
22    pub fn new(reader: R, columns: Vec<String>) -> Self {
23        Self { reader, columns }
24    }
25}
26
27impl CsvLoader<Cursor<Vec<u8>>> {
28    pub fn from_string<S: Into<String>>(input: S, columns: Vec<String>) -> Self {
29        let input = input.into();
30        let reader = Cursor::new(input.into_bytes());
31        Self::new(reader, columns)
32    }
33}
34
35impl CsvLoader<BufReader<File>> {
36    pub fn from_path<P: AsRef<Path>>(path: P, columns: Vec<String>) -> Result<Self, LoaderError> {
37        let file = File::open(path)?;
38        let reader = BufReader::new(file);
39        Ok(Self::new(reader, columns))
40    }
41}
42
43#[async_trait]
44impl<R: Read + Send + Sync + 'static> Loader for CsvLoader<R> {
45    async fn load(
46        mut self,
47    ) -> Result<
48        Pin<Box<dyn Stream<Item = Result<Document, LoaderError>> + Send + 'static>>,
49        LoaderError,
50    > {
51        let mut reader = csv::Reader::from_reader(self.reader);
52        let headers = reader.headers()?.clone();
53
54        // Initialize rown to track row number
55        let mut row_number: i64 = 0;
56        let columns = self.columns.clone();
57
58        let stream = stream! {
59            for result in reader.records() {
60                let record = result?;
61                let mut content = String::new();
62
63                for (i, field) in record.iter().enumerate() {
64                    let header = &headers[i];
65                    if !columns.contains(&header.to_string()) {
66                        continue;
67                    }
68
69                    let line = format!("{}: {}", header, field);
70                    content.push_str(&line);
71                    content.push('\n');
72                }
73
74                row_number += 1; // Increment the row number by 1 for each row
75
76                // Generate document with the content and metadata
77                let mut document = Document::new(content);
78                let mut metadata = HashMap::new();
79                metadata.insert("row".to_string(), Value::from(row_number));
80
81                // Attach the metadata to the document
82                document.metadata = metadata;
83
84                yield Ok(document);
85            }
86        };
87
88        Ok(Box::pin(stream))
89    }
90
91    async fn load_and_split<TS: TextSplitter + 'static>(
92        mut self,
93        splitter: TS,
94    ) -> Result<
95        Pin<Box<dyn Stream<Item = Result<Document, LoaderError>> + Send + 'static>>,
96        LoaderError,
97    > {
98        let doc_stream = self.load().await?;
99        let stream = process_doc_stream(doc_stream, splitter).await;
100        Ok(Box::pin(stream))
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use futures_util::StreamExt;
107
108    use super::*;
109
110    #[tokio::test]
111    async fn test_csv_loader() {
112        // text to represent csv data
113        let input = "name,age,city,country
114John Doe,25,New York,United States
115Jane Smith,32,London,United Kingdom";
116
117        let columns = vec![
118            "name".to_string(),
119            "age".to_string(),
120            "city".to_string(),
121            "country".to_string(),
122        ];
123        let csv_loader = CsvLoader::new(input.as_bytes(), columns);
124
125        let documents = csv_loader
126            .load()
127            .await
128            .unwrap()
129            .map(|x| x.unwrap())
130            .collect::<Vec<_>>()
131            .await;
132
133        assert_eq!(documents.len(), 2);
134
135        let expected1 = "name: John Doe\nage: 25\ncity: New York\ncountry: United States\n";
136        assert_eq!(documents[0].metadata.get("row").unwrap(), &Value::from(1));
137        assert_eq!(documents[0].page_content, expected1);
138
139        let expected2 = "name: Jane Smith\nage: 32\ncity: London\ncountry: United Kingdom\n";
140        assert_eq!(documents[1].metadata.get("row").unwrap(), &Value::from(2));
141        assert_eq!(documents[1].page_content, expected2);
142    }
143
144    #[tokio::test]
145    async fn test_csv_load_from_path() {
146        let path = "./src/document_loaders/test_data/test.csv";
147        let columns = vec![
148            "name".to_string(),
149            "age".to_string(),
150            "city".to_string(),
151            "country".to_string(),
152        ];
153        let csv_loader = CsvLoader::from_path(path, columns).expect("Failed to create csv loader");
154
155        let documents = csv_loader
156            .load()
157            .await
158            .unwrap()
159            .map(|x| x.unwrap())
160            .collect::<Vec<_>>()
161            .await;
162
163        assert_eq!(documents.len(), 20);
164
165        let expected1 = "name: John Doe\nage: 25\ncity: New York\ncountry: United States\n";
166        assert_eq!(documents[0].metadata.get("row").unwrap(), &Value::from(1));
167        assert_eq!(documents[0].page_content, expected1);
168
169        let expected2 = "name: Jane Smith\nage: 32\ncity: London\ncountry: United Kingdom\n";
170        assert_eq!(documents[1].metadata.get("row").unwrap(), &Value::from(2));
171        assert_eq!(documents[1].page_content, expected2);
172    }
173}