1pub mod array;
55pub mod dataframe;
56
57pub use array::Array;
59pub use dataframe::core::JoinType;
60pub use dataframe::{DataFrame, Series};
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use std::io::Write;
66 use tempfile::NamedTempFile;
67
68 #[test]
69 fn test_array_n_dimensional() {
70 let arr = Array::from_vec((0..24).map(|x| x as f64).collect(), vec![2, 3, 4]);
72 assert_eq!(arr.shape, vec![2, 3, 4]);
73 assert_eq!(arr.ndim(), 3);
74 assert_eq!(arr[&[0, 0, 0][..]], 0.0);
75 assert_eq!(arr[&[1, 2, 3][..]], 23.0);
76
77 let reshaped = arr.reshape(vec![6, 4]);
79 assert_eq!(reshaped.shape, vec![6, 4]);
80 assert_eq!(reshaped.data.len(), 24);
81 }
82
83 #[test]
84 fn test_array_broadcasting() {
85 let arr1 = Array::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
87 let arr2 = Array::from_vec(vec![10.0, 20.0], vec![2, 1]);
88
89 if let Some(result) = arr1.add_broadcast(&arr2) {
91 assert_eq!(result.shape, vec![2, 3]);
92 assert_eq!(result[&[0, 0][..]], 11.0); assert_eq!(result[&[1, 2][..]], 23.0); }
95
96 let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
98 let scaled = &arr + 5.0;
99 assert_eq!(scaled.data, vec![6.0, 7.0, 8.0, 9.0]);
100 }
101
102 #[test]
103 fn test_array_reductions() {
104 let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
105
106 assert_eq!(arr.sum(), 21.0);
108 assert_eq!(arr.mean(), 3.5);
109 assert_eq!(arr.max(), 6.0);
110 assert_eq!(arr.min(), 1.0);
111
112 let sum_axis_0 = arr.sum_axis(0);
114 assert_eq!(sum_axis_0.shape, vec![3]);
115 assert_eq!(sum_axis_0.data, vec![5.0, 7.0, 9.0]); let mean_axis_1 = arr.mean_axis(1);
118 assert_eq!(mean_axis_1.shape, vec![2]);
119 assert_eq!(mean_axis_1.data, vec![2.0, 5.0]); }
121
122 #[test]
123 fn test_linear_algebra() {
124 let matrix = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
125
126 assert_eq!(matrix.det(), -2.0);
128
129 assert_eq!(matrix.trace(), 5.0);
131
132 let other = Array::from_vec(vec![2.0, 0.0, 1.0, 3.0], vec![2, 2]);
134 let product = matrix.dot(&other);
135 assert_eq!(product.data, vec![4.0, 6.0, 10.0, 12.0]);
136
137 if let Some(inv) = matrix.inv() {
139 let should_be_identity = matrix.dot(&inv);
141 assert!((should_be_identity[(0, 0)] - 1.0).abs() < 1e-10);
142 assert!((should_be_identity[(1, 1)] - 1.0).abs() < 1e-10);
143 assert!(should_be_identity[(0, 1)].abs() < 1e-10);
144 assert!(should_be_identity[(1, 0)].abs() < 1e-10);
145 }
146
147 let (q, r) = matrix.qr();
149 assert_eq!(q.shape, vec![2, 2]);
150 assert_eq!(r.shape, vec![2, 2]);
151
152 let qt = q.transpose();
154 let should_be_identity = q.dot(&qt);
155 assert!((should_be_identity[(0, 0)] - 1.0).abs() < 1e-10);
156 assert!((should_be_identity[(1, 1)] - 1.0).abs() < 1e-10);
157 }
158
159 #[test]
160 fn test_dataframe_enhanced() {
161 let df = DataFrame::new(vec![
162 ("id".to_string(), Series::from(vec![1, 2, 3, 4])),
163 (
164 "name".to_string(),
165 Series::from(vec!["Alice", "Bob", "Charlie", "Diana"]),
166 ),
167 (
168 "score".to_string(),
169 Series::from(vec![85.5, 92.0, 78.5, 88.0]),
170 ),
171 (
172 "active".to_string(),
173 Series::from(vec![true, true, false, true]),
174 ),
175 ]);
176
177 assert_eq!(df.shape(), (4, 4));
179 assert_eq!(df.len(), 4);
180 assert!(!df.is_empty());
181
182 let head = df.head(2);
184 assert_eq!(head.len(), 2);
185
186 let tail = df.tail(2);
187 assert_eq!(tail.len(), 2);
188
189 let mask = vec![true, false, true, false];
191 let filtered = df.filter(&mask);
192 assert_eq!(filtered.len(), 2);
193
194 let sorted = df.sort_by("score", true); if let Some(Series::Float64(scores)) = sorted.get_column("score") {
197 assert!(scores[0] < scores[1]); }
199
200 let with_bonus = df.with_column(
202 "bonus".to_string(),
203 Series::from(vec![100.0, 150.0, 75.0, 120.0]),
204 );
205 assert_eq!(with_bonus.shape().1, 5); let dropped = df.drop(&["active"]);
208 assert_eq!(dropped.shape().1, 3); }
210
211 #[test]
212 fn test_groupby_enhanced() {
213 let df = DataFrame::new(vec![
214 (
215 "department".to_string(),
216 Series::from(vec!["IT", "HR", "IT", "Finance", "HR"]),
217 ),
218 (
219 "salary".to_string(),
220 Series::from(vec![75000, 65000, 80000, 70000, 68000]),
221 ),
222 ("experience".to_string(), Series::from(vec![3, 5, 7, 4, 6])),
223 ]);
224
225 let grouped = df.groupby("department");
226
227 let counts = grouped.count();
229 assert_eq!(counts.len(), 3); let sums = grouped.sum();
233 assert_eq!(sums.columns.len(), 3); let means = grouped.mean();
237 assert_eq!(means.columns.len(), 3);
238
239 let first = grouped.first();
241 assert_eq!(first.len(), 3);
242
243 let last = grouped.last();
244 assert_eq!(last.len(), 3);
245 }
246
247 #[test]
248 fn test_joins() {
249 let left = DataFrame::new(vec![
250 ("id".to_string(), Series::from(vec!["1", "2", "3"])),
251 (
252 "name".to_string(),
253 Series::from(vec!["Alice", "Bob", "Charlie"]),
254 ),
255 ]);
256
257 let right = DataFrame::new(vec![
258 ("id".to_string(), Series::from(vec!["1", "2", "4"])),
259 ("score".to_string(), Series::from(vec!["85", "92", "78"])),
260 ]);
261
262 let joined = left.join(&right, "id", JoinType::Inner);
264 assert_eq!(joined.len(), 2);
265 assert_eq!(joined.columns.len(), 3); }
267
268 #[test]
269 fn test_csv_io_with_inference() -> Result<(), Box<dyn std::error::Error>> {
270 let mut temp_file = NamedTempFile::new()?;
272 writeln!(temp_file, "name,age,salary,active")?;
273 writeln!(temp_file, "Alice,25,50000.5,true")?;
274 writeln!(temp_file, "Bob,30,60000.0,false")?;
275 writeln!(temp_file, "Charlie,35,70000.25,true")?;
276
277 let df = DataFrame::from_csv(temp_file.path().to_str().unwrap())?;
278
279 assert_eq!(df.shape(), (3, 4));
281
282 match df.get_column("age") {
284 Some(Series::Int64(_)) => {} _ => panic!("Age should be inferred as Int64"),
286 }
287
288 match df.get_column("salary") {
289 Some(Series::Float64(_)) => {} _ => panic!("Salary should be inferred as Float64"),
291 }
292
293 match df.get_column("active") {
294 Some(Series::Bool(_)) => {} _ => panic!("Active should be inferred as Bool"),
296 }
297
298 let output_file = NamedTempFile::new()?;
300 df.to_csv(output_file.path().to_str().unwrap())?;
301
302 let df2 = DataFrame::from_csv(output_file.path().to_str().unwrap())?;
304 assert_eq!(df2.shape(), df.shape());
305
306 Ok(())
307 }
308
309 #[test]
310 fn test_json_io() -> Result<(), Box<dyn std::error::Error>> {
311 let df = DataFrame::new(vec![
312 ("name".to_string(), Series::from(vec!["Alice", "Bob"])),
313 ("age".to_string(), Series::from(vec![25, 30])),
314 ("active".to_string(), Series::from(vec![true, false])),
315 ]);
316
317 let jsonl_file = NamedTempFile::new()?;
319 df.to_jsonl(jsonl_file.path().to_str().unwrap())?;
320
321 let df_from_jsonl = DataFrame::from_jsonl(jsonl_file.path().to_str().unwrap())?;
322 assert_eq!(df_from_jsonl.shape(), (2, 3));
323
324 let json_file = NamedTempFile::new()?;
326 df.to_json(json_file.path().to_str().unwrap())?;
327
328 Ok(())
329 }
330
331 #[test]
332 fn test_statistical_summary() {
333 let df = DataFrame::new(vec![
334 (
335 "values".to_string(),
336 Series::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]),
337 ),
338 (
339 "integers".to_string(),
340 Series::from(vec![10, 20, 30, 40, 50]),
341 ),
342 (
343 "text".to_string(),
344 Series::from(vec!["a", "b", "c", "d", "e"]),
345 ),
346 ]);
347
348 let stats = df.describe();
349
350 assert_eq!(
352 stats.columns,
353 vec!["count", "mean", "std", "min", "25%", "50%", "75%", "max"]
354 );
355
356 if let Some(Series::Float64(means)) = stats.data.get(1) {
358 assert_eq!(means[0], 3.0); assert_eq!(means[1], 30.0); }
361 }
362
363 #[test]
364 fn test_mathematical_functions() {
365 let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
366
367 let exp_arr = arr.exp();
369 assert!((exp_arr[&[0, 0][..]] - 1.0_f64.exp()).abs() < 1e-10);
370
371 let ln_arr = arr.ln();
372 assert!((ln_arr[(0, 0)] - 1.0_f64.ln()).abs() < 1e-10);
373
374 let sin_arr = arr.sin();
375 assert!((sin_arr[(0, 0)] - 1.0_f64.sin()).abs() < 1e-10);
376
377 let sqrt_arr = arr.sqrt();
378 assert!((sqrt_arr[(0, 0)] - 1.0).abs() < 1e-10);
379
380 let pow_arr = arr.pow(2.0);
381 assert_eq!(pow_arr[(0, 0)], 1.0);
382 assert_eq!(pow_arr[(1, 1)], 16.0);
383 }
384}