use polars_arrow::prelude::FromData;
#[cfg(feature = "random")]
use rand::prelude::SliceRandom;
#[cfg(feature = "random")]
use rand::{rngs::SmallRng, thread_rng, SeedableRng};
use crate::prelude::*;
#[derive(Copy, Clone)]
pub enum RankMethod {
Average,
Min,
Max,
Dense,
Ordinal,
#[cfg(feature = "random")]
Random,
}
#[derive(Copy, Clone)]
pub struct RankOptions {
pub method: RankMethod,
pub descending: bool,
}
impl Default for RankOptions {
fn default() -> Self {
Self {
method: RankMethod::Dense,
descending: false,
}
}
}
pub(crate) fn rank(s: &Series, method: RankMethod, reverse: bool) -> Series {
match s.len() {
1 => {
return match method {
Average => Series::new(s.name(), &[1.0f32]),
_ => Series::new(s.name(), &[1 as IdxSize]),
};
}
0 => {
return match method {
Average => Float32Chunked::from_slice(s.name(), &[]).into_series(),
_ => IdxCa::from_slice(s.name(), &[]).into_series(),
};
}
_ => {}
}
if s.null_count() > 0 {
let nulls = s.is_not_null().rechunk();
let arr = nulls.downcast_iter().next().unwrap();
let validity = arr.values();
let null_strategy = if reverse {
FillNullStrategy::MinBound
} else {
FillNullStrategy::MaxBound
};
let s = s.fill_null(null_strategy).unwrap();
let mut out = rank(&s, method, reverse);
unsafe {
let arr = &mut out.chunks_mut()[0];
*arr = arr.with_validity(Some(validity.clone()))
}
return out;
}
let len = s.len();
let null_count = s.null_count();
let sort_idx_ca = s.argsort(SortOptions {
descending: reverse,
..Default::default()
});
let sort_idx = sort_idx_ca.downcast_iter().next().unwrap().values();
let mut inv: Vec<IdxSize> = Vec::with_capacity(len);
#[allow(clippy::uninit_vec)]
unsafe {
inv.set_len(len)
}
let inv_values = inv.as_mut_slice();
#[cfg(feature = "random")]
let mut count = if let RankMethod::Ordinal | RankMethod::Random = method {
1 as IdxSize
} else {
0
};
#[cfg(not(feature = "random"))]
let mut count = if let RankMethod::Ordinal = method {
1 as IdxSize
} else {
0
};
unsafe {
sort_idx.iter().for_each(|&i| {
*inv_values.get_unchecked_mut(i as usize) = count;
count += 1;
});
}
use RankMethod::*;
match method {
Ordinal => {
let inv_ca = IdxCa::from_vec(s.name(), inv);
inv_ca.into_series()
}
#[cfg(feature = "random")]
Random => {
let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() };
let not_consecutive_same = arr
.slice(1, len - 1)
.not_equal(&arr.slice(0, len - 1))
.unwrap()
.rechunk();
let obs = not_consecutive_same.downcast_iter().next().unwrap();
let mut ties_indices = Vec::with_capacity(len + 1);
let mut ties_index: usize = 0;
ties_indices.push(ties_index);
obs.iter().for_each(|b| {
if let Some(b) = b {
ties_index += 1;
if b {
ties_indices.push(ties_index)
}
}
});
ties_indices.push(len);
let mut sort_idx = sort_idx.to_vec();
let mut thread_rng = thread_rng();
let rng = &mut SmallRng::from_rng(&mut thread_rng).unwrap();
for i in 0..(ties_indices.len() - 1) {
let ties_index_start = ties_indices[i];
let ties_index_end = ties_indices[i + 1];
if ties_index_end - ties_index_start > 1 {
sort_idx[ties_index_start..ties_index_end].shuffle(rng);
}
}
let mut count = 1 as IdxSize;
unsafe {
sort_idx.iter().for_each(|&i| {
*inv_values.get_unchecked_mut(i as usize) = count;
count += 1;
});
}
let inv_ca = IdxCa::from_vec(s.name(), inv);
inv_ca.into_series()
}
_ => {
let inv_ca = IdxCa::from_vec(s.name(), inv);
let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() };
let validity = arr.chunks()[0].validity().cloned();
let not_consecutive_same = arr
.slice(1, len - 1)
.not_equal(&arr.slice(0, len - 1))
.unwrap()
.rechunk();
let obs = not_consecutive_same.downcast_iter().next().unwrap();
let mut dense = Vec::with_capacity(len);
#[allow(clippy::bool_to_int_with_if)]
let mut cumsum: IdxSize = if let RankMethod::Min = method {
0
} else {
if matches!(method, RankMethod::Dense) && s.null_count() > 0 {
0
} else {
1
}
};
dense.push(cumsum);
obs.values_iter().for_each(|b| {
if b {
cumsum += 1;
}
dense.push(cumsum)
});
let arr = IdxArr::from_data_default(dense.into(), validity);
let dense: IdxCa = (s.name(), arr).into();
let dense = unsafe { dense.take_unchecked((&inv_ca).into()) };
if let RankMethod::Dense = method {
return if s.null_count() == 0 {
dense.into_series()
} else {
let validity = s.is_null().rechunk();
let validity = validity.downcast_iter().next().unwrap();
let validity = validity.values().clone();
let arr = dense.downcast_iter().next().unwrap();
let arr = arr.with_validity(Some(validity));
let dtype = arr.data_type().clone();
unsafe {
Series::try_from_arrow_unchecked(s.name(), vec![arr], &dtype).unwrap()
}
};
}
let bitmap = obs.values();
let cap = bitmap.len() - bitmap.unset_bits();
let mut count = Vec::with_capacity(cap + 1);
let mut cnt: IdxSize = 0;
count.push(cnt);
if null_count > 0 {
obs.iter().for_each(|b| {
if let Some(b) = b {
cnt += 1;
if b {
count.push(cnt)
}
}
});
} else {
obs.values_iter().for_each(|b| {
cnt += 1;
if b {
count.push(cnt)
}
});
}
count.push((len - null_count) as IdxSize);
let count = IdxCa::from_vec(s.name(), count);
match method {
Max => {
unsafe { count.take_unchecked((&dense).into()).into_series() }
}
Min => {
unsafe { (count.take_unchecked((&dense).into()) + 1).into_series() }
}
Average => {
let a = unsafe { count.take_unchecked((&dense).into()) }
.cast(&DataType::Float32)
.unwrap();
let b = unsafe { count.take_unchecked((&(dense - 1)).into()) }
.cast(&DataType::Float32)
.unwrap()
+ 1.0;
(&a + &b) * 0.5
}
#[cfg(feature = "random")]
Dense | Ordinal | Random => unimplemented!(),
#[cfg(not(feature = "random"))]
Dense | Ordinal => unimplemented!(),
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_rank() -> PolarsResult<()> {
let s = Series::new("a", &[1, 2, 3, 2, 2, 3, 0]);
let out = rank(&s, RankMethod::Ordinal, false)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2 as IdxSize, 3, 6, 4, 5, 7, 1]);
#[cfg(feature = "random")]
{
let out = rank(&s, RankMethod::Random, false)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out[0], 2);
assert_eq!(out[6], 1);
assert_eq!(out[1] + out[3] + out[4], 12);
assert_eq!(out[2] + out[5], 13);
assert_ne!(out[1], out[3]);
assert_ne!(out[1], out[4]);
assert_ne!(out[3], out[4]);
}
let out = rank(&s, RankMethod::Dense, false)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 4, 3, 3, 4, 1]);
let out = rank(&s, RankMethod::Max, false)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 5, 7, 5, 5, 7, 1]);
let out = rank(&s, RankMethod::Min, false)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2, 3, 6, 3, 3, 6, 1]);
let out = rank(&s, RankMethod::Average, false)
.f32()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f32, 4.0, 6.5, 4.0, 4.0, 6.5, 1.0]);
let s = Series::new(
"a",
&[Some(1), Some(2), Some(3), Some(2), None, None, Some(0)],
);
let out = rank(&s, RankMethod::Average, false)
.f32()?
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
&[
Some(2.0f32),
Some(3.5),
Some(5.0),
Some(3.5),
None,
None,
Some(1.0)
]
);
let s = Series::new(
"a",
&[
Some(5),
Some(6),
Some(4),
None,
Some(78),
Some(4),
Some(2),
Some(8),
],
);
let out = rank(&s, RankMethod::Max, false)
.idx()?
.into_iter()
.collect::<Vec<_>>();
assert_eq!(
out,
&[
Some(4),
Some(5),
Some(3),
None,
Some(7),
Some(3),
Some(1),
Some(6)
]
);
Ok(())
}
#[test]
fn test_rank_all_null() -> PolarsResult<()> {
let s = UInt32Chunked::new("", &[None, None, None]).into_series();
let out = rank(&s, RankMethod::Average, false)
.f32()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f32, 2.0, 2.0]);
let out = rank(&s, RankMethod::Dense, false)
.idx()?
.into_no_null_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[1, 1, 1]);
Ok(())
}
#[test]
fn test_rank_empty() {
let s = UInt32Chunked::from_slice("", &[]).into_series();
let out = rank(&s, RankMethod::Average, false);
assert_eq!(out.dtype(), &DataType::Float32);
let out = rank(&s, RankMethod::Max, false);
assert_eq!(out.dtype(), &IDX_DTYPE);
}
#[test]
fn test_rank_reverse() -> PolarsResult<()> {
let s = Series::new("", &[None, Some(1), Some(1), Some(5), None]);
let out = rank(&s, RankMethod::Dense, true)
.idx()?
.into_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[None, Some(2 as IdxSize), Some(2), Some(1), None]);
Ok(())
}
}