use std::ops::Add;
#[cfg(feature = "diagonal_concat")]
use ahash::AHashSet;
use arrow::compute;
use arrow::types::simd::Simd;
use num_traits::{Float, NumCast, ToPrimitive};
#[cfg(feature = "concat_str")]
use polars_arrow::prelude::ValueSize;
use crate::chunked_array::ops::sort::prepare_arg_sort;
use crate::prelude::*;
use crate::utils::coalesce_nulls;
#[cfg(feature = "diagonal_concat")]
use crate::utils::concat_df;
pub fn cov_f<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<T::Native>
where
T: PolarsFloatType,
T::Native: Float,
<T::Native as Simd>::Simd: Add<Output = <T::Native as Simd>::Simd>
+ compute::aggregate::Sum<T::Native>
+ compute::aggregate::SimdOrd<T::Native>,
{
if a.len() != b.len() {
None
} else {
let tmp = (a - a.mean()?) * (b - b.mean()?);
let n = tmp.len() - tmp.null_count();
Some(tmp.sum()? / NumCast::from(n - 1).unwrap())
}
}
pub fn cov_i<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<f64>
where
T: PolarsIntegerType,
T::Native: ToPrimitive,
<T::Native as Simd>::Simd: Add<Output = <T::Native as Simd>::Simd>
+ compute::aggregate::Sum<T::Native>
+ compute::aggregate::SimdOrd<T::Native>,
{
if a.len() != b.len() {
None
} else {
let a_mean = a.mean()?;
let b_mean = b.mean()?;
let a = a.apply_cast_numeric::<_, Float64Type>(|a| a.to_f64().unwrap() - a_mean);
let b = b.apply_cast_numeric(|b| b.to_f64().unwrap() - b_mean);
let tmp = a * b;
let n = tmp.len() - tmp.null_count();
Some(tmp.sum()? / (n - 1) as f64)
}
}
pub fn pearson_corr_i<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>, ddof: u8) -> Option<f64>
where
T: PolarsIntegerType,
T::Native: ToPrimitive,
<T::Native as Simd>::Simd: Add<Output = <T::Native as Simd>::Simd>
+ compute::aggregate::Sum<T::Native>
+ compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: ChunkVar<f64>,
{
let (a, b) = coalesce_nulls(a, b);
let a = a.as_ref();
let b = b.as_ref();
Some(cov_i(a, b)? / (a.std(ddof)? * b.std(ddof)?))
}
pub fn pearson_corr_f<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>, ddof: u8) -> Option<T::Native>
where
T: PolarsFloatType,
T::Native: Float,
<T::Native as Simd>::Simd: Add<Output = <T::Native as Simd>::Simd>
+ compute::aggregate::Sum<T::Native>
+ compute::aggregate::SimdOrd<T::Native>,
ChunkedArray<T>: ChunkVar<T::Native>,
{
let (a, b) = coalesce_nulls(a, b);
let a = a.as_ref();
let b = b.as_ref();
Some(cov_f(a, b)? / (a.std(ddof)? * b.std(ddof)?))
}
pub fn arg_sort_by(by: &[Series], descending: &[bool]) -> PolarsResult<IdxCa> {
polars_ensure!(
by.len() == descending.len(),
ComputeError: "the number of ordering booleans: {} does not match the number of series: {}",
descending.len(), by.len()
);
let (first, by, descending) = prepare_arg_sort(by.to_vec(), descending.to_vec()).unwrap();
first.arg_sort_multiple(&by, &descending)
}
#[cfg(feature = "concat_str")]
enum IterBroadCast<'a> {
Column(Box<dyn PolarsIterator<Item = Option<&'a str>> + 'a>),
Value(Option<&'a str>),
}
#[cfg(feature = "concat_str")]
impl<'a> IterBroadCast<'a> {
fn next(&mut self) -> Option<Option<&'a str>> {
use IterBroadCast::*;
match self {
Column(iter) => iter.next(),
Value(val) => Some(*val),
}
}
}
#[cfg(feature = "concat_str")]
pub fn concat_str(s: &[Series], delimiter: &str) -> PolarsResult<Utf8Chunked> {
polars_ensure!(!s.is_empty(), NoData: "expected multiple series in `concat_str`");
if s.iter().any(|s| s.is_empty()) {
return Ok(Utf8Chunked::full_null(s[0].name(), 0));
}
let len = s.iter().map(|s| s.len()).max().unwrap();
let cas = s
.iter()
.map(|s| {
let s = s.cast(&DataType::Utf8)?;
let mut ca = s.utf8()?.clone();
if ca.len() == 1 && len > 1 {
ca = ca.new_from_index(0, len)
}
Ok(ca)
})
.collect::<PolarsResult<Vec<_>>>()?;
polars_ensure!(
s.iter().all(|s| s.len() == 1 || s.len() == len),
ComputeError: "all series in `concat_str` should have equal or unit length"
);
let mut iters = cas
.iter()
.map(|ca| match ca.len() {
1 => IterBroadCast::Value(ca.get(0)),
_ => IterBroadCast::Column(ca.into_iter()),
})
.collect::<Vec<_>>();
let bytes_cap = cas.iter().map(|ca| ca.get_values_size()).sum();
let mut builder = Utf8ChunkedBuilder::new(s[0].name(), len, bytes_cap);
let mut buf = String::with_capacity(128);
for _ in 0..len {
let mut has_null = false;
iters.iter_mut().enumerate().for_each(|(i, it)| {
if i > 0 {
buf.push_str(delimiter);
}
match it.next() {
Some(Some(s)) => buf.push_str(s),
Some(None) => has_null = true,
None => {
unreachable!()
}
}
});
if has_null {
builder.append_null();
} else {
builder.append_value(&buf)
}
buf.truncate(0)
}
Ok(builder.finish())
}
#[cfg(feature = "horizontal_concat")]
pub fn hor_concat_df(dfs: &[DataFrame]) -> PolarsResult<DataFrame> {
let max_len = dfs
.iter()
.map(|df| df.height())
.max()
.ok_or_else(|| polars_err!(ComputeError: "cannot concat empty dataframes"))?;
let owned_df;
let dfs = if !dfs.iter().all(|df| df.height() == max_len) {
owned_df = dfs
.iter()
.cloned()
.map(|mut df| {
if df.height() != max_len {
let diff = max_len - df.height();
df.columns
.iter_mut()
.for_each(|s| *s = s.extend_constant(AnyValue::Null, diff).unwrap());
}
df
})
.collect::<Vec<_>>();
owned_df.as_slice()
} else {
dfs
};
let mut first_df = dfs[0].clone();
for df in &dfs[1..] {
first_df.hstack_mut(df.get_columns())?;
}
Ok(first_df)
}
#[cfg(feature = "diagonal_concat")]
pub fn diag_concat_df(dfs: &[DataFrame]) -> PolarsResult<DataFrame> {
let upper_bound_width = dfs.iter().map(|df| df.width()).sum();
let mut column_names = AHashSet::with_capacity(upper_bound_width);
let mut schema = Vec::with_capacity(upper_bound_width);
for df in dfs {
df.get_columns().iter().for_each(|s| {
let name = s.name();
if column_names.insert(name) {
schema.push((name, s.dtype()))
}
});
}
let dfs = dfs
.iter()
.map(|df| {
let height = df.height();
let mut columns = Vec::with_capacity(schema.len());
for (name, dtype) in &schema {
match df.column(name).ok() {
Some(s) => columns.push(s.clone()),
None => columns.push(Series::full_null(name, height, dtype)),
}
}
DataFrame::new_no_checks(columns)
})
.collect::<Vec<_>>();
concat_df(&dfs)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_cov() {
let a = Series::new("a", &[1.0f32, 2.0, 5.0]);
let b = Series::new("b", &[1.0f32, 2.0, -3.0]);
let out = cov_f(a.f32().unwrap(), b.f32().unwrap());
assert_eq!(out, Some(-5.0));
let a = a.cast(&DataType::Int32).unwrap();
let b = b.cast(&DataType::Int32).unwrap();
let out = cov_i(a.i32().unwrap(), b.i32().unwrap());
assert_eq!(out, Some(-5.0));
}
#[test]
fn test_pearson_corr() {
let a = Series::new("a", &[1.0f32, 2.0]);
let b = Series::new("b", &[1.0f32, 2.0]);
assert!((cov_f(a.f32().unwrap(), b.f32().unwrap()).unwrap() - 0.5).abs() < 0.001);
assert!(
(pearson_corr_f(a.f32().unwrap(), b.f32().unwrap(), 1).unwrap() - 1.0).abs() < 0.001
);
}
#[test]
#[cfg(feature = "concat_str")]
fn test_concat_str() {
let a = Series::new("a", &["foo", "bar"]);
let b = Series::new("b", &["spam", "ham"]);
let out = concat_str(&[a.clone(), b.clone()], "_").unwrap();
assert_eq!(Vec::from(&out), &[Some("foo_spam"), Some("bar_ham")]);
let c = Series::new("b", &["literal"]);
let out = concat_str(&[a, b, c], "_").unwrap();
assert_eq!(
Vec::from(&out),
&[Some("foo_spam_literal"), Some("bar_ham_literal")]
);
}
#[test]
#[cfg(feature = "diagonal_concat")]
fn test_diag_concat() -> PolarsResult<()> {
let a = df![
"a" => [1, 2],
"b" => ["a", "b"]
]?;
let b = df![
"b" => ["a", "b"],
"c" => [1, 2]
]?;
let c = df![
"a" => [5, 7],
"c" => [1, 2],
"d" => [1, 2]
]?;
let out = diag_concat_df(&[a, b, c])?;
let expected = df![
"a" => [Some(1), Some(2), None, None, Some(5), Some(7)],
"b" => [Some("a"), Some("b"), Some("a"), Some("b"), None, None],
"c" => [None, None, Some(1), Some(2), Some(1), Some(2)],
"d" => [None, None, None, None, Some(1), Some(2)]
]?;
assert!(out.frame_equal_missing(&expected));
Ok(())
}
}