helix/dna/atp/
output.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use serde::{Serialize, Deserialize};
4use crate::hel::error::HlxError;
5use crate::atp::types::Value;
6use arrow::datatypes::{Schema, Field, DataType};
7use arrow::array::{Array, ArrayRef, StringArray, Float64Array, Int64Array};
8use arrow::record_batch::RecordBatch;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
12pub enum OutputFormat {
13    Helix,
14    Hlxc,
15    Parquet,
16    MsgPack,
17    Jsonl,
18    Csv,
19}
20impl OutputFormat {
21    pub fn from(s: &str) -> Result<Self, HlxError> {
22        match s.to_lowercase().as_str() {
23            "helix" | "hlx" => Ok(OutputFormat::Helix),
24            "hlxc" | "compressed" => Ok(OutputFormat::Hlxc),
25            "parquet" => Ok(OutputFormat::Parquet),
26            "msgpack" | "messagepack" => Ok(OutputFormat::MsgPack),
27            "jsonl" | "json" => Ok(OutputFormat::Jsonl),
28            "csv" => Ok(OutputFormat::Csv),
29            _ => {
30                Err(
31                    HlxError::validation_error(
32                        format!("Unsupported output format: {}", s),
33                        "Supported formats: helix, hlxc, parquet, msgpack, jsonl, csv",
34                    ),
35                )
36            }
37        }
38    }
39}
40impl std::str::FromStr for OutputFormat {
41    type Err = HlxError;
42    fn from_str(s: &str) -> Result<Self, Self::Err> {
43        Self::from(s)
44    }
45}
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct OutputConfig {
48    pub output_dir: PathBuf,
49    pub formats: Vec<OutputFormat>,
50    pub compression: CompressionConfig,
51    pub batch_size: usize,
52    pub include_preview: bool,
53    pub preview_rows: usize,
54}
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct CompressionConfig {
57    pub enabled: bool,
58    pub algorithm: CompressionAlgorithm,
59    pub level: u32,
60}
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub enum CompressionAlgorithm {
63    Zstd,
64    Lz4,
65    Snappy,
66}
67impl Default for CompressionConfig {
68    fn default() -> Self {
69        Self {
70            enabled: true,
71            algorithm: CompressionAlgorithm::Zstd,
72            level: 4,
73        }
74    }
75}
76impl Default for OutputConfig {
77    fn default() -> Self {
78        Self {
79            output_dir: PathBuf::from("output"),
80            formats: vec![OutputFormat::Helix, OutputFormat::Jsonl],
81            compression: CompressionConfig::default(),
82            batch_size: 1000,
83            include_preview: true,
84            preview_rows: 10,
85        }
86    }
87}
88pub trait DataWriter {
89    fn write_batch(&mut self, batch: RecordBatch) -> Result<(), HlxError>;
90    fn finalize(&mut self) -> Result<(), HlxError>;
91}
92pub struct OutputManager {
93    config: OutputConfig,
94    writers: HashMap<OutputFormat, Box<dyn DataWriter>>,
95    current_batch: Vec<HashMap<String, Value>>,
96    schema: Option<Schema>,
97    batch_count: usize,
98    writers_initialized: bool,
99}
100impl OutputManager {
101    pub fn new(config: OutputConfig) -> Self {
102        Self {
103            config,
104            writers: HashMap::new(),
105            current_batch: Vec::new(),
106            schema: None,
107            batch_count: 0,
108            writers_initialized: false,
109        }
110    }
111    pub fn add_row(&mut self, row: HashMap<String, Value>) -> Result<(), HlxError> {
112        if self.schema.is_none() {
113            self.schema = Some(infer_schema(&row));
114        }
115        self.current_batch.push(row);
116        if self.current_batch.len() >= self.config.batch_size {
117            self.flush_batch()?;
118        }
119        Ok(())
120    }
121    pub fn flush_batch(&mut self) -> Result<(), HlxError> {
122        if self.current_batch.is_empty() {
123            return Ok(());
124        }
125        if let Some(schema) = &self.schema {
126            let batch = convert_to_record_batch(schema, &self.current_batch)?;
127            self.write_batch_to_all_writers(batch)?;
128        }
129        self.current_batch.clear();
130        Ok(())
131    }
132    pub fn finalize_all(&mut self) -> Result<(), HlxError> {
133        self.flush_batch()?;
134        for writer in self.writers.values_mut() {
135            writer.finalize()?;
136        }
137        Ok(())
138    }
139    fn initialize_writers(&mut self) -> Result<(), HlxError> {
140        if self.writers_initialized {
141            return Ok(());
142        }
143        for format in &self.config.formats {
144            let writer: Box<dyn DataWriter> = match format {
145                OutputFormat::Hlxc => Box::new(HlxcDataWriter::new(self.config.clone())),
146                _ => {
147                    continue;
148                }
149            };
150            self.writers.insert(format.clone(), writer);
151        }
152        self.writers_initialized = true;
153        Ok(())
154    }
155    fn write_batch_to_all_writers(
156        &mut self,
157        batch: RecordBatch,
158    ) -> Result<(), HlxError> {
159        self.initialize_writers()?;
160        for (format, writer) in &mut self.writers {
161            if *format == OutputFormat::Hlxc {
162                writer.write_batch(batch.clone())?;
163            }
164        }
165        Ok(())
166    }
167    pub fn get_output_files(&self) -> Vec<PathBuf> {
168        let mut files = Vec::new();
169        for format in &self.config.formats {
170            let extension = match format {
171                OutputFormat::Helix => "helix",
172                OutputFormat::Hlxc => "hlxc",
173                OutputFormat::Parquet => "parquet",
174                OutputFormat::MsgPack => "msgpack",
175                OutputFormat::Jsonl => "jsonl",
176                OutputFormat::Csv => "csv",
177            };
178            let filename = format!("output_{:04}.{}", self.batch_count, extension);
179            files.push(self.config.output_dir.join(filename));
180        }
181        files
182    }
183}
184
185pub struct HlxcDataWriter {
186    config: OutputConfig,
187    buffer: Vec<u8>,
188}
189
190impl HlxcDataWriter {
191    pub fn new(config: OutputConfig) -> Self {
192        Self {
193            config,
194            buffer: Vec::new(),
195        }
196    }
197}
198
199impl DataWriter for HlxcDataWriter {
200    fn write_batch(&mut self, batch: RecordBatch) -> Result<(), HlxError> {
201        // For now, just serialize to JSON as placeholder
202        // In real implementation, this would write compressed binary format
203        let schema_info = format!("{{\"fields\": {}, \"rows\": {}}}", 
204            batch.schema().fields().len(), 
205            batch.num_rows()
206        );
207        let data_json = format!("{{\"schema\": {}, \"rows\": {}}}", schema_info, batch.num_rows());
208        self.buffer.extend_from_slice(data_json.as_bytes());
209        Ok(())
210    }
211
212    fn finalize(&mut self) -> Result<(), HlxError> {
213        // Write buffer to file if configured
214        // For now, this is a placeholder implementation
215        Ok(())
216    }
217}
218
219fn infer_schema(row: &HashMap<String, Value>) -> Schema {
220    let fields: Vec<arrow::datatypes::Field> = row
221        .iter()
222        .map(|(name, value)| {
223            let data_type = match value {
224                Value::String(_) => DataType::Utf8,
225                Value::Number(_) => DataType::Float64,
226                Value::Bool(_) => DataType::Boolean,
227                _ => DataType::Utf8,
228            };
229            Field::new(name, data_type, true)
230        })
231        .collect();
232    Schema::new(fields)
233}
234fn convert_to_record_batch(
235    schema: &Schema,
236    batch: &[HashMap<String, Value>],
237) -> Result<RecordBatch, HlxError> {
238    let arrays: Result<Vec<ArrayRef>, HlxError> = schema
239        .fields()
240        .iter()
241        .map(|field| {
242            let column_data: Vec<Value> = batch
243                .iter()
244                .map(|row| { row.get(field.name()).cloned().unwrap_or(Value::Null) })
245                .collect();
246            match field.data_type() {
247                DataType::Utf8 => {
248                    let string_data: Vec<Option<String>> = column_data
249                        .into_iter()
250                        .map(|v| {
251                            match v {
252                                Value::String(s) => Some(s),
253                                _ => Some(v.to_string()),
254                            }
255                        })
256                        .collect();
257                    Ok(Arc::new(StringArray::from(string_data)) as ArrayRef)
258                }
259                DataType::Float64 => {
260                    let float_data: Vec<Option<f64>> = column_data
261                        .into_iter()
262                        .map(|v| {
263                            match v {
264                                Value::Number(n) => Some(n),
265                                Value::String(s) => s.parse().ok(),
266                                _ => None,
267                            }
268                        })
269                        .collect();
270                    Ok(Arc::new(Float64Array::from(float_data)) as ArrayRef)
271                }
272                DataType::Int64 => {
273                    let int_data: Vec<Option<i64>> = column_data
274                        .into_iter()
275                        .map(|v| {
276                            match v {
277                                Value::Number(n) => Some(n as i64),
278                                Value::String(s) => s.parse().ok(),
279                                _ => None,
280                            }
281                        })
282                        .collect();
283                    Ok(Arc::new(Int64Array::from(int_data)) as ArrayRef)
284                }
285                DataType::Boolean => {
286                    let bool_data: Vec<Option<bool>> = column_data
287                        .into_iter()
288                        .map(|v| {
289                            match v {
290                                Value::Bool(b) => Some(b),
291                                Value::String(s) => {
292                                    match s.to_lowercase().as_str() {
293                                        "true" | "1" | "yes" => Some(true),
294                                        "false" | "0" | "no" => Some(false),
295                                        _ => None,
296                                    }
297                                }
298                                _ => None,
299                            }
300                        })
301                        .collect();
302                    Ok(Arc::new(arrow::array::BooleanArray::from(bool_data)) as ArrayRef)
303                }
304                _ => {
305                    let string_data: Vec<Option<String>> = column_data
306                        .into_iter()
307                        .map(|v| { Some(v.to_string()) })
308                        .collect();
309                    Ok(Arc::new(StringArray::from(string_data)) as ArrayRef)
310                }
311            }
312        })
313        .collect();
314    let arrays = arrays?;
315    RecordBatch::try_new(Arc::new(schema.clone()), arrays)
316        .map_err(|e| HlxError::validation_error(
317            format!("Failed to create record batch: {}", e),
318            "",
319        ))
320}
321fn convert_batch_to_hashmap(batch: &RecordBatch) -> HashMap<String, Value> {
322    let mut result = HashMap::new();
323    for (field_idx, field) in batch.schema().fields().iter().enumerate() {
324        if let Some(array) = batch
325            .column(field_idx)
326            .as_any()
327            .downcast_ref::<StringArray>()
328        {
329            let values: Vec<Value> = (0..batch.num_rows())
330                .map(|i| {
331                    if array.is_valid(i) {
332                        Value::String(array.value(i).to_string())
333                    } else {
334                        Value::Null
335                    }
336                })
337                .collect();
338            result.insert(field.name().clone(), Value::Array(values));
339        } else if let Some(array) = batch
340            .column(field_idx)
341            .as_any()
342            .downcast_ref::<Float64Array>()
343        {
344            let values: Vec<Value> = (0..batch.num_rows())
345                .map(|i| {
346                    if array.is_valid(i) {
347                        Value::Number(array.value(i))
348                    } else {
349                        Value::Null
350                    }
351                })
352                .collect();
353            result.insert(field.name().clone(), Value::Array(values));
354        } else if let Some(array) = batch
355            .column(field_idx)
356            .as_any()
357            .downcast_ref::<Int64Array>()
358        {
359            let values: Vec<Value> = (0..batch.num_rows())
360                .map(|i| {
361                    if array.is_valid(i) {
362                        Value::Number(array.value(i) as f64)
363                    } else {
364                        Value::Null
365                    }
366                })
367                .collect();
368            result.insert(field.name().clone(), Value::Array(values));
369        }
370    }
371    result
372}
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use std::collections::HashMap;
377    #[test]
378    fn test_infer_schema() {
379        let mut row = HashMap::new();
380        row.insert("name".to_string(), Value::String("John".to_string()));
381        row.insert("age".to_string(), Value::Number(30.0));
382        row.insert("active".to_string(), Value::Bool(true));
383        let schema = infer_schema(&row);
384        assert_eq!(schema.fields().len(), 3);
385        assert_eq!(schema.field(0).name(), "name");
386        assert_eq!(schema.field(0).data_type(), & DataType::Utf8);
387        assert_eq!(schema.field(1).name(), "age");
388        assert_eq!(schema.field(1).data_type(), & DataType::Float64);
389    }
390    #[test]
391    fn test_output_format_from_str() {
392        assert_eq!(
393            OutputFormat::from("helix").expect("Failed to parse 'helix'"),
394            OutputFormat::Helix
395        );
396        assert_eq!(
397            OutputFormat::from("hlxc").expect("Failed to parse 'hlxc'"),
398            OutputFormat::Hlxc
399        );
400        assert_eq!(
401            OutputFormat::from("compressed").expect("Failed to parse 'compressed'"),
402            OutputFormat::Hlxc
403        );
404        assert_eq!(
405            OutputFormat::from("parquet").expect("Failed to parse 'parquet'"),
406            OutputFormat::Parquet
407        );
408        assert_eq!(
409            OutputFormat::from("msgpack").expect("Failed to parse 'msgpack'"),
410            OutputFormat::MsgPack
411        );
412        assert_eq!(
413            OutputFormat::from("jsonl").expect("Failed to parse 'jsonl'"),
414            OutputFormat::Jsonl
415        );
416        assert_eq!(
417            OutputFormat::from("csv").expect("Failed to parse 'csv'"), OutputFormat::Csv
418        );
419        assert!(OutputFormat::from("invalid").is_err());
420    }
421}