langchain_rust/document_loaders/csv_loader/
csv_loader.rs1use crate::document_loaders::{process_doc_stream, LoaderError};
2use crate::{document_loaders::Loader, schemas::Document, text_splitter::TextSplitter};
3use async_stream::stream;
4use async_trait::async_trait;
5use csv;
6use futures::Stream;
7use serde_json::Value;
8
9use std::collections::HashMap;
10use std::fs::File;
11use std::io::{BufReader, Cursor, Read};
12use std::path::Path;
13use std::pin::Pin;
14
15#[derive(Debug, Clone)]
16pub struct CsvLoader<R> {
17 reader: R,
18 columns: Vec<String>,
19}
20
21impl<R: Read> CsvLoader<R> {
22 pub fn new(reader: R, columns: Vec<String>) -> Self {
23 Self { reader, columns }
24 }
25}
26
27impl CsvLoader<Cursor<Vec<u8>>> {
28 pub fn from_string<S: Into<String>>(input: S, columns: Vec<String>) -> Self {
29 let input = input.into();
30 let reader = Cursor::new(input.into_bytes());
31 Self::new(reader, columns)
32 }
33}
34
35impl CsvLoader<BufReader<File>> {
36 pub fn from_path<P: AsRef<Path>>(path: P, columns: Vec<String>) -> Result<Self, LoaderError> {
37 let file = File::open(path)?;
38 let reader = BufReader::new(file);
39 Ok(Self::new(reader, columns))
40 }
41}
42
43#[async_trait]
44impl<R: Read + Send + Sync + 'static> Loader for CsvLoader<R> {
45 async fn load(
46 mut self,
47 ) -> Result<
48 Pin<Box<dyn Stream<Item = Result<Document, LoaderError>> + Send + 'static>>,
49 LoaderError,
50 > {
51 let mut reader = csv::Reader::from_reader(self.reader);
52 let headers = reader.headers()?.clone();
53
54 let mut row_number: i64 = 0;
56 let columns = self.columns.clone();
57
58 let stream = stream! {
59 for result in reader.records() {
60 let record = result?;
61 let mut content = String::new();
62
63 for (i, field) in record.iter().enumerate() {
64 let header = &headers[i];
65 if !columns.contains(&header.to_string()) {
66 continue;
67 }
68
69 let line = format!("{}: {}", header, field);
70 content.push_str(&line);
71 content.push('\n');
72 }
73
74 row_number += 1; let mut document = Document::new(content);
78 let mut metadata = HashMap::new();
79 metadata.insert("row".to_string(), Value::from(row_number));
80
81 document.metadata = metadata;
83
84 yield Ok(document);
85 }
86 };
87
88 Ok(Box::pin(stream))
89 }
90
91 async fn load_and_split<TS: TextSplitter + 'static>(
92 mut self,
93 splitter: TS,
94 ) -> Result<
95 Pin<Box<dyn Stream<Item = Result<Document, LoaderError>> + Send + 'static>>,
96 LoaderError,
97 > {
98 let doc_stream = self.load().await?;
99 let stream = process_doc_stream(doc_stream, splitter).await;
100 Ok(Box::pin(stream))
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use futures_util::StreamExt;
107
108 use super::*;
109
110 #[tokio::test]
111 async fn test_csv_loader() {
112 let input = "name,age,city,country
114John Doe,25,New York,United States
115Jane Smith,32,London,United Kingdom";
116
117 let columns = vec![
118 "name".to_string(),
119 "age".to_string(),
120 "city".to_string(),
121 "country".to_string(),
122 ];
123 let csv_loader = CsvLoader::new(input.as_bytes(), columns);
124
125 let documents = csv_loader
126 .load()
127 .await
128 .unwrap()
129 .map(|x| x.unwrap())
130 .collect::<Vec<_>>()
131 .await;
132
133 assert_eq!(documents.len(), 2);
134
135 let expected1 = "name: John Doe\nage: 25\ncity: New York\ncountry: United States\n";
136 assert_eq!(documents[0].metadata.get("row").unwrap(), &Value::from(1));
137 assert_eq!(documents[0].page_content, expected1);
138
139 let expected2 = "name: Jane Smith\nage: 32\ncity: London\ncountry: United Kingdom\n";
140 assert_eq!(documents[1].metadata.get("row").unwrap(), &Value::from(2));
141 assert_eq!(documents[1].page_content, expected2);
142 }
143
144 #[tokio::test]
145 async fn test_csv_load_from_path() {
146 let path = "./src/document_loaders/test_data/test.csv";
147 let columns = vec![
148 "name".to_string(),
149 "age".to_string(),
150 "city".to_string(),
151 "country".to_string(),
152 ];
153 let csv_loader = CsvLoader::from_path(path, columns).expect("Failed to create csv loader");
154
155 let documents = csv_loader
156 .load()
157 .await
158 .unwrap()
159 .map(|x| x.unwrap())
160 .collect::<Vec<_>>()
161 .await;
162
163 assert_eq!(documents.len(), 20);
164
165 let expected1 = "name: John Doe\nage: 25\ncity: New York\ncountry: United States\n";
166 assert_eq!(documents[0].metadata.get("row").unwrap(), &Value::from(1));
167 assert_eq!(documents[0].page_content, expected1);
168
169 let expected2 = "name: Jane Smith\nage: 32\ncity: London\ncountry: United Kingdom\n";
170 assert_eq!(documents[1].metadata.get("row").unwrap(), &Value::from(2));
171 assert_eq!(documents[1].page_content, expected2);
172 }
173}