polars_redis/types/set/
convert.rs1use std::sync::Arc;
4
5use arrow::array::{ArrayRef, RecordBatch, StringBuilder, UInt64Builder};
6use arrow::datatypes::{DataType, Field, Schema};
7
8use super::reader::SetData;
9use crate::error::Result;
10
11#[derive(Debug, Clone)]
32pub struct SetSchema {
33 pub include_key: bool,
35 pub key_column_name: String,
37 pub member_column_name: String,
39 pub include_row_index: bool,
41 pub row_index_column_name: String,
43}
44
45impl Default for SetSchema {
46 fn default() -> Self {
47 Self {
48 include_key: true,
49 key_column_name: "_key".to_string(),
50 member_column_name: "member".to_string(),
51 include_row_index: false,
52 row_index_column_name: "_index".to_string(),
53 }
54 }
55}
56
57impl SetSchema {
58 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn with_key(mut self, include: bool) -> Self {
65 self.include_key = include;
66 self
67 }
68
69 pub fn with_key_column_name(mut self, name: &str) -> Self {
71 self.key_column_name = name.to_string();
72 self
73 }
74
75 pub fn with_member_column_name(mut self, name: &str) -> Self {
77 self.member_column_name = name.to_string();
78 self
79 }
80
81 pub fn with_row_index(mut self, include: bool) -> Self {
83 self.include_row_index = include;
84 self
85 }
86
87 pub fn with_row_index_column_name(mut self, name: &str) -> Self {
89 self.row_index_column_name = name.to_string();
90 self
91 }
92
93 pub fn to_arrow_schema(&self) -> Schema {
95 let mut fields = Vec::new();
96
97 if self.include_row_index {
98 fields.push(Field::new(
99 &self.row_index_column_name,
100 DataType::UInt64,
101 false,
102 ));
103 }
104
105 if self.include_key {
106 fields.push(Field::new(&self.key_column_name, DataType::Utf8, false));
107 }
108
109 fields.push(Field::new(&self.member_column_name, DataType::Utf8, false));
110
111 Schema::new(fields)
112 }
113}
114
115pub fn sets_to_record_batch(
119 data: &[SetData],
120 schema: &SetSchema,
121 row_index_offset: u64,
122) -> Result<RecordBatch> {
123 let total_members: usize = data.iter().map(|s| s.members.len()).sum();
125
126 let arrow_schema = Arc::new(schema.to_arrow_schema());
127 let mut columns: Vec<ArrayRef> = Vec::new();
128
129 if schema.include_row_index {
131 let mut builder = UInt64Builder::with_capacity(total_members);
132 let mut idx = row_index_offset;
133 for set_data in data {
134 for _ in &set_data.members {
135 builder.append_value(idx);
136 idx += 1;
137 }
138 }
139 columns.push(Arc::new(builder.finish()));
140 }
141
142 if schema.include_key {
144 let mut builder = StringBuilder::with_capacity(total_members, total_members * 32);
145 for set_data in data {
146 for _ in &set_data.members {
147 builder.append_value(&set_data.key);
148 }
149 }
150 columns.push(Arc::new(builder.finish()));
151 }
152
153 let mut builder = StringBuilder::with_capacity(total_members, total_members * 64);
155 for set_data in data {
156 for member in &set_data.members {
157 builder.append_value(member);
158 }
159 }
160 columns.push(Arc::new(builder.finish()));
161
162 RecordBatch::try_new(arrow_schema, columns).map_err(|e| {
163 crate::error::Error::TypeConversion(format!("Failed to create RecordBatch: {}", e))
164 })
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_set_schema_default() {
173 let schema = SetSchema::new();
174 assert!(schema.include_key);
175 assert_eq!(schema.key_column_name, "_key");
176 assert_eq!(schema.member_column_name, "member");
177 assert!(!schema.include_row_index);
178 }
179
180 #[test]
181 fn test_set_schema_builder() {
182 let schema = SetSchema::new()
183 .with_key(false)
184 .with_member_column_name("value")
185 .with_row_index(true)
186 .with_row_index_column_name("idx");
187
188 assert!(!schema.include_key);
189 assert_eq!(schema.member_column_name, "value");
190 assert!(schema.include_row_index);
191 assert_eq!(schema.row_index_column_name, "idx");
192 }
193
194 #[test]
195 fn test_sets_to_record_batch_basic() {
196 let data = vec![
197 SetData {
198 key: "set:1".to_string(),
199 members: vec!["a".to_string(), "b".to_string(), "c".to_string()],
200 },
201 SetData {
202 key: "set:2".to_string(),
203 members: vec!["x".to_string(), "y".to_string()],
204 },
205 ];
206
207 let schema = SetSchema::new();
208 let batch = sets_to_record_batch(&data, &schema, 0).unwrap();
209
210 assert_eq!(batch.num_rows(), 5); assert_eq!(batch.num_columns(), 2); }
213
214 #[test]
215 fn test_sets_to_record_batch_with_row_index() {
216 let data = vec![SetData {
217 key: "set:1".to_string(),
218 members: vec!["a".to_string(), "b".to_string()],
219 }];
220
221 let schema = SetSchema::new().with_row_index(true);
222 let batch = sets_to_record_batch(&data, &schema, 10).unwrap();
223
224 assert_eq!(batch.num_rows(), 2);
225 assert_eq!(batch.num_columns(), 3); let idx_col = batch
229 .column(0)
230 .as_any()
231 .downcast_ref::<arrow::array::UInt64Array>()
232 .unwrap();
233 assert_eq!(idx_col.value(0), 10);
234 assert_eq!(idx_col.value(1), 11);
235 }
236
237 #[test]
238 fn test_sets_to_record_batch_no_key() {
239 let data = vec![SetData {
240 key: "set:1".to_string(),
241 members: vec!["a".to_string()],
242 }];
243
244 let schema = SetSchema::new().with_key(false);
245 let batch = sets_to_record_batch(&data, &schema, 0).unwrap();
246
247 assert_eq!(batch.num_columns(), 1); }
249
250 #[test]
251 fn test_empty_set() {
252 let data = vec![SetData {
253 key: "empty:set".to_string(),
254 members: vec![],
255 }];
256
257 let schema = SetSchema::new();
258 let batch = sets_to_record_batch(&data, &schema, 0).unwrap();
259
260 assert_eq!(batch.num_rows(), 0);
261 }
262}