1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use crate::prelude::compare_inner::PartialOrdInner;
use crate::prelude::{
    CategoricalChunked, IntoTakeRandom, NumTakeRandomChunked, NumTakeRandomCont,
    NumTakeRandomSingleChunk, PlHashMap, RevMapping, TakeRandBranch3, TakeRandom,
};
use arrow::array::Utf8Array;
use std::cmp::Ordering;

type TakeCats<'a> = TakeRandBranch3<
    NumTakeRandomCont<'a, u32>,
    NumTakeRandomSingleChunk<'a, u32>,
    NumTakeRandomChunked<'a, u32>,
>;

pub(crate) struct CategoricalTakeRandomLocal<'a> {
    rev_map: &'a Utf8Array<i64>,
    cats: TakeCats<'a>,
}

impl<'a> CategoricalTakeRandomLocal<'a> {
    pub(crate) fn new(ca: &'a CategoricalChunked) -> Self {
        // should be rechunked upstream
        assert_eq!(ca.logical.chunks.len(), 1, "implementation error");
        if let RevMapping::Local(rev_map) = &**ca.get_rev_map() {
            let cats = ca.logical().take_rand();
            Self { rev_map, cats }
        } else {
            unreachable!()
        }
    }
}

impl PartialOrdInner for CategoricalTakeRandomLocal<'_> {
    unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
        let a = self
            .cats
            .get_unchecked(idx_a)
            .map(|cat| self.rev_map.value_unchecked(cat as usize));
        let b = self
            .cats
            .get_unchecked(idx_b)
            .map(|cat| self.rev_map.value_unchecked(cat as usize));
        a.partial_cmp(&b).unwrap()
    }
}

pub(crate) struct CategoricalTakeRandomGlobal<'a> {
    rev_map_part_1: &'a PlHashMap<u32, u32>,
    rev_map_part_2: &'a Utf8Array<i64>,
    cats: TakeCats<'a>,
}
impl<'a> CategoricalTakeRandomGlobal<'a> {
    pub(crate) fn new(ca: &'a CategoricalChunked) -> Self {
        // should be rechunked upstream
        assert_eq!(ca.logical.chunks.len(), 1, "implementation error");
        if let RevMapping::Global(rev_map_part_1, rev_map_part_2, _) = &**ca.get_rev_map() {
            let cats = ca.logical().take_rand();
            Self {
                rev_map_part_1,
                rev_map_part_2,
                cats,
            }
        } else {
            unreachable!()
        }
    }
}

impl PartialOrdInner for CategoricalTakeRandomGlobal<'_> {
    unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
        let a = self.cats.get_unchecked(idx_a).map(|cat| {
            let idx = self.rev_map_part_1.get(&cat).unwrap();
            self.rev_map_part_2.value_unchecked(*idx as usize)
        });
        let b = self.cats.get_unchecked(idx_b).map(|cat| {
            let idx = self.rev_map_part_1.get(&cat).unwrap();
            self.rev_map_part_2.value_unchecked(*idx as usize)
        });
        a.partial_cmp(&b).unwrap()
    }
}