Skip to main content

aprender/mining/
mod.rs

1//! Pattern mining algorithms for association rule discovery.
2//!
3//! This module provides algorithms for discovering patterns in transactional data,
4//! particularly association rules used in market basket analysis.
5//!
6//! # Algorithms
7//!
8//! - [`Apriori`]: Frequent itemset mining and association rule generation
9//!
10//! # Example
11//!
12//! ```
13//! use aprender::mining::Apriori;
14//!
15//! // Market basket transactions (each transaction is a set of item IDs)
16//! let transactions = vec![
17//!     vec![1, 2, 3],    // Transaction 1: items 1, 2, 3
18//!     vec![1, 2],       // Transaction 2: items 1, 2
19//!     vec![1, 3],       // Transaction 3: items 1, 3
20//!     vec![2, 3],       // Transaction 4: items 2, 3
21//! ];
22//!
23//! // Find frequent itemsets with minimum support 0.5 (50%)
24//! let mut apriori = Apriori::new()
25//!     .with_min_support(0.5)
26//!     .with_min_confidence(0.7);
27//!
28//! apriori.fit(&transactions);
29//!
30//! // Get association rules
31//! let rules = apriori.get_rules();
32//! for rule in rules {
33//!     println!("{:?} => {:?} (conf={:.2}, lift={:.2})",
34//!         rule.antecedent, rule.consequent, rule.confidence, rule.lift);
35//! }
36//! ```
37
38use std::collections::HashSet;
39
40/// Association rule: antecedent => consequent
41#[derive(Debug, Clone, PartialEq)]
42pub struct AssociationRule {
43    /// Items in the antecedent (left side)
44    pub antecedent: Vec<usize>,
45    /// Items in the consequent (right side)
46    pub consequent: Vec<usize>,
47    /// Support: P(antecedent ∪ consequent)
48    pub support: f64,
49    /// Confidence: P(consequent | antecedent) = support / P(antecedent)
50    pub confidence: f64,
51    /// Lift: confidence / P(consequent)
52    pub lift: f64,
53}
54
55/// Apriori algorithm for frequent itemset mining and association rule generation.
56///
57/// The Apriori algorithm discovers frequent itemsets in transactional data
58/// and generates association rules based on support and confidence thresholds.
59///
60/// # Algorithm
61///
62/// 1. Find frequent 1-itemsets (support >= `min_support`)
63/// 2. Generate candidate k-itemsets from frequent (k-1)-itemsets
64/// 3. Prune candidates that don't meet minimum support
65/// 4. Repeat until no more frequent itemsets can be generated
66/// 5. Generate association rules from frequent itemsets
67/// 6. Filter rules by minimum confidence
68///
69/// # Parameters
70///
71/// - `min_support`: Minimum support threshold (0.0 to 1.0)
72/// - `min_confidence`: Minimum confidence threshold (0.0 to 1.0)
73///
74/// # Example
75///
76/// ```
77/// use aprender::mining::Apriori;
78///
79/// let transactions = vec![
80///     vec![1, 2, 3],
81///     vec![1, 2],
82///     vec![1, 3],
83///     vec![2, 3],
84/// ];
85///
86/// let mut apriori = Apriori::new()
87///     .with_min_support(0.5)
88///     .with_min_confidence(0.7);
89///
90/// apriori.fit(&transactions);
91/// let rules = apriori.get_rules();
92/// ```
93#[derive(Debug, Clone)]
94pub struct Apriori {
95    min_support: f64,
96    min_confidence: f64,
97    frequent_itemsets: Vec<(HashSet<usize>, f64)>, // (itemset, support)
98    rules: Vec<AssociationRule>,
99}
100
101impl Apriori {
102    /// Create a new Apriori instance with default parameters.
103    ///
104    /// # Default Parameters
105    ///
106    /// - `min_support`: 0.1 (10%)
107    /// - `min_confidence`: 0.5 (50%)
108    #[must_use]
109    pub fn new() -> Self {
110        Self {
111            min_support: 0.1,
112            min_confidence: 0.5,
113            frequent_itemsets: Vec::new(),
114            rules: Vec::new(),
115        }
116    }
117
118    /// Set the minimum support threshold.
119    ///
120    /// # Arguments
121    ///
122    /// * `min_support` - Minimum support (0.0 to 1.0)
123    #[must_use]
124    pub fn with_min_support(mut self, min_support: f64) -> Self {
125        self.min_support = min_support;
126        self
127    }
128
129    /// Set the minimum confidence threshold.
130    ///
131    /// # Arguments
132    ///
133    /// * `min_confidence` - Minimum confidence (0.0 to 1.0)
134    #[must_use]
135    pub fn with_min_confidence(mut self, min_confidence: f64) -> Self {
136        self.min_confidence = min_confidence;
137        self
138    }
139
140    /// Find all frequent 1-itemsets.
141    fn find_frequent_1_itemsets(&self, transactions: &[Vec<usize>]) -> Vec<(HashSet<usize>, f64)> {
142        use std::collections::HashMap;
143        let mut item_counts: HashMap<usize, usize> = HashMap::new();
144
145        // Count occurrences of each item
146        for transaction in transactions {
147            for &item in transaction {
148                *item_counts.entry(item).or_insert(0) += 1;
149            }
150        }
151
152        // Filter by minimum support
153        let n_transactions = transactions.len() as f64;
154        let mut frequent_1_itemsets = Vec::new();
155
156        for (item, count) in item_counts {
157            let support = count as f64 / n_transactions;
158            if support >= self.min_support {
159                let mut itemset = HashSet::new();
160                itemset.insert(item);
161                frequent_1_itemsets.push((itemset, support));
162            }
163        }
164
165        frequent_1_itemsets
166    }
167
168    /// Generate candidate k-itemsets from frequent (k-1)-itemsets.
169    fn generate_candidates(&self, prev_itemsets: &[(HashSet<usize>, f64)]) -> Vec<HashSet<usize>> {
170        let mut candidates = Vec::new();
171
172        // For each pair of (k-1)-itemsets
173        for i in 0..prev_itemsets.len() {
174            for j in (i + 1)..prev_itemsets.len() {
175                let set1 = &prev_itemsets[i].0;
176                let set2 = &prev_itemsets[j].0;
177
178                // Join step: combine two (k-1)-itemsets that differ by exactly one item
179                let union: HashSet<usize> = set1.union(set2).copied().collect();
180
181                // If union has k items, it's a valid candidate
182                if union.len() == set1.len() + 1 {
183                    // Prune step: ensure all (k-1)-subsets are frequent
184                    if self.has_infrequent_subset(&union, prev_itemsets) {
185                        continue;
186                    }
187
188                    // Avoid duplicates
189                    if !candidates.contains(&union) {
190                        candidates.push(union);
191                    }
192                }
193            }
194        }
195
196        candidates
197    }
198
199    /// Check if an itemset has any infrequent subset.
200    #[allow(clippy::unused_self)]
201    fn has_infrequent_subset(
202        &self,
203        itemset: &HashSet<usize>,
204        prev_itemsets: &[(HashSet<usize>, f64)],
205    ) -> bool {
206        // For each (k-1)-subset of itemset
207        for &item in itemset {
208            let mut subset = itemset.clone();
209            subset.remove(&item);
210
211            // Check if this subset is frequent
212            let is_frequent = prev_itemsets
213                .iter()
214                .any(|(freq_set, _)| freq_set == &subset);
215
216            if !is_frequent {
217                return true; // Has infrequent subset
218            }
219        }
220
221        false // All subsets are frequent
222    }
223
224    /// Prune candidates by minimum support.
225    fn prune_candidates(
226        &self,
227        candidates: Vec<HashSet<usize>>,
228        transactions: &[Vec<usize>],
229    ) -> Vec<(HashSet<usize>, f64)> {
230        let mut frequent = Vec::new();
231
232        for candidate in candidates {
233            let support = Self::calculate_support(&candidate, transactions);
234            if support >= self.min_support {
235                frequent.push((candidate, support));
236            }
237        }
238
239        frequent
240    }
241
242    /// Generate association rules from frequent itemsets.
243    fn generate_rules(&mut self, transactions: &[Vec<usize>]) {
244        let mut rules = Vec::new();
245
246        // For each frequent itemset with at least 2 items
247        for (itemset, itemset_support) in &self.frequent_itemsets {
248            if itemset.len() < 2 {
249                continue;
250            }
251
252            // Generate all non-empty proper subsets as antecedents
253            let items: Vec<usize> = itemset.iter().copied().collect();
254            let subsets = self.generate_subsets(&items);
255
256            for antecedent_items in subsets {
257                if antecedent_items.is_empty() || antecedent_items.len() == items.len() {
258                    continue; // Skip empty and full sets
259                }
260
261                // Consequent = itemset \ antecedent
262                let antecedent_set: HashSet<usize> = antecedent_items.iter().copied().collect();
263                let consequent_set: HashSet<usize> =
264                    itemset.difference(&antecedent_set).copied().collect();
265
266                // Calculate confidence = support(itemset) / support(antecedent)
267                let antecedent_support = Self::calculate_support(&antecedent_set, transactions);
268                let confidence = itemset_support / antecedent_support;
269
270                if confidence >= self.min_confidence {
271                    // Calculate lift = confidence / support(consequent)
272                    let consequent_support = Self::calculate_support(&consequent_set, transactions);
273                    let lift = confidence / consequent_support;
274
275                    let rule = AssociationRule {
276                        antecedent: antecedent_items,
277                        consequent: consequent_set.into_iter().collect(),
278                        support: *itemset_support,
279                        confidence,
280                        lift,
281                    };
282
283                    rules.push(rule);
284                }
285            }
286        }
287
288        self.rules = rules;
289    }
290
291    /// Generate all non-empty subsets of items.
292    #[allow(clippy::unused_self)]
293    fn generate_subsets(&self, items: &[usize]) -> Vec<Vec<usize>> {
294        let mut subsets = Vec::new();
295        let n = items.len();
296
297        // Generate all 2^n - 1 non-empty subsets (skip 0 and 2^n - 1)
298        for mask in 1..(1 << n) {
299            let mut subset = Vec::new();
300            for (i, &item) in items.iter().enumerate() {
301                if (mask & (1 << i)) != 0 {
302                    subset.push(item);
303                }
304            }
305            subsets.push(subset);
306        }
307
308        subsets
309    }
310
311    /// Fit the Apriori algorithm on transaction data.
312    ///
313    /// # Arguments
314    ///
315    /// * `transactions` - Vector of transactions, where each transaction is a vector of item IDs
316    pub fn fit(&mut self, transactions: &[Vec<usize>]) {
317        if transactions.is_empty() {
318            self.frequent_itemsets = Vec::new();
319            self.rules = Vec::new();
320            return;
321        }
322
323        self.frequent_itemsets = Vec::new();
324
325        // Step 1: Find frequent 1-itemsets
326        let mut current_itemsets = self.find_frequent_1_itemsets(transactions);
327
328        // Step 2: Iteratively generate frequent k-itemsets (k >= 2)
329        loop {
330            if current_itemsets.is_empty() {
331                break;
332            }
333
334            // Add current frequent itemsets to results
335            self.frequent_itemsets.extend(current_itemsets.clone());
336
337            // Generate candidates for next level
338            let candidates = self.generate_candidates(&current_itemsets);
339            if candidates.is_empty() {
340                break;
341            }
342
343            // Prune candidates by support
344            current_itemsets = self.prune_candidates(candidates, transactions);
345        }
346
347        // Step 3: Generate association rules from frequent itemsets
348        self.generate_rules(transactions);
349
350        // Sort frequent itemsets by support descending
351        self.frequent_itemsets.sort_by(|a, b| {
352            b.1.partial_cmp(&a.1)
353                .expect("Support values must be valid f64 (not NaN)")
354        });
355
356        // Sort rules by confidence descending
357        self.rules.sort_by(|a, b| {
358            b.confidence
359                .partial_cmp(&a.confidence)
360                .expect("Confidence values must be valid f64 (not NaN)")
361        });
362    }
363
364    /// Get the discovered frequent itemsets.
365    ///
366    /// Returns a vector of (itemset, support) tuples sorted by support descending.
367    #[must_use]
368    pub fn get_frequent_itemsets(&self) -> Vec<(Vec<usize>, f64)> {
369        self.frequent_itemsets
370            .iter()
371            .map(|(itemset, support)| (itemset.iter().copied().collect(), *support))
372            .collect()
373    }
374
375    /// Get the generated association rules.
376    ///
377    /// Returns rules sorted by confidence descending.
378    #[must_use]
379    pub fn get_rules(&self) -> Vec<AssociationRule> {
380        self.rules.clone()
381    }
382
383    /// Calculate support for a specific itemset.
384    ///
385    /// # Arguments
386    ///
387    /// * `itemset` - The itemset to calculate support for
388    /// * `transactions` - Transaction data
389    ///
390    /// # Returns
391    ///
392    /// Support value (0.0 to 1.0)
393    #[must_use]
394    pub fn calculate_support(itemset: &HashSet<usize>, transactions: &[Vec<usize>]) -> f64 {
395        if transactions.is_empty() {
396            return 0.0;
397        }
398
399        let mut count = 0;
400
401        for transaction in transactions {
402            // Check if all items in itemset appear in this transaction
403            if itemset.iter().all(|item| transaction.contains(item)) {
404                count += 1;
405            }
406        }
407
408        f64::from(count) / transactions.len() as f64
409    }
410}
411
412impl Default for Apriori {
413    fn default() -> Self {
414        Self::new()
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_apriori_new() {
424        let apriori = Apriori::new();
425        assert_eq!(apriori.min_support, 0.1);
426        assert_eq!(apriori.min_confidence, 0.5);
427        assert_eq!(apriori.frequent_itemsets.len(), 0);
428        assert_eq!(apriori.rules.len(), 0);
429    }
430
431    #[test]
432    fn test_apriori_with_min_support() {
433        let apriori = Apriori::new().with_min_support(0.3);
434        assert_eq!(apriori.min_support, 0.3);
435    }
436
437    #[test]
438    fn test_apriori_with_min_confidence() {
439        let apriori = Apriori::new().with_min_confidence(0.7);
440        assert_eq!(apriori.min_confidence, 0.7);
441    }
442
443    #[test]
444    fn test_apriori_fit_basic() {
445        // Simple market basket transactions
446        let transactions = vec![
447            vec![1, 2, 3],    // Transaction 1: milk, bread, butter
448            vec![1, 2],       // Transaction 2: milk, bread
449            vec![1, 3],       // Transaction 3: milk, butter
450            vec![2, 3],       // Transaction 4: bread, butter
451            vec![1, 2, 3, 4], // Transaction 5: milk, bread, butter, eggs
452        ];
453
454        let mut apriori = Apriori::new()
455            .with_min_support(0.4) // 40% support
456            .with_min_confidence(0.6); // 60% confidence
457
458        apriori.fit(&transactions);
459
460        // Should have found frequent itemsets
461        assert!(!apriori.frequent_itemsets.is_empty());
462    }
463
464    #[test]
465    fn test_frequent_itemsets() {
466        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
467
468        let mut apriori = Apriori::new().with_min_support(0.5); // 50% support
469        apriori.fit(&transactions);
470
471        let itemsets = apriori.get_frequent_itemsets();
472
473        // With 4 transactions and min_support=0.5, need >= 2 occurrences
474        // {1} appears in 3 transactions (75%) - frequent
475        // {2} appears in 3 transactions (75%) - frequent
476        // {3} appears in 3 transactions (75%) - frequent
477        // {1,2} appears in 2 transactions (50%) - frequent
478        // {1,3} appears in 2 transactions (50%) - frequent
479        // {2,3} appears in 2 transactions (50%) - frequent
480        // {1,2,3} appears in 1 transaction (25%) - not frequent
481
482        assert!(itemsets.len() >= 6);
483
484        // Verify itemsets are sorted by support descending
485        for i in 1..itemsets.len() {
486            assert!(itemsets[i - 1].1 >= itemsets[i].1);
487        }
488    }
489
490    #[test]
491    fn test_association_rules() {
492        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
493
494        let mut apriori = Apriori::new()
495            .with_min_support(0.5)
496            .with_min_confidence(0.6);
497
498        apriori.fit(&transactions);
499        let rules = apriori.get_rules();
500
501        // Should have generated some rules
502        assert!(!rules.is_empty());
503
504        // All rules should meet minimum confidence
505        for rule in &rules {
506            assert!(rule.confidence >= 0.6);
507        }
508
509        // Rules should be sorted by confidence descending
510        for i in 1..rules.len() {
511            assert!(rules[i - 1].confidence >= rules[i].confidence);
512        }
513    }
514
515    #[test]
516    fn test_support_calculation() {
517        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
518
519        // Calculate support for {1, 2}
520        let itemset: HashSet<usize> = vec![1, 2].into_iter().collect();
521        let support = Apriori::calculate_support(&itemset, &transactions);
522
523        // {1,2} appears in 2 out of 4 transactions = 0.5
524        assert!((support - 0.5).abs() < 1e-10);
525
526        // Calculate support for {1}
527        let itemset: HashSet<usize> = vec![1].into_iter().collect();
528        let support = Apriori::calculate_support(&itemset, &transactions);
529
530        // {1} appears in 3 out of 4 transactions = 0.75
531        assert!((support - 0.75).abs() < 1e-10);
532    }
533
534    #[test]
535    fn test_confidence_calculation() {
536        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
537
538        let mut apriori = Apriori::new()
539            .with_min_support(0.5)
540            .with_min_confidence(0.0); // Accept all rules to verify confidence
541
542        apriori.fit(&transactions);
543        let rules = apriori.get_rules();
544
545        // Find rule {1} => {2}
546        let rule = rules
547            .iter()
548            .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
549            .expect("Should have rule {1} => {2}");
550
551        // Confidence({1} => {2}) = P({1,2}) / P({1}) = 0.5 / 0.75 = 0.667
552        assert!((rule.confidence - 0.6666666).abs() < 1e-5);
553    }
554
555    #[test]
556    fn test_lift_calculation() {
557        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
558
559        let mut apriori = Apriori::new()
560            .with_min_support(0.5)
561            .with_min_confidence(0.0);
562
563        apriori.fit(&transactions);
564        let rules = apriori.get_rules();
565
566        // Find rule {1} => {2}
567        let rule = rules
568            .iter()
569            .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
570            .expect("Should have rule {1} => {2}");
571
572        // Lift({1} => {2}) = confidence / P({2}) = 0.667 / 0.75 = 0.889
573        assert!((rule.lift - 0.8888888).abs() < 1e-5);
574
575        // Lift > 1.0 means positive correlation
576        // Lift < 1.0 means negative correlation
577        // Lift = 1.0 means independence
578    }
579
580    #[test]
581    fn test_min_support_filter() {
582        let transactions = vec![
583            vec![1, 2],
584            vec![1, 2],
585            vec![1, 2],
586            vec![3, 4], // Infrequent items
587        ];
588
589        let mut apriori = Apriori::new().with_min_support(0.5);
590        apriori.fit(&transactions);
591
592        let itemsets = apriori.get_frequent_itemsets();
593
594        // Only {1}, {2}, {1,2} should be frequent (75% support each)
595        // {3}, {4}, {3,4} are infrequent (25% support)
596        for (itemset, support) in itemsets {
597            assert!(support >= 0.5, "All itemsets should meet min_support");
598            assert!(
599                !itemset.contains(&3) && !itemset.contains(&4),
600                "Infrequent items should be pruned"
601            );
602        }
603    }
604
605    #[test]
606    fn test_min_confidence_filter() {
607        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![1]];
608
609        let mut apriori = Apriori::new()
610            .with_min_support(0.25)
611            .with_min_confidence(0.8); // High confidence threshold
612
613        apriori.fit(&transactions);
614        let rules = apriori.get_rules();
615
616        // All rules should meet minimum confidence
617        for rule in &rules {
618            assert!(
619                rule.confidence >= 0.8,
620                "Rule {:?} => {:?} has confidence {:.2} < 0.8",
621                rule.antecedent,
622                rule.consequent,
623                rule.confidence
624            );
625        }
626    }
627
628    #[test]
629    fn test_empty_transactions() {
630        let transactions: Vec<Vec<usize>> = vec![];
631
632        let mut apriori = Apriori::new();
633        apriori.fit(&transactions);
634
635        let itemsets = apriori.get_frequent_itemsets();
636        assert_eq!(itemsets.len(), 0, "Should have no frequent itemsets");
637
638        let rules = apriori.get_rules();
639        assert_eq!(rules.len(), 0, "Should have no rules");
640    }
641
642    #[test]
643    fn test_single_item_transactions() {
644        let transactions = vec![vec![1], vec![2], vec![3], vec![4]];
645
646        let mut apriori = Apriori::new().with_min_support(0.25);
647        apriori.fit(&transactions);
648
649        let itemsets = apriori.get_frequent_itemsets();
650
651        // Each item appears once (25% support)
652        // Should have 4 frequent 1-itemsets
653        assert_eq!(itemsets.len(), 4);
654
655        // No multi-item itemsets possible
656        for (itemset, _) in itemsets {
657            assert_eq!(itemset.len(), 1);
658        }
659
660        let rules = apriori.get_rules();
661        // No rules can be generated from single-item itemsets
662        assert_eq!(rules.len(), 0);
663    }
664
665    #[test]
666    fn test_get_rules_before_fit() {
667        let apriori = Apriori::new();
668        let rules = apriori.get_rules();
669        assert_eq!(rules.len(), 0, "Should have no rules before fit");
670    }
671
672    #[test]
673    fn test_get_itemsets_before_fit() {
674        let apriori = Apriori::new();
675        let itemsets = apriori.get_frequent_itemsets();
676        assert_eq!(itemsets.len(), 0, "Should have no itemsets before fit");
677    }
678}