rustframes/dataframe/
core.rs

1use super::Series;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Debug, Clone, PartialEq)]
5pub struct DataFrame {
6    pub columns: Vec<String>,
7    pub data: Vec<Series>,
8}
9
10impl DataFrame {
11    pub fn new(columns: Vec<(String, Series)>) -> Self {
12        if !columns.is_empty() {
13            let first_len = columns[0].1.len();
14            for (name, series) in &columns {
15                if series.len() != first_len {
16                    panic!("All columns must have the same length. Column '{}' has length {}, expected {}", name, series.len(), first_len);
17                }
18            }
19        }
20
21        let (names, series): (Vec<_>, Vec<_>) = columns.into_iter().unzip();
22        DataFrame {
23            columns: names,
24            data: series,
25        }
26    }
27
28    /// Create empty DataFrame with specified column names and types
29    pub fn empty(columns: Vec<(String, SeriesType)>) -> Self {
30        let series: Vec<Series> = columns
31            .iter()
32            .map(|(_, dtype)| match dtype {
33                SeriesType::Int64 => Series::Int64(Vec::new()),
34                SeriesType::Float64 => Series::Float64(Vec::new()),
35                SeriesType::Bool => Series::Bool(Vec::new()),
36                SeriesType::Utf8 => Series::Utf8(Vec::new()),
37            })
38            .collect();
39
40        let names: Vec<String> = columns.into_iter().map(|(name, _)| name).collect();
41        DataFrame {
42            columns: names,
43            data: series,
44        }
45    }
46
47    /// Get number of rows
48    pub fn len(&self) -> usize {
49        if self.data.is_empty() {
50            0
51        } else {
52            self.data[0].len()
53        }
54    }
55
56    /// Check if DataFrame is empty
57    pub fn is_empty(&self) -> bool {
58        self.len() == 0
59    }
60
61    /// Get shape (rows, columns)
62    pub fn shape(&self) -> (usize, usize) {
63        (self.len(), self.columns.len())
64    }
65
66    /// Get first n rows
67    pub fn head(&self, n: usize) -> DataFrame {
68        let new_data: Vec<Series> = self
69            .data
70            .iter()
71            .map(|s| match s {
72                Series::Int64(v) => Series::Int64(v.iter().take(n).cloned().collect()),
73                Series::Float64(v) => Series::Float64(v.iter().take(n).cloned().collect()),
74                Series::Bool(v) => Series::Bool(v.iter().take(n).cloned().collect()),
75                Series::Utf8(v) => Series::Utf8(v.iter().take(n).cloned().collect()),
76            })
77            .collect();
78
79        DataFrame {
80            columns: self.columns.clone(),
81            data: new_data,
82        }
83    }
84
85    /// Get last n rows
86    pub fn tail(&self, n: usize) -> DataFrame {
87        let len = self.len();
88        let start = len.saturating_sub(n);
89
90        let new_data: Vec<Series> = self
91            .data
92            .iter()
93            .map(|s| match s {
94                Series::Int64(v) => Series::Int64(v.iter().skip(start).cloned().collect()),
95                Series::Float64(v) => Series::Float64(v.iter().skip(start).cloned().collect()),
96                Series::Bool(v) => Series::Bool(v.iter().skip(start).cloned().collect()),
97                Series::Utf8(v) => Series::Utf8(v.iter().skip(start).cloned().collect()),
98            })
99            .collect();
100
101        DataFrame {
102            columns: self.columns.clone(),
103            data: new_data,
104        }
105    }
106
107    /// Select specific columns
108    pub fn select(&self, cols: &[&str]) -> DataFrame {
109        let mut new_cols = Vec::new();
110        let mut new_data = Vec::new();
111
112        for col in cols {
113            if let Some(pos) = self.columns.iter().position(|c| c == col) {
114                new_cols.push(self.columns[pos].clone());
115                new_data.push(self.data[pos].clone());
116            } else {
117                panic!("Column '{}' not found", col);
118            }
119        }
120
121        DataFrame {
122            columns: new_cols,
123            data: new_data,
124        }
125    }
126
127    /// Get a single column as a Series
128    pub fn get_column(&self, name: &str) -> Option<&Series> {
129        self.columns
130            .iter()
131            .position(|c| c == name)
132            .map(|pos| &self.data[pos])
133    }
134
135    /// Filter rows based on a boolean mask
136    pub fn filter(&self, mask: &[bool]) -> DataFrame {
137        assert_eq!(
138            mask.len(),
139            self.len(),
140            "Mask length must match DataFrame length"
141        );
142
143        let new_data: Vec<Series> = self
144            .data
145            .iter()
146            .map(|s| match s {
147                Series::Int64(v) => Series::Int64(
148                    v.iter()
149                        .zip(mask)
150                        .filter_map(|(&val, &keep)| if keep { Some(val) } else { None })
151                        .collect(),
152                ),
153                Series::Float64(v) => Series::Float64(
154                    v.iter()
155                        .zip(mask)
156                        .filter_map(|(&val, &keep)| if keep { Some(val) } else { None })
157                        .collect(),
158                ),
159                Series::Bool(v) => Series::Bool(
160                    v.iter()
161                        .zip(mask)
162                        .filter_map(|(&val, &keep)| if keep { Some(val) } else { None })
163                        .collect(),
164                ),
165                Series::Utf8(v) => Series::Utf8(
166                    v.iter()
167                        .zip(mask)
168                        .filter_map(|(val, &keep)| if keep { Some(val.clone()) } else { None })
169                        .collect(),
170                ),
171            })
172            .collect();
173
174        DataFrame {
175            columns: self.columns.clone(),
176            data: new_data,
177        }
178    }
179
180    /// Sort by column
181    pub fn sort_by(&self, column: &str, ascending: bool) -> DataFrame {
182        let col_idx = self
183            .columns
184            .iter()
185            .position(|c| c == column)
186            .expect("Column not found");
187
188        let mut indices: Vec<usize> = (0..self.len()).collect();
189
190        match &self.data[col_idx] {
191            Series::Int64(values) => {
192                indices.sort_by(|&a, &b| {
193                    if ascending {
194                        values[a].cmp(&values[b])
195                    } else {
196                        values[b].cmp(&values[a])
197                    }
198                });
199            }
200            Series::Float64(values) => {
201                indices.sort_by(|&a, &b| {
202                    if ascending {
203                        values[a].partial_cmp(&values[b]).unwrap()
204                    } else {
205                        values[b].partial_cmp(&values[a]).unwrap()
206                    }
207                });
208            }
209            Series::Bool(values) => {
210                indices.sort_by(|&a, &b| {
211                    if ascending {
212                        values[a].cmp(&values[b])
213                    } else {
214                        values[b].cmp(&values[a])
215                    }
216                });
217            }
218            Series::Utf8(values) => {
219                indices.sort_by(|&a, &b| {
220                    if ascending {
221                        values[a].cmp(&values[b])
222                    } else {
223                        values[b].cmp(&values[a])
224                    }
225                });
226            }
227        }
228
229        let new_data: Vec<Series> = self
230            .data
231            .iter()
232            .map(|s| match s {
233                Series::Int64(v) => Series::Int64(indices.iter().map(|&i| v[i]).collect()),
234                Series::Float64(v) => Series::Float64(indices.iter().map(|&i| v[i]).collect()),
235                Series::Bool(v) => Series::Bool(indices.iter().map(|&i| v[i]).collect()),
236                Series::Utf8(v) => Series::Utf8(indices.iter().map(|&i| v[i].clone()).collect()),
237            })
238            .collect();
239
240        DataFrame {
241            columns: self.columns.clone(),
242            data: new_data,
243        }
244    }
245
246    /// Add a new column
247    pub fn with_column(&self, name: String, series: Series) -> DataFrame {
248        assert_eq!(
249            series.len(),
250            self.len(),
251            "New column length must match DataFrame length"
252        );
253
254        let mut new_columns = self.columns.clone();
255        let mut new_data = self.data.clone();
256
257        // Check if column already exists
258        if let Some(pos) = new_columns.iter().position(|c| c == &name) {
259            new_data[pos] = series;
260        } else {
261            new_columns.push(name);
262            new_data.push(series);
263        }
264
265        DataFrame {
266            columns: new_columns,
267            data: new_data,
268        }
269    }
270
271    /// Drop columns
272    pub fn drop(&self, cols: &[&str]) -> DataFrame {
273        let cols_to_drop: HashSet<&str> = cols.iter().cloned().collect();
274        let mut new_columns = Vec::new();
275        let mut new_data = Vec::new();
276
277        for (i, col_name) in self.columns.iter().enumerate() {
278            if !cols_to_drop.contains(col_name.as_str()) {
279                new_columns.push(col_name.clone());
280                new_data.push(self.data[i].clone());
281            }
282        }
283
284        DataFrame {
285            columns: new_columns,
286            data: new_data,
287        }
288    }
289
290    /// Inner join with another DataFrame
291    pub fn join(&self, other: &DataFrame, on: &str, how: JoinType) -> DataFrame {
292        let left_col_idx = self
293            .columns
294            .iter()
295            .position(|c| c == on)
296            .expect("Join column not found in left DataFrame");
297        let right_col_idx = other
298            .columns
299            .iter()
300            .position(|c| c == on)
301            .expect("Join column not found in right DataFrame");
302
303        match how {
304            JoinType::Inner => self.inner_join(other, left_col_idx, right_col_idx, on),
305            JoinType::Left => self.left_join(other, left_col_idx, right_col_idx, on),
306            JoinType::Right => other.left_join(self, right_col_idx, left_col_idx, on),
307            JoinType::Outer => self.outer_join(other, left_col_idx, right_col_idx, on),
308        }
309    }
310
311    fn inner_join(
312        &self,
313        other: &DataFrame,
314        left_col_idx: usize,
315        right_col_idx: usize,
316        _on: &str,
317    ) -> DataFrame {
318        let mut result_columns = self.columns.clone();
319
320        // Add columns from right DataFrame (excluding join column)
321        for (i, col) in other.columns.iter().enumerate() {
322            if i != right_col_idx {
323                let mut new_name = col.clone();
324                if result_columns.contains(&new_name) {
325                    new_name = format!("{}_y", col);
326                }
327                result_columns.push(new_name);
328            }
329        }
330
331        // Build hash map for right DataFrame
332        let mut right_map: HashMap<String, Vec<usize>> = HashMap::new();
333        if let Series::Utf8(right_values) = &other.data[right_col_idx] {
334            for (idx, value) in right_values.iter().enumerate() {
335                right_map.entry(value.clone()).or_default().push(idx);
336            }
337        }
338
339        let mut result_data: Vec<Vec<String>> = vec![Vec::new(); result_columns.len()];
340
341        // Process left DataFrame
342        if let Series::Utf8(left_values) = &self.data[left_col_idx] {
343            for (left_idx, left_value) in left_values.iter().enumerate() {
344                if let Some(right_indices) = right_map.get(left_value) {
345                    for &right_idx in right_indices {
346                        // Add left row
347                        for (col_idx, series) in self.data.iter().enumerate() {
348                            let value = match series {
349                                Series::Int64(v) => v[left_idx].to_string(),
350                                Series::Float64(v) => v[left_idx].to_string(),
351                                Series::Bool(v) => v[left_idx].to_string(),
352                                Series::Utf8(v) => v[left_idx].clone(),
353                            };
354                            result_data[col_idx].push(value);
355                        }
356
357                        // Add right row (excluding join column)
358                        let mut result_col_idx = self.columns.len();
359                        for (col_idx, series) in other.data.iter().enumerate() {
360                            if col_idx != right_col_idx {
361                                let value = match series {
362                                    Series::Int64(v) => v[right_idx].to_string(),
363                                    Series::Float64(v) => v[right_idx].to_string(),
364                                    Series::Bool(v) => v[right_idx].to_string(),
365                                    Series::Utf8(v) => v[right_idx].clone(),
366                                };
367                                result_data[result_col_idx].push(value);
368                                result_col_idx += 1;
369                            }
370                        }
371                    }
372                }
373            }
374        }
375
376        // Convert result to DataFrame
377        let result_series: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
378
379        DataFrame {
380            columns: result_columns,
381            data: result_series,
382        }
383    }
384
385    fn left_join(
386        &self,
387        other: &DataFrame,
388        left_col_idx: usize,
389        right_col_idx: usize,
390        on: &str,
391    ) -> DataFrame {
392        // Similar to inner join but includes all left rows
393        // Implementation would be similar but always include left rows, padding with nulls
394        self.inner_join(other, left_col_idx, right_col_idx, on) // Simplified for now
395    }
396
397    fn outer_join(
398        &self,
399        other: &DataFrame,
400        left_col_idx: usize,
401        right_col_idx: usize,
402        on: &str,
403    ) -> DataFrame {
404        // Full outer join - includes all rows from both DataFrames
405        // Implementation would combine left and right joins
406        self.inner_join(other, left_col_idx, right_col_idx, on) // Simplified for now
407    }
408
409    /// Describe numeric columns (basic statistics)
410    pub fn describe(&self) -> DataFrame {
411        let mut stats_data: Vec<(String, Series)> = Vec::new();
412        let stats = vec!["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
413
414        for stat in stats {
415            let mut values = Vec::new();
416
417            for series in &self.data {
418                let value = match series {
419                    Series::Float64(v) if !v.is_empty() => match stat {
420                        "count" => v.len() as f64,
421                        "mean" => v.iter().sum::<f64>() / v.len() as f64,
422                        "std" => {
423                            let mean = v.iter().sum::<f64>() / v.len() as f64;
424                            let variance =
425                                v.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / v.len() as f64;
426                            variance.sqrt()
427                        }
428                        "min" => v.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
429                        "max" => v.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
430                        "25%" | "50%" | "75%" => {
431                            let mut sorted = v.clone();
432                            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
433                            let idx = match stat {
434                                "25%" => sorted.len() / 4,
435                                "50%" => sorted.len() / 2,
436                                "75%" => 3 * sorted.len() / 4,
437                                _ => 0,
438                            };
439                            sorted.get(idx).copied().unwrap_or(0.0)
440                        }
441                        _ => 0.0,
442                    },
443                    Series::Int64(v) if !v.is_empty() => match stat {
444                        "count" => v.len() as f64,
445                        "mean" => v.iter().sum::<i64>() as f64 / v.len() as f64,
446                        "std" => {
447                            let mean = v.iter().sum::<i64>() as f64 / v.len() as f64;
448                            let variance =
449                                v.iter().map(|&x| (x as f64 - mean).powi(2)).sum::<f64>()
450                                    / v.len() as f64;
451                            variance.sqrt()
452                        }
453                        "min" => *v.iter().min().unwrap() as f64,
454                        "max" => *v.iter().max().unwrap() as f64,
455                        _ => 0.0,
456                    },
457                    _ => f64::NAN, // Non-numeric or empty series
458                };
459
460                values.push(value);
461            }
462
463            stats_data.push((stat.to_string(), Series::Float64(values)));
464        }
465
466        DataFrame::new(stats_data)
467    }
468}
469
470#[derive(Debug, Clone, PartialEq)]
471pub enum JoinType {
472    Inner,
473    Left,
474    Right,
475    Outer,
476}
477
478#[derive(Debug, Clone, PartialEq)]
479pub enum SeriesType {
480    Int64,
481    Float64,
482    Bool,
483    Utf8,
484}