synthclaw 0.1.3

Lightweight synthetic data generation library/CLI.
Documentation
use polars::prelude::*;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;

use super::{DataSource, DatasetInfo, Record};
use crate::config::FileFormat;
use crate::{Error, Result};

pub struct LocalSource {
    path: PathBuf,
    format: FileFormat,
    info: DatasetInfo,
}

impl LocalSource {
    pub fn new(path: PathBuf, format: FileFormat) -> Result<Self> {
        if !path.exists() {
            return Err(Error::Dataset(format!("File not found: {:?}", path)));
        }

        let info = Self::detect_info(&path, &format)?;

        Ok(Self { path, format, info })
    }

    fn detect_info(path: &PathBuf, format: &FileFormat) -> Result<DatasetInfo> {
        let (columns, num_rows) = match format {
            FileFormat::Jsonl => Self::detect_jsonl_info(path)?,
            FileFormat::Json => Self::detect_json_info(path)?,
            FileFormat::Csv => Self::detect_csv_info(path)?,
            FileFormat::Parquet => Self::detect_parquet_info(path)?,
        };

        Ok(DatasetInfo {
            name: path
                .file_name()
                .and_then(|n| n.to_str())
                .unwrap_or("local")
                .to_string(),
            description: None,
            num_rows,
            columns,
            splits: vec![],
        })
    }

    fn detect_jsonl_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
        let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
        let reader = BufReader::new(file);
        let mut columns = Vec::new();
        let mut num_rows = 0;

        for (i, line) in reader.lines().enumerate() {
            let line = line.map_err(|e| Error::Dataset(e.to_string()))?;
            if line.trim().is_empty() {
                continue;
            }
            num_rows += 1;

            if i == 0 {
                let obj: serde_json::Value =
                    serde_json::from_str(&line).map_err(|e| Error::Dataset(e.to_string()))?;
                if let Some(map) = obj.as_object() {
                    columns = map.keys().cloned().collect();
                }
            }
        }

        Ok((columns, num_rows))
    }

    fn detect_json_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
        let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
        let data: serde_json::Value =
            serde_json::from_reader(file).map_err(|e| Error::Dataset(e.to_string()))?;

        match data {
            serde_json::Value::Array(arr) => {
                let num_rows = arr.len();
                let columns = arr
                    .first()
                    .and_then(|v| v.as_object())
                    .map(|m| m.keys().cloned().collect())
                    .unwrap_or_default();
                Ok((columns, num_rows))
            }
            _ => Err(Error::Dataset("JSON file must contain an array".into())),
        }
    }

    fn detect_csv_info(path: &std::path::Path) -> Result<(Vec<String>, usize)> {
        let df = CsvReadOptions::default()
            .with_has_header(true)
            .try_into_reader_with_file_path(Some(path.to_path_buf()))
            .map_err(|e| Error::Dataset(e.to_string()))?
            .finish()
            .map_err(|e| Error::Dataset(e.to_string()))?;

        let columns: Vec<String> = df
            .get_column_names()
            .iter()
            .map(|s| s.to_string())
            .collect();
        Ok((columns, df.height()))
    }

    fn detect_parquet_info(path: &PathBuf) -> Result<(Vec<String>, usize)> {
        let file = File::open(path).map_err(|e| Error::Dataset(e.to_string()))?;
        let df = ParquetReader::new(file)
            .finish()
            .map_err(|e| Error::Dataset(e.to_string()))?;

        let columns: Vec<String> = df
            .get_column_names()
            .iter()
            .map(|s| s.to_string())
            .collect();
        Ok((columns, df.height()))
    }

    fn load_jsonl(&self, sample: Option<usize>) -> Result<Vec<Record>> {
        let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
        let reader = BufReader::new(file);
        let mut records = Vec::new();

        for (i, line) in reader.lines().enumerate() {
            if sample.is_some_and(|n| records.len() >= n) {
                break;
            }

            let line = line.map_err(|e| Error::Dataset(e.to_string()))?;
            if line.trim().is_empty() {
                continue;
            }

            let data: serde_json::Value =
                serde_json::from_str(&line).map_err(|e| Error::Dataset(e.to_string()))?;
            records.push(Record { data, index: i });
        }

        Ok(records)
    }

    fn load_json(&self, sample: Option<usize>) -> Result<Vec<Record>> {
        let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
        let data: serde_json::Value =
            serde_json::from_reader(file).map_err(|e| Error::Dataset(e.to_string()))?;

        match data {
            serde_json::Value::Array(arr) => {
                let limit = sample.unwrap_or(arr.len());
                Ok(arr
                    .into_iter()
                    .take(limit)
                    .enumerate()
                    .map(|(i, data)| Record { data, index: i })
                    .collect())
            }
            _ => Err(Error::Dataset("JSON file must contain an array".into())),
        }
    }

    fn load_csv(&self, sample: Option<usize>) -> Result<Vec<Record>> {
        let mut df = CsvReadOptions::default()
            .with_has_header(true)
            .try_into_reader_with_file_path(Some(self.path.clone()))
            .map_err(|e| Error::Dataset(e.to_string()))?
            .finish()
            .map_err(|e| Error::Dataset(e.to_string()))?;

        if let Some(n) = sample {
            df = df.head(Some(n));
        }

        dataframe_to_records(df)
    }

    fn load_parquet(&self, sample: Option<usize>) -> Result<Vec<Record>> {
        let file = File::open(&self.path).map_err(|e| Error::Dataset(e.to_string()))?;
        let mut df = ParquetReader::new(file)
            .finish()
            .map_err(|e| Error::Dataset(e.to_string()))?;

        if let Some(n) = sample {
            df = df.head(Some(n));
        }

        dataframe_to_records(df)
    }
}

impl DataSource for LocalSource {
    fn info(&self) -> &DatasetInfo {
        &self.info
    }

    fn load(&mut self, sample: Option<usize>) -> Result<Vec<Record>> {
        match self.format {
            FileFormat::Jsonl => self.load_jsonl(sample),
            FileFormat::Json => self.load_json(sample),
            FileFormat::Csv => self.load_csv(sample),
            FileFormat::Parquet => self.load_parquet(sample),
        }
    }
}

fn dataframe_to_records(df: DataFrame) -> Result<Vec<Record>> {
    let mut records = Vec::with_capacity(df.height());

    for i in 0..df.height() {
        let row = df
            .get(i)
            .ok_or_else(|| Error::Dataset("Row not found".into()))?;
        let mut map = serde_json::Map::new();

        for (col_name, value) in df.get_column_names().iter().zip(row.iter()) {
            let json_value = anyvalue_to_json(value);
            map.insert(col_name.to_string(), json_value);
        }

        records.push(Record {
            data: serde_json::Value::Object(map),
            index: i,
        });
    }

    Ok(records)
}

fn anyvalue_to_json(value: &AnyValue) -> serde_json::Value {
    match value {
        AnyValue::Null => serde_json::Value::Null,
        AnyValue::Boolean(b) => serde_json::Value::Bool(*b),
        AnyValue::String(s) => serde_json::Value::String(s.to_string()),
        AnyValue::StringOwned(s) => serde_json::Value::String(s.to_string()),
        AnyValue::Float32(n) => serde_json::Number::from_f64(*n as f64)
            .map(serde_json::Value::Number)
            .unwrap_or(serde_json::Value::Null),
        AnyValue::Float64(n) => serde_json::Number::from_f64(*n)
            .map(serde_json::Value::Number)
            .unwrap_or(serde_json::Value::Null),
        other => serde_json::Value::String(format!("{}", other)),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::io::Write;
    use tempfile::NamedTempFile;

    #[test]
    fn test_load_jsonl() {
        let mut file = NamedTempFile::new().unwrap();
        writeln!(file, r#"{{"text": "hello", "label": 1}}"#).unwrap();
        writeln!(file, r#"{{"text": "world", "label": 0}}"#).unwrap();
        writeln!(file, r#"{{"text": "test", "label": 1}}"#).unwrap();

        let mut source = LocalSource::new(file.path().to_path_buf(), FileFormat::Jsonl).unwrap();
        let records = source.load(Some(2)).unwrap();

        assert_eq!(records.len(), 2);
        assert_eq!(records[0].data["text"], "hello");
        assert_eq!(records[1].data["text"], "world");
    }

    #[test]
    fn test_load_json() {
        let mut file = NamedTempFile::new().unwrap();
        write!(
            file,
            r#"[{{"text": "a", "n": 1}}, {{"text": "b", "n": 2}}]"#
        )
        .unwrap();

        let mut source = LocalSource::new(file.path().to_path_buf(), FileFormat::Json).unwrap();
        let records = source.load(None).unwrap();

        assert_eq!(records.len(), 2);
        assert_eq!(records[0].data["text"], "a");
    }

    #[test]
    fn test_local_source_info() {
        let mut file = NamedTempFile::new().unwrap();
        writeln!(file, r#"{{"col1": "val1", "col2": 123}}"#).unwrap();
        writeln!(file, r#"{{"col1": "val2", "col2": 456}}"#).unwrap();

        let source = LocalSource::new(file.path().to_path_buf(), FileFormat::Jsonl).unwrap();
        let info = source.info();

        assert_eq!(info.num_rows, 2);
        assert!(info.columns.contains(&"col1".to_string()));
        assert!(info.columns.contains(&"col2".to_string()));
    }
}