synth_claw/datasets/
local.rs1use polars::prelude::*;
2use std::fs::File;
3use std::io::{BufRead, BufReader};
4use std::path::PathBuf;
5
6use super::{DataSource, DatasetInfo, Record};
7use crate::config::FileFormat;
8use crate::{Error, Result};
9
10pub struct LocalSource {
11 path: PathBuf,
12 format: FileFormat,
13 info: DatasetInfo,
14}
15
16impl LocalSource {
17 pub fn new(path: PathBuf, format: FileFormat) -> Result<Self> {
18 if !path.exists() {
19 return Err(Error::Dataset(format!("File not found: {:?}", path)));
20 }
21
22 let info = Self::detect_info(&path, &format)?;
23
24 Ok(Self { path, format, info })
25 }
26
27 fn detect_info(path: &PathBuf, format: &FileFormat) -> Result<DatasetInfo> {
28 let (columns, num_rows) = match format {
29 FileFormat::Jsonl => Self::detect_jsonl_info(path)?,
30 FileFormat::Json => Self::detect_json_info(path)?,
31 FileFormat::Csv => Self::detect_csv_info(path)?,
32 FileFormat::Parquet => Self::detect_parquet_info(path)?,
33 };
34
35 Ok(DatasetInfo {
36 name: path
37 .file_name()
38 .and_then(|n| n.to_str())
39 .unwrap_or("local")
40 .to_string(),
41 description: None,
42 num_rows,
43 columns,
44 splits: vec![],
45 })
46 }
47
48 fn detect_jsonl_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
49 let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
50 let reader = BufReader::new(file);
51 let mut columns = Vec::new();
52 let mut num_rows = 0;
53
54 for (i, line) in reader.lines().enumerate() {
55 let line = line.map_err(|e| Error::Dataset(e.to_string()))?;
56 if line.trim().is_empty() {
57 continue;
58 }
59 num_rows += 1;
60
61 if i == 0 {
62 let obj: serde_json::Value =
63 serde_json::from_str(&line).map_err(|e| Error::Dataset(e.to_string()))?;
64 if let Some(map) = obj.as_object() {
65 columns = map.keys().cloned().collect();
66 }
67 }
68 }
69
70 Ok((columns, num_rows))
71 }
72
73 fn detect_json_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
74 let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
75 let data: serde_json::Value =
76 serde_json::from_reader(file).map_err(|e| Error::Dataset(e.to_string()))?;
77
78 match data {
79 serde_json::Value::Array(arr) => {
80 let num_rows = arr.len();
81 let columns = arr
82 .first()
83 .and_then(|v| v.as_object())
84 .map(|m| m.keys().cloned().collect())
85 .unwrap_or_default();
86 Ok((columns, num_rows))
87 }
88 _ => Err(Error::Dataset("JSON file must contain an array".into())),
89 }
90 }
91
92 fn detect_csv_info(path: &std::path::Path) -> Result<(Vec<String>, usize)> {
93 let df = CsvReadOptions::default()
94 .with_has_header(true)
95 .try_into_reader_with_file_path(Some(path.to_path_buf()))
96 .map_err(|e| Error::Dataset(e.to_string()))?
97 .finish()
98 .map_err(|e| Error::Dataset(e.to_string()))?;
99
100 let columns: Vec<String> = df
101 .get_column_names()
102 .iter()
103 .map(|s| s.to_string())
104 .collect();
105 Ok((columns, df.height()))
106 }
107
108 fn detect_parquet_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
109 let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
110 let df = ParquetReader::new(file)
111 .finish()
112 .map_err(|e| Error::Dataset(e.to_string()))?;
113
114 let columns: Vec<String> = df
115 .get_column_names()
116 .iter()
117 .map(|s| s.to_string())
118 .collect();
119 Ok((columns, df.height()))
120 }
121
122 fn load_jsonl(&self, sample: Option<usize>) -> Result<Vec<Record>> {
123 let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
124 let reader = BufReader::new(file);
125 let mut records = Vec::new();
126
127 for (i, line) in reader.lines().enumerate() {
128 if sample.is_some_and(|n| records.len() >= n) {
129 break;
130 }
131
132 let line = line.map_err(|e| Error::Dataset(e.to_string()))?;
133 if line.trim().is_empty() {
134 continue;
135 }
136
137 let data: serde_json::Value =
138 serde_json::from_str(&line).map_err(|e| Error::Dataset(e.to_string()))?;
139 records.push(Record { data, index: i });
140 }
141
142 Ok(records)
143 }
144
145 fn load_json(&self, sample: Option<usize>) -> Result<Vec<Record>> {
146 let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
147 let data: serde_json::Value =
148 serde_json::from_reader(file).map_err(|e| Error::Dataset(e.to_string()))?;
149
150 match data {
151 serde_json::Value::Array(arr) => {
152 let limit = sample.unwrap_or(arr.len());
153 Ok(arr
154 .into_iter()
155 .take(limit)
156 .enumerate()
157 .map(|(i, data)| Record { data, index: i })
158 .collect())
159 }
160 _ => Err(Error::Dataset("JSON file must contain an array".into())),
161 }
162 }
163
164 fn load_csv(&self, sample: Option<usize>) -> Result<Vec<Record>> {
165 let mut df = CsvReadOptions::default()
166 .with_has_header(true)
167 .try_into_reader_with_file_path(Some(self.path.clone()))
168 .map_err(|e| Error::Dataset(e.to_string()))?
169 .finish()
170 .map_err(|e| Error::Dataset(e.to_string()))?;
171
172 if let Some(n) = sample {
173 df = df.head(Some(n));
174 }
175
176 dataframe_to_records(df)
177 }
178
179 fn load_parquet(&self, sample: Option<usize>) -> Result<Vec<Record>> {
180 let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
181 let mut df = ParquetReader::new(file)
182 .finish()
183 .map_err(|e| Error::Dataset(e.to_string()))?;
184
185 if let Some(n) = sample {
186 df = df.head(Some(n));
187 }
188
189 dataframe_to_records(df)
190 }
191}
192
193impl DataSource for LocalSource {
194 fn info(&self) -> &DatasetInfo {
195 &self.info
196 }
197
198 fn load(&mut self, sample: Option<usize>) -> Result<Vec<Record>> {
199 match self.format {
200 FileFormat::Jsonl => self.load_jsonl(sample),
201 FileFormat::Json => self.load_json(sample),
202 FileFormat::Csv => self.load_csv(sample),
203 FileFormat::Parquet => self.load_parquet(sample),
204 }
205 }
206}
207
208fn dataframe_to_records(df: DataFrame) -> Result<Vec<Record>> {
209 let mut records = Vec::with_capacity(df.height());
210
211 for i in 0..df.height() {
212 let row = df
213 .get(i)
214 .ok_or_else(|| Error::Dataset("Row not found".into()))?;
215 let mut map = serde_json::Map::new();
216
217 for (col_name, value) in df.get_column_names().iter().zip(row.iter()) {
218 let json_value = anyvalue_to_json(value);
219 map.insert(col_name.to_string(), json_value);
220 }
221
222 records.push(Record {
223 data: serde_json::Value::Object(map),
224 index: i,
225 });
226 }
227
228 Ok(records)
229}
230
231fn anyvalue_to_json(value: &AnyValue) -> serde_json::Value {
232 match value {
233 AnyValue::Null => serde_json::Value::Null,
234 AnyValue::Boolean(b) => serde_json::Value::Bool(*b),
235 AnyValue::String(s) => serde_json::Value::String(s.to_string()),
236 AnyValue::StringOwned(s) => serde_json::Value::String(s.to_string()),
237 AnyValue::Float32(n) => serde_json::Number::from_f64(*n as f64)
238 .map(serde_json::Value::Number)
239 .unwrap_or(serde_json::Value::Null),
240 AnyValue::Float64(n) => serde_json::Number::from_f64(*n)
241 .map(serde_json::Value::Number)
242 .unwrap_or(serde_json::Value::Null),
243 other => serde_json::Value::String(format!("{}", other)),
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use std::io::Write;
251 use tempfile::NamedTempFile;
252
253 #[test]
254 fn test_load_jsonl() {
255 let mut file = NamedTempFile::new().unwrap();
256 writeln!(file, r#"{{"text": "hello", "label": 1}}"#).unwrap();
257 writeln!(file, r#"{{"text": "world", "label": 0}}"#).unwrap();
258 writeln!(file, r#"{{"text": "test", "label": 1}}"#).unwrap();
259
260 let mut source = LocalSource::new(file.path().to_path_buf(), FileFormat::Jsonl).unwrap();
261 let records = source.load(Some(2)).unwrap();
262
263 assert_eq!(records.len(), 2);
264 assert_eq!(records[0].data["text"], "hello");
265 assert_eq!(records[1].data["text"], "world");
266 }
267
268 #[test]
269 fn test_load_json() {
270 let mut file = NamedTempFile::new().unwrap();
271 write!(
272 file,
273 r#"[{{"text": "a", "n": 1}}, {{"text": "b", "n": 2}}]"#
274 )
275 .unwrap();
276
277 let mut source = LocalSource::new(file.path().to_path_buf(), FileFormat::Json).unwrap();
278 let records = source.load(None).unwrap();
279
280 assert_eq!(records.len(), 2);
281 assert_eq!(records[0].data["text"], "a");
282 }
283
284 #[test]
285 fn test_local_source_info() {
286 let mut file = NamedTempFile::new().unwrap();
287 writeln!(file, r#"{{"col1": "val1", "col2": 123}}"#).unwrap();
288 writeln!(file, r#"{{"col1": "val2", "col2": 456}}"#).unwrap();
289
290 let source = LocalSource::new(file.path().to_path_buf(), FileFormat::Jsonl).unwrap();
291 let info = source.info();
292
293 assert_eq!(info.num_rows, 2);
294 assert!(info.columns.contains(&"col1".to_string()));
295 assert!(info.columns.contains(&"col2".to_string()));
296 }
297}