1use crate::generation::GenerationResult;
2use crate::{Error, Result};
3use polars::prelude::*;
4use serde_json::json;
5use std::fs::{File, OpenOptions};
6use std::io::{BufWriter, Write};
7use std::path::PathBuf;
8
9pub trait OutputWriter: Send {
10 fn write(&mut self, result: &GenerationResult) -> Result<()>;
11 fn flush(&mut self) -> Result<()>;
12}
13
14pub struct JsonlWriter {
15 writer: BufWriter<File>,
16}
17
18impl JsonlWriter {
19 pub fn new(path: PathBuf) -> Result<Self> {
20 if let Some(parent) = path.parent() {
21 std::fs::create_dir_all(parent)?;
22 }
23 let file = OpenOptions::new()
24 .write(true)
25 .create(true)
26 .truncate(true)
27 .open(path)?;
28 Ok(Self {
29 writer: BufWriter::new(file),
30 })
31 }
32}
33
34impl OutputWriter for JsonlWriter {
35 fn write(&mut self, result: &GenerationResult) -> Result<()> {
36 let mut obj = json!({
37 "content": result.content,
38 });
39
40 if let Some(idx) = result.source_index {
41 obj["source_index"] = json!(idx);
42 }
43 if let Some(cat) = &result.category {
44 obj["category"] = json!(cat);
45 }
46
47 obj["input_tokens"] = json!(result.input_tokens);
48 obj["output_tokens"] = json!(result.output_tokens);
49
50 writeln!(self.writer, "{}", serde_json::to_string(&obj)?)?;
51 Ok(())
52 }
53
54 fn flush(&mut self) -> Result<()> {
55 self.writer.flush()?;
56 Ok(())
57 }
58}
59
60pub struct CsvWriter {
61 writer: csv::Writer<File>,
62}
63
64impl CsvWriter {
65 pub fn new(path: PathBuf) -> Result<Self> {
66 if let Some(parent) = path.parent() {
67 std::fs::create_dir_all(parent)?;
68 }
69 let file = OpenOptions::new()
70 .write(true)
71 .create(true)
72 .truncate(true)
73 .open(path)?;
74 let mut writer = csv::Writer::from_writer(file);
75 writer.write_record([
76 "content",
77 "source_index",
78 "category",
79 "input_tokens",
80 "output_tokens",
81 ])?;
82 Ok(Self { writer })
83 }
84}
85
86impl OutputWriter for CsvWriter {
87 fn write(&mut self, result: &GenerationResult) -> Result<()> {
88 self.writer.write_record([
89 &result.content,
90 &result
91 .source_index
92 .map(|i| i.to_string())
93 .unwrap_or_default(),
94 &result.category.clone().unwrap_or_default(),
95 &result.input_tokens.to_string(),
96 &result.output_tokens.to_string(),
97 ])?;
98 Ok(())
99 }
100
101 fn flush(&mut self) -> Result<()> {
102 self.writer.flush()?;
103 Ok(())
104 }
105}
106
107pub struct ParquetWriter {
108 results: Vec<GenerationResult>,
109 path: PathBuf,
110}
111
112impl ParquetWriter {
113 pub fn new(path: PathBuf) -> Result<Self> {
114 if let Some(parent) = path.parent() {
115 std::fs::create_dir_all(parent)?;
116 }
117 Ok(Self {
118 results: Vec::new(),
119 path,
120 })
121 }
122}
123
124impl OutputWriter for ParquetWriter {
125 fn write(&mut self, result: &GenerationResult) -> Result<()> {
126 self.results.push(GenerationResult {
127 content: result.content.clone(),
128 source_index: result.source_index,
129 category: result.category.clone(),
130 input_tokens: result.input_tokens,
131 output_tokens: result.output_tokens,
132 });
133 Ok(())
134 }
135
136 fn flush(&mut self) -> Result<()> {
137 if self.results.is_empty() {
138 return Ok(());
139 }
140
141 let content: Vec<String> = self.results.iter().map(|r| r.content.clone()).collect();
142 let source_index: Vec<Option<i64>> = self
143 .results
144 .iter()
145 .map(|r| r.source_index.map(|i| i as i64))
146 .collect();
147 let category: Vec<Option<String>> =
148 self.results.iter().map(|r| r.category.clone()).collect();
149 let input_tokens: Vec<u32> = self.results.iter().map(|r| r.input_tokens).collect();
150 let output_tokens: Vec<u32> = self.results.iter().map(|r| r.output_tokens).collect();
151
152 let df = DataFrame::new(vec![
153 Series::new("content".into(), content).into(),
154 Series::new("source_index".into(), source_index).into(),
155 Series::new("category".into(), category).into(),
156 Series::new("input_tokens".into(), input_tokens).into(),
157 Series::new("output_tokens".into(), output_tokens).into(),
158 ])
159 .map_err(|e| Error::Dataset(e.to_string()))?;
160
161 let file = File::create(&self.path)?;
162 ParquetWriter::write_parquet(df, file)?;
163
164 Ok(())
165 }
166}
167
168impl ParquetWriter {
169 fn write_parquet(df: DataFrame, file: File) -> Result<()> {
170 polars::prelude::ParquetWriter::new(file)
171 .finish(&mut df.clone())
172 .map_err(|e| Error::Dataset(e.to_string()))?;
173 Ok(())
174 }
175}
176
177pub fn create_writer(
178 format: &crate::config::OutputFormat,
179 path: PathBuf,
180) -> Result<Box<dyn OutputWriter>> {
181 match format {
182 crate::config::OutputFormat::Jsonl => Ok(Box::new(JsonlWriter::new(path)?)),
183 crate::config::OutputFormat::Csv => Ok(Box::new(CsvWriter::new(path)?)),
184 crate::config::OutputFormat::Parquet => Ok(Box::new(ParquetWriter::new(path)?)),
185 crate::config::OutputFormat::Json => Err(Error::Config(
186 "JSON array output not yet implemented, use JSONL instead".to_string(),
187 )),
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use tempfile::NamedTempFile;
195
196 fn sample_result() -> GenerationResult {
197 GenerationResult {
198 content: "Test content".to_string(),
199 source_index: Some(42),
200 category: Some("test".to_string()),
201 input_tokens: 10,
202 output_tokens: 20,
203 }
204 }
205
206 #[test]
207 fn test_jsonl_writer() {
208 let temp = NamedTempFile::new().unwrap();
209 let path = temp.path().to_path_buf();
210
211 let mut writer = JsonlWriter::new(path.clone()).unwrap();
212 writer.write(&sample_result()).unwrap();
213 writer.flush().unwrap();
214 drop(writer);
215
216 let content = std::fs::read_to_string(path).unwrap();
217 assert!(content.contains("Test content"));
218 assert!(content.contains("\"source_index\":42"));
219 }
220
221 #[test]
222 fn test_csv_writer() {
223 let temp = NamedTempFile::new().unwrap();
224 let path = temp.path().to_path_buf();
225
226 let mut writer = CsvWriter::new(path.clone()).unwrap();
227 writer.write(&sample_result()).unwrap();
228 writer.flush().unwrap();
229 drop(writer);
230
231 let content = std::fs::read_to_string(path).unwrap();
232 assert!(content.contains("Test content"));
233 assert!(content.contains("42"));
234 }
235
236 #[test]
237 fn test_parquet_writer() {
238 let temp = NamedTempFile::new().unwrap();
239 let path = temp.path().to_path_buf();
240
241 let mut writer = ParquetWriter::new(path.clone()).unwrap();
242 writer.write(&sample_result()).unwrap();
243 writer.write(&sample_result()).unwrap();
244 writer.flush().unwrap();
245
246 let file = File::open(path).unwrap();
247 let df = polars::prelude::ParquetReader::new(file).finish().unwrap();
248 assert_eq!(df.height(), 2);
249 assert!(df
250 .get_column_names()
251 .iter()
252 .any(|s| s.as_str() == "content"));
253 }
254}