Skip to main content

ggplot_rs/data/
dataframe.rs

1use indexmap::IndexMap;
2
3use super::Value;
4
5/// Internal columnar DataFrame for data storage and manipulation.
6#[derive(Clone, Debug)]
7pub struct DataFrame {
8    columns: IndexMap<String, Vec<Value>>,
9    nrows: usize,
10}
11
12impl DataFrame {
13    /// Create an empty DataFrame.
14    pub fn new() -> Self {
15        DataFrame {
16            columns: IndexMap::new(),
17            nrows: 0,
18        }
19    }
20
21    /// Get a column by name.
22    pub fn column(&self, name: &str) -> Option<&[Value]> {
23        self.columns.get(name).map(|v| v.as_slice())
24    }
25
26    /// Get number of rows.
27    pub fn nrows(&self) -> usize {
28        self.nrows
29    }
30
31    /// Get number of columns.
32    pub fn ncols(&self) -> usize {
33        self.columns.len()
34    }
35
36    /// Get column names.
37    pub fn column_names(&self) -> Vec<&str> {
38        self.columns.keys().map(|s| s.as_str()).collect()
39    }
40
41    /// Check if a column exists.
42    pub fn has_column(&self, name: &str) -> bool {
43        self.columns.contains_key(name)
44    }
45
46    /// Add a column. Panics if length doesn't match existing rows (unless empty).
47    pub fn add_column(&mut self, name: String, values: Vec<Value>) {
48        if self.columns.is_empty() {
49            self.nrows = values.len();
50        } else {
51            assert_eq!(
52                values.len(),
53                self.nrows,
54                "Column '{}' has {} values but DataFrame has {} rows",
55                name,
56                values.len(),
57                self.nrows
58            );
59        }
60        self.columns.insert(name, values);
61    }
62
63    /// Get a mutable reference to a column.
64    pub fn column_mut(&mut self, name: &str) -> Option<&mut Vec<Value>> {
65        self.columns.get_mut(name)
66    }
67
68    /// Group by one or more key columns. Returns a Vec of DataFrames, one per group.
69    pub fn group_by(&self, keys: &[&str]) -> Vec<DataFrame> {
70        if self.nrows == 0 {
71            return vec![];
72        }
73
74        // Build group keys for each row
75        let mut group_map: IndexMap<Vec<String>, Vec<usize>> = IndexMap::new();
76
77        for i in 0..self.nrows {
78            let key: Vec<String> = keys
79                .iter()
80                .map(|k| {
81                    self.columns
82                        .get(*k)
83                        .map(|col| col[i].to_group_key())
84                        .unwrap_or_else(|| "NA".to_string())
85                })
86                .collect();
87            group_map.entry(key).or_default().push(i);
88        }
89
90        group_map
91            .into_values()
92            .map(|indices| {
93                let mut df = DataFrame::new();
94                for (name, col) in &self.columns {
95                    let values: Vec<Value> = indices.iter().map(|&i| col[i].clone()).collect();
96                    df.add_column(name.clone(), values);
97                }
98                df
99            })
100            .collect()
101    }
102
103    /// Vertically stack another DataFrame onto this one.
104    pub fn vstack(&mut self, other: &DataFrame) {
105        if other.nrows == 0 {
106            return;
107        }
108        if self.columns.is_empty() {
109            *self = other.clone();
110            return;
111        }
112
113        // Add columns from other that we have
114        for (name, col) in &self.columns {
115            if let Some(other_col) = other.columns.get(name) {
116                // Will extend below
117                let _ = (col, other_col);
118            }
119        }
120
121        // Also add columns from other that we don't have (fill with NA)
122        for name in other.columns.keys() {
123            if !self.columns.contains_key(name) {
124                self.columns
125                    .insert(name.clone(), vec![Value::Na; self.nrows]);
126            }
127        }
128
129        let old_nrows = self.nrows;
130        self.nrows += other.nrows;
131
132        for (name, col) in &mut self.columns {
133            if let Some(other_col) = other.columns.get(name) {
134                col.extend(other_col.iter().cloned());
135            } else {
136                col.extend(std::iter::repeat_with(|| Value::Na).take(other.nrows));
137            }
138            debug_assert_eq!(col.len(), old_nrows + other.nrows);
139        }
140    }
141
142    /// Select a subset of columns.
143    pub fn select(&self, columns: &[&str]) -> DataFrame {
144        let mut df = DataFrame::new();
145        for &col_name in columns {
146            if let Some(col) = self.columns.get(col_name) {
147                df.add_column(col_name.to_string(), col.clone());
148            }
149        }
150        df
151    }
152
153    /// Get a single row as a map.
154    pub fn row(&self, idx: usize) -> IndexMap<String, Value> {
155        assert!(
156            idx < self.nrows,
157            "Row index {idx} out of bounds ({} rows)",
158            self.nrows
159        );
160        let mut map = IndexMap::new();
161        for (name, col) in &self.columns {
162            map.insert(name.clone(), col[idx].clone());
163        }
164        map
165    }
166
167    /// Sort by a column (ascending). Returns a new DataFrame.
168    pub fn sort_by(&self, column: &str) -> DataFrame {
169        let col = match self.columns.get(column) {
170            Some(c) => c,
171            None => return self.clone(),
172        };
173
174        let mut indices: Vec<usize> = (0..self.nrows).collect();
175        indices.sort_by(|&a, &b| {
176            let va = col[a].as_f64().unwrap_or(f64::NAN);
177            let vb = col[b].as_f64().unwrap_or(f64::NAN);
178            va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
179        });
180
181        let mut df = DataFrame::new();
182        for (name, c) in &self.columns {
183            let values: Vec<Value> = indices.iter().map(|&i| c[i].clone()).collect();
184            df.add_column(name.clone(), values);
185        }
186        df
187    }
188
189    /// Create from rows (list of maps).
190    pub fn from_rows(rows: Vec<IndexMap<String, Value>>) -> Self {
191        if rows.is_empty() {
192            return DataFrame::new();
193        }
194
195        // Collect all column names from all rows
196        let mut col_names: IndexMap<String, ()> = IndexMap::new();
197        for row in &rows {
198            for key in row.keys() {
199                col_names.entry(key.clone()).or_default();
200            }
201        }
202
203        let mut df = DataFrame::new();
204        for name in col_names.keys() {
205            let values: Vec<Value> = rows
206                .iter()
207                .map(|row| row.get(name).cloned().unwrap_or(Value::Na))
208                .collect();
209            df.add_column(name.clone(), values);
210        }
211        df
212    }
213
214    /// Get all unique values in a column.
215    pub fn unique_values(&self, column: &str) -> Vec<Value> {
216        let col = match self.columns.get(column) {
217            Some(c) => c,
218            None => return vec![],
219        };
220        let mut seen: Vec<String> = Vec::new();
221        let mut result = Vec::new();
222        for v in col {
223            let key = v.to_group_key();
224            if !seen.contains(&key) {
225                seen.push(key);
226                result.push(v.clone());
227            }
228        }
229        result
230    }
231}
232
233impl DataFrame {
234    /// Load a DataFrame from a CSV file.
235    /// First row is treated as column headers.
236    /// Values are parsed as Float if possible, otherwise kept as strings.
237    /// The literal string "NA" is parsed as Value::Na.
238    pub fn from_csv(path: &str) -> Result<Self, std::io::Error> {
239        let content = std::fs::read_to_string(path)?;
240        let mut lines = content.lines();
241
242        let header = match lines.next() {
243            Some(h) => h,
244            None => return Ok(DataFrame::new()),
245        };
246
247        let col_names: Vec<&str> = header.split(',').map(|s| s.trim()).collect();
248        let mut columns: Vec<Vec<Value>> = vec![Vec::new(); col_names.len()];
249
250        for line in lines {
251            let line = line.trim();
252            if line.is_empty() {
253                continue;
254            }
255            let fields: Vec<&str> = line.split(',').collect();
256            for (i, field) in fields.iter().enumerate() {
257                if i >= col_names.len() {
258                    continue;
259                }
260                let field = field.trim();
261                let val = if field == "NA" || field == "na" {
262                    Value::Na
263                } else if let Ok(f) = field.parse::<f64>() {
264                    Value::Float(f)
265                } else {
266                    Value::Str(field.to_string())
267                };
268                columns[i].push(val);
269            }
270            // Pad missing columns with NA
271            for col in columns.iter_mut().skip(fields.len()) {
272                col.push(Value::Na);
273            }
274        }
275
276        let mut df = DataFrame::new();
277        for (i, name) in col_names.iter().enumerate() {
278            if !columns[i].is_empty() {
279                df.add_column(name.to_string(), std::mem::take(&mut columns[i]));
280            }
281        }
282
283        Ok(df)
284    }
285}
286
287impl Default for DataFrame {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_add_column_and_access() {
299        let mut df = DataFrame::new();
300        df.add_column("x".into(), vec![Value::Float(1.0), Value::Float(2.0)]);
301        df.add_column("y".into(), vec![Value::Float(3.0), Value::Float(4.0)]);
302
303        assert_eq!(df.nrows(), 2);
304        assert_eq!(df.ncols(), 2);
305        assert!(df.has_column("x"));
306        assert!(!df.has_column("z"));
307    }
308
309    #[test]
310    fn test_group_by() {
311        let mut df = DataFrame::new();
312        df.add_column(
313            "cat".into(),
314            vec![
315                Value::Str("a".into()),
316                Value::Str("b".into()),
317                Value::Str("a".into()),
318            ],
319        );
320        df.add_column(
321            "val".into(),
322            vec![Value::Float(1.0), Value::Float(2.0), Value::Float(3.0)],
323        );
324
325        let groups = df.group_by(&["cat"]);
326        assert_eq!(groups.len(), 2);
327        assert_eq!(groups[0].nrows(), 2); // "a" group
328        assert_eq!(groups[1].nrows(), 1); // "b" group
329    }
330
331    #[test]
332    fn test_vstack() {
333        let mut df1 = DataFrame::new();
334        df1.add_column("x".into(), vec![Value::Float(1.0)]);
335
336        let mut df2 = DataFrame::new();
337        df2.add_column("x".into(), vec![Value::Float(2.0)]);
338
339        df1.vstack(&df2);
340        assert_eq!(df1.nrows(), 2);
341    }
342
343    #[test]
344    fn test_sort_by() {
345        let mut df = DataFrame::new();
346        df.add_column(
347            "x".into(),
348            vec![Value::Float(3.0), Value::Float(1.0), Value::Float(2.0)],
349        );
350        let sorted = df.sort_by("x");
351        let col = sorted.column("x").unwrap();
352        assert_eq!(col[0].as_f64(), Some(1.0));
353        assert_eq!(col[1].as_f64(), Some(2.0));
354        assert_eq!(col[2].as_f64(), Some(3.0));
355    }
356}