1use std::sync::Arc;
2
3use arrow::datatypes::{DataType as ArrowDataType, Field as ArrowField, Schema, TimeUnit};
4
5use crate::error::WpArrowError;
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub enum WpDataType {
10 Chars,
11 Digit,
12 Float,
13 Bool,
14 Time,
15 Ip,
16 Hex,
17 Array(Box<WpDataType>),
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct FieldDef {
23 pub name: String,
24 pub data_type: WpDataType,
25 pub nullable: bool,
26}
27
28impl FieldDef {
29 pub fn new(name: impl Into<String>, data_type: WpDataType) -> Self {
30 Self {
31 name: name.into(),
32 data_type,
33 nullable: true,
34 }
35 }
36
37 pub fn with_nullable(mut self, nullable: bool) -> Self {
38 self.nullable = nullable;
39 self
40 }
41}
42
43pub fn to_arrow_type(wp_type: &WpDataType) -> ArrowDataType {
45 match wp_type {
46 WpDataType::Chars => ArrowDataType::Utf8,
47 WpDataType::Digit => ArrowDataType::Int64,
48 WpDataType::Float => ArrowDataType::Float64,
49 WpDataType::Bool => ArrowDataType::Boolean,
50 WpDataType::Time => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
51 WpDataType::Ip => ArrowDataType::Utf8,
52 WpDataType::Hex => ArrowDataType::Utf8,
53 WpDataType::Array(inner) => {
54 let inner_arrow = to_arrow_type(inner);
55 ArrowDataType::List(Arc::new(ArrowField::new("item", inner_arrow, true)))
56 }
57 }
58}
59
60pub fn to_arrow_field(field: &FieldDef) -> Result<ArrowField, WpArrowError> {
64 if field.name.is_empty() {
65 return Err(WpArrowError::EmptyFieldName);
66 }
67 let arrow_type = to_arrow_type(&field.data_type);
68 Ok(ArrowField::new(&field.name, arrow_type, field.nullable))
69}
70
71pub fn to_arrow_schema(fields: &[FieldDef]) -> Result<Schema, WpArrowError> {
73 let arrow_fields: Vec<ArrowField> = fields
74 .iter()
75 .map(to_arrow_field)
76 .collect::<Result<_, _>>()?;
77 Ok(Schema::new(arrow_fields))
78}
79
80pub fn parse_wp_type(s: &str) -> Result<WpDataType, WpArrowError> {
88 let s = s.trim();
89 let lower = s.to_ascii_lowercase();
90
91 match lower.as_str() {
92 "chars" => Ok(WpDataType::Chars),
93 "digit" => Ok(WpDataType::Digit),
94 "float" => Ok(WpDataType::Float),
95 "bool" => Ok(WpDataType::Bool),
96 "time" => Ok(WpDataType::Time),
97 "ip" => Ok(WpDataType::Ip),
98 "hex" => Ok(WpDataType::Hex),
99 _ if lower.starts_with("array<") && lower.ends_with('>') => {
100 let inner_str = &s[6..s.len() - 1];
101 let inner_trimmed = inner_str.trim();
102 if inner_trimmed.is_empty() {
103 return Err(WpArrowError::InvalidArrayInnerType(String::new()));
104 }
105 let inner = parse_wp_type(inner_trimmed)?;
106 Ok(WpDataType::Array(Box::new(inner)))
107 }
108 _ => Err(WpArrowError::UnsupportedDataType(s.to_string())),
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
121 fn arrow_type_chars() {
122 assert_eq!(to_arrow_type(&WpDataType::Chars), ArrowDataType::Utf8);
123 }
124
125 #[test]
126 fn arrow_type_digit() {
127 assert_eq!(to_arrow_type(&WpDataType::Digit), ArrowDataType::Int64);
128 }
129
130 #[test]
131 fn arrow_type_float() {
132 assert_eq!(to_arrow_type(&WpDataType::Float), ArrowDataType::Float64);
133 }
134
135 #[test]
136 fn arrow_type_bool() {
137 assert_eq!(to_arrow_type(&WpDataType::Bool), ArrowDataType::Boolean);
138 }
139
140 #[test]
141 fn arrow_type_time() {
142 assert_eq!(
143 to_arrow_type(&WpDataType::Time),
144 ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
145 );
146 }
147
148 #[test]
149 fn arrow_type_ip() {
150 assert_eq!(to_arrow_type(&WpDataType::Ip), ArrowDataType::Utf8);
151 }
152
153 #[test]
154 fn arrow_type_hex() {
155 assert_eq!(to_arrow_type(&WpDataType::Hex), ArrowDataType::Utf8);
156 }
157
158 #[test]
163 fn arrow_type_array_digit() {
164 let wp = WpDataType::Array(Box::new(WpDataType::Digit));
165 let arrow = to_arrow_type(&wp);
166 assert_eq!(
167 arrow,
168 ArrowDataType::List(Arc::new(ArrowField::new(
169 "item",
170 ArrowDataType::Int64,
171 true
172 )))
173 );
174 }
175
176 #[test]
177 fn arrow_type_array_chars() {
178 let wp = WpDataType::Array(Box::new(WpDataType::Chars));
179 let arrow = to_arrow_type(&wp);
180 assert_eq!(
181 arrow,
182 ArrowDataType::List(Arc::new(ArrowField::new("item", ArrowDataType::Utf8, true)))
183 );
184 }
185
186 #[test]
187 fn arrow_type_nested_array() {
188 let wp = WpDataType::Array(Box::new(WpDataType::Array(Box::new(WpDataType::Float))));
189 let inner_list = ArrowDataType::List(Arc::new(ArrowField::new(
190 "item",
191 ArrowDataType::Float64,
192 true,
193 )));
194 let expected = ArrowDataType::List(Arc::new(ArrowField::new("item", inner_list, true)));
195 assert_eq!(to_arrow_type(&wp), expected);
196 }
197
198 #[test]
203 fn arrow_field_basic() {
204 let fd = FieldDef::new("src_ip", WpDataType::Ip);
205 let field = to_arrow_field(&fd).unwrap();
206 assert_eq!(field.name(), "src_ip");
207 assert_eq!(field.data_type(), &ArrowDataType::Utf8);
208 assert!(field.is_nullable());
209 }
210
211 #[test]
212 fn arrow_field_non_nullable() {
213 let fd = FieldDef::new("count", WpDataType::Digit).with_nullable(false);
214 let field = to_arrow_field(&fd).unwrap();
215 assert!(!field.is_nullable());
216 }
217
218 #[test]
219 fn arrow_field_empty_name_errors() {
220 let fd = FieldDef::new("", WpDataType::Chars);
221 assert_eq!(to_arrow_field(&fd), Err(WpArrowError::EmptyFieldName));
222 }
223
224 #[test]
229 fn arrow_schema_firewall_log() {
230 let fields = vec![
231 FieldDef::new("src_ip", WpDataType::Ip),
232 FieldDef::new("dst_ip", WpDataType::Ip),
233 FieldDef::new("port", WpDataType::Digit),
234 FieldDef::new("protocol", WpDataType::Chars),
235 FieldDef::new("timestamp", WpDataType::Time),
236 FieldDef::new("allowed", WpDataType::Bool),
237 ];
238 let schema = to_arrow_schema(&fields).unwrap();
239 assert_eq!(schema.fields().len(), 6);
240 assert_eq!(schema.field(0).name(), "src_ip");
241 assert_eq!(schema.field(2).data_type(), &ArrowDataType::Int64);
242 assert_eq!(
243 schema.field(4).data_type(),
244 &ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
245 );
246 }
247
248 #[test]
249 fn arrow_schema_with_array_field() {
250 let fields = vec![
251 FieldDef::new("name", WpDataType::Chars),
252 FieldDef::new("tags", WpDataType::Array(Box::new(WpDataType::Chars))),
253 ];
254 let schema = to_arrow_schema(&fields).unwrap();
255 assert_eq!(schema.fields().len(), 2);
256 assert!(matches!(
257 schema.field(1).data_type(),
258 ArrowDataType::List(_)
259 ));
260 }
261
262 #[test]
263 fn arrow_schema_empty_fields() {
264 let schema = to_arrow_schema(&[]).unwrap();
265 assert_eq!(schema.fields().len(), 0);
266 }
267
268 #[test]
269 fn arrow_schema_error_propagation() {
270 let fields = vec![
271 FieldDef::new("ok", WpDataType::Chars),
272 FieldDef::new("", WpDataType::Digit),
273 ];
274 assert_eq!(to_arrow_schema(&fields), Err(WpArrowError::EmptyFieldName));
275 }
276
277 #[test]
282 fn parse_chars() {
283 assert_eq!(parse_wp_type("chars"), Ok(WpDataType::Chars));
284 }
285
286 #[test]
287 fn parse_digit() {
288 assert_eq!(parse_wp_type("digit"), Ok(WpDataType::Digit));
289 }
290
291 #[test]
292 fn parse_float() {
293 assert_eq!(parse_wp_type("float"), Ok(WpDataType::Float));
294 }
295
296 #[test]
297 fn parse_bool() {
298 assert_eq!(parse_wp_type("bool"), Ok(WpDataType::Bool));
299 }
300
301 #[test]
302 fn parse_time() {
303 assert_eq!(parse_wp_type("time"), Ok(WpDataType::Time));
304 }
305
306 #[test]
307 fn parse_ip() {
308 assert_eq!(parse_wp_type("ip"), Ok(WpDataType::Ip));
309 }
310
311 #[test]
312 fn parse_hex() {
313 assert_eq!(parse_wp_type("hex"), Ok(WpDataType::Hex));
314 }
315
316 #[test]
321 fn parse_case_insensitive() {
322 assert_eq!(parse_wp_type("CHARS"), Ok(WpDataType::Chars));
323 assert_eq!(parse_wp_type("Digit"), Ok(WpDataType::Digit));
324 assert_eq!(parse_wp_type("BOOL"), Ok(WpDataType::Bool));
325 }
326
327 #[test]
332 fn parse_array_chars() {
333 assert_eq!(
334 parse_wp_type("array<chars>"),
335 Ok(WpDataType::Array(Box::new(WpDataType::Chars)))
336 );
337 }
338
339 #[test]
340 fn parse_array_digit() {
341 assert_eq!(
342 parse_wp_type("array<digit>"),
343 Ok(WpDataType::Array(Box::new(WpDataType::Digit)))
344 );
345 }
346
347 #[test]
348 fn parse_nested_array() {
349 assert_eq!(
350 parse_wp_type("array<array<float>>"),
351 Ok(WpDataType::Array(Box::new(WpDataType::Array(Box::new(
352 WpDataType::Float
353 )))))
354 );
355 }
356
357 #[test]
358 fn parse_array_with_whitespace() {
359 assert_eq!(
360 parse_wp_type(" array< chars > "),
361 Ok(WpDataType::Array(Box::new(WpDataType::Chars)))
362 );
363 }
364
365 #[test]
370 fn parse_unsupported_type() {
371 let err = parse_wp_type("unknown").unwrap_err();
372 assert_eq!(
373 err,
374 WpArrowError::UnsupportedDataType("unknown".to_string())
375 );
376 }
377
378 #[test]
379 fn parse_array_empty_inner() {
380 let err = parse_wp_type("array<>").unwrap_err();
381 assert_eq!(err, WpArrowError::InvalidArrayInnerType(String::new()));
382 }
383
384 #[test]
385 fn parse_array_invalid_inner() {
386 let err = parse_wp_type("array<invalid>").unwrap_err();
387 assert_eq!(
388 err,
389 WpArrowError::UnsupportedDataType("invalid".to_string())
390 );
391 }
392
393 #[test]
398 fn wf_data_type_clone_eq() {
399 let a = WpDataType::Array(Box::new(WpDataType::Chars));
400 let b = a.clone();
401 assert_eq!(a, b);
402 }
403
404 #[test]
405 fn wf_data_type_hash_consistent() {
406 use std::collections::HashSet;
407 let mut set = HashSet::new();
408 set.insert(WpDataType::Digit);
409 set.insert(WpDataType::Digit);
410 assert_eq!(set.len(), 1);
411 }
412
413 #[test]
414 fn field_def_default_nullable() {
415 let fd = FieldDef::new("test", WpDataType::Bool);
416 assert!(fd.nullable);
417 }
418}