synthclaw 0.1.3

Lightweight synthetic data generation library/CLI.
Documentation
use crate::generation::GenerationResult;
use crate::{Error, Result};
use polars::prelude::*;
use serde_json::json;
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Write};
use std::path::PathBuf;

pub trait OutputWriter: Send {
    fn write(&mut self, result: &GenerationResult) -> Result<()>;
    fn flush(&mut self) -> Result<()>;
}

pub struct JsonlWriter {
    writer: BufWriter<File>,
}

impl JsonlWriter {
    pub fn new(path: PathBuf) -> Result<Self> {
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let file = OpenOptions::new()
            .write(true)
            .create(true)
            .truncate(true)
            .open(path)?;
        Ok(Self {
            writer: BufWriter::new(file),
        })
    }
}

impl OutputWriter for JsonlWriter {
    fn write(&mut self, result: &GenerationResult) -> Result<()> {
        let mut obj = json!({
            "content": result.content,
        });

        if let Some(idx) = result.source_index {
            obj["source_index"] = json!(idx);
        }
        if let Some(cat) = &result.category {
            obj["category"] = json!(cat);
        }

        obj["input_tokens"] = json!(result.input_tokens);
        obj["output_tokens"] = json!(result.output_tokens);

        writeln!(self.writer, "{}", serde_json::to_string(&obj)?)?;
        Ok(())
    }

    fn flush(&mut self) -> Result<()> {
        self.writer.flush()?;
        Ok(())
    }
}

pub struct CsvWriter {
    writer: csv::Writer<File>,
}

impl CsvWriter {
    pub fn new(path: PathBuf) -> Result<Self> {
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let file = OpenOptions::new()
            .write(true)
            .create(true)
            .truncate(true)
            .open(path)?;
        let mut writer = csv::Writer::from_writer(file);
        writer.write_record([
            "content",
            "source_index",
            "category",
            "input_tokens",
            "output_tokens",
        ])?;
        Ok(Self { writer })
    }
}

impl OutputWriter for CsvWriter {
    fn write(&mut self, result: &GenerationResult) -> Result<()> {
        self.writer.write_record([
            &result.content,
            &result
                .source_index
                .map(|i| i.to_string())
                .unwrap_or_default(),
            &result.category.clone().unwrap_or_default(),
            &result.input_tokens.to_string(),
            &result.output_tokens.to_string(),
        ])?;
        Ok(())
    }

    fn flush(&mut self) -> Result<()> {
        self.writer.flush()?;
        Ok(())
    }
}

pub struct ParquetWriter {
    results: Vec<GenerationResult>,
    path: PathBuf,
}

impl ParquetWriter {
    pub fn new(path: PathBuf) -> Result<Self> {
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        Ok(Self {
            results: Vec::new(),
            path,
        })
    }
}

impl OutputWriter for ParquetWriter {
    fn write(&mut self, result: &GenerationResult) -> Result<()> {
        self.results.push(GenerationResult {
            content: result.content.clone(),
            source_index: result.source_index,
            category: result.category.clone(),
            input_tokens: result.input_tokens,
            output_tokens: result.output_tokens,
        });
        Ok(())
    }

    fn flush(&mut self) -> Result<()> {
        if self.results.is_empty() {
            return Ok(());
        }

        let content: Vec<String> = self.results.iter().map(|r| r.content.clone()).collect();
        let source_index: Vec<Option<i64>> = self
            .results
            .iter()
            .map(|r| r.source_index.map(|i| i as i64))
            .collect();
        let category: Vec<Option<String>> =
            self.results.iter().map(|r| r.category.clone()).collect();
        let input_tokens: Vec<u32> = self.results.iter().map(|r| r.input_tokens).collect();
        let output_tokens: Vec<u32> = self.results.iter().map(|r| r.output_tokens).collect();

        let df = DataFrame::new(vec![
            Series::new("content".into(), content).into(),
            Series::new("source_index".into(), source_index).into(),
            Series::new("category".into(), category).into(),
            Series::new("input_tokens".into(), input_tokens).into(),
            Series::new("output_tokens".into(), output_tokens).into(),
        ])
        .map_err(|e| Error::Dataset(e.to_string()))?;

        let file = File::create(&self.path)?;
        ParquetWriter::write_parquet(df, file)?;

        Ok(())
    }
}

impl ParquetWriter {
    fn write_parquet(df: DataFrame, file: File) -> Result<()> {
        polars::prelude::ParquetWriter::new(file)
            .finish(&mut df.clone())
            .map_err(|e| Error::Dataset(e.to_string()))?;
        Ok(())
    }
}

pub fn create_writer(
    format: &crate::config::OutputFormat,
    path: PathBuf,
) -> Result<Box<dyn OutputWriter>> {
    match format {
        crate::config::OutputFormat::Jsonl => Ok(Box::new(JsonlWriter::new(path)?)),
        crate::config::OutputFormat::Csv => Ok(Box::new(CsvWriter::new(path)?)),
        crate::config::OutputFormat::Parquet => Ok(Box::new(ParquetWriter::new(path)?)),
        crate::config::OutputFormat::Json => Err(Error::Config(
            "JSON array output not yet implemented, use JSONL instead".to_string(),
        )),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::NamedTempFile;

    fn sample_result() -> GenerationResult {
        GenerationResult {
            content: "Test content".to_string(),
            source_index: Some(42),
            category: Some("test".to_string()),
            input_tokens: 10,
            output_tokens: 20,
        }
    }

    #[test]
    fn test_jsonl_writer() {
        let temp = NamedTempFile::new().unwrap();
        let path = temp.path().to_path_buf();

        let mut writer = JsonlWriter::new(path.clone()).unwrap();
        writer.write(&sample_result()).unwrap();
        writer.flush().unwrap();
        drop(writer);

        let content = std::fs::read_to_string(path).unwrap();
        assert!(content.contains("Test content"));
        assert!(content.contains("\"source_index\":42"));
    }

    #[test]
    fn test_csv_writer() {
        let temp = NamedTempFile::new().unwrap();
        let path = temp.path().to_path_buf();

        let mut writer = CsvWriter::new(path.clone()).unwrap();
        writer.write(&sample_result()).unwrap();
        writer.flush().unwrap();
        drop(writer);

        let content = std::fs::read_to_string(path).unwrap();
        assert!(content.contains("Test content"));
        assert!(content.contains("42"));
    }

    #[test]
    fn test_parquet_writer() {
        let temp = NamedTempFile::new().unwrap();
        let path = temp.path().to_path_buf();

        let mut writer = ParquetWriter::new(path.clone()).unwrap();
        writer.write(&sample_result()).unwrap();
        writer.write(&sample_result()).unwrap();
        writer.flush().unwrap();

        let file = File::open(path).unwrap();
        let df = polars::prelude::ParquetReader::new(file).finish().unwrap();
        assert_eq!(df.height(), 2);
        assert!(df
            .get_column_names()
            .iter()
            .any(|s| s.as_str() == "content"));
    }
}