polars_redis/types/set/
convert.rs

1//! Arrow conversion for Redis set data.
2
3use std::sync::Arc;
4
5use arrow::array::{ArrayRef, RecordBatch, StringBuilder, UInt64Builder};
6use arrow::datatypes::{DataType, Field, Schema};
7
8use super::reader::SetData;
9use crate::error::Result;
10
11/// Schema configuration for Redis set scanning.
12///
13/// Defines output columns when scanning Redis sets. Each set member becomes
14/// a row in the output DataFrame.
15///
16/// # Example
17///
18/// ```ignore
19/// use polars_redis::SetSchema;
20///
21/// let schema = SetSchema::new()
22///     .with_key(true)
23///     .with_member_column_name("tag");
24/// ```
25///
26/// # Output Schema
27///
28/// - `_key` (optional): The Redis key
29/// - `member`: The set member value (Utf8)
30/// - `_index` (optional): Row number
31#[derive(Debug, Clone)]
32pub struct SetSchema {
33    /// Whether to include the Redis key as a column.
34    pub include_key: bool,
35    /// Name of the key column.
36    pub key_column_name: String,
37    /// Name of the member column.
38    pub member_column_name: String,
39    /// Whether to include a row index column.
40    pub include_row_index: bool,
41    /// Name of the row index column.
42    pub row_index_column_name: String,
43}
44
45impl Default for SetSchema {
46    fn default() -> Self {
47        Self {
48            include_key: true,
49            key_column_name: "_key".to_string(),
50            member_column_name: "member".to_string(),
51            include_row_index: false,
52            row_index_column_name: "_index".to_string(),
53        }
54    }
55}
56
57impl SetSchema {
58    /// Create a new SetSchema with default settings.
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Set whether to include the key column.
64    pub fn with_key(mut self, include: bool) -> Self {
65        self.include_key = include;
66        self
67    }
68
69    /// Set the key column name.
70    pub fn with_key_column_name(mut self, name: &str) -> Self {
71        self.key_column_name = name.to_string();
72        self
73    }
74
75    /// Set the member column name.
76    pub fn with_member_column_name(mut self, name: &str) -> Self {
77        self.member_column_name = name.to_string();
78        self
79    }
80
81    /// Set whether to include a row index column.
82    pub fn with_row_index(mut self, include: bool) -> Self {
83        self.include_row_index = include;
84        self
85    }
86
87    /// Set the row index column name.
88    pub fn with_row_index_column_name(mut self, name: &str) -> Self {
89        self.row_index_column_name = name.to_string();
90        self
91    }
92
93    /// Build the Arrow schema for this configuration.
94    pub fn to_arrow_schema(&self) -> Schema {
95        let mut fields = Vec::new();
96
97        if self.include_row_index {
98            fields.push(Field::new(
99                &self.row_index_column_name,
100                DataType::UInt64,
101                false,
102            ));
103        }
104
105        if self.include_key {
106            fields.push(Field::new(&self.key_column_name, DataType::Utf8, false));
107        }
108
109        fields.push(Field::new(&self.member_column_name, DataType::Utf8, false));
110
111        Schema::new(fields)
112    }
113}
114
115/// Convert set data to an Arrow RecordBatch.
116///
117/// Each set member becomes a row in the output.
118pub fn sets_to_record_batch(
119    data: &[SetData],
120    schema: &SetSchema,
121    row_index_offset: u64,
122) -> Result<RecordBatch> {
123    // Count total members across all sets
124    let total_members: usize = data.iter().map(|s| s.members.len()).sum();
125
126    let arrow_schema = Arc::new(schema.to_arrow_schema());
127    let mut columns: Vec<ArrayRef> = Vec::new();
128
129    // Row index column
130    if schema.include_row_index {
131        let mut builder = UInt64Builder::with_capacity(total_members);
132        let mut idx = row_index_offset;
133        for set_data in data {
134            for _ in &set_data.members {
135                builder.append_value(idx);
136                idx += 1;
137            }
138        }
139        columns.push(Arc::new(builder.finish()));
140    }
141
142    // Key column
143    if schema.include_key {
144        let mut builder = StringBuilder::with_capacity(total_members, total_members * 32);
145        for set_data in data {
146            for _ in &set_data.members {
147                builder.append_value(&set_data.key);
148            }
149        }
150        columns.push(Arc::new(builder.finish()));
151    }
152
153    // Member column
154    let mut builder = StringBuilder::with_capacity(total_members, total_members * 64);
155    for set_data in data {
156        for member in &set_data.members {
157            builder.append_value(member);
158        }
159    }
160    columns.push(Arc::new(builder.finish()));
161
162    RecordBatch::try_new(arrow_schema, columns).map_err(|e| {
163        crate::error::Error::TypeConversion(format!("Failed to create RecordBatch: {}", e))
164    })
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_set_schema_default() {
173        let schema = SetSchema::new();
174        assert!(schema.include_key);
175        assert_eq!(schema.key_column_name, "_key");
176        assert_eq!(schema.member_column_name, "member");
177        assert!(!schema.include_row_index);
178    }
179
180    #[test]
181    fn test_set_schema_builder() {
182        let schema = SetSchema::new()
183            .with_key(false)
184            .with_member_column_name("value")
185            .with_row_index(true)
186            .with_row_index_column_name("idx");
187
188        assert!(!schema.include_key);
189        assert_eq!(schema.member_column_name, "value");
190        assert!(schema.include_row_index);
191        assert_eq!(schema.row_index_column_name, "idx");
192    }
193
194    #[test]
195    fn test_sets_to_record_batch_basic() {
196        let data = vec![
197            SetData {
198                key: "set:1".to_string(),
199                members: vec!["a".to_string(), "b".to_string(), "c".to_string()],
200            },
201            SetData {
202                key: "set:2".to_string(),
203                members: vec!["x".to_string(), "y".to_string()],
204            },
205        ];
206
207        let schema = SetSchema::new();
208        let batch = sets_to_record_batch(&data, &schema, 0).unwrap();
209
210        assert_eq!(batch.num_rows(), 5); // 3 + 2 members
211        assert_eq!(batch.num_columns(), 2); // key + member
212    }
213
214    #[test]
215    fn test_sets_to_record_batch_with_row_index() {
216        let data = vec![SetData {
217            key: "set:1".to_string(),
218            members: vec!["a".to_string(), "b".to_string()],
219        }];
220
221        let schema = SetSchema::new().with_row_index(true);
222        let batch = sets_to_record_batch(&data, &schema, 10).unwrap();
223
224        assert_eq!(batch.num_rows(), 2);
225        assert_eq!(batch.num_columns(), 3); // index + key + member
226
227        // Check row indices start at offset
228        let idx_col = batch
229            .column(0)
230            .as_any()
231            .downcast_ref::<arrow::array::UInt64Array>()
232            .unwrap();
233        assert_eq!(idx_col.value(0), 10);
234        assert_eq!(idx_col.value(1), 11);
235    }
236
237    #[test]
238    fn test_sets_to_record_batch_no_key() {
239        let data = vec![SetData {
240            key: "set:1".to_string(),
241            members: vec!["a".to_string()],
242        }];
243
244        let schema = SetSchema::new().with_key(false);
245        let batch = sets_to_record_batch(&data, &schema, 0).unwrap();
246
247        assert_eq!(batch.num_columns(), 1); // member only
248    }
249
250    #[test]
251    fn test_empty_set() {
252        let data = vec![SetData {
253            key: "empty:set".to_string(),
254            members: vec![],
255        }];
256
257        let schema = SetSchema::new();
258        let batch = sets_to_record_batch(&data, &schema, 0).unwrap();
259
260        assert_eq!(batch.num_rows(), 0);
261    }
262}