polars_redis/types/stream/
convert.rs

1//! Arrow conversion for Redis Stream data.
2
3use std::sync::Arc;
4
5use arrow::array::{ArrayRef, RecordBatch, StringBuilder, UInt64Builder};
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7
8use super::reader::StreamData;
9use crate::error::Result;
10
11/// Schema configuration for Redis Stream scanning.
12///
13/// Defines output columns when scanning Redis Streams. Each stream entry becomes
14/// a row in the output DataFrame.
15///
16/// # Example
17///
18/// ```ignore
19/// use polars_redis::StreamSchema;
20///
21/// let schema = StreamSchema::new(vec![
22///     "user_id".to_string(),
23///     "action".to_string(),
24///     "payload".to_string(),
25/// ])
26/// .with_key(true)
27/// .with_id(true)
28/// .with_timestamp(true);
29/// ```
30///
31/// # Output Schema
32///
33/// - `_key` (optional): The Redis stream key
34/// - `_id` (optional): The entry ID (e.g., "1234567890123-0")
35/// - `_timestamp` (optional): Timestamp extracted from entry ID (Int64 ms)
36/// - `_sequence` (optional): Sequence number from entry ID
37/// - User-defined fields extracted from entry data
38/// - `_index` (optional): Global row number
39#[derive(Debug, Clone)]
40pub struct StreamSchema {
41    /// Whether to include the Redis key as a column.
42    pub include_key: bool,
43    /// Name of the key column.
44    pub key_column_name: String,
45    /// Whether to include the entry ID as a column.
46    pub include_id: bool,
47    /// Name of the entry ID column.
48    pub id_column_name: String,
49    /// Whether to include the timestamp as a column.
50    pub include_timestamp: bool,
51    /// Name of the timestamp column.
52    pub timestamp_column_name: String,
53    /// Whether to include the sequence number as a column.
54    pub include_sequence: bool,
55    /// Name of the sequence column.
56    pub sequence_column_name: String,
57    /// Field names to extract from entries.
58    pub fields: Vec<String>,
59    /// Whether to include a global row index column.
60    pub include_row_index: bool,
61    /// Name of the row index column.
62    pub row_index_column_name: String,
63}
64
65impl Default for StreamSchema {
66    fn default() -> Self {
67        Self {
68            include_key: true,
69            key_column_name: "_key".to_string(),
70            include_id: true,
71            id_column_name: "_id".to_string(),
72            include_timestamp: true,
73            timestamp_column_name: "_ts".to_string(),
74            include_sequence: false,
75            sequence_column_name: "_seq".to_string(),
76            fields: Vec::new(),
77            include_row_index: false,
78            row_index_column_name: "_index".to_string(),
79        }
80    }
81}
82
83impl StreamSchema {
84    /// Create a new StreamSchema with default settings.
85    pub fn new() -> Self {
86        Self::default()
87    }
88
89    /// Create a new StreamSchema with the specified fields.
90    pub fn with_fields(fields: Vec<String>) -> Self {
91        Self {
92            fields,
93            ..Default::default()
94        }
95    }
96
97    /// Set whether to include the key column.
98    pub fn with_key(mut self, include: bool) -> Self {
99        self.include_key = include;
100        self
101    }
102
103    /// Set the key column name.
104    pub fn with_key_column_name(mut self, name: &str) -> Self {
105        self.key_column_name = name.to_string();
106        self
107    }
108
109    /// Set whether to include the entry ID column.
110    pub fn with_id(mut self, include: bool) -> Self {
111        self.include_id = include;
112        self
113    }
114
115    /// Set the entry ID column name.
116    pub fn with_id_column_name(mut self, name: &str) -> Self {
117        self.id_column_name = name.to_string();
118        self
119    }
120
121    /// Set whether to include the timestamp column.
122    pub fn with_timestamp(mut self, include: bool) -> Self {
123        self.include_timestamp = include;
124        self
125    }
126
127    /// Set the timestamp column name.
128    pub fn with_timestamp_column_name(mut self, name: &str) -> Self {
129        self.timestamp_column_name = name.to_string();
130        self
131    }
132
133    /// Set whether to include the sequence column.
134    pub fn with_sequence(mut self, include: bool) -> Self {
135        self.include_sequence = include;
136        self
137    }
138
139    /// Set the sequence column name.
140    pub fn with_sequence_column_name(mut self, name: &str) -> Self {
141        self.sequence_column_name = name.to_string();
142        self
143    }
144
145    /// Add a field to extract from entries.
146    pub fn add_field(mut self, name: &str) -> Self {
147        self.fields.push(name.to_string());
148        self
149    }
150
151    /// Set the fields to extract from entries.
152    pub fn set_fields(mut self, fields: Vec<String>) -> Self {
153        self.fields = fields;
154        self
155    }
156
157    /// Set whether to include a global row index column.
158    pub fn with_row_index(mut self, include: bool) -> Self {
159        self.include_row_index = include;
160        self
161    }
162
163    /// Set the row index column name.
164    pub fn with_row_index_column_name(mut self, name: &str) -> Self {
165        self.row_index_column_name = name.to_string();
166        self
167    }
168
169    /// Build the Arrow schema for this configuration.
170    pub fn to_arrow_schema(&self) -> Schema {
171        let mut arrow_fields = Vec::new();
172
173        if self.include_row_index {
174            arrow_fields.push(Field::new(
175                &self.row_index_column_name,
176                DataType::UInt64,
177                false,
178            ));
179        }
180
181        if self.include_key {
182            arrow_fields.push(Field::new(&self.key_column_name, DataType::Utf8, false));
183        }
184
185        if self.include_id {
186            arrow_fields.push(Field::new(&self.id_column_name, DataType::Utf8, false));
187        }
188
189        if self.include_timestamp {
190            arrow_fields.push(Field::new(
191                &self.timestamp_column_name,
192                DataType::Timestamp(TimeUnit::Millisecond, None),
193                false,
194            ));
195        }
196
197        if self.include_sequence {
198            arrow_fields.push(Field::new(
199                &self.sequence_column_name,
200                DataType::UInt64,
201                false,
202            ));
203        }
204
205        // Add user-defined fields (all as nullable Utf8)
206        for field_name in &self.fields {
207            arrow_fields.push(Field::new(field_name, DataType::Utf8, true));
208        }
209
210        Schema::new(arrow_fields)
211    }
212}
213
214/// Convert stream data to an Arrow RecordBatch.
215///
216/// Each stream entry becomes a row in the output.
217pub fn streams_to_record_batch(
218    data: &[StreamData],
219    schema: &StreamSchema,
220    row_index_offset: u64,
221) -> Result<RecordBatch> {
222    // Count total entries across all streams
223    let total_entries: usize = data.iter().map(|s| s.entries.len()).sum();
224
225    let arrow_schema = Arc::new(schema.to_arrow_schema());
226    let mut columns: Vec<ArrayRef> = Vec::new();
227
228    // Row index column (global)
229    if schema.include_row_index {
230        let mut builder = UInt64Builder::with_capacity(total_entries);
231        let mut idx = row_index_offset;
232        for stream_data in data {
233            for _ in &stream_data.entries {
234                builder.append_value(idx);
235                idx += 1;
236            }
237        }
238        columns.push(Arc::new(builder.finish()));
239    }
240
241    // Key column
242    if schema.include_key {
243        let mut builder = StringBuilder::with_capacity(total_entries, total_entries * 32);
244        for stream_data in data {
245            for _ in &stream_data.entries {
246                builder.append_value(&stream_data.key);
247            }
248        }
249        columns.push(Arc::new(builder.finish()));
250    }
251
252    // Entry ID column
253    if schema.include_id {
254        let mut builder = StringBuilder::with_capacity(total_entries, total_entries * 24);
255        for stream_data in data {
256            for entry in &stream_data.entries {
257                builder.append_value(&entry.id);
258            }
259        }
260        columns.push(Arc::new(builder.finish()));
261    }
262
263    // Timestamp column (milliseconds since epoch)
264    if schema.include_timestamp {
265        let mut values = Vec::with_capacity(total_entries);
266        for stream_data in data {
267            for entry in &stream_data.entries {
268                values.push(entry.timestamp_ms);
269            }
270        }
271        let ts_array = arrow::array::TimestampMillisecondArray::from(values);
272        columns.push(Arc::new(ts_array));
273    }
274
275    // Sequence column
276    if schema.include_sequence {
277        let mut builder = UInt64Builder::with_capacity(total_entries);
278        for stream_data in data {
279            for entry in &stream_data.entries {
280                builder.append_value(entry.sequence);
281            }
282        }
283        columns.push(Arc::new(builder.finish()));
284    }
285
286    // User-defined fields
287    for field_name in &schema.fields {
288        let mut builder = StringBuilder::with_capacity(total_entries, total_entries * 32);
289        for stream_data in data {
290            for entry in &stream_data.entries {
291                match entry.fields.get(field_name) {
292                    Some(value) => builder.append_value(value),
293                    None => builder.append_null(),
294                }
295            }
296        }
297        columns.push(Arc::new(builder.finish()));
298    }
299
300    RecordBatch::try_new(arrow_schema, columns).map_err(|e| {
301        crate::error::Error::TypeConversion(format!("Failed to create RecordBatch: {}", e))
302    })
303}
304
305#[cfg(test)]
306mod tests {
307    use std::collections::HashMap;
308
309    use super::*;
310    use crate::types::stream::reader::StreamEntry;
311
312    #[test]
313    fn test_stream_schema_default() {
314        let schema = StreamSchema::new();
315        assert!(schema.include_key);
316        assert!(schema.include_id);
317        assert!(schema.include_timestamp);
318        assert!(!schema.include_sequence);
319        assert!(!schema.include_row_index);
320        assert!(schema.fields.is_empty());
321    }
322
323    #[test]
324    fn test_stream_schema_builder() {
325        let schema = StreamSchema::new()
326            .with_key(false)
327            .with_id(false)
328            .with_timestamp(true)
329            .with_sequence(true)
330            .add_field("action")
331            .add_field("user");
332
333        assert!(!schema.include_key);
334        assert!(!schema.include_id);
335        assert!(schema.include_timestamp);
336        assert!(schema.include_sequence);
337        assert_eq!(schema.fields, vec!["action", "user"]);
338    }
339
340    #[test]
341    fn test_streams_to_record_batch_basic() {
342        let mut fields = HashMap::new();
343        fields.insert("action".to_string(), "login".to_string());
344
345        let data = vec![StreamData {
346            key: "stream:1".to_string(),
347            entries: vec![
348                StreamEntry {
349                    id: "1234567890123-0".to_string(),
350                    timestamp_ms: 1234567890123,
351                    sequence: 0,
352                    fields: fields.clone(),
353                },
354                StreamEntry {
355                    id: "1234567890124-0".to_string(),
356                    timestamp_ms: 1234567890124,
357                    sequence: 0,
358                    fields: fields.clone(),
359                },
360            ],
361        }];
362
363        let schema = StreamSchema::new().add_field("action");
364        let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
365
366        assert_eq!(batch.num_rows(), 2);
367        // key + id + timestamp + action = 4 columns
368        assert_eq!(batch.num_columns(), 4);
369    }
370
371    #[test]
372    fn test_streams_to_record_batch_with_sequence() {
373        let data = vec![StreamData {
374            key: "stream:1".to_string(),
375            entries: vec![StreamEntry {
376                id: "1234567890123-5".to_string(),
377                timestamp_ms: 1234567890123,
378                sequence: 5,
379                fields: HashMap::new(),
380            }],
381        }];
382
383        let schema = StreamSchema::new().with_sequence(true);
384        let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
385
386        assert_eq!(batch.num_rows(), 1);
387
388        // Check sequence value
389        let seq_col = batch
390            .column(3) // key, id, timestamp, sequence
391            .as_any()
392            .downcast_ref::<arrow::array::UInt64Array>()
393            .unwrap();
394        assert_eq!(seq_col.value(0), 5);
395    }
396
397    #[test]
398    fn test_streams_to_record_batch_missing_field() {
399        use arrow::array::Array;
400
401        let mut fields1 = HashMap::new();
402        fields1.insert("action".to_string(), "login".to_string());
403
404        let fields2 = HashMap::new(); // No action field
405
406        let data = vec![StreamData {
407            key: "stream:1".to_string(),
408            entries: vec![
409                StreamEntry {
410                    id: "1234567890123-0".to_string(),
411                    timestamp_ms: 1234567890123,
412                    sequence: 0,
413                    fields: fields1,
414                },
415                StreamEntry {
416                    id: "1234567890124-0".to_string(),
417                    timestamp_ms: 1234567890124,
418                    sequence: 0,
419                    fields: fields2,
420                },
421            ],
422        }];
423
424        let schema = StreamSchema::new().add_field("action");
425        let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
426
427        // Check that second entry has null action
428        let action_col = batch
429            .column(3) // key, id, timestamp, action
430            .as_any()
431            .downcast_ref::<arrow::array::StringArray>()
432            .unwrap();
433        assert_eq!(action_col.value(0), "login");
434        assert!(action_col.is_null(1));
435    }
436
437    #[test]
438    fn test_streams_to_record_batch_with_row_index() {
439        let data = vec![StreamData {
440            key: "stream:1".to_string(),
441            entries: vec![
442                StreamEntry {
443                    id: "1234567890123-0".to_string(),
444                    timestamp_ms: 1234567890123,
445                    sequence: 0,
446                    fields: HashMap::new(),
447                },
448                StreamEntry {
449                    id: "1234567890124-0".to_string(),
450                    timestamp_ms: 1234567890124,
451                    sequence: 0,
452                    fields: HashMap::new(),
453                },
454            ],
455        }];
456
457        let schema = StreamSchema::new().with_row_index(true);
458        let batch = streams_to_record_batch(&data, &schema, 100).unwrap();
459
460        // Check row indices start at offset
461        let idx_col = batch
462            .column(0)
463            .as_any()
464            .downcast_ref::<arrow::array::UInt64Array>()
465            .unwrap();
466        assert_eq!(idx_col.value(0), 100);
467        assert_eq!(idx_col.value(1), 101);
468    }
469
470    #[test]
471    fn test_empty_stream() {
472        let data = vec![StreamData {
473            key: "empty:stream".to_string(),
474            entries: vec![],
475        }];
476
477        let schema = StreamSchema::new();
478        let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
479
480        assert_eq!(batch.num_rows(), 0);
481    }
482}