Skip to main content

polars_core/frame/
arithmetic.rs

1use std::ops::{Add, Div, Mul, Rem, Sub};
2
3use rayon::prelude::*;
4
5use crate::POOL;
6use crate::prelude::*;
7use crate::utils::try_get_supertype;
8
9/// Get the supertype that is valid for all columns in the [`DataFrame`].
10/// This reduces casting of the rhs in arithmetic.
11fn get_supertype_all(df: &DataFrame, rhs: &Series) -> PolarsResult<DataType> {
12    df.columns().iter().try_fold(rhs.dtype().clone(), |dt, s| {
13        try_get_supertype(s.dtype(), &dt)
14    })
15}
16
17macro_rules! impl_arithmetic {
18    ($self:expr, $rhs:expr, $operand:expr) => {{
19        let st = get_supertype_all($self, $rhs)?;
20        let rhs = $rhs.cast(&st)?;
21        let cols = $self.try_apply_columns_par(|c| {
22            let s = c.as_materialized_series();
23            $operand(&s.cast(&st)?, &rhs).map(Column::from)
24        })?;
25        Ok(unsafe { DataFrame::new_unchecked($self.height(), cols) })
26    }};
27}
28
29impl Add<&Series> for &DataFrame {
30    type Output = PolarsResult<DataFrame>;
31
32    fn add(self, rhs: &Series) -> Self::Output {
33        impl_arithmetic!(self, rhs, std::ops::Add::add)
34    }
35}
36
37impl Add<&Series> for DataFrame {
38    type Output = PolarsResult<DataFrame>;
39
40    fn add(self, rhs: &Series) -> Self::Output {
41        (&self).add(rhs)
42    }
43}
44
45impl Sub<&Series> for &DataFrame {
46    type Output = PolarsResult<DataFrame>;
47
48    fn sub(self, rhs: &Series) -> Self::Output {
49        impl_arithmetic!(self, rhs, std::ops::Sub::sub)
50    }
51}
52
53impl Sub<&Series> for DataFrame {
54    type Output = PolarsResult<DataFrame>;
55
56    fn sub(self, rhs: &Series) -> Self::Output {
57        (&self).sub(rhs)
58    }
59}
60
61impl Mul<&Series> for &DataFrame {
62    type Output = PolarsResult<DataFrame>;
63
64    fn mul(self, rhs: &Series) -> Self::Output {
65        impl_arithmetic!(self, rhs, std::ops::Mul::mul)
66    }
67}
68
69impl Mul<&Series> for DataFrame {
70    type Output = PolarsResult<DataFrame>;
71
72    fn mul(self, rhs: &Series) -> Self::Output {
73        (&self).mul(rhs)
74    }
75}
76
77impl Div<&Series> for &DataFrame {
78    type Output = PolarsResult<DataFrame>;
79
80    fn div(self, rhs: &Series) -> Self::Output {
81        impl_arithmetic!(self, rhs, std::ops::Div::div)
82    }
83}
84
85impl Div<&Series> for DataFrame {
86    type Output = PolarsResult<DataFrame>;
87
88    fn div(self, rhs: &Series) -> Self::Output {
89        (&self).div(rhs)
90    }
91}
92
93impl Rem<&Series> for &DataFrame {
94    type Output = PolarsResult<DataFrame>;
95
96    fn rem(self, rhs: &Series) -> Self::Output {
97        impl_arithmetic!(self, rhs, std::ops::Rem::rem)
98    }
99}
100
101impl Rem<&Series> for DataFrame {
102    type Output = PolarsResult<DataFrame>;
103
104    fn rem(self, rhs: &Series) -> Self::Output {
105        (&self).rem(rhs)
106    }
107}
108
109impl DataFrame {
110    fn binary_aligned(
111        &self,
112        other: &DataFrame,
113        f: &(dyn Fn(&Series, &Series) -> PolarsResult<Series> + Sync + Send),
114    ) -> PolarsResult<DataFrame> {
115        let max_len = std::cmp::max(self.height(), other.height());
116        let max_width = std::cmp::max(self.width(), other.width());
117        let cols = self
118            .columns()
119            .par_iter()
120            .zip(other.columns().par_iter())
121            .map(|(l, r)| {
122                let l = l.as_materialized_series();
123                let r = r.as_materialized_series();
124
125                let diff_l = max_len - l.len();
126                let diff_r = max_len - r.len();
127
128                let st = try_get_supertype(l.dtype(), r.dtype())?;
129                let mut l = l.cast(&st)?;
130                let mut r = r.cast(&st)?;
131
132                if diff_l > 0 {
133                    l = l.extend_constant(AnyValue::Null, diff_l)?;
134                };
135                if diff_r > 0 {
136                    r = r.extend_constant(AnyValue::Null, diff_r)?;
137                };
138
139                f(&l, &r).map(Column::from)
140            });
141        let mut cols = POOL.install(|| cols.collect::<PolarsResult<Vec<_>>>())?;
142
143        let col_len = cols.len();
144        if col_len < max_width {
145            let df = if col_len < self.width() { self } else { other };
146
147            for i in col_len..max_len {
148                let s = &df.columns().get(i).ok_or_else(|| polars_err!(InvalidOperation: "cannot do arithmetic on DataFrames with shapes: {:?} and {:?}", self.shape(), other.shape()))?;
149                let name = s.name();
150                let dtype = s.dtype();
151
152                // trick to fill a series with nulls
153                let vals: &[Option<i32>] = &[None];
154                let s = Series::new(name.clone(), vals).cast(dtype)?;
155                cols.push(s.new_from_index(0, max_len).into())
156            }
157        }
158
159        DataFrame::new_infer_height(cols)
160    }
161}
162
163impl Add<&DataFrame> for &DataFrame {
164    type Output = PolarsResult<DataFrame>;
165
166    fn add(self, rhs: &DataFrame) -> Self::Output {
167        self.binary_aligned(rhs, &|a, b| a + b)
168    }
169}
170
171impl Sub<&DataFrame> for &DataFrame {
172    type Output = PolarsResult<DataFrame>;
173
174    fn sub(self, rhs: &DataFrame) -> Self::Output {
175        self.binary_aligned(rhs, &|a, b| a - b)
176    }
177}
178
179impl Div<&DataFrame> for &DataFrame {
180    type Output = PolarsResult<DataFrame>;
181
182    fn div(self, rhs: &DataFrame) -> Self::Output {
183        self.binary_aligned(rhs, &|a, b| a / b)
184    }
185}
186
187impl Mul<&DataFrame> for &DataFrame {
188    type Output = PolarsResult<DataFrame>;
189
190    fn mul(self, rhs: &DataFrame) -> Self::Output {
191        self.binary_aligned(rhs, &|a, b| a * b)
192    }
193}
194
195impl Rem<&DataFrame> for &DataFrame {
196    type Output = PolarsResult<DataFrame>;
197
198    fn rem(self, rhs: &DataFrame) -> Self::Output {
199        self.binary_aligned(rhs, &|a, b| a % b)
200    }
201}