Skip to main content

nexcore_dataframe/
counter.rs

1//! Counter: optimized hash-based group-count for the FAERS pipeline.
2//!
3//! Replaces the pattern: `df.lazy().group_by([cols]).agg([col.count()])` with a
4//! direct `HashMap<Vec<String>, u64>` accumulation. For 20-50M rows where the
5//! only aggregation is count, this avoids building an intermediate DataFrame.
6
7// Counter is purpose-built for O(1) key lookup on 20-50M FAERS rows.
8// BTreeMap would increase insert/lookup from O(1) to O(log n) — unacceptable at this scale.
9// Iteration order is explicitly unspecified: callers sort the resulting DataFrame if needed.
10#[allow(
11    clippy::disallowed_types,
12    reason = "HashMap required for O(1) amortized insert/lookup at FAERS scale (20-50M rows); BTreeMap O(log n) cost is prohibitive here"
13)]
14use std::collections::HashMap;
15
16use crate::column::Column;
17use crate::dataframe::DataFrame;
18use crate::error::DataFrameError;
19
20/// Hash-based counter: accumulates `(key_tuple → count)` without building
21/// an intermediate DataFrame. Purpose-built for FAERS drug×event counting.
22#[derive(Debug, Clone)]
23pub struct Counter {
24    /// Column names that form the composite key.
25    key_names: Vec<String>,
26    /// Accumulated counts per unique key combination.
27    #[allow(
28        clippy::disallowed_types,
29        reason = "HashMap required for O(1) amortized insert/lookup at FAERS scale (20-50M rows); BTreeMap O(log n) cost is prohibitive here"
30    )]
31    counts: HashMap<Vec<String>, u64>,
32}
33
34impl Counter {
35    /// Create a new counter for the given key columns.
36    #[must_use]
37    pub fn new(key_names: Vec<String>) -> Self {
38        Self {
39            key_names,
40            #[allow(
41                clippy::disallowed_types,
42                reason = "HashMap::new() for the counts field; see field-level allow"
43            )]
44            counts: HashMap::new(),
45        }
46    }
47
48    /// Increment the count for a key combination.
49    pub fn increment(&mut self, key: Vec<String>) {
50        // count starts at 0 and is incremented once per call; overflow at u64::MAX (>1.8×10^19)
51        // is not a realistic concern for any DataFrame workload
52        #[allow(
53            clippy::arithmetic_side_effects,
54            reason = "u64 counter incremented by 1; realistic row counts are far below u64::MAX"
55        )]
56        {
57            *self.counts.entry(key).or_insert(0) += 1;
58        }
59    }
60
61    /// Increment by a specific amount.
62    pub fn increment_by(&mut self, key: Vec<String>, n: u64) {
63        // Same reasoning: accumulating u64 counts from finite data cannot realistically overflow
64        #[allow(
65            clippy::arithmetic_side_effects,
66            reason = "u64 accumulator; sum of row counts bounded by total dataset size which is far below u64::MAX"
67        )]
68        {
69            *self.counts.entry(key).or_insert(0) += n;
70        }
71    }
72
73    /// Number of unique key combinations.
74    #[must_use]
75    pub fn len(&self) -> usize {
76        self.counts.len()
77    }
78
79    /// Whether no counts have been accumulated.
80    #[must_use]
81    pub fn is_empty(&self) -> bool {
82        self.counts.is_empty()
83    }
84
85    /// Get the count for a specific key combination.
86    #[must_use]
87    pub fn get(&self, key: &[String]) -> u64 {
88        self.counts.get(key).copied().unwrap_or(0)
89    }
90
91    /// Iterate over all (key, count) pairs.
92    pub fn iter(&self) -> impl Iterator<Item = (&Vec<String>, &u64)> {
93        self.counts.iter()
94    }
95
96    /// Total count across all keys.
97    #[must_use]
98    pub fn total(&self) -> u64 {
99        self.counts.values().sum()
100    }
101
102    /// Convert the counter into a DataFrame with key columns + a "count" column.
103    ///
104    /// Each unique key combination becomes a row. The last column is always "count".
105    pub fn into_dataframe(self) -> Result<DataFrame, DataFrameError> {
106        let n_keys = self.key_names.len();
107        let n_rows = self.counts.len();
108
109        // Pre-allocate column vecs
110        let mut key_vecs: Vec<Vec<Option<String>>> =
111            (0..n_keys).map(|_| Vec::with_capacity(n_rows)).collect();
112        let mut count_vec: Vec<Option<u64>> = Vec::with_capacity(n_rows);
113
114        // Iteration order of HashMap is unspecified but consistent within a single call;
115        // the resulting DataFrame rows are not required to be in any particular order.
116        #[allow(
117            clippy::iter_over_hash_type,
118            reason = "HashMap iteration builds parallel column vecs; output row order is explicitly unspecified — callers sort if order matters"
119        )]
120        for (key, count) in &self.counts {
121            for (i, val) in key.iter().enumerate() {
122                if i < n_keys {
123                    // i < n_keys <= key_vecs.len() by construction above
124                    #[allow(
125                        clippy::indexing_slicing,
126                        reason = "i is bounded by n_keys = key_vecs.len(); the guard i < n_keys ensures the index is valid"
127                    )]
128                    key_vecs[i].push(Some(val.clone()));
129                }
130            }
131            count_vec.push(Some(*count));
132        }
133
134        let mut columns: Vec<Column> = key_vecs
135            .into_iter()
136            .enumerate()
137            .map(|(i, data)| {
138                // i < n_keys = key_names.len() because key_vecs was built with n_keys elements
139                #[allow(clippy::indexing_slicing, reason = "i iterates over 0..n_keys which equals key_names.len(); index is always valid")]
140                Column::new_string(self.key_names[i].clone(), data)
141            })
142            .collect();
143        columns.push(Column::new_u64("count", count_vec));
144
145        DataFrame::new(columns)
146    }
147
148    /// Build a counter from a DataFrame by counting rows grouped by specified columns.
149    pub fn from_dataframe(df: &DataFrame, group_cols: &[&str]) -> Result<Self, DataFrameError> {
150        // Validate columns exist — propagate error if not found
151        for name in group_cols {
152            df.column(name)?;
153        }
154
155        let key_names: Vec<String> = group_cols.iter().map(|s| (*s).to_string()).collect();
156        let mut counter = Self::new(key_names);
157
158        for row_idx in 0..df.height() {
159            let key: Vec<String> = group_cols
160                .iter()
161                .map(|name| {
162                    df.column(name)
163                        .ok()
164                        .and_then(|col| col.get(row_idx))
165                        .map_or_else(|| "null".to_string(), |s| s.to_string())
166                })
167                .collect();
168            counter.increment(key);
169        }
170
171        Ok(counter)
172    }
173
174    /// Filter the counter, keeping only entries where count >= min_count.
175    #[must_use]
176    pub fn filter_min_count(&self, min_count: u64) -> Self {
177        Self {
178            key_names: self.key_names.clone(),
179            counts: self
180                .counts
181                .iter()
182                .filter(|&(_, &count)| count >= min_count)
183                .map(|(k, v)| (k.clone(), *v))
184                .collect(),
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn counter_basic() {
195        let mut c = Counter::new(vec!["drug".into(), "event".into()]);
196        c.increment(vec!["aspirin".into(), "headache".into()]);
197        c.increment(vec!["aspirin".into(), "headache".into()]);
198        c.increment(vec!["aspirin".into(), "nausea".into()]);
199
200        assert_eq!(c.len(), 2);
201        assert_eq!(c.get(&["aspirin".to_string(), "headache".to_string()]), 2);
202        assert_eq!(c.get(&["aspirin".to_string(), "nausea".to_string()]), 1);
203        assert_eq!(c.total(), 3);
204    }
205
206    #[test]
207    fn counter_into_dataframe() {
208        let mut c = Counter::new(vec!["drug".into()]);
209        c.increment(vec!["asp".into()]);
210        c.increment(vec!["asp".into()]);
211        c.increment(vec!["met".into()]);
212
213        let df = c.into_dataframe().unwrap_or_else(|_| unreachable!());
214        assert_eq!(df.height(), 2);
215        assert_eq!(df.width(), 2); // drug + count
216        assert!(df.column("drug").is_ok());
217        assert!(df.column("count").is_ok());
218    }
219
220    #[test]
221    fn counter_from_dataframe() {
222        let df = DataFrame::new(vec![
223            Column::from_strs("drug", &["asp", "met", "asp", "asp", "met"]),
224            Column::from_strs("event", &["ha", "na", "ha", "di", "na"]),
225        ])
226        .unwrap_or_else(|_| unreachable!());
227
228        let c = Counter::from_dataframe(&df, &["drug", "event"]).unwrap_or_else(|_| unreachable!());
229        assert_eq!(c.len(), 3); // asp+ha, met+na, asp+di
230        assert_eq!(c.get(&["asp".to_string(), "ha".to_string()]), 2);
231        assert_eq!(c.total(), 5);
232    }
233
234    #[test]
235    fn counter_filter_min_count() {
236        let mut c = Counter::new(vec!["x".into()]);
237        c.increment(vec!["a".into()]);
238        c.increment(vec!["b".into()]);
239        c.increment(vec!["b".into()]);
240        c.increment(vec!["b".into()]);
241
242        let filtered = c.filter_min_count(2);
243        assert_eq!(filtered.len(), 1);
244        assert_eq!(filtered.get(&["b".to_string()]), 3);
245    }
246
247    #[test]
248    fn counter_from_dataframe_missing_column() {
249        let df = DataFrame::new(vec![Column::from_i64s("x", vec![1])])
250            .unwrap_or_else(|_| unreachable!());
251        assert!(Counter::from_dataframe(&df, &["missing"]).is_err());
252    }
253
254    #[test]
255    fn counter_empty() {
256        let c = Counter::new(vec!["k".into()]);
257        assert!(c.is_empty());
258        assert_eq!(c.len(), 0);
259        assert_eq!(c.total(), 0);
260    }
261}