polars_redis/types/stream/
convert.rs1use std::sync::Arc;
4
5use arrow::array::{ArrayRef, RecordBatch, StringBuilder, UInt64Builder};
6use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
7
8use super::reader::StreamData;
9use crate::error::Result;
10
11#[derive(Debug, Clone)]
40pub struct StreamSchema {
41 pub include_key: bool,
43 pub key_column_name: String,
45 pub include_id: bool,
47 pub id_column_name: String,
49 pub include_timestamp: bool,
51 pub timestamp_column_name: String,
53 pub include_sequence: bool,
55 pub sequence_column_name: String,
57 pub fields: Vec<String>,
59 pub include_row_index: bool,
61 pub row_index_column_name: String,
63}
64
65impl Default for StreamSchema {
66 fn default() -> Self {
67 Self {
68 include_key: true,
69 key_column_name: "_key".to_string(),
70 include_id: true,
71 id_column_name: "_id".to_string(),
72 include_timestamp: true,
73 timestamp_column_name: "_ts".to_string(),
74 include_sequence: false,
75 sequence_column_name: "_seq".to_string(),
76 fields: Vec::new(),
77 include_row_index: false,
78 row_index_column_name: "_index".to_string(),
79 }
80 }
81}
82
83impl StreamSchema {
84 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn with_fields(fields: Vec<String>) -> Self {
91 Self {
92 fields,
93 ..Default::default()
94 }
95 }
96
97 pub fn with_key(mut self, include: bool) -> Self {
99 self.include_key = include;
100 self
101 }
102
103 pub fn with_key_column_name(mut self, name: &str) -> Self {
105 self.key_column_name = name.to_string();
106 self
107 }
108
109 pub fn with_id(mut self, include: bool) -> Self {
111 self.include_id = include;
112 self
113 }
114
115 pub fn with_id_column_name(mut self, name: &str) -> Self {
117 self.id_column_name = name.to_string();
118 self
119 }
120
121 pub fn with_timestamp(mut self, include: bool) -> Self {
123 self.include_timestamp = include;
124 self
125 }
126
127 pub fn with_timestamp_column_name(mut self, name: &str) -> Self {
129 self.timestamp_column_name = name.to_string();
130 self
131 }
132
133 pub fn with_sequence(mut self, include: bool) -> Self {
135 self.include_sequence = include;
136 self
137 }
138
139 pub fn with_sequence_column_name(mut self, name: &str) -> Self {
141 self.sequence_column_name = name.to_string();
142 self
143 }
144
145 pub fn add_field(mut self, name: &str) -> Self {
147 self.fields.push(name.to_string());
148 self
149 }
150
151 pub fn set_fields(mut self, fields: Vec<String>) -> Self {
153 self.fields = fields;
154 self
155 }
156
157 pub fn with_row_index(mut self, include: bool) -> Self {
159 self.include_row_index = include;
160 self
161 }
162
163 pub fn with_row_index_column_name(mut self, name: &str) -> Self {
165 self.row_index_column_name = name.to_string();
166 self
167 }
168
169 pub fn to_arrow_schema(&self) -> Schema {
171 let mut arrow_fields = Vec::new();
172
173 if self.include_row_index {
174 arrow_fields.push(Field::new(
175 &self.row_index_column_name,
176 DataType::UInt64,
177 false,
178 ));
179 }
180
181 if self.include_key {
182 arrow_fields.push(Field::new(&self.key_column_name, DataType::Utf8, false));
183 }
184
185 if self.include_id {
186 arrow_fields.push(Field::new(&self.id_column_name, DataType::Utf8, false));
187 }
188
189 if self.include_timestamp {
190 arrow_fields.push(Field::new(
191 &self.timestamp_column_name,
192 DataType::Timestamp(TimeUnit::Millisecond, None),
193 false,
194 ));
195 }
196
197 if self.include_sequence {
198 arrow_fields.push(Field::new(
199 &self.sequence_column_name,
200 DataType::UInt64,
201 false,
202 ));
203 }
204
205 for field_name in &self.fields {
207 arrow_fields.push(Field::new(field_name, DataType::Utf8, true));
208 }
209
210 Schema::new(arrow_fields)
211 }
212}
213
214pub fn streams_to_record_batch(
218 data: &[StreamData],
219 schema: &StreamSchema,
220 row_index_offset: u64,
221) -> Result<RecordBatch> {
222 let total_entries: usize = data.iter().map(|s| s.entries.len()).sum();
224
225 let arrow_schema = Arc::new(schema.to_arrow_schema());
226 let mut columns: Vec<ArrayRef> = Vec::new();
227
228 if schema.include_row_index {
230 let mut builder = UInt64Builder::with_capacity(total_entries);
231 let mut idx = row_index_offset;
232 for stream_data in data {
233 for _ in &stream_data.entries {
234 builder.append_value(idx);
235 idx += 1;
236 }
237 }
238 columns.push(Arc::new(builder.finish()));
239 }
240
241 if schema.include_key {
243 let mut builder = StringBuilder::with_capacity(total_entries, total_entries * 32);
244 for stream_data in data {
245 for _ in &stream_data.entries {
246 builder.append_value(&stream_data.key);
247 }
248 }
249 columns.push(Arc::new(builder.finish()));
250 }
251
252 if schema.include_id {
254 let mut builder = StringBuilder::with_capacity(total_entries, total_entries * 24);
255 for stream_data in data {
256 for entry in &stream_data.entries {
257 builder.append_value(&entry.id);
258 }
259 }
260 columns.push(Arc::new(builder.finish()));
261 }
262
263 if schema.include_timestamp {
265 let mut values = Vec::with_capacity(total_entries);
266 for stream_data in data {
267 for entry in &stream_data.entries {
268 values.push(entry.timestamp_ms);
269 }
270 }
271 let ts_array = arrow::array::TimestampMillisecondArray::from(values);
272 columns.push(Arc::new(ts_array));
273 }
274
275 if schema.include_sequence {
277 let mut builder = UInt64Builder::with_capacity(total_entries);
278 for stream_data in data {
279 for entry in &stream_data.entries {
280 builder.append_value(entry.sequence);
281 }
282 }
283 columns.push(Arc::new(builder.finish()));
284 }
285
286 for field_name in &schema.fields {
288 let mut builder = StringBuilder::with_capacity(total_entries, total_entries * 32);
289 for stream_data in data {
290 for entry in &stream_data.entries {
291 match entry.fields.get(field_name) {
292 Some(value) => builder.append_value(value),
293 None => builder.append_null(),
294 }
295 }
296 }
297 columns.push(Arc::new(builder.finish()));
298 }
299
300 RecordBatch::try_new(arrow_schema, columns).map_err(|e| {
301 crate::error::Error::TypeConversion(format!("Failed to create RecordBatch: {}", e))
302 })
303}
304
305#[cfg(test)]
306mod tests {
307 use std::collections::HashMap;
308
309 use super::*;
310 use crate::types::stream::reader::StreamEntry;
311
312 #[test]
313 fn test_stream_schema_default() {
314 let schema = StreamSchema::new();
315 assert!(schema.include_key);
316 assert!(schema.include_id);
317 assert!(schema.include_timestamp);
318 assert!(!schema.include_sequence);
319 assert!(!schema.include_row_index);
320 assert!(schema.fields.is_empty());
321 }
322
323 #[test]
324 fn test_stream_schema_builder() {
325 let schema = StreamSchema::new()
326 .with_key(false)
327 .with_id(false)
328 .with_timestamp(true)
329 .with_sequence(true)
330 .add_field("action")
331 .add_field("user");
332
333 assert!(!schema.include_key);
334 assert!(!schema.include_id);
335 assert!(schema.include_timestamp);
336 assert!(schema.include_sequence);
337 assert_eq!(schema.fields, vec!["action", "user"]);
338 }
339
340 #[test]
341 fn test_streams_to_record_batch_basic() {
342 let mut fields = HashMap::new();
343 fields.insert("action".to_string(), "login".to_string());
344
345 let data = vec![StreamData {
346 key: "stream:1".to_string(),
347 entries: vec![
348 StreamEntry {
349 id: "1234567890123-0".to_string(),
350 timestamp_ms: 1234567890123,
351 sequence: 0,
352 fields: fields.clone(),
353 },
354 StreamEntry {
355 id: "1234567890124-0".to_string(),
356 timestamp_ms: 1234567890124,
357 sequence: 0,
358 fields: fields.clone(),
359 },
360 ],
361 }];
362
363 let schema = StreamSchema::new().add_field("action");
364 let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
365
366 assert_eq!(batch.num_rows(), 2);
367 assert_eq!(batch.num_columns(), 4);
369 }
370
371 #[test]
372 fn test_streams_to_record_batch_with_sequence() {
373 let data = vec![StreamData {
374 key: "stream:1".to_string(),
375 entries: vec![StreamEntry {
376 id: "1234567890123-5".to_string(),
377 timestamp_ms: 1234567890123,
378 sequence: 5,
379 fields: HashMap::new(),
380 }],
381 }];
382
383 let schema = StreamSchema::new().with_sequence(true);
384 let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
385
386 assert_eq!(batch.num_rows(), 1);
387
388 let seq_col = batch
390 .column(3) .as_any()
392 .downcast_ref::<arrow::array::UInt64Array>()
393 .unwrap();
394 assert_eq!(seq_col.value(0), 5);
395 }
396
397 #[test]
398 fn test_streams_to_record_batch_missing_field() {
399 use arrow::array::Array;
400
401 let mut fields1 = HashMap::new();
402 fields1.insert("action".to_string(), "login".to_string());
403
404 let fields2 = HashMap::new(); let data = vec![StreamData {
407 key: "stream:1".to_string(),
408 entries: vec![
409 StreamEntry {
410 id: "1234567890123-0".to_string(),
411 timestamp_ms: 1234567890123,
412 sequence: 0,
413 fields: fields1,
414 },
415 StreamEntry {
416 id: "1234567890124-0".to_string(),
417 timestamp_ms: 1234567890124,
418 sequence: 0,
419 fields: fields2,
420 },
421 ],
422 }];
423
424 let schema = StreamSchema::new().add_field("action");
425 let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
426
427 let action_col = batch
429 .column(3) .as_any()
431 .downcast_ref::<arrow::array::StringArray>()
432 .unwrap();
433 assert_eq!(action_col.value(0), "login");
434 assert!(action_col.is_null(1));
435 }
436
437 #[test]
438 fn test_streams_to_record_batch_with_row_index() {
439 let data = vec![StreamData {
440 key: "stream:1".to_string(),
441 entries: vec![
442 StreamEntry {
443 id: "1234567890123-0".to_string(),
444 timestamp_ms: 1234567890123,
445 sequence: 0,
446 fields: HashMap::new(),
447 },
448 StreamEntry {
449 id: "1234567890124-0".to_string(),
450 timestamp_ms: 1234567890124,
451 sequence: 0,
452 fields: HashMap::new(),
453 },
454 ],
455 }];
456
457 let schema = StreamSchema::new().with_row_index(true);
458 let batch = streams_to_record_batch(&data, &schema, 100).unwrap();
459
460 let idx_col = batch
462 .column(0)
463 .as_any()
464 .downcast_ref::<arrow::array::UInt64Array>()
465 .unwrap();
466 assert_eq!(idx_col.value(0), 100);
467 assert_eq!(idx_col.value(1), 101);
468 }
469
470 #[test]
471 fn test_empty_stream() {
472 let data = vec![StreamData {
473 key: "empty:stream".to_string(),
474 entries: vec![],
475 }];
476
477 let schema = StreamSchema::new();
478 let batch = streams_to_record_batch(&data, &schema, 0).unwrap();
479
480 assert_eq!(batch.num_rows(), 0);
481 }
482}