Skip to main content

synth_claw/output/
mod.rs

1use crate::generation::GenerationResult;
2use crate::{Error, Result};
3use polars::prelude::*;
4use serde_json::json;
5use std::fs::{File, OpenOptions};
6use std::io::{BufWriter, Write};
7use std::path::PathBuf;
8
9pub trait OutputWriter: Send {
10    fn write(&mut self, result: &GenerationResult) -> Result<()>;
11    fn flush(&mut self) -> Result<()>;
12}
13
14pub struct JsonlWriter {
15    writer: BufWriter<File>,
16}
17
18impl JsonlWriter {
19    pub fn new(path: PathBuf) -> Result<Self> {
20        if let Some(parent) = path.parent() {
21            std::fs::create_dir_all(parent)?;
22        }
23        let file = OpenOptions::new()
24            .write(true)
25            .create(true)
26            .truncate(true)
27            .open(path)?;
28        Ok(Self {
29            writer: BufWriter::new(file),
30        })
31    }
32}
33
34impl OutputWriter for JsonlWriter {
35    fn write(&mut self, result: &GenerationResult) -> Result<()> {
36        let mut obj = json!({
37            "content": result.content,
38        });
39
40        if let Some(idx) = result.source_index {
41            obj["source_index"] = json!(idx);
42        }
43        if let Some(cat) = &result.category {
44            obj["category"] = json!(cat);
45        }
46
47        obj["input_tokens"] = json!(result.input_tokens);
48        obj["output_tokens"] = json!(result.output_tokens);
49
50        writeln!(self.writer, "{}", serde_json::to_string(&obj)?)?;
51        Ok(())
52    }
53
54    fn flush(&mut self) -> Result<()> {
55        self.writer.flush()?;
56        Ok(())
57    }
58}
59
60pub struct CsvWriter {
61    writer: csv::Writer<File>,
62}
63
64impl CsvWriter {
65    pub fn new(path: PathBuf) -> Result<Self> {
66        if let Some(parent) = path.parent() {
67            std::fs::create_dir_all(parent)?;
68        }
69        let file = OpenOptions::new()
70            .write(true)
71            .create(true)
72            .truncate(true)
73            .open(path)?;
74        let mut writer = csv::Writer::from_writer(file);
75        writer.write_record([
76            "content",
77            "source_index",
78            "category",
79            "input_tokens",
80            "output_tokens",
81        ])?;
82        Ok(Self { writer })
83    }
84}
85
86impl OutputWriter for CsvWriter {
87    fn write(&mut self, result: &GenerationResult) -> Result<()> {
88        self.writer.write_record([
89            &result.content,
90            &result
91                .source_index
92                .map(|i| i.to_string())
93                .unwrap_or_default(),
94            &result.category.clone().unwrap_or_default(),
95            &result.input_tokens.to_string(),
96            &result.output_tokens.to_string(),
97        ])?;
98        Ok(())
99    }
100
101    fn flush(&mut self) -> Result<()> {
102        self.writer.flush()?;
103        Ok(())
104    }
105}
106
107pub struct ParquetWriter {
108    results: Vec<GenerationResult>,
109    path: PathBuf,
110}
111
112impl ParquetWriter {
113    pub fn new(path: PathBuf) -> Result<Self> {
114        if let Some(parent) = path.parent() {
115            std::fs::create_dir_all(parent)?;
116        }
117        Ok(Self {
118            results: Vec::new(),
119            path,
120        })
121    }
122}
123
124impl OutputWriter for ParquetWriter {
125    fn write(&mut self, result: &GenerationResult) -> Result<()> {
126        self.results.push(GenerationResult {
127            content: result.content.clone(),
128            source_index: result.source_index,
129            category: result.category.clone(),
130            input_tokens: result.input_tokens,
131            output_tokens: result.output_tokens,
132        });
133        Ok(())
134    }
135
136    fn flush(&mut self) -> Result<()> {
137        if self.results.is_empty() {
138            return Ok(());
139        }
140
141        let content: Vec<String> = self.results.iter().map(|r| r.content.clone()).collect();
142        let source_index: Vec<Option<i64>> = self
143            .results
144            .iter()
145            .map(|r| r.source_index.map(|i| i as i64))
146            .collect();
147        let category: Vec<Option<String>> =
148            self.results.iter().map(|r| r.category.clone()).collect();
149        let input_tokens: Vec<u32> = self.results.iter().map(|r| r.input_tokens).collect();
150        let output_tokens: Vec<u32> = self.results.iter().map(|r| r.output_tokens).collect();
151
152        let df = DataFrame::new(vec![
153            Series::new("content".into(), content).into(),
154            Series::new("source_index".into(), source_index).into(),
155            Series::new("category".into(), category).into(),
156            Series::new("input_tokens".into(), input_tokens).into(),
157            Series::new("output_tokens".into(), output_tokens).into(),
158        ])
159        .map_err(|e| Error::Dataset(e.to_string()))?;
160
161        let file = File::create(&self.path)?;
162        ParquetWriter::write_parquet(df, file)?;
163
164        Ok(())
165    }
166}
167
168impl ParquetWriter {
169    fn write_parquet(df: DataFrame, file: File) -> Result<()> {
170        polars::prelude::ParquetWriter::new(file)
171            .finish(&mut df.clone())
172            .map_err(|e| Error::Dataset(e.to_string()))?;
173        Ok(())
174    }
175}
176
177pub fn create_writer(
178    format: &crate::config::OutputFormat,
179    path: PathBuf,
180) -> Result<Box<dyn OutputWriter>> {
181    match format {
182        crate::config::OutputFormat::Jsonl => Ok(Box::new(JsonlWriter::new(path)?)),
183        crate::config::OutputFormat::Csv => Ok(Box::new(CsvWriter::new(path)?)),
184        crate::config::OutputFormat::Parquet => Ok(Box::new(ParquetWriter::new(path)?)),
185        crate::config::OutputFormat::Json => Err(Error::Config(
186            "JSON array output not yet implemented, use JSONL instead".to_string(),
187        )),
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use tempfile::NamedTempFile;
195
196    fn sample_result() -> GenerationResult {
197        GenerationResult {
198            content: "Test content".to_string(),
199            source_index: Some(42),
200            category: Some("test".to_string()),
201            input_tokens: 10,
202            output_tokens: 20,
203        }
204    }
205
206    #[test]
207    fn test_jsonl_writer() {
208        let temp = NamedTempFile::new().unwrap();
209        let path = temp.path().to_path_buf();
210
211        let mut writer = JsonlWriter::new(path.clone()).unwrap();
212        writer.write(&sample_result()).unwrap();
213        writer.flush().unwrap();
214        drop(writer);
215
216        let content = std::fs::read_to_string(path).unwrap();
217        assert!(content.contains("Test content"));
218        assert!(content.contains("\"source_index\":42"));
219    }
220
221    #[test]
222    fn test_csv_writer() {
223        let temp = NamedTempFile::new().unwrap();
224        let path = temp.path().to_path_buf();
225
226        let mut writer = CsvWriter::new(path.clone()).unwrap();
227        writer.write(&sample_result()).unwrap();
228        writer.flush().unwrap();
229        drop(writer);
230
231        let content = std::fs::read_to_string(path).unwrap();
232        assert!(content.contains("Test content"));
233        assert!(content.contains("42"));
234    }
235
236    #[test]
237    fn test_parquet_writer() {
238        let temp = NamedTempFile::new().unwrap();
239        let path = temp.path().to_path_buf();
240
241        let mut writer = ParquetWriter::new(path.clone()).unwrap();
242        writer.write(&sample_result()).unwrap();
243        writer.write(&sample_result()).unwrap();
244        writer.flush().unwrap();
245
246        let file = File::open(path).unwrap();
247        let df = polars::prelude::ParquetReader::new(file).finish().unwrap();
248        assert_eq!(df.height(), 2);
249        assert!(df
250            .get_column_names()
251            .iter()
252            .any(|s| s.as_str() == "content"));
253    }
254}