1use polars::prelude::*;
2use regex::Regex;
3use std::collections::HashMap;
4
5use crate::error::{DtransformError, Result};
6use crate::parser::ast::*;
7
8pub struct Executor {
9 variables: HashMap<String, DataFrame>,
10}
11
12fn auto_detect_delimiter(content: &str, file_extension: Option<&str>) -> Result<(char, bool)> {
15 if file_extension == Some("tsv") {
17 let needs_trim = content.lines().take(100).any(|line| {
18 line.trim() != line || line.contains(" ")
19 });
20 return Ok(('\t', needs_trim));
21 }
22
23 if file_extension == Some("csv") {
25 let needs_trim = content.lines().take(100).any(|line| {
27 line.trim() != line || line.contains(" ")
28 });
29 return Ok((',', needs_trim));
30 }
31
32 let sample_lines: Vec<&str> = content
34 .lines()
35 .filter(|l| !l.trim().is_empty())
36 .take(100)
37 .collect();
38
39 if sample_lines.is_empty() {
40 return Err(DtransformError::InvalidOperation(
41 "File is empty or contains no data".to_string()
42 ));
43 }
44
45 let needs_trim = sample_lines.iter().any(|line| {
47 line.trim() != *line || line.contains(" ")
48 });
49
50 let detection_lines: Vec<String> = if needs_trim {
52 sample_lines.iter().map(|line| {
53 let trimmed = line.trim();
54 trimmed.split_whitespace().collect::<Vec<_>>().join(" ")
55 }).collect()
56 } else {
57 sample_lines.iter().map(|s| s.to_string()).collect()
58 };
59
60 let delimiters = [',', '\t', '|', ';', ' '];
62 let mut delimiter_counts: HashMap<char, Vec<usize>> = HashMap::new();
63
64 for line in &detection_lines {
65 for &delim in &delimiters {
66 let count = line.matches(delim).count();
67 delimiter_counts.entry(delim).or_insert_with(Vec::new).push(count);
68 }
69 }
70
71 let mut best_delimiter = None;
75 let mut best_score = 0.0;
76
77 for (&delim, counts) in &delimiter_counts {
78 let non_zero_counts: Vec<usize> = counts.iter().filter(|&&c| c > 0).copied().collect();
80
81 if non_zero_counts.is_empty() {
82 continue;
83 }
84
85 let min = *non_zero_counts.iter().min().unwrap();
87 let max = *non_zero_counts.iter().max().unwrap();
88 let avg = non_zero_counts.iter().sum::<usize>() as f64 / non_zero_counts.len() as f64;
89
90 let consistency = non_zero_counts.len() as f64 / detection_lines.len() as f64;
93
94 if min == max || (max as f64 - min as f64) / avg < 0.3 {
96 let score = avg * consistency;
97 if score > best_score {
98 best_score = score;
99 best_delimiter = Some(delim);
100 }
101 }
102 }
103
104 match best_delimiter {
105 Some(delim) => Ok((delim, needs_trim)),
106 None => {
107 let all_zero = delimiter_counts.values().all(|counts| {
110 counts.iter().all(|&c| c == 0)
111 });
112
113 if all_zero {
114 Ok((',', needs_trim))
116 } else {
117 Err(DtransformError::InvalidOperation(
119 "Could not auto-detect delimiter. The file format is ambiguous.\n\n\
120 Please specify the delimiter explicitly:\n\
121 • Comma: read('file', delimiter=',')\n\
122 • Tab: read('file', delimiter='\\t')\n\
123 • Pipe: read('file', delimiter='|')\n\
124 • Semicolon: read('file', delimiter=';')\n\
125 • Space: read('file', delimiter=' ')".to_string()
126 ))
127 }
128 }
129 }
130}
131
132impl Executor {
133 pub fn new() -> Self {
134 Self {
135 variables: HashMap::new(),
136 }
137 }
138
139 pub fn execute_program(&mut self, program: Program) -> Result<Option<DataFrame>> {
140 let mut last_result = None;
141
142 for statement in program.statements {
143 match statement {
144 Statement::Assignment { name, pipeline } => {
145 let df = self.execute_pipeline(pipeline)?;
146 self.variables.insert(name, df);
147 }
149 Statement::Pipeline(pipeline) => {
150 let df = self.execute_pipeline(pipeline)?;
151 last_result = Some(df);
152 }
153 }
154 }
155
156 Ok(last_result)
157 }
158
159 pub fn execute_statement(&mut self, statement: Statement) -> Result<Option<DataFrame>> {
160 match statement {
161 Statement::Assignment { name, pipeline } => {
162 let df = self.execute_pipeline(pipeline)?;
163 self.variables.insert(name.clone(), df.clone());
164 Ok(Some(df))
165 }
166 Statement::Pipeline(pipeline) => {
167 let df = self.execute_pipeline(pipeline)?;
168 Ok(Some(df))
169 }
170 }
171 }
172
173 pub fn execute_pipeline(&mut self, pipeline: Pipeline) -> Result<DataFrame> {
174 let mut df = match pipeline.source {
175 Some(Source::Read(read_op)) => self.execute_read(read_op)?,
176 Some(Source::Variable(var_name)) => {
177 self.variables
178 .get(&var_name)
179 .ok_or_else(|| DtransformError::VariableNotFound(var_name.clone()))?
180 .clone()
181 }
182 None => {
183 return Err(DtransformError::InvalidOperation(
184 "Pipeline must start with a data source (read() or variable)".to_string(),
185 ));
186 }
187 };
188
189 for operation in pipeline.operations {
190 df = self.execute_operation(df, operation)?;
191 }
192
193 Ok(df)
194 }
195
196 fn execute_operation(&mut self, df: DataFrame, op: Operation) -> Result<DataFrame> {
197 match op {
198 Operation::Read(read_op) => self.execute_read(read_op),
199 Operation::Variable(_var_name) => {
200 Err(DtransformError::InvalidOperation(
202 "Variable references can only be used as pipeline sources, not as operations".to_string()
203 ))
204 }
205 Operation::Write(write_op) => self.execute_write(df, write_op),
206 Operation::Select(select_op) => self.execute_select(df, select_op),
207 Operation::Filter(filter_op) => self.execute_filter(df, filter_op),
208 Operation::Mutate(mutate_op) => self.execute_mutate(df, mutate_op),
209 Operation::Rename(rename_op) => self.execute_rename(df, rename_op),
210 Operation::RenameAll(rename_all_op) => self.execute_rename_all(df, rename_all_op),
211 Operation::Sort(sort_op) => self.execute_sort(df, sort_op),
212 Operation::Take(take_op) => self.execute_take(df, take_op),
213 Operation::Skip(skip_op) => self.execute_skip(df, skip_op),
214 Operation::Slice(slice_op) => self.execute_slice(df, slice_op),
215 Operation::Drop(drop_op) => self.execute_drop(df, drop_op),
216 Operation::Distinct(distinct_op) => self.execute_distinct(df, distinct_op),
217 }
218 }
219
220 fn check_duplicate_columns(&self, df: &DataFrame) -> Result<()> {
221 use std::collections::HashSet;
222 let column_names: Vec<String> = df.get_column_names().iter().map(|s| s.to_string()).collect();
223 let mut seen = HashSet::new();
224 let mut duplicates = Vec::new();
225
226 for name in &column_names {
227 if !seen.insert(name) {
228 duplicates.push(name.clone());
229 }
230 }
231
232 if !duplicates.is_empty() {
233 return Err(DtransformError::InvalidOperation(format!(
234 "File contains duplicate column names: {}. Malformed files with repeated columns are not allowed.",
235 duplicates.join(", ")
236 )));
237 }
238
239 Ok(())
240 }
241
242 fn execute_read(&self, op: ReadOp) -> Result<DataFrame> {
243 let path = std::path::Path::new(&op.path);
244
245 let format = op.format.as_deref().or_else(|| path.extension()?.to_str());
247
248 match format {
249 Some("csv") | Some("tsv") | None => {
250 let has_header = op.header.unwrap_or(true);
251 let skip_rows = op.skip_rows.unwrap_or(0);
252
253 let (delimiter, trim_whitespace) = if op.delimiter.is_none() || op.trim_whitespace.is_none() {
255 let content = std::fs::read_to_string(path)?;
257 let (detected_delim, detected_trim) = auto_detect_delimiter(&content, format)?;
258
259 (
260 op.delimiter.unwrap_or(detected_delim),
261 op.trim_whitespace.unwrap_or(detected_trim)
262 )
263 } else {
264 (op.delimiter.unwrap(), op.trim_whitespace.unwrap())
265 };
266
267 let result = if trim_whitespace {
268 let content = std::fs::read_to_string(path)?;
270 let trimmed_content: String = content
271 .lines()
272 .map(|line| {
273 let trimmed = line.trim();
275 trimmed.split_whitespace().collect::<Vec<_>>().join(" ")
277 })
278 .collect::<Vec<_>>()
279 .join("\n");
280
281 let cursor = std::io::Cursor::new(trimmed_content.as_bytes());
282 CsvReadOptions::default()
283 .with_has_header(has_header)
284 .with_skip_rows(skip_rows)
285 .with_parse_options(
286 CsvParseOptions::default()
287 .with_separator(delimiter as u8)
288 )
289 .into_reader_with_file_handle(cursor)
290 .finish()
291 } else {
292 CsvReadOptions::default()
294 .with_has_header(has_header)
295 .with_skip_rows(skip_rows)
296 .with_parse_options(
297 CsvParseOptions::default()
298 .with_separator(delimiter as u8)
299 )
300 .try_into_reader_with_file_path(Some(path.into()))?
301 .finish()
302 };
303
304 match result {
305 Ok(df) => {
306 self.check_duplicate_columns(&df)?;
307 Ok(df)
308 },
309 Err(e) => {
310 let error_msg = e.to_string();
311 if error_msg.contains("found more fields") || error_msg.contains("Schema") {
312 Err(DtransformError::InvalidOperation(
313 format!(
314 "CSV parsing error: Rows have different numbers of fields.\n\n\
315 The auto-detected settings may be incorrect:\n\
316 • Detected delimiter: {:?}\n\
317 • Detected trim_whitespace: {}\n\n\
318 Try specifying explicitly:\n\
319 • read('{}', delimiter=' ') # space-separated\n\
320 • read('{}', delimiter='\\t') # tab-separated\n\
321 • read('{}', trim_whitespace=true)\n\
322 • read('{}', skip_rows=N) # skip header lines",
323 delimiter, trim_whitespace,
324 path.display(), path.display(), path.display(), path.display()
325 )
326 ))
327 } else {
328 Err(DtransformError::PolarsError(e))
329 }
330 }
331 }
332 }
333 Some("json") => {
334 let file = std::fs::File::open(path)?;
335 let df = JsonReader::new(file).finish()?;
336 self.check_duplicate_columns(&df)?;
337 Ok(df)
338 }
339 Some("parquet") => {
340 let file = std::fs::File::open(path)?;
341 let df = ParquetReader::new(file).finish()?;
342 self.check_duplicate_columns(&df)?;
343 Ok(df)
344 }
345 Some(_) => {
346 let has_header = op.header.unwrap_or(true);
348 let skip_rows = op.skip_rows.unwrap_or(0);
349
350 let (delimiter, trim_whitespace) = if op.delimiter.is_none() || op.trim_whitespace.is_none() {
352 let content = std::fs::read_to_string(path)?;
354 let (detected_delim, detected_trim) = auto_detect_delimiter(&content, format)?;
355
356 (
357 op.delimiter.unwrap_or(detected_delim),
358 op.trim_whitespace.unwrap_or(detected_trim)
359 )
360 } else {
361 (op.delimiter.unwrap(), op.trim_whitespace.unwrap())
362 };
363
364 let result = if trim_whitespace {
365 let content = std::fs::read_to_string(path)?;
367 let trimmed_content: String = content
368 .lines()
369 .map(|line| {
370 let trimmed = line.trim();
372 trimmed.split_whitespace().collect::<Vec<_>>().join(" ")
374 })
375 .collect::<Vec<_>>()
376 .join("\n");
377
378 let cursor = std::io::Cursor::new(trimmed_content.as_bytes());
379 CsvReadOptions::default()
380 .with_has_header(has_header)
381 .with_skip_rows(skip_rows)
382 .with_parse_options(
383 CsvParseOptions::default()
384 .with_separator(delimiter as u8)
385 )
386 .into_reader_with_file_handle(cursor)
387 .finish()
388 } else {
389 CsvReadOptions::default()
391 .with_has_header(has_header)
392 .with_skip_rows(skip_rows)
393 .with_parse_options(
394 CsvParseOptions::default()
395 .with_separator(delimiter as u8)
396 )
397 .try_into_reader_with_file_path(Some(path.into()))?
398 .finish()
399 };
400
401 match result {
402 Ok(df) => {
403 self.check_duplicate_columns(&df)?;
404 Ok(df)
405 },
406 Err(e) => {
407 let error_msg = e.to_string();
408 if error_msg.contains("found more fields") || error_msg.contains("Schema") {
409 Err(DtransformError::InvalidOperation(
410 format!(
411 "CSV parsing error: Rows have different numbers of fields.\n\n\
412 The auto-detected settings may be incorrect:\n\
413 • Detected delimiter: {:?}\n\
414 • Detected trim_whitespace: {}\n\n\
415 Try specifying explicitly:\n\
416 • read('{}', delimiter=' ') # space-separated\n\
417 • read('{}', delimiter='\\t') # tab-separated\n\
418 • read('{}', trim_whitespace=true)\n\
419 • read('{}', skip_rows=N) # skip header lines",
420 delimiter, trim_whitespace,
421 path.display(), path.display(), path.display(), path.display()
422 )
423 ))
424 } else {
425 Err(DtransformError::PolarsError(e))
426 }
427 }
428 }
429 }
430 }
431 }
432
433 fn execute_write(&self, df: DataFrame, op: WriteOp) -> Result<DataFrame> {
434 let path = std::path::Path::new(&op.path);
435 let format = op.format.as_deref().or_else(|| path.extension()?.to_str());
436
437 match format {
438 Some("csv") | Some("tsv") | None => {
439 let mut file = std::fs::File::create(path)?;
440 let delimiter = op.delimiter.unwrap_or(if format == Some("tsv") { '\t' } else { ',' });
441 let has_header = op.header.unwrap_or(true); CsvWriter::new(&mut file)
444 .with_separator(delimiter as u8)
445 .include_header(has_header)
446 .finish(&mut df.clone())?;
447 }
448 Some("json") => {
449 let mut file = std::fs::File::create(path)?;
450 JsonWriter::new(&mut file)
451 .finish(&mut df.clone())?;
452 }
453 Some("parquet") => {
454 let mut file = std::fs::File::create(path)?;
455 ParquetWriter::new(&mut file)
456 .finish(&mut df.clone())?;
457 }
458 Some(_) => {
459 let mut file = std::fs::File::create(path)?;
461 let delimiter = op.delimiter.unwrap_or(',');
462 let has_header = op.header.unwrap_or(true);
463
464 CsvWriter::new(&mut file)
465 .with_separator(delimiter as u8)
466 .include_header(has_header)
467 .finish(&mut df.clone())?;
468 }
469 }
470
471 Ok(df)
472 }
473
474 fn execute_select(&self, df: DataFrame, op: SelectOp) -> Result<DataFrame> {
475 let schema = df.schema();
476 let mut selected_columns = Vec::new();
477 let mut aliases = Vec::new();
478
479 for (selector, alias) in op.selectors {
480 let cols = self.resolve_selector(&selector, &schema, &df)?;
481
482 for col in cols {
485 selected_columns.push(col);
486 aliases.push(alias.clone());
487 }
488 }
489
490 if selected_columns.is_empty() {
491 return Err(DtransformError::InvalidOperation(
492 "No columns selected".to_string(),
493 ));
494 }
495
496 let mut result = df.select(&selected_columns)?;
497
498 for (i, alias_opt) in aliases.iter().enumerate() {
500 if let Some(alias) = alias_opt {
501 let old_name = result.get_column_names()[i].to_string();
502 result.rename(&old_name, PlSmallStr::from(alias.as_str()))?;
503 }
504 }
505
506 Ok(result)
507 }
508
509 fn resolve_selector(
510 &self,
511 selector: &ColumnSelector,
512 schema: &Schema,
513 df: &DataFrame,
514 ) -> Result<Vec<String>> {
515 match selector {
516 ColumnSelector::Name(name) => {
517 if schema.contains(name) {
518 Ok(vec![name.clone()])
519 } else {
520 Err(DtransformError::ColumnNotFound(name.clone()))
521 }
522 }
523
524 ColumnSelector::Index(idx) => {
525 let name = schema
526 .get_at_index(*idx)
527 .ok_or_else(|| {
528 DtransformError::InvalidOperation(format!("Column index {} out of bounds", idx))
529 })?
530 .0
531 .clone();
532 Ok(vec![name.as_str().to_string()])
533 }
534
535 ColumnSelector::Range(start, end) => {
536 let names: Vec<String> = schema
537 .iter()
538 .enumerate()
539 .filter(|(i, _)| i >= start && i <= end)
540 .map(|(_, (name, _))| name.as_str().to_string())
541 .collect();
542
543 if names.is_empty() {
544 return Err(DtransformError::InvalidOperation(
545 format!("Range ${}..${} is out of bounds or invalid", start + 1, end + 1)
546 ));
547 }
548
549 Ok(names)
550 }
551
552 ColumnSelector::Regex(pattern) => {
553 let re = Regex::new(pattern)?;
554 let names: Vec<String> = schema
555 .iter()
556 .filter(|(name, _)| re.is_match(name.as_str()))
557 .map(|(name, _)| name.as_str().to_string())
558 .collect();
559 Ok(names)
560 }
561
562 ColumnSelector::Type(dtypes) => {
563 let names: Vec<String> = schema
564 .iter()
565 .filter(|(_, field)| {
566 dtypes.iter().any(|dt| self.matches_dtype(dt, field))
567 })
568 .map(|(name, _)| name.as_str().to_string())
569 .collect();
570 Ok(names)
571 }
572
573 ColumnSelector::All => Ok(schema.iter().map(|(name, _)| name.as_str().to_string()).collect()),
574
575 ColumnSelector::Except(inner) => {
576 let all_cols: Vec<String> = schema.iter().map(|(name, _)| name.as_str().to_string()).collect();
577 let excluded = self.resolve_selector(inner, schema, df)?;
578 Ok(all_cols
579 .into_iter()
580 .filter(|col| !excluded.contains(col))
581 .collect())
582 }
583
584 ColumnSelector::And(left, right) => {
585 let left_cols = self.resolve_selector(left, schema, df)?;
586 let right_cols = self.resolve_selector(right, schema, df)?;
587 Ok(left_cols
588 .into_iter()
589 .filter(|col| right_cols.contains(col))
590 .collect())
591 }
592 }
593 }
594
595 fn matches_dtype(&self, dt: &crate::parser::ast::DataType, polars_dt: &polars::datatypes::DataType) -> bool {
596 use polars::datatypes::DataType as PDT;
597 use crate::parser::ast::DataType as AstDT;
598 match dt {
599 AstDT::Number => matches!(
600 polars_dt,
601 PDT::Int8
602 | PDT::Int16
603 | PDT::Int32
604 | PDT::Int64
605 | PDT::UInt8
606 | PDT::UInt16
607 | PDT::UInt32
608 | PDT::UInt64
609 | PDT::Float32
610 | PDT::Float64
611 ),
612 AstDT::String => matches!(polars_dt, PDT::String),
613 AstDT::Boolean => matches!(polars_dt, PDT::Boolean),
614 AstDT::Date => matches!(polars_dt, PDT::Date),
615 AstDT::DateTime => matches!(polars_dt, PDT::Datetime(_, _)),
616 }
617 }
618
619 fn execute_filter(&self, df: DataFrame, op: FilterOp) -> Result<DataFrame> {
620 let mask = self.evaluate_expression(&op.condition, &df)?;
621 let mask_bool = mask.bool()?;
622 Ok(df.filter(mask_bool)?)
623 }
624
625 fn execute_mutate(&self, mut df: DataFrame, op: MutateOp) -> Result<DataFrame> {
626 for assignment in op.assignments {
627 let series = self.evaluate_expression(&assignment.expression, &df)?;
628
629 let col_name = match &assignment.column {
631 AssignmentTarget::Name(name) => name.clone(),
632 AssignmentTarget::Position(pos) => {
633 if *pos == 0 {
634 return Err(DtransformError::InvalidOperation(
635 "Position $0 is invalid; column positions start at $1".to_string()
636 ));
637 }
638 let col_names = df.get_column_names();
639 if *pos <= col_names.len() {
640 col_names[pos - 1].to_string()
642 } else {
643 format!("column_{}", pos)
645 }
646 }
647 };
648
649 let renamed_series = series.with_name(PlSmallStr::from(col_name.as_str()));
650 let _ = df.with_column(renamed_series)?;
651 }
652
653 Ok(df)
654 }
655
656 fn execute_rename(&self, df: DataFrame, op: RenameOp) -> Result<DataFrame> {
657 let mut result = df;
658 for (col_ref, new_name) in op.mappings {
659 let old_name = self.resolve_column_name(&col_ref, &result)?;
660 result.rename(&old_name, PlSmallStr::from(new_name.as_str()))?;
661 }
662 Ok(result)
663 }
664
665 fn execute_rename_all(&self, mut df: DataFrame, op: RenameAllOp) -> Result<DataFrame> {
666 match &op.strategy {
667 RenameStrategy::Replace { old, new } => {
668 let old_names: Vec<String> = df
669 .get_column_names()
670 .iter()
671 .map(|s| s.as_str().to_string())
672 .collect();
673
674 for old_name in old_names {
675 let new_name = old_name.replace(old, new);
676 df.rename(&old_name, PlSmallStr::from(new_name.as_str()))?;
677 }
678
679 Ok(df)
680 }
681 RenameStrategy::Sequential { prefix, start, end } => {
682 let num_cols = df.width();
683 let range_size = end - start + 1;
684
685 if range_size != num_cols {
686 return Err(DtransformError::InvalidOperation(format!(
687 "Range {}..{} ({} columns) doesn't match table width ({} columns). Use select() first.",
688 start, end, range_size, num_cols
689 )));
690 }
691
692 let old_names: Vec<String> = df
693 .get_column_names()
694 .iter()
695 .map(|s| s.as_str().to_string())
696 .collect();
697
698 for (i, old_name) in old_names.iter().enumerate() {
699 let new_name = format!("{}{}", prefix, start + i);
700 df.rename(old_name, PlSmallStr::from(new_name.as_str()))?;
701 }
702
703 Ok(df)
704 }
705 }
706 }
707
708 fn execute_sort(&self, df: DataFrame, op: SortOp) -> Result<DataFrame> {
709 let col_names: Vec<String> = op
710 .columns
711 .iter()
712 .map(|(col_ref, _)| self.resolve_column_name(col_ref, &df))
713 .collect::<Result<Vec<_>>>()?;
714
715 let descending: Vec<bool> = op
716 .columns
717 .iter()
718 .map(|(_, desc)| *desc)
719 .collect();
720
721 Ok(df.sort(col_names, SortMultipleOptions::default().with_order_descending_multi(descending))?)
722 }
723
724 fn execute_take(&self, df: DataFrame, op: TakeOp) -> Result<DataFrame> {
725 Ok(df.head(Some(op.n)))
726 }
727
728 fn execute_skip(&self, df: DataFrame, op: SkipOp) -> Result<DataFrame> {
729 let height = df.height();
730 if op.n >= height {
731 Ok(df.head(Some(0)))
732 } else {
733 Ok(df.slice(op.n as i64, height - op.n))
734 }
735 }
736
737 fn execute_slice(&self, df: DataFrame, op: SliceOp) -> Result<DataFrame> {
738 let start = op.start.min(df.height());
739 let len = (op.end.saturating_sub(start)).min(df.height() - start);
740 Ok(df.slice(start as i64, len))
741 }
742
743 fn execute_drop(&self, df: DataFrame, op: DropOp) -> Result<DataFrame> {
744 let schema = df.schema();
745 let mut columns_to_drop: Vec<String> = Vec::new();
746
747 for selector in op.columns {
749 let names = self.resolve_selector(&selector, &schema, &df)?;
750 columns_to_drop.extend(names);
751 }
752
753 let mut result = df;
755 for col_name in columns_to_drop {
756 result = result.drop(&col_name)?;
757 }
758 Ok(result)
759 }
760
761 fn execute_distinct(&self, df: DataFrame, op: DistinctOp) -> Result<DataFrame> {
762 use polars::prelude::UniqueKeepStrategy;
763
764 match op.columns {
765 None => {
767 df.unique::<Vec<String>, String>(None, UniqueKeepStrategy::First, None)
768 .map_err(DtransformError::from)
769 }
770
771 Some(ref selectors) => {
773 let schema = df.schema();
775 let mut column_names: Vec<String> = Vec::new();
776
777 for selector in selectors {
778 let names = self.resolve_selector(selector, &schema, &df)?;
779 column_names.extend(names);
780 }
781
782 df.unique::<Vec<String>, String>(
784 Some(&column_names),
785 UniqueKeepStrategy::First,
786 None
787 ).map_err(DtransformError::from)
788 }
789 }
790 }
791
792 fn resolve_column_name(&self, col_ref: &ColumnRef, df: &DataFrame) -> Result<String> {
793 match col_ref {
794 ColumnRef::Name(name) => Ok(name.clone()),
795 ColumnRef::Index(idx) => {
796 let col_names = df.get_column_names();
797 if *idx < col_names.len() {
798 Ok(col_names[*idx].to_string())
799 } else {
800 Err(DtransformError::InvalidOperation(format!(
801 "Column index {} out of bounds (table has {} columns)",
802 idx, col_names.len()
803 )))
804 }
805 }
806 ColumnRef::Position(pos) => {
807 if *pos == 0 {
809 return Err(DtransformError::InvalidOperation(
810 "Positional columns start at $1, not $0".to_string()
811 ));
812 }
813 let zero_based_idx = pos - 1;
814 let col_names = df.get_column_names();
815 if zero_based_idx < col_names.len() {
816 Ok(col_names[zero_based_idx].to_string())
817 } else {
818 Err(DtransformError::InvalidOperation(format!(
819 "Column ${} out of bounds (table has {} columns)",
820 pos, col_names.len()
821 )))
822 }
823 }
824 }
825 }
826
827 fn evaluate_expression(&self, expr: &Expression, df: &DataFrame) -> Result<Series> {
828 match expr {
829 Expression::Literal(lit) => self.literal_to_series(lit, df.height()),
830
831 Expression::List(literals) => {
832 use crate::parser::ast::Literal as AstLiteral;
833 if literals.is_empty() {
836 return Ok(Series::new_empty(PlSmallStr::from("list"), &polars::datatypes::DataType::Null));
837 }
838 match &literals[0] {
840 AstLiteral::Number(_) => {
841 let values: Vec<f64> = literals.iter().map(|lit| {
842 match lit {
843 AstLiteral::Number(n) => *n,
844 _ => 0.0, }
846 }).collect();
847 Ok(Series::new(PlSmallStr::from("list"), values))
848 }
849 AstLiteral::String(_) => {
850 let values: Vec<String> = literals.iter().map(|lit| {
851 match lit {
852 AstLiteral::String(s) => s.clone(),
853 _ => String::new(),
854 }
855 }).collect();
856 Ok(Series::new(PlSmallStr::from("list"), values))
857 }
858 AstLiteral::Boolean(_) => {
859 let values: Vec<bool> = literals.iter().map(|lit| {
860 match lit {
861 AstLiteral::Boolean(b) => *b,
862 _ => false,
863 }
864 }).collect();
865 Ok(Series::new(PlSmallStr::from("list"), values))
866 }
867 AstLiteral::Null => {
868 Ok(Series::new_null(PlSmallStr::from("list"), literals.len()))
869 }
870 }
871 }
872
873 Expression::Column(col_ref) => {
874 if let ColumnRef::Name(name) = col_ref {
876 if let Some(var_df) = self.variables.get(name) {
877 let col = var_df.get_columns().first()
879 .ok_or_else(|| DtransformError::InvalidOperation(
880 format!("Variable '{}' has no columns", name)
881 ))?;
882 return Ok(col.as_materialized_series().clone());
883 }
884 }
885
886 let col_name = self.resolve_column_name(col_ref, df)?;
888 df.column(&col_name)
889 .map(|col| col.as_materialized_series().clone())
890 .map_err(|e| DtransformError::PolarsError(e))
891 }
892
893 Expression::Variable(var_name) => {
894 let var_df = self.variables.get(var_name)
896 .ok_or_else(|| DtransformError::VariableNotFound(var_name.clone()))?;
897 let col = var_df.get_columns().first()
898 .ok_or_else(|| DtransformError::InvalidOperation(
899 format!("Variable '{}' has no columns", var_name)
900 ))?;
901 Ok(col.as_materialized_series().clone())
902 }
903
904 Expression::BinaryOp { left, op, right } => {
905 let left_series = self.evaluate_expression(left, df)?;
906 let right_series = self.evaluate_expression(right, df)?;
907 self.apply_binary_op(&left_series, op, &right_series, df)
908 }
909
910 Expression::MethodCall { object, method, args } => {
911 let obj_series = self.evaluate_expression(object, df)?;
912 self.apply_method(&obj_series, method, args, df)
913 }
914
915 Expression::Split { string, delimiter, index } => {
916 let string_series = self.evaluate_expression(string, df)?;
918 let delimiter_series = self.evaluate_expression(delimiter, df)?;
919
920 let delim = match delimiter_series.dtype() {
922 polars::datatypes::DataType::String => {
923 delimiter_series.str()
924 .map_err(|_| DtransformError::InvalidOperation("Delimiter must be a string".to_string()))?
925 .get(0)
926 .ok_or_else(|| DtransformError::InvalidOperation("Delimiter is null".to_string()))?
927 .to_string()
928 }
929 _ => return Err(DtransformError::InvalidOperation("Delimiter must be a string".to_string())),
930 };
931
932 let string_ca = string_series.str()
934 .map_err(|_| DtransformError::InvalidOperation("Split can only be applied to string columns".to_string()))?;
935
936 let result: Vec<Option<String>> = string_ca.into_iter().map(|opt_str| {
938 opt_str.and_then(|s| {
939 let parts: Vec<&str> = s.split(&delim).collect();
940 parts.get(*index).map(|&part| part.to_string())
942 })
943 }).collect();
944
945 Ok(Series::new(PlSmallStr::from("split"), result))
946 }
947
948 Expression::Lookup { table, key, on, return_field } => {
949 use crate::parser::ast::LookupField;
950
951 let lookup_df = self.variables.get(table)
953 .ok_or_else(|| DtransformError::VariableNotFound(table.clone()))?;
954
955 let on_col_name = match on {
957 LookupField::Name(name) => name.clone(),
958 LookupField::Position(pos) => {
959 let schema = lookup_df.schema();
960 let col_names: Vec<_> = schema.iter_names().collect();
961 if *pos == 0 || *pos > col_names.len() {
962 return Err(DtransformError::InvalidOperation(format!(
963 "Lookup table '{}' has {} columns, but on=${} was specified",
964 table, col_names.len(), pos
965 )));
966 }
967 col_names[pos - 1].to_string()
968 }
969 };
970
971 let return_col_name = match return_field {
973 LookupField::Name(name) => name.clone(),
974 LookupField::Position(pos) => {
975 let schema = lookup_df.schema();
976 let col_names: Vec<_> = schema.iter_names().collect();
977 if *pos == 0 || *pos > col_names.len() {
978 return Err(DtransformError::InvalidOperation(format!(
979 "Lookup table '{}' has {} columns, but return=${} was specified",
980 table, col_names.len(), pos
981 )));
982 }
983 col_names[pos - 1].to_string()
984 }
985 };
986
987 if !lookup_df.schema().contains(&on_col_name) {
989 return Err(DtransformError::ColumnNotFound(format!(
990 "Lookup table '{}' does not have column '{}' (specified in on=)",
991 table, on_col_name
992 )));
993 }
994 if !lookup_df.schema().contains(&return_col_name) {
995 return Err(DtransformError::ColumnNotFound(format!(
996 "Lookup table '{}' does not have column '{}' (specified in return=)",
997 table, return_col_name
998 )));
999 }
1000
1001 let lookup_key_col = lookup_df.column(&on_col_name)
1003 .map_err(|e| DtransformError::PolarsError(e))?
1004 .as_materialized_series();
1005
1006 let lookup_value_col = lookup_df.column(&return_col_name)
1008 .map_err(|e| DtransformError::PolarsError(e))?
1009 .as_materialized_series();
1010
1011 let key_series = self.evaluate_expression(key, df)?;
1013
1014 use std::collections::HashMap;
1016 use polars::datatypes::DataType;
1017
1018 match (lookup_key_col.dtype(), lookup_value_col.dtype()) {
1019 (DataType::String, DataType::String) => {
1020 let lookup_keys = lookup_key_col.str()
1021 .map_err(|_| DtransformError::TypeMismatch {
1022 expected: "String".to_string(),
1023 got: format!("{:?}", lookup_key_col.dtype()),
1024 })?;
1025 let lookup_values = lookup_value_col.str()
1026 .map_err(|_| DtransformError::TypeMismatch {
1027 expected: "String".to_string(),
1028 got: format!("{:?}", lookup_value_col.dtype()),
1029 })?;
1030
1031 let mut map: HashMap<String, String> = HashMap::new();
1033 for i in 0..lookup_df.height() {
1034 if let (Some(k), Some(v)) = (lookup_keys.get(i), lookup_values.get(i)) {
1035 map.insert(k.to_string(), v.to_string());
1036 }
1037 }
1038
1039 let input_keys = key_series.str()
1041 .map_err(|_| DtransformError::TypeMismatch {
1042 expected: "String".to_string(),
1043 got: format!("{:?}", key_series.dtype()),
1044 })?;
1045
1046 let result: Vec<Option<String>> = input_keys.into_iter()
1047 .map(|opt_key| {
1048 opt_key.and_then(|k| map.get(k).cloned())
1049 })
1050 .collect();
1051
1052 Ok(Series::new(PlSmallStr::from(return_col_name.as_str()), result))
1053 }
1054 (DataType::String, value_dtype) if matches!(
1055 value_dtype,
1056 DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 |
1057 DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 |
1058 DataType::Float32 | DataType::Float64
1059 ) => {
1060 let lookup_keys = lookup_key_col.str()
1061 .map_err(|_| DtransformError::TypeMismatch {
1062 expected: "String".to_string(),
1063 got: format!("{:?}", lookup_key_col.dtype()),
1064 })?;
1065
1066 let lookup_values_f64 = lookup_value_col.cast(&DataType::Float64)
1068 .map_err(|e| DtransformError::PolarsError(e))?;
1069 let lookup_values = lookup_values_f64.f64()
1070 .map_err(|_| DtransformError::InvalidOperation("Failed to cast to Float64".to_string()))?;
1071
1072 let mut map: HashMap<String, f64> = HashMap::new();
1074 for i in 0..lookup_df.height() {
1075 if let (Some(k), Some(v)) = (lookup_keys.get(i), lookup_values.get(i)) {
1076 map.insert(k.to_string(), v);
1077 }
1078 }
1079
1080 let input_keys = key_series.str()
1082 .map_err(|_| DtransformError::TypeMismatch {
1083 expected: "String".to_string(),
1084 got: format!("{:?}", key_series.dtype()),
1085 })?;
1086
1087 let result: Vec<Option<f64>> = input_keys.into_iter()
1088 .map(|opt_key| {
1089 opt_key.and_then(|k| map.get(k).copied())
1090 })
1091 .collect();
1092
1093 Ok(Series::new(PlSmallStr::from(return_col_name.as_str()), result))
1094 }
1095 _ => {
1096 Err(DtransformError::InvalidOperation(
1099 format!(
1100 "Unsupported lookup type combination: key={:?}, value={:?}",
1101 lookup_key_col.dtype(),
1102 lookup_value_col.dtype()
1103 )
1104 ))
1105 }
1106 }
1107 }
1108
1109 Expression::Replace { text, old, new } => {
1110 let text_series = self.evaluate_expression(text, df)?;
1112 let new_series = self.evaluate_expression(new, df)?;
1113
1114 let text_ca = text_series.str()
1116 .map_err(|_| DtransformError::InvalidOperation(
1117 "replace() can only be applied to string columns".to_string()
1118 ))?;
1119
1120 use polars::datatypes::DataType;
1122 let new_str = match new_series.dtype() {
1123 DataType::String => {
1124 new_series.str()
1125 .map_err(|_| DtransformError::InvalidOperation("Replacement text must be a string".to_string()))?
1126 .get(0)
1127 .ok_or_else(|| DtransformError::InvalidOperation("Replacement text is null".to_string()))?
1128 .to_string()
1129 }
1130 _ => return Err(DtransformError::InvalidOperation("Replacement text must be a string".to_string())),
1131 };
1132
1133 match old.as_ref() {
1135 Expression::Regex(pattern) => {
1136 let re = Regex::new(pattern)
1138 .map_err(|e| DtransformError::InvalidOperation(
1139 format!("Invalid regex pattern '{}': {}", pattern, e)
1140 ))?;
1141
1142 let result: Vec<Option<String>> = text_ca.into_iter().map(|opt_str| {
1143 opt_str.map(|s| re.replace_all(s, &new_str).to_string())
1144 }).collect();
1145
1146 Ok(Series::new(PlSmallStr::from("replace"), result))
1147 }
1148 _ => {
1149 let old_series = self.evaluate_expression(old, df)?;
1151 let old_str = match old_series.dtype() {
1152 DataType::String => {
1153 old_series.str()
1154 .map_err(|_| DtransformError::InvalidOperation("Pattern must be a string".to_string()))?
1155 .get(0)
1156 .ok_or_else(|| DtransformError::InvalidOperation("Pattern is null".to_string()))?
1157 .to_string()
1158 }
1159 _ => return Err(DtransformError::InvalidOperation("Pattern must be a string".to_string())),
1160 };
1161
1162 let result: Vec<Option<String>> = text_ca.into_iter().map(|opt_str| {
1163 opt_str.map(|s| s.replace(&old_str, &new_str))
1164 }).collect();
1165
1166 Ok(Series::new(PlSmallStr::from("replace"), result))
1167 }
1168 }
1169 }
1170
1171 Expression::Regex(pattern) => {
1172 Err(DtransformError::InvalidOperation(
1174 format!("Regex pattern '{}' cannot be used directly. Use it with replace() function.", pattern)
1175 ))
1176 }
1177 }
1178 }
1179
1180 fn literal_to_series(&self, lit: &crate::parser::ast::Literal, len: usize) -> Result<Series> {
1181 use crate::parser::ast::Literal as Lit;
1182 match lit {
1183 Lit::Number(n) => Ok(Series::new(PlSmallStr::from("literal"), vec![*n; len])),
1184 Lit::String(s) => Ok(Series::new(PlSmallStr::from("literal"), vec![s.as_str(); len])),
1185 Lit::Boolean(b) => Ok(Series::new(PlSmallStr::from("literal"), vec![*b; len])),
1186 Lit::Null => Ok(Series::new_null(PlSmallStr::from("literal"), len)),
1187 }
1188 }
1189
1190 fn apply_binary_op(&self, left: &Series, op: &BinOp, right: &Series, _df: &DataFrame) -> Result<Series> {
1191 use polars::datatypes::DataType;
1192
1193 let result = match op {
1194 BinOp::Add => {
1195 match (left.dtype(), right.dtype()) {
1197 (DataType::String, DataType::String) => {
1198 let left_str = left.str().map_err(|_| DtransformError::TypeMismatch {
1199 expected: "String".to_string(),
1200 got: format!("{:?}", left.dtype()),
1201 })?;
1202 let right_str = right.str().map_err(|_| DtransformError::TypeMismatch {
1203 expected: "String".to_string(),
1204 got: format!("{:?}", right.dtype()),
1205 })?;
1206
1207 let result: Vec<Option<String>> = left_str.into_iter()
1209 .zip(right_str.into_iter())
1210 .map(|(l, r)| {
1211 match (l, r) {
1212 (Some(ls), Some(rs)) => Some(format!("{}{}", ls, rs)),
1213 _ => None,
1214 }
1215 })
1216 .collect();
1217
1218 Series::new(PlSmallStr::from("concat"), result)
1219 }
1220 _ => (left + right)?,
1222 }
1223 }
1224 BinOp::Sub => (left - right)?,
1225 BinOp::Mul => (left * right)?,
1226 BinOp::Div => (left / right)?,
1227 BinOp::Gt => left.gt(right)?.into_series(),
1228 BinOp::Lt => left.lt(right)?.into_series(),
1229 BinOp::Gte => left.gt_eq(right)?.into_series(),
1230 BinOp::Lte => left.lt_eq(right)?.into_series(),
1231 BinOp::Eq => left.equal(right)?.into_series(),
1232 BinOp::Neq => left.not_equal(right)?.into_series(),
1233 BinOp::And => {
1234 let left_bool = left.bool()?;
1235 let right_bool = right.bool()?;
1236 (left_bool & right_bool).into_series()
1237 }
1238 BinOp::Or => {
1239 let left_bool = left.bool()?;
1240 let right_bool = right.bool()?;
1241 (left_bool | right_bool).into_series()
1242 }
1243 BinOp::In => {
1244 use std::collections::HashSet;
1249 use polars::datatypes::DataType;
1250
1251 match left.dtype() {
1252 DataType::String => {
1253 let left_str = left.str()?;
1254 let right_str = right.str()?;
1255
1256 let right_set: HashSet<Option<&str>> = right_str.into_iter().collect();
1258
1259 let mask: BooleanChunked = left_str
1261 .into_iter()
1262 .map(|val| right_set.contains(&val))
1263 .collect();
1264
1265 mask.into_series()
1266 }
1267 DataType::Int64 | DataType::Int32 | DataType::Float64 | DataType::Float32 => {
1268 let left_f64 = left.cast(&DataType::Float64)?;
1270 let right_f64 = right.cast(&DataType::Float64)?;
1271
1272 let left_num = left_f64.f64()?;
1273 let right_num = right_f64.f64()?;
1274
1275 let right_values: Vec<Option<f64>> = right_num.into_iter().collect();
1277
1278 let mask: BooleanChunked = left_num
1280 .into_iter()
1281 .map(|left_val| {
1282 right_values.iter().any(|right_val| {
1283 match (left_val, right_val) {
1284 (Some(l), Some(r)) => (l - r).abs() < f64::EPSILON,
1285 (None, None) => true,
1286 _ => false,
1287 }
1288 })
1289 })
1290 .collect();
1291
1292 mask.into_series()
1293 }
1294 _ => {
1295 return Err(DtransformError::TypeMismatch {
1296 expected: "String or Number".to_string(),
1297 got: format!("{:?}", left.dtype()),
1298 });
1299 }
1300 }
1301 }
1302 };
1303 Ok(result)
1304 }
1305
1306 fn apply_method(&self, _obj: &Series, method: &str, _args: &[Expression], _df: &DataFrame) -> Result<Series> {
1307 Err(DtransformError::InvalidOperation(format!(
1310 "Method '{}' is not supported. Use function-based operations instead.\n\
1311 Example: mutate(clean = replace(text, 'old', 'new'))",
1312 method
1313 )))
1314 }
1315
1316 pub fn get_variable(&self, name: &str) -> Option<&DataFrame> {
1317 self.variables.get(name)
1318 }
1319
1320 pub fn set_variable(&mut self, name: String, df: DataFrame) {
1321 self.variables.insert(name, df);
1322 }
1323
1324 pub fn remove_variable(&mut self, name: &str) {
1325 self.variables.remove(name);
1326 }
1327
1328 pub fn list_variables(&self) -> Vec<String> {
1329 self.variables.keys().cloned().collect()
1330 }
1331
1332 pub fn get_all_variables(&self) -> HashMap<String, DataFrame> {
1333 self.variables.clone()
1334 }
1335
1336 pub fn restore_variables(&mut self, snapshot: HashMap<String, DataFrame>) {
1337 self.variables = snapshot;
1338 }
1339}