1use super::Series;
2use std::collections::{HashMap, HashSet};
3
4#[derive(Debug, Clone, PartialEq)]
5pub struct DataFrame {
6 pub columns: Vec<String>,
7 pub data: Vec<Series>,
8}
9
10impl DataFrame {
11 pub fn new(columns: Vec<(String, Series)>) -> Self {
12 if !columns.is_empty() {
13 let first_len = columns[0].1.len();
14 for (name, series) in &columns {
15 if series.len() != first_len {
16 panic!("All columns must have the same length. Column '{}' has length {}, expected {}", name, series.len(), first_len);
17 }
18 }
19 }
20
21 let (names, series): (Vec<_>, Vec<_>) = columns.into_iter().unzip();
22 DataFrame {
23 columns: names,
24 data: series,
25 }
26 }
27
28 pub fn empty(columns: Vec<(String, SeriesType)>) -> Self {
30 let series: Vec<Series> = columns
31 .iter()
32 .map(|(_, dtype)| match dtype {
33 SeriesType::Int64 => Series::Int64(Vec::new()),
34 SeriesType::Float64 => Series::Float64(Vec::new()),
35 SeriesType::Bool => Series::Bool(Vec::new()),
36 SeriesType::Utf8 => Series::Utf8(Vec::new()),
37 })
38 .collect();
39
40 let names: Vec<String> = columns.into_iter().map(|(name, _)| name).collect();
41 DataFrame {
42 columns: names,
43 data: series,
44 }
45 }
46
47 pub fn len(&self) -> usize {
49 if self.data.is_empty() {
50 0
51 } else {
52 self.data[0].len()
53 }
54 }
55
56 pub fn is_empty(&self) -> bool {
58 self.len() == 0
59 }
60
61 pub fn shape(&self) -> (usize, usize) {
63 (self.len(), self.columns.len())
64 }
65
66 pub fn head(&self, n: usize) -> DataFrame {
68 let new_data: Vec<Series> = self
69 .data
70 .iter()
71 .map(|s| match s {
72 Series::Int64(v) => Series::Int64(v.iter().take(n).cloned().collect()),
73 Series::Float64(v) => Series::Float64(v.iter().take(n).cloned().collect()),
74 Series::Bool(v) => Series::Bool(v.iter().take(n).cloned().collect()),
75 Series::Utf8(v) => Series::Utf8(v.iter().take(n).cloned().collect()),
76 })
77 .collect();
78
79 DataFrame {
80 columns: self.columns.clone(),
81 data: new_data,
82 }
83 }
84
85 pub fn tail(&self, n: usize) -> DataFrame {
87 let len = self.len();
88 let start = len.saturating_sub(n);
89
90 let new_data: Vec<Series> = self
91 .data
92 .iter()
93 .map(|s| match s {
94 Series::Int64(v) => Series::Int64(v.iter().skip(start).cloned().collect()),
95 Series::Float64(v) => Series::Float64(v.iter().skip(start).cloned().collect()),
96 Series::Bool(v) => Series::Bool(v.iter().skip(start).cloned().collect()),
97 Series::Utf8(v) => Series::Utf8(v.iter().skip(start).cloned().collect()),
98 })
99 .collect();
100
101 DataFrame {
102 columns: self.columns.clone(),
103 data: new_data,
104 }
105 }
106
107 pub fn select(&self, cols: &[&str]) -> DataFrame {
109 let mut new_cols = Vec::new();
110 let mut new_data = Vec::new();
111
112 for col in cols {
113 if let Some(pos) = self.columns.iter().position(|c| c == col) {
114 new_cols.push(self.columns[pos].clone());
115 new_data.push(self.data[pos].clone());
116 } else {
117 panic!("Column '{}' not found", col);
118 }
119 }
120
121 DataFrame {
122 columns: new_cols,
123 data: new_data,
124 }
125 }
126
127 pub fn get_column(&self, name: &str) -> Option<&Series> {
129 self.columns
130 .iter()
131 .position(|c| c == name)
132 .map(|pos| &self.data[pos])
133 }
134
135 pub fn filter(&self, mask: &[bool]) -> DataFrame {
137 assert_eq!(
138 mask.len(),
139 self.len(),
140 "Mask length must match DataFrame length"
141 );
142
143 let new_data: Vec<Series> = self
144 .data
145 .iter()
146 .map(|s| match s {
147 Series::Int64(v) => Series::Int64(
148 v.iter()
149 .zip(mask)
150 .filter_map(|(&val, &keep)| if keep { Some(val) } else { None })
151 .collect(),
152 ),
153 Series::Float64(v) => Series::Float64(
154 v.iter()
155 .zip(mask)
156 .filter_map(|(&val, &keep)| if keep { Some(val) } else { None })
157 .collect(),
158 ),
159 Series::Bool(v) => Series::Bool(
160 v.iter()
161 .zip(mask)
162 .filter_map(|(&val, &keep)| if keep { Some(val) } else { None })
163 .collect(),
164 ),
165 Series::Utf8(v) => Series::Utf8(
166 v.iter()
167 .zip(mask)
168 .filter_map(|(val, &keep)| if keep { Some(val.clone()) } else { None })
169 .collect(),
170 ),
171 })
172 .collect();
173
174 DataFrame {
175 columns: self.columns.clone(),
176 data: new_data,
177 }
178 }
179
180 pub fn sort_by(&self, column: &str, ascending: bool) -> DataFrame {
182 let col_idx = self
183 .columns
184 .iter()
185 .position(|c| c == column)
186 .expect("Column not found");
187
188 let mut indices: Vec<usize> = (0..self.len()).collect();
189
190 match &self.data[col_idx] {
191 Series::Int64(values) => {
192 indices.sort_by(|&a, &b| {
193 if ascending {
194 values[a].cmp(&values[b])
195 } else {
196 values[b].cmp(&values[a])
197 }
198 });
199 }
200 Series::Float64(values) => {
201 indices.sort_by(|&a, &b| {
202 if ascending {
203 values[a].partial_cmp(&values[b]).unwrap()
204 } else {
205 values[b].partial_cmp(&values[a]).unwrap()
206 }
207 });
208 }
209 Series::Bool(values) => {
210 indices.sort_by(|&a, &b| {
211 if ascending {
212 values[a].cmp(&values[b])
213 } else {
214 values[b].cmp(&values[a])
215 }
216 });
217 }
218 Series::Utf8(values) => {
219 indices.sort_by(|&a, &b| {
220 if ascending {
221 values[a].cmp(&values[b])
222 } else {
223 values[b].cmp(&values[a])
224 }
225 });
226 }
227 }
228
229 let new_data: Vec<Series> = self
230 .data
231 .iter()
232 .map(|s| match s {
233 Series::Int64(v) => Series::Int64(indices.iter().map(|&i| v[i]).collect()),
234 Series::Float64(v) => Series::Float64(indices.iter().map(|&i| v[i]).collect()),
235 Series::Bool(v) => Series::Bool(indices.iter().map(|&i| v[i]).collect()),
236 Series::Utf8(v) => Series::Utf8(indices.iter().map(|&i| v[i].clone()).collect()),
237 })
238 .collect();
239
240 DataFrame {
241 columns: self.columns.clone(),
242 data: new_data,
243 }
244 }
245
246 pub fn with_column(&self, name: String, series: Series) -> DataFrame {
248 assert_eq!(
249 series.len(),
250 self.len(),
251 "New column length must match DataFrame length"
252 );
253
254 let mut new_columns = self.columns.clone();
255 let mut new_data = self.data.clone();
256
257 if let Some(pos) = new_columns.iter().position(|c| c == &name) {
259 new_data[pos] = series;
260 } else {
261 new_columns.push(name);
262 new_data.push(series);
263 }
264
265 DataFrame {
266 columns: new_columns,
267 data: new_data,
268 }
269 }
270
271 pub fn drop(&self, cols: &[&str]) -> DataFrame {
273 let cols_to_drop: HashSet<&str> = cols.iter().cloned().collect();
274 let mut new_columns = Vec::new();
275 let mut new_data = Vec::new();
276
277 for (i, col_name) in self.columns.iter().enumerate() {
278 if !cols_to_drop.contains(col_name.as_str()) {
279 new_columns.push(col_name.clone());
280 new_data.push(self.data[i].clone());
281 }
282 }
283
284 DataFrame {
285 columns: new_columns,
286 data: new_data,
287 }
288 }
289
290 pub fn join(&self, other: &DataFrame, on: &str, how: JoinType) -> DataFrame {
292 let left_col_idx = self
293 .columns
294 .iter()
295 .position(|c| c == on)
296 .expect("Join column not found in left DataFrame");
297 let right_col_idx = other
298 .columns
299 .iter()
300 .position(|c| c == on)
301 .expect("Join column not found in right DataFrame");
302
303 match how {
304 JoinType::Inner => self.inner_join(other, left_col_idx, right_col_idx, on),
305 JoinType::Left => self.left_join(other, left_col_idx, right_col_idx, on),
306 JoinType::Right => other.left_join(self, right_col_idx, left_col_idx, on),
307 JoinType::Outer => self.outer_join(other, left_col_idx, right_col_idx, on),
308 }
309 }
310
311 fn inner_join(
312 &self,
313 other: &DataFrame,
314 left_col_idx: usize,
315 right_col_idx: usize,
316 _on: &str,
317 ) -> DataFrame {
318 let mut result_columns = self.columns.clone();
319
320 for (i, col) in other.columns.iter().enumerate() {
322 if i != right_col_idx {
323 let mut new_name = col.clone();
324 if result_columns.contains(&new_name) {
325 new_name = format!("{}_y", col);
326 }
327 result_columns.push(new_name);
328 }
329 }
330
331 let mut right_map: HashMap<String, Vec<usize>> = HashMap::new();
333 if let Series::Utf8(right_values) = &other.data[right_col_idx] {
334 for (idx, value) in right_values.iter().enumerate() {
335 right_map.entry(value.clone()).or_default().push(idx);
336 }
337 }
338
339 let mut result_data: Vec<Vec<String>> = vec![Vec::new(); result_columns.len()];
340
341 if let Series::Utf8(left_values) = &self.data[left_col_idx] {
343 for (left_idx, left_value) in left_values.iter().enumerate() {
344 if let Some(right_indices) = right_map.get(left_value) {
345 for &right_idx in right_indices {
346 for (col_idx, series) in self.data.iter().enumerate() {
348 let value = match series {
349 Series::Int64(v) => v[left_idx].to_string(),
350 Series::Float64(v) => v[left_idx].to_string(),
351 Series::Bool(v) => v[left_idx].to_string(),
352 Series::Utf8(v) => v[left_idx].clone(),
353 };
354 result_data[col_idx].push(value);
355 }
356
357 let mut result_col_idx = self.columns.len();
359 for (col_idx, series) in other.data.iter().enumerate() {
360 if col_idx != right_col_idx {
361 let value = match series {
362 Series::Int64(v) => v[right_idx].to_string(),
363 Series::Float64(v) => v[right_idx].to_string(),
364 Series::Bool(v) => v[right_idx].to_string(),
365 Series::Utf8(v) => v[right_idx].clone(),
366 };
367 result_data[result_col_idx].push(value);
368 result_col_idx += 1;
369 }
370 }
371 }
372 }
373 }
374 }
375
376 let result_series: Vec<Series> = result_data.into_iter().map(Series::Utf8).collect();
378
379 DataFrame {
380 columns: result_columns,
381 data: result_series,
382 }
383 }
384
385 fn left_join(
386 &self,
387 other: &DataFrame,
388 left_col_idx: usize,
389 right_col_idx: usize,
390 on: &str,
391 ) -> DataFrame {
392 self.inner_join(other, left_col_idx, right_col_idx, on) }
396
397 fn outer_join(
398 &self,
399 other: &DataFrame,
400 left_col_idx: usize,
401 right_col_idx: usize,
402 on: &str,
403 ) -> DataFrame {
404 self.inner_join(other, left_col_idx, right_col_idx, on) }
408
409 pub fn describe(&self) -> DataFrame {
411 let mut stats_data: Vec<(String, Series)> = Vec::new();
412 let stats = vec!["count", "mean", "std", "min", "25%", "50%", "75%", "max"];
413
414 for stat in stats {
415 let mut values = Vec::new();
416
417 for series in &self.data {
418 let value = match series {
419 Series::Float64(v) if !v.is_empty() => match stat {
420 "count" => v.len() as f64,
421 "mean" => v.iter().sum::<f64>() / v.len() as f64,
422 "std" => {
423 let mean = v.iter().sum::<f64>() / v.len() as f64;
424 let variance =
425 v.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / v.len() as f64;
426 variance.sqrt()
427 }
428 "min" => v.iter().fold(f64::INFINITY, |a, &b| a.min(b)),
429 "max" => v.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b)),
430 "25%" | "50%" | "75%" => {
431 let mut sorted = v.clone();
432 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
433 let idx = match stat {
434 "25%" => sorted.len() / 4,
435 "50%" => sorted.len() / 2,
436 "75%" => 3 * sorted.len() / 4,
437 _ => 0,
438 };
439 sorted.get(idx).copied().unwrap_or(0.0)
440 }
441 _ => 0.0,
442 },
443 Series::Int64(v) if !v.is_empty() => match stat {
444 "count" => v.len() as f64,
445 "mean" => v.iter().sum::<i64>() as f64 / v.len() as f64,
446 "std" => {
447 let mean = v.iter().sum::<i64>() as f64 / v.len() as f64;
448 let variance =
449 v.iter().map(|&x| (x as f64 - mean).powi(2)).sum::<f64>()
450 / v.len() as f64;
451 variance.sqrt()
452 }
453 "min" => *v.iter().min().unwrap() as f64,
454 "max" => *v.iter().max().unwrap() as f64,
455 _ => 0.0,
456 },
457 _ => f64::NAN, };
459
460 values.push(value);
461 }
462
463 stats_data.push((stat.to_string(), Series::Float64(values)));
464 }
465
466 DataFrame::new(stats_data)
467 }
468}
469
470#[derive(Debug, Clone, PartialEq)]
471pub enum JoinType {
472 Inner,
473 Left,
474 Right,
475 Outer,
476}
477
478#[derive(Debug, Clone, PartialEq)]
479pub enum SeriesType {
480 Int64,
481 Float64,
482 Bool,
483 Utf8,
484}