rustframes/
lib.rs

1//! # RustFrames
2//!
3//! A blazing-fast, memory-safe alternative to NumPy + Pandas, written in Rust.
4//!
5//! RustFrames provides:
6//! - N-dimensional arrays with broadcasting support
7//! - DataFrame operations with groupby, joins, and filtering
8//! - Linear algebra operations (matrix multiplication, decompositions, etc.)
9//! - CSV and JSON I/O with automatic type inference
10//! - Memory-safe operations with zero-cost abstractions
11//!
12//! ## Quick Start
13//!
14//! ### Arrays
15//! ```rust
16//! use rustframes::array::Array;
17//!
18//! // Create a 2D array
19//! let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
20//!
21//! // Element-wise operations with broadcasting
22//! let scalar_mult = &arr * 2.0;
23//!
24//! // Linear algebra
25//! let identity = Array::<f64>::ones(vec![2, 2]);
26//! let product = arr.dot(&identity);
27//!
28//! // Reductions
29//! println!("Sum: {}", arr.sum());
30//! println!("Mean: {}", arr.mean());
31//! ```
32//!
33//! ### DataFrames
34//! ```rust
35//! use rustframes::dataframe::{DataFrame, Series};
36//!
37//! // Create DataFrame
38//! let df = DataFrame::new(vec![
39//!     ("name".to_string(), Series::from(vec!["Alice", "Bob", "Charlie"])),
40//!     ("age".to_string(), Series::from(vec![25, 30, 35])),
41//!     ("score".to_string(), Series::from(vec![85.5, 92.0, 78.5])),
42//! ]);
43//!
44//! // Operations
45//! let filtered = df.filter(&[true, false, true]);
46//! let sorted = df.sort_by("age", true);
47//! let grouped = df.groupby("age").mean();
48//!
49//! // I/O
50//! let df_from_csv = DataFrame::from_csv("tests/data/test.csv")?;
51//! df.to_csv("output.csv")?;
52//! # Ok::<(), Box<dyn std::error::Error>>(())
53
54pub mod array;
55pub mod dataframe;
56
57// Re-export main types for convenience
58pub use array::Array;
59pub use dataframe::core::JoinType;
60pub use dataframe::{DataFrame, Series};
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use std::io::Write;
66    use tempfile::NamedTempFile;
67
68    #[test]
69    fn test_array_n_dimensional() {
70        // Test N-dimensional array creation and indexing
71        let arr = Array::from_vec((0..24).map(|x| x as f64).collect(), vec![2, 3, 4]);
72        assert_eq!(arr.shape, vec![2, 3, 4]);
73        assert_eq!(arr.ndim(), 3);
74        assert_eq!(arr[&[0, 0, 0][..]], 0.0);
75        assert_eq!(arr[&[1, 2, 3][..]], 23.0);
76
77        // Test reshape
78        let reshaped = arr.reshape(vec![6, 4]);
79        assert_eq!(reshaped.shape, vec![6, 4]);
80        assert_eq!(reshaped.data.len(), 24);
81    }
82
83    #[test]
84    fn test_array_broadcasting() {
85        // Test broadcasting with different shapes
86        let arr1 = Array::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
87        let arr2 = Array::from_vec(vec![10.0, 20.0], vec![2, 1]);
88
89        // This should broadcast to shape [2, 3]
90        if let Some(result) = arr1.add_broadcast(&arr2) {
91            assert_eq!(result.shape, vec![2, 3]);
92            assert_eq!(result[&[0, 0][..]], 11.0); // 1 + 10
93            assert_eq!(result[&[1, 2][..]], 23.0); // 3 + 20
94        }
95
96        // Test scalar operations
97        let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
98        let scaled = &arr + 5.0;
99        assert_eq!(scaled.data, vec![6.0, 7.0, 8.0, 9.0]);
100    }
101
102    #[test]
103    fn test_array_reductions() {
104        let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
105
106        // Test basic reductions
107        assert_eq!(arr.sum(), 21.0);
108        assert_eq!(arr.mean(), 3.5);
109        assert_eq!(arr.max(), 6.0);
110        assert_eq!(arr.min(), 1.0);
111
112        // Test axis reductions
113        let sum_axis_0 = arr.sum_axis(0);
114        assert_eq!(sum_axis_0.shape, vec![3]);
115        assert_eq!(sum_axis_0.data, vec![5.0, 7.0, 9.0]); // [1+4, 2+5, 3+6]
116
117        let mean_axis_1 = arr.mean_axis(1);
118        assert_eq!(mean_axis_1.shape, vec![2]);
119        assert_eq!(mean_axis_1.data, vec![2.0, 5.0]); // [(1+2+3)/3, (4+5+6)/3]
120    }
121
122    #[test]
123    fn test_linear_algebra() {
124        let matrix = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
125
126        // Test determinant
127        assert_eq!(matrix.det(), -2.0);
128
129        // Test trace
130        assert_eq!(matrix.trace(), 5.0);
131
132        // Test matrix multiplication
133        let other = Array::from_vec(vec![2.0, 0.0, 1.0, 3.0], vec![2, 2]);
134        let product = matrix.dot(&other);
135        assert_eq!(product.data, vec![4.0, 6.0, 10.0, 12.0]);
136
137        // Test matrix inverse (for 2x2)
138        if let Some(inv) = matrix.inv() {
139            // matrix * inv should be approximately identity
140            let should_be_identity = matrix.dot(&inv);
141            assert!((should_be_identity[(0, 0)] - 1.0).abs() < 1e-10);
142            assert!((should_be_identity[(1, 1)] - 1.0).abs() < 1e-10);
143            assert!(should_be_identity[(0, 1)].abs() < 1e-10);
144            assert!(should_be_identity[(1, 0)].abs() < 1e-10);
145        }
146
147        // Test QR decomposition
148        let (q, r) = matrix.qr();
149        assert_eq!(q.shape, vec![2, 2]);
150        assert_eq!(r.shape, vec![2, 2]);
151
152        // Q should be orthogonal (Q * Q^T = I)
153        let qt = q.transpose();
154        let should_be_identity = q.dot(&qt);
155        assert!((should_be_identity[(0, 0)] - 1.0).abs() < 1e-10);
156        assert!((should_be_identity[(1, 1)] - 1.0).abs() < 1e-10);
157    }
158
159    #[test]
160    fn test_dataframe_enhanced() {
161        let df = DataFrame::new(vec![
162            ("id".to_string(), Series::from(vec![1, 2, 3, 4])),
163            (
164                "name".to_string(),
165                Series::from(vec!["Alice", "Bob", "Charlie", "Diana"]),
166            ),
167            (
168                "score".to_string(),
169                Series::from(vec![85.5, 92.0, 78.5, 88.0]),
170            ),
171            (
172                "active".to_string(),
173                Series::from(vec![true, true, false, true]),
174            ),
175        ]);
176
177        // Test shape
178        assert_eq!(df.shape(), (4, 4));
179        assert_eq!(df.len(), 4);
180        assert!(!df.is_empty());
181
182        // Test head/tail
183        let head = df.head(2);
184        assert_eq!(head.len(), 2);
185
186        let tail = df.tail(2);
187        assert_eq!(tail.len(), 2);
188
189        // Test filtering
190        let mask = vec![true, false, true, false];
191        let filtered = df.filter(&mask);
192        assert_eq!(filtered.len(), 2);
193
194        // Test sorting
195        let sorted = df.sort_by("score", true); // ascending
196        if let Some(Series::Float64(scores)) = sorted.get_column("score") {
197            assert!(scores[0] < scores[1]); // Should be sorted ascending
198        }
199
200        // Test column operations
201        let with_bonus = df.with_column(
202            "bonus".to_string(),
203            Series::from(vec![100.0, 150.0, 75.0, 120.0]),
204        );
205        assert_eq!(with_bonus.shape().1, 5); // One more column
206
207        let dropped = df.drop(&["active"]);
208        assert_eq!(dropped.shape().1, 3); // One less column
209    }
210
211    #[test]
212    fn test_groupby_enhanced() {
213        let df = DataFrame::new(vec![
214            (
215                "department".to_string(),
216                Series::from(vec!["IT", "HR", "IT", "Finance", "HR"]),
217            ),
218            (
219                "salary".to_string(),
220                Series::from(vec![75000, 65000, 80000, 70000, 68000]),
221            ),
222            ("experience".to_string(), Series::from(vec![3, 5, 7, 4, 6])),
223        ]);
224
225        let grouped = df.groupby("department");
226
227        // Test count
228        let counts = grouped.count();
229        assert_eq!(counts.len(), 3); // 3 unique departments
230
231        // Test sum
232        let sums = grouped.sum();
233        assert_eq!(sums.columns.len(), 3); // department + 2 numeric columns
234
235        // Test mean
236        let means = grouped.mean();
237        assert_eq!(means.columns.len(), 3);
238
239        // Test first/last
240        let first = grouped.first();
241        assert_eq!(first.len(), 3);
242
243        let last = grouped.last();
244        assert_eq!(last.len(), 3);
245    }
246
247    #[test]
248    fn test_joins() {
249        let left = DataFrame::new(vec![
250            ("id".to_string(), Series::from(vec!["1", "2", "3"])),
251            (
252                "name".to_string(),
253                Series::from(vec!["Alice", "Bob", "Charlie"]),
254            ),
255        ]);
256
257        let right = DataFrame::new(vec![
258            ("id".to_string(), Series::from(vec!["1", "2", "4"])),
259            ("score".to_string(), Series::from(vec!["85", "92", "78"])),
260        ]);
261
262        // Inner join should match on ids 1 and 2
263        let joined = left.join(&right, "id", JoinType::Inner);
264        assert_eq!(joined.len(), 2);
265        assert_eq!(joined.columns.len(), 3); // id, name, score
266    }
267
268    #[test]
269    fn test_csv_io_with_inference() -> Result<(), Box<dyn std::error::Error>> {
270        // Create temporary CSV with mixed types
271        let mut temp_file = NamedTempFile::new()?;
272        writeln!(temp_file, "name,age,salary,active")?;
273        writeln!(temp_file, "Alice,25,50000.5,true")?;
274        writeln!(temp_file, "Bob,30,60000.0,false")?;
275        writeln!(temp_file, "Charlie,35,70000.25,true")?;
276
277        let df = DataFrame::from_csv(temp_file.path().to_str().unwrap())?;
278
279        // Check shape
280        assert_eq!(df.shape(), (3, 4));
281
282        // Check type inference
283        match df.get_column("age") {
284            Some(Series::Int64(_)) => {} // Expected
285            _ => panic!("Age should be inferred as Int64"),
286        }
287
288        match df.get_column("salary") {
289            Some(Series::Float64(_)) => {} // Expected
290            _ => panic!("Salary should be inferred as Float64"),
291        }
292
293        match df.get_column("active") {
294            Some(Series::Bool(_)) => {} // Expected
295            _ => panic!("Active should be inferred as Bool"),
296        }
297
298        // Test writing back to CSV
299        let output_file = NamedTempFile::new()?;
300        df.to_csv(output_file.path().to_str().unwrap())?;
301
302        // Read it back
303        let df2 = DataFrame::from_csv(output_file.path().to_str().unwrap())?;
304        assert_eq!(df2.shape(), df.shape());
305
306        Ok(())
307    }
308
309    #[test]
310    fn test_json_io() -> Result<(), Box<dyn std::error::Error>> {
311        let df = DataFrame::new(vec![
312            ("name".to_string(), Series::from(vec!["Alice", "Bob"])),
313            ("age".to_string(), Series::from(vec![25, 30])),
314            ("active".to_string(), Series::from(vec![true, false])),
315        ]);
316
317        // Test JSON Lines format
318        let jsonl_file = NamedTempFile::new()?;
319        df.to_jsonl(jsonl_file.path().to_str().unwrap())?;
320
321        let df_from_jsonl = DataFrame::from_jsonl(jsonl_file.path().to_str().unwrap())?;
322        assert_eq!(df_from_jsonl.shape(), (2, 3));
323
324        // Test regular JSON format
325        let json_file = NamedTempFile::new()?;
326        df.to_json(json_file.path().to_str().unwrap())?;
327
328        Ok(())
329    }
330
331    #[test]
332    fn test_statistical_summary() {
333        let df = DataFrame::new(vec![
334            (
335                "values".to_string(),
336                Series::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]),
337            ),
338            (
339                "integers".to_string(),
340                Series::from(vec![10, 20, 30, 40, 50]),
341            ),
342            (
343                "text".to_string(),
344                Series::from(vec!["a", "b", "c", "d", "e"]),
345            ),
346        ]);
347
348        let stats = df.describe();
349
350        // Should have statistics for numeric columns only
351        assert_eq!(
352            stats.columns,
353            vec!["count", "mean", "std", "min", "25%", "50%", "75%", "max"]
354        );
355
356        // Check some basic statistics
357        if let Some(Series::Float64(means)) = stats.data.get(1) {
358            assert_eq!(means[0], 3.0); // Mean of [1,2,3,4,5]
359            assert_eq!(means[1], 30.0); // Mean of [10,20,30,40,50]
360        }
361    }
362
363    #[test]
364    fn test_mathematical_functions() {
365        let arr = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
366
367        // Test element-wise functions
368        let exp_arr = arr.exp();
369        assert!((exp_arr[&[0, 0][..]] - 1.0_f64.exp()).abs() < 1e-10);
370
371        let ln_arr = arr.ln();
372        assert!((ln_arr[(0, 0)] - 1.0_f64.ln()).abs() < 1e-10);
373
374        let sin_arr = arr.sin();
375        assert!((sin_arr[(0, 0)] - 1.0_f64.sin()).abs() < 1e-10);
376
377        let sqrt_arr = arr.sqrt();
378        assert!((sqrt_arr[(0, 0)] - 1.0).abs() < 1e-10);
379
380        let pow_arr = arr.pow(2.0);
381        assert_eq!(pow_arr[(0, 0)], 1.0);
382        assert_eq!(pow_arr[(1, 1)], 16.0);
383    }
384}