1use std::collections::HashMap;
2use std::path::PathBuf;
3use serde::{Serialize, Deserialize};
4use crate::hel::error::HlxError;
5use crate::atp::types::Value;
6use arrow::datatypes::{Schema, Field, DataType};
7use arrow::array::{Array, ArrayRef, StringArray, Float64Array, Int64Array};
8use arrow::record_batch::RecordBatch;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
12pub enum OutputFormat {
13 Helix,
14 Hlxc,
15 Parquet,
16 MsgPack,
17 Jsonl,
18 Csv,
19}
20impl OutputFormat {
21 pub fn from(s: &str) -> Result<Self, HlxError> {
22 match s.to_lowercase().as_str() {
23 "helix" | "hlx" => Ok(OutputFormat::Helix),
24 "hlxc" | "compressed" => Ok(OutputFormat::Hlxc),
25 "parquet" => Ok(OutputFormat::Parquet),
26 "msgpack" | "messagepack" => Ok(OutputFormat::MsgPack),
27 "jsonl" | "json" => Ok(OutputFormat::Jsonl),
28 "csv" => Ok(OutputFormat::Csv),
29 _ => {
30 Err(
31 HlxError::validation_error(
32 format!("Unsupported output format: {}", s),
33 "Supported formats: helix, hlxc, parquet, msgpack, jsonl, csv",
34 ),
35 )
36 }
37 }
38 }
39}
40impl std::str::FromStr for OutputFormat {
41 type Err = HlxError;
42 fn from_str(s: &str) -> Result<Self, Self::Err> {
43 Self::from(s)
44 }
45}
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct OutputConfig {
48 pub output_dir: PathBuf,
49 pub formats: Vec<OutputFormat>,
50 pub compression: CompressionConfig,
51 pub batch_size: usize,
52 pub include_preview: bool,
53 pub preview_rows: usize,
54}
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CompressionConfig {
57 pub enabled: bool,
58 pub algorithm: CompressionAlgorithm,
59 pub level: u32,
60}
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum CompressionAlgorithm {
63 Zstd,
64 Lz4,
65 Snappy,
66}
67impl Default for CompressionConfig {
68 fn default() -> Self {
69 Self {
70 enabled: true,
71 algorithm: CompressionAlgorithm::Zstd,
72 level: 4,
73 }
74 }
75}
76impl Default for OutputConfig {
77 fn default() -> Self {
78 Self {
79 output_dir: PathBuf::from("output"),
80 formats: vec![OutputFormat::Helix, OutputFormat::Jsonl],
81 compression: CompressionConfig::default(),
82 batch_size: 1000,
83 include_preview: true,
84 preview_rows: 10,
85 }
86 }
87}
88pub trait DataWriter {
89 fn write_batch(&mut self, batch: RecordBatch) -> Result<(), HlxError>;
90 fn finalize(&mut self) -> Result<(), HlxError>;
91}
92pub struct OutputManager {
93 config: OutputConfig,
94 writers: HashMap<OutputFormat, Box<dyn DataWriter>>,
95 current_batch: Vec<HashMap<String, Value>>,
96 schema: Option<Schema>,
97 batch_count: usize,
98 writers_initialized: bool,
99}
100impl OutputManager {
101 pub fn new(config: OutputConfig) -> Self {
102 Self {
103 config,
104 writers: HashMap::new(),
105 current_batch: Vec::new(),
106 schema: None,
107 batch_count: 0,
108 writers_initialized: false,
109 }
110 }
111 pub fn add_row(&mut self, row: HashMap<String, Value>) -> Result<(), HlxError> {
112 if self.schema.is_none() {
113 self.schema = Some(infer_schema(&row));
114 }
115 self.current_batch.push(row);
116 if self.current_batch.len() >= self.config.batch_size {
117 self.flush_batch()?;
118 }
119 Ok(())
120 }
121 pub fn flush_batch(&mut self) -> Result<(), HlxError> {
122 if self.current_batch.is_empty() {
123 return Ok(());
124 }
125 if let Some(schema) = &self.schema {
126 let batch = convert_to_record_batch(schema, &self.current_batch)?;
127 self.write_batch_to_all_writers(batch)?;
128 }
129 self.current_batch.clear();
130 Ok(())
131 }
132 pub fn finalize_all(&mut self) -> Result<(), HlxError> {
133 self.flush_batch()?;
134 for writer in self.writers.values_mut() {
135 writer.finalize()?;
136 }
137 Ok(())
138 }
139 fn initialize_writers(&mut self) -> Result<(), HlxError> {
140 if self.writers_initialized {
141 return Ok(());
142 }
143 for format in &self.config.formats {
144 let writer: Box<dyn DataWriter> = match format {
145 OutputFormat::Hlxc => Box::new(HlxcDataWriter::new(self.config.clone())),
146 _ => {
147 continue;
148 }
149 };
150 self.writers.insert(format.clone(), writer);
151 }
152 self.writers_initialized = true;
153 Ok(())
154 }
155 fn write_batch_to_all_writers(
156 &mut self,
157 batch: RecordBatch,
158 ) -> Result<(), HlxError> {
159 self.initialize_writers()?;
160 for (format, writer) in &mut self.writers {
161 if *format == OutputFormat::Hlxc {
162 writer.write_batch(batch.clone())?;
163 }
164 }
165 Ok(())
166 }
167 pub fn get_output_files(&self) -> Vec<PathBuf> {
168 let mut files = Vec::new();
169 for format in &self.config.formats {
170 let extension = match format {
171 OutputFormat::Helix => "helix",
172 OutputFormat::Hlxc => "hlxc",
173 OutputFormat::Parquet => "parquet",
174 OutputFormat::MsgPack => "msgpack",
175 OutputFormat::Jsonl => "jsonl",
176 OutputFormat::Csv => "csv",
177 };
178 let filename = format!("output_{:04}.{}", self.batch_count, extension);
179 files.push(self.config.output_dir.join(filename));
180 }
181 files
182 }
183}
184
185pub struct HlxcDataWriter {
186 config: OutputConfig,
187 buffer: Vec<u8>,
188}
189
190impl HlxcDataWriter {
191 pub fn new(config: OutputConfig) -> Self {
192 Self {
193 config,
194 buffer: Vec::new(),
195 }
196 }
197}
198
199impl DataWriter for HlxcDataWriter {
200 fn write_batch(&mut self, batch: RecordBatch) -> Result<(), HlxError> {
201 let schema_info = format!("{{\"fields\": {}, \"rows\": {}}}",
204 batch.schema().fields().len(),
205 batch.num_rows()
206 );
207 let data_json = format!("{{\"schema\": {}, \"rows\": {}}}", schema_info, batch.num_rows());
208 self.buffer.extend_from_slice(data_json.as_bytes());
209 Ok(())
210 }
211
212 fn finalize(&mut self) -> Result<(), HlxError> {
213 Ok(())
216 }
217}
218
219fn infer_schema(row: &HashMap<String, Value>) -> Schema {
220 let fields: Vec<arrow::datatypes::Field> = row
221 .iter()
222 .map(|(name, value)| {
223 let data_type = match value {
224 Value::String(_) => DataType::Utf8,
225 Value::Number(_) => DataType::Float64,
226 Value::Bool(_) => DataType::Boolean,
227 _ => DataType::Utf8,
228 };
229 Field::new(name, data_type, true)
230 })
231 .collect();
232 Schema::new(fields)
233}
234fn convert_to_record_batch(
235 schema: &Schema,
236 batch: &[HashMap<String, Value>],
237) -> Result<RecordBatch, HlxError> {
238 let arrays: Result<Vec<ArrayRef>, HlxError> = schema
239 .fields()
240 .iter()
241 .map(|field| {
242 let column_data: Vec<Value> = batch
243 .iter()
244 .map(|row| { row.get(field.name()).cloned().unwrap_or(Value::Null) })
245 .collect();
246 match field.data_type() {
247 DataType::Utf8 => {
248 let string_data: Vec<Option<String>> = column_data
249 .into_iter()
250 .map(|v| {
251 match v {
252 Value::String(s) => Some(s),
253 _ => Some(v.to_string()),
254 }
255 })
256 .collect();
257 Ok(Arc::new(StringArray::from(string_data)) as ArrayRef)
258 }
259 DataType::Float64 => {
260 let float_data: Vec<Option<f64>> = column_data
261 .into_iter()
262 .map(|v| {
263 match v {
264 Value::Number(n) => Some(n),
265 Value::String(s) => s.parse().ok(),
266 _ => None,
267 }
268 })
269 .collect();
270 Ok(Arc::new(Float64Array::from(float_data)) as ArrayRef)
271 }
272 DataType::Int64 => {
273 let int_data: Vec<Option<i64>> = column_data
274 .into_iter()
275 .map(|v| {
276 match v {
277 Value::Number(n) => Some(n as i64),
278 Value::String(s) => s.parse().ok(),
279 _ => None,
280 }
281 })
282 .collect();
283 Ok(Arc::new(Int64Array::from(int_data)) as ArrayRef)
284 }
285 DataType::Boolean => {
286 let bool_data: Vec<Option<bool>> = column_data
287 .into_iter()
288 .map(|v| {
289 match v {
290 Value::Bool(b) => Some(b),
291 Value::String(s) => {
292 match s.to_lowercase().as_str() {
293 "true" | "1" | "yes" => Some(true),
294 "false" | "0" | "no" => Some(false),
295 _ => None,
296 }
297 }
298 _ => None,
299 }
300 })
301 .collect();
302 Ok(Arc::new(arrow::array::BooleanArray::from(bool_data)) as ArrayRef)
303 }
304 _ => {
305 let string_data: Vec<Option<String>> = column_data
306 .into_iter()
307 .map(|v| { Some(v.to_string()) })
308 .collect();
309 Ok(Arc::new(StringArray::from(string_data)) as ArrayRef)
310 }
311 }
312 })
313 .collect();
314 let arrays = arrays?;
315 RecordBatch::try_new(Arc::new(schema.clone()), arrays)
316 .map_err(|e| HlxError::validation_error(
317 format!("Failed to create record batch: {}", e),
318 "",
319 ))
320}
321fn convert_batch_to_hashmap(batch: &RecordBatch) -> HashMap<String, Value> {
322 let mut result = HashMap::new();
323 for (field_idx, field) in batch.schema().fields().iter().enumerate() {
324 if let Some(array) = batch
325 .column(field_idx)
326 .as_any()
327 .downcast_ref::<StringArray>()
328 {
329 let values: Vec<Value> = (0..batch.num_rows())
330 .map(|i| {
331 if array.is_valid(i) {
332 Value::String(array.value(i).to_string())
333 } else {
334 Value::Null
335 }
336 })
337 .collect();
338 result.insert(field.name().clone(), Value::Array(values));
339 } else if let Some(array) = batch
340 .column(field_idx)
341 .as_any()
342 .downcast_ref::<Float64Array>()
343 {
344 let values: Vec<Value> = (0..batch.num_rows())
345 .map(|i| {
346 if array.is_valid(i) {
347 Value::Number(array.value(i))
348 } else {
349 Value::Null
350 }
351 })
352 .collect();
353 result.insert(field.name().clone(), Value::Array(values));
354 } else if let Some(array) = batch
355 .column(field_idx)
356 .as_any()
357 .downcast_ref::<Int64Array>()
358 {
359 let values: Vec<Value> = (0..batch.num_rows())
360 .map(|i| {
361 if array.is_valid(i) {
362 Value::Number(array.value(i) as f64)
363 } else {
364 Value::Null
365 }
366 })
367 .collect();
368 result.insert(field.name().clone(), Value::Array(values));
369 }
370 }
371 result
372}
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use std::collections::HashMap;
377 #[test]
378 fn test_infer_schema() {
379 let mut row = HashMap::new();
380 row.insert("name".to_string(), Value::String("John".to_string()));
381 row.insert("age".to_string(), Value::Number(30.0));
382 row.insert("active".to_string(), Value::Bool(true));
383 let schema = infer_schema(&row);
384 assert_eq!(schema.fields().len(), 3);
385 assert_eq!(schema.field(0).name(), "name");
386 assert_eq!(schema.field(0).data_type(), & DataType::Utf8);
387 assert_eq!(schema.field(1).name(), "age");
388 assert_eq!(schema.field(1).data_type(), & DataType::Float64);
389 }
390 #[test]
391 fn test_output_format_from_str() {
392 assert_eq!(
393 OutputFormat::from("helix").expect("Failed to parse 'helix'"),
394 OutputFormat::Helix
395 );
396 assert_eq!(
397 OutputFormat::from("hlxc").expect("Failed to parse 'hlxc'"),
398 OutputFormat::Hlxc
399 );
400 assert_eq!(
401 OutputFormat::from("compressed").expect("Failed to parse 'compressed'"),
402 OutputFormat::Hlxc
403 );
404 assert_eq!(
405 OutputFormat::from("parquet").expect("Failed to parse 'parquet'"),
406 OutputFormat::Parquet
407 );
408 assert_eq!(
409 OutputFormat::from("msgpack").expect("Failed to parse 'msgpack'"),
410 OutputFormat::MsgPack
411 );
412 assert_eq!(
413 OutputFormat::from("jsonl").expect("Failed to parse 'jsonl'"),
414 OutputFormat::Jsonl
415 );
416 assert_eq!(
417 OutputFormat::from("csv").expect("Failed to parse 'csv'"), OutputFormat::Csv
418 );
419 assert!(OutputFormat::from("invalid").is_err());
420 }
421}