aprender/mining/
mod.rs

1//! Pattern mining algorithms for association rule discovery.
2//!
3//! This module provides algorithms for discovering patterns in transactional data,
4//! particularly association rules used in market basket analysis.
5//!
6//! # Algorithms
7//!
8//! - [`Apriori`]: Frequent itemset mining and association rule generation
9//!
10//! # Example
11//!
12//! ```
13//! use aprender::mining::Apriori;
14//!
15//! // Market basket transactions (each transaction is a set of item IDs)
16//! let transactions = vec![
17//!     vec![1, 2, 3],    // Transaction 1: items 1, 2, 3
18//!     vec![1, 2],       // Transaction 2: items 1, 2
19//!     vec![1, 3],       // Transaction 3: items 1, 3
20//!     vec![2, 3],       // Transaction 4: items 2, 3
21//! ];
22//!
23//! // Find frequent itemsets with minimum support 0.5 (50%)
24//! let mut apriori = Apriori::new()
25//!     .with_min_support(0.5)
26//!     .with_min_confidence(0.7);
27//!
28//! apriori.fit(&transactions);
29//!
30//! // Get association rules
31//! let rules = apriori.get_rules();
32//! for rule in rules {
33//!     println!("{:?} => {:?} (conf={:.2}, lift={:.2})",
34//!         rule.antecedent, rule.consequent, rule.confidence, rule.lift);
35//! }
36//! ```
37
38use std::collections::HashSet;
39
40/// Association rule: antecedent => consequent
41#[derive(Debug, Clone, PartialEq)]
42pub struct AssociationRule {
43    /// Items in the antecedent (left side)
44    pub antecedent: Vec<usize>,
45    /// Items in the consequent (right side)
46    pub consequent: Vec<usize>,
47    /// Support: P(antecedent ∪ consequent)
48    pub support: f64,
49    /// Confidence: P(consequent | antecedent) = support / P(antecedent)
50    pub confidence: f64,
51    /// Lift: confidence / P(consequent)
52    pub lift: f64,
53}
54
55/// Apriori algorithm for frequent itemset mining and association rule generation.
56///
57/// The Apriori algorithm discovers frequent itemsets in transactional data
58/// and generates association rules based on support and confidence thresholds.
59///
60/// # Algorithm
61///
62/// 1. Find frequent 1-itemsets (support >= min_support)
63/// 2. Generate candidate k-itemsets from frequent (k-1)-itemsets
64/// 3. Prune candidates that don't meet minimum support
65/// 4. Repeat until no more frequent itemsets can be generated
66/// 5. Generate association rules from frequent itemsets
67/// 6. Filter rules by minimum confidence
68///
69/// # Parameters
70///
71/// - `min_support`: Minimum support threshold (0.0 to 1.0)
72/// - `min_confidence`: Minimum confidence threshold (0.0 to 1.0)
73///
74/// # Example
75///
76/// ```
77/// use aprender::mining::Apriori;
78///
79/// let transactions = vec![
80///     vec![1, 2, 3],
81///     vec![1, 2],
82///     vec![1, 3],
83///     vec![2, 3],
84/// ];
85///
86/// let mut apriori = Apriori::new()
87///     .with_min_support(0.5)
88///     .with_min_confidence(0.7);
89///
90/// apriori.fit(&transactions);
91/// let rules = apriori.get_rules();
92/// ```
93#[derive(Debug, Clone)]
94pub struct Apriori {
95    min_support: f64,
96    min_confidence: f64,
97    frequent_itemsets: Vec<(HashSet<usize>, f64)>, // (itemset, support)
98    rules: Vec<AssociationRule>,
99}
100
101impl Apriori {
102    /// Create a new Apriori instance with default parameters.
103    ///
104    /// # Default Parameters
105    ///
106    /// - `min_support`: 0.1 (10%)
107    /// - `min_confidence`: 0.5 (50%)
108    pub fn new() -> Self {
109        Self {
110            min_support: 0.1,
111            min_confidence: 0.5,
112            frequent_itemsets: Vec::new(),
113            rules: Vec::new(),
114        }
115    }
116
117    /// Set the minimum support threshold.
118    ///
119    /// # Arguments
120    ///
121    /// * `min_support` - Minimum support (0.0 to 1.0)
122    pub fn with_min_support(mut self, min_support: f64) -> Self {
123        self.min_support = min_support;
124        self
125    }
126
127    /// Set the minimum confidence threshold.
128    ///
129    /// # Arguments
130    ///
131    /// * `min_confidence` - Minimum confidence (0.0 to 1.0)
132    pub fn with_min_confidence(mut self, min_confidence: f64) -> Self {
133        self.min_confidence = min_confidence;
134        self
135    }
136
137    /// Find all frequent 1-itemsets.
138    fn find_frequent_1_itemsets(&self, transactions: &[Vec<usize>]) -> Vec<(HashSet<usize>, f64)> {
139        use std::collections::HashMap;
140        let mut item_counts: HashMap<usize, usize> = HashMap::new();
141
142        // Count occurrences of each item
143        for transaction in transactions {
144            for &item in transaction {
145                *item_counts.entry(item).or_insert(0) += 1;
146            }
147        }
148
149        // Filter by minimum support
150        let n_transactions = transactions.len() as f64;
151        let mut frequent_1_itemsets = Vec::new();
152
153        for (item, count) in item_counts {
154            let support = count as f64 / n_transactions;
155            if support >= self.min_support {
156                let mut itemset = HashSet::new();
157                itemset.insert(item);
158                frequent_1_itemsets.push((itemset, support));
159            }
160        }
161
162        frequent_1_itemsets
163    }
164
165    /// Generate candidate k-itemsets from frequent (k-1)-itemsets.
166    fn generate_candidates(&self, prev_itemsets: &[(HashSet<usize>, f64)]) -> Vec<HashSet<usize>> {
167        let mut candidates = Vec::new();
168
169        // For each pair of (k-1)-itemsets
170        for i in 0..prev_itemsets.len() {
171            for j in (i + 1)..prev_itemsets.len() {
172                let set1 = &prev_itemsets[i].0;
173                let set2 = &prev_itemsets[j].0;
174
175                // Join step: combine two (k-1)-itemsets that differ by exactly one item
176                let union: HashSet<usize> = set1.union(set2).copied().collect();
177
178                // If union has k items, it's a valid candidate
179                if union.len() == set1.len() + 1 {
180                    // Prune step: ensure all (k-1)-subsets are frequent
181                    if self.has_infrequent_subset(&union, prev_itemsets) {
182                        continue;
183                    }
184
185                    // Avoid duplicates
186                    if !candidates.contains(&union) {
187                        candidates.push(union);
188                    }
189                }
190            }
191        }
192
193        candidates
194    }
195
196    /// Check if an itemset has any infrequent subset.
197    #[allow(clippy::unused_self)]
198    fn has_infrequent_subset(
199        &self,
200        itemset: &HashSet<usize>,
201        prev_itemsets: &[(HashSet<usize>, f64)],
202    ) -> bool {
203        // For each (k-1)-subset of itemset
204        for &item in itemset {
205            let mut subset = itemset.clone();
206            subset.remove(&item);
207
208            // Check if this subset is frequent
209            let is_frequent = prev_itemsets
210                .iter()
211                .any(|(freq_set, _)| freq_set == &subset);
212
213            if !is_frequent {
214                return true; // Has infrequent subset
215            }
216        }
217
218        false // All subsets are frequent
219    }
220
221    /// Prune candidates by minimum support.
222    fn prune_candidates(
223        &self,
224        candidates: Vec<HashSet<usize>>,
225        transactions: &[Vec<usize>],
226    ) -> Vec<(HashSet<usize>, f64)> {
227        let mut frequent = Vec::new();
228
229        for candidate in candidates {
230            let support = Self::calculate_support(&candidate, transactions);
231            if support >= self.min_support {
232                frequent.push((candidate, support));
233            }
234        }
235
236        frequent
237    }
238
239    /// Generate association rules from frequent itemsets.
240    fn generate_rules(&mut self, transactions: &[Vec<usize>]) {
241        let mut rules = Vec::new();
242
243        // For each frequent itemset with at least 2 items
244        for (itemset, itemset_support) in &self.frequent_itemsets {
245            if itemset.len() < 2 {
246                continue;
247            }
248
249            // Generate all non-empty proper subsets as antecedents
250            let items: Vec<usize> = itemset.iter().copied().collect();
251            let subsets = self.generate_subsets(&items);
252
253            for antecedent_items in subsets {
254                if antecedent_items.is_empty() || antecedent_items.len() == items.len() {
255                    continue; // Skip empty and full sets
256                }
257
258                // Consequent = itemset \ antecedent
259                let antecedent_set: HashSet<usize> = antecedent_items.iter().copied().collect();
260                let consequent_set: HashSet<usize> =
261                    itemset.difference(&antecedent_set).copied().collect();
262
263                // Calculate confidence = support(itemset) / support(antecedent)
264                let antecedent_support = Self::calculate_support(&antecedent_set, transactions);
265                let confidence = itemset_support / antecedent_support;
266
267                if confidence >= self.min_confidence {
268                    // Calculate lift = confidence / support(consequent)
269                    let consequent_support = Self::calculate_support(&consequent_set, transactions);
270                    let lift = confidence / consequent_support;
271
272                    let rule = AssociationRule {
273                        antecedent: antecedent_items,
274                        consequent: consequent_set.into_iter().collect(),
275                        support: *itemset_support,
276                        confidence,
277                        lift,
278                    };
279
280                    rules.push(rule);
281                }
282            }
283        }
284
285        self.rules = rules;
286    }
287
288    /// Generate all non-empty subsets of items.
289    #[allow(clippy::unused_self)]
290    fn generate_subsets(&self, items: &[usize]) -> Vec<Vec<usize>> {
291        let mut subsets = Vec::new();
292        let n = items.len();
293
294        // Generate all 2^n - 1 non-empty subsets (skip 0 and 2^n - 1)
295        for mask in 1..(1 << n) {
296            let mut subset = Vec::new();
297            for (i, &item) in items.iter().enumerate() {
298                if (mask & (1 << i)) != 0 {
299                    subset.push(item);
300                }
301            }
302            subsets.push(subset);
303        }
304
305        subsets
306    }
307
308    /// Fit the Apriori algorithm on transaction data.
309    ///
310    /// # Arguments
311    ///
312    /// * `transactions` - Vector of transactions, where each transaction is a vector of item IDs
313    pub fn fit(&mut self, transactions: &[Vec<usize>]) {
314        if transactions.is_empty() {
315            self.frequent_itemsets = Vec::new();
316            self.rules = Vec::new();
317            return;
318        }
319
320        self.frequent_itemsets = Vec::new();
321
322        // Step 1: Find frequent 1-itemsets
323        let mut current_itemsets = self.find_frequent_1_itemsets(transactions);
324
325        // Step 2: Iteratively generate frequent k-itemsets (k >= 2)
326        loop {
327            if current_itemsets.is_empty() {
328                break;
329            }
330
331            // Add current frequent itemsets to results
332            self.frequent_itemsets.extend(current_itemsets.clone());
333
334            // Generate candidates for next level
335            let candidates = self.generate_candidates(&current_itemsets);
336            if candidates.is_empty() {
337                break;
338            }
339
340            // Prune candidates by support
341            current_itemsets = self.prune_candidates(candidates, transactions);
342        }
343
344        // Step 3: Generate association rules from frequent itemsets
345        self.generate_rules(transactions);
346
347        // Sort frequent itemsets by support descending
348        self.frequent_itemsets.sort_by(|a, b| {
349            b.1.partial_cmp(&a.1)
350                .expect("Support values must be valid f64 (not NaN)")
351        });
352
353        // Sort rules by confidence descending
354        self.rules.sort_by(|a, b| {
355            b.confidence
356                .partial_cmp(&a.confidence)
357                .expect("Confidence values must be valid f64 (not NaN)")
358        });
359    }
360
361    /// Get the discovered frequent itemsets.
362    ///
363    /// Returns a vector of (itemset, support) tuples sorted by support descending.
364    pub fn get_frequent_itemsets(&self) -> Vec<(Vec<usize>, f64)> {
365        self.frequent_itemsets
366            .iter()
367            .map(|(itemset, support)| (itemset.iter().copied().collect(), *support))
368            .collect()
369    }
370
371    /// Get the generated association rules.
372    ///
373    /// Returns rules sorted by confidence descending.
374    pub fn get_rules(&self) -> Vec<AssociationRule> {
375        self.rules.clone()
376    }
377
378    /// Calculate support for a specific itemset.
379    ///
380    /// # Arguments
381    ///
382    /// * `itemset` - The itemset to calculate support for
383    /// * `transactions` - Transaction data
384    ///
385    /// # Returns
386    ///
387    /// Support value (0.0 to 1.0)
388    pub fn calculate_support(itemset: &HashSet<usize>, transactions: &[Vec<usize>]) -> f64 {
389        if transactions.is_empty() {
390            return 0.0;
391        }
392
393        let mut count = 0;
394
395        for transaction in transactions {
396            // Check if all items in itemset appear in this transaction
397            if itemset.iter().all(|item| transaction.contains(item)) {
398                count += 1;
399            }
400        }
401
402        f64::from(count) / transactions.len() as f64
403    }
404}
405
406impl Default for Apriori {
407    fn default() -> Self {
408        Self::new()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_apriori_new() {
418        let apriori = Apriori::new();
419        assert_eq!(apriori.min_support, 0.1);
420        assert_eq!(apriori.min_confidence, 0.5);
421        assert_eq!(apriori.frequent_itemsets.len(), 0);
422        assert_eq!(apriori.rules.len(), 0);
423    }
424
425    #[test]
426    fn test_apriori_with_min_support() {
427        let apriori = Apriori::new().with_min_support(0.3);
428        assert_eq!(apriori.min_support, 0.3);
429    }
430
431    #[test]
432    fn test_apriori_with_min_confidence() {
433        let apriori = Apriori::new().with_min_confidence(0.7);
434        assert_eq!(apriori.min_confidence, 0.7);
435    }
436
437    #[test]
438    fn test_apriori_fit_basic() {
439        // Simple market basket transactions
440        let transactions = vec![
441            vec![1, 2, 3],    // Transaction 1: milk, bread, butter
442            vec![1, 2],       // Transaction 2: milk, bread
443            vec![1, 3],       // Transaction 3: milk, butter
444            vec![2, 3],       // Transaction 4: bread, butter
445            vec![1, 2, 3, 4], // Transaction 5: milk, bread, butter, eggs
446        ];
447
448        let mut apriori = Apriori::new()
449            .with_min_support(0.4) // 40% support
450            .with_min_confidence(0.6); // 60% confidence
451
452        apriori.fit(&transactions);
453
454        // Should have found frequent itemsets
455        assert!(!apriori.frequent_itemsets.is_empty());
456    }
457
458    #[test]
459    fn test_frequent_itemsets() {
460        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
461
462        let mut apriori = Apriori::new().with_min_support(0.5); // 50% support
463        apriori.fit(&transactions);
464
465        let itemsets = apriori.get_frequent_itemsets();
466
467        // With 4 transactions and min_support=0.5, need >= 2 occurrences
468        // {1} appears in 3 transactions (75%) - frequent
469        // {2} appears in 3 transactions (75%) - frequent
470        // {3} appears in 3 transactions (75%) - frequent
471        // {1,2} appears in 2 transactions (50%) - frequent
472        // {1,3} appears in 2 transactions (50%) - frequent
473        // {2,3} appears in 2 transactions (50%) - frequent
474        // {1,2,3} appears in 1 transaction (25%) - not frequent
475
476        assert!(itemsets.len() >= 6);
477
478        // Verify itemsets are sorted by support descending
479        for i in 1..itemsets.len() {
480            assert!(itemsets[i - 1].1 >= itemsets[i].1);
481        }
482    }
483
484    #[test]
485    fn test_association_rules() {
486        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
487
488        let mut apriori = Apriori::new()
489            .with_min_support(0.5)
490            .with_min_confidence(0.6);
491
492        apriori.fit(&transactions);
493        let rules = apriori.get_rules();
494
495        // Should have generated some rules
496        assert!(!rules.is_empty());
497
498        // All rules should meet minimum confidence
499        for rule in &rules {
500            assert!(rule.confidence >= 0.6);
501        }
502
503        // Rules should be sorted by confidence descending
504        for i in 1..rules.len() {
505            assert!(rules[i - 1].confidence >= rules[i].confidence);
506        }
507    }
508
509    #[test]
510    fn test_support_calculation() {
511        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
512
513        // Calculate support for {1, 2}
514        let itemset: HashSet<usize> = vec![1, 2].into_iter().collect();
515        let support = Apriori::calculate_support(&itemset, &transactions);
516
517        // {1,2} appears in 2 out of 4 transactions = 0.5
518        assert!((support - 0.5).abs() < 1e-10);
519
520        // Calculate support for {1}
521        let itemset: HashSet<usize> = vec![1].into_iter().collect();
522        let support = Apriori::calculate_support(&itemset, &transactions);
523
524        // {1} appears in 3 out of 4 transactions = 0.75
525        assert!((support - 0.75).abs() < 1e-10);
526    }
527
528    #[test]
529    fn test_confidence_calculation() {
530        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
531
532        let mut apriori = Apriori::new()
533            .with_min_support(0.5)
534            .with_min_confidence(0.0); // Accept all rules to verify confidence
535
536        apriori.fit(&transactions);
537        let rules = apriori.get_rules();
538
539        // Find rule {1} => {2}
540        let rule = rules
541            .iter()
542            .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
543            .expect("Should have rule {1} => {2}");
544
545        // Confidence({1} => {2}) = P({1,2}) / P({1}) = 0.5 / 0.75 = 0.667
546        assert!((rule.confidence - 0.6666666).abs() < 1e-5);
547    }
548
549    #[test]
550    fn test_lift_calculation() {
551        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
552
553        let mut apriori = Apriori::new()
554            .with_min_support(0.5)
555            .with_min_confidence(0.0);
556
557        apriori.fit(&transactions);
558        let rules = apriori.get_rules();
559
560        // Find rule {1} => {2}
561        let rule = rules
562            .iter()
563            .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
564            .expect("Should have rule {1} => {2}");
565
566        // Lift({1} => {2}) = confidence / P({2}) = 0.667 / 0.75 = 0.889
567        assert!((rule.lift - 0.8888888).abs() < 1e-5);
568
569        // Lift > 1.0 means positive correlation
570        // Lift < 1.0 means negative correlation
571        // Lift = 1.0 means independence
572    }
573
574    #[test]
575    fn test_min_support_filter() {
576        let transactions = vec![
577            vec![1, 2],
578            vec![1, 2],
579            vec![1, 2],
580            vec![3, 4], // Infrequent items
581        ];
582
583        let mut apriori = Apriori::new().with_min_support(0.5);
584        apriori.fit(&transactions);
585
586        let itemsets = apriori.get_frequent_itemsets();
587
588        // Only {1}, {2}, {1,2} should be frequent (75% support each)
589        // {3}, {4}, {3,4} are infrequent (25% support)
590        for (itemset, support) in itemsets {
591            assert!(support >= 0.5, "All itemsets should meet min_support");
592            assert!(
593                !itemset.contains(&3) && !itemset.contains(&4),
594                "Infrequent items should be pruned"
595            );
596        }
597    }
598
599    #[test]
600    fn test_min_confidence_filter() {
601        let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![1]];
602
603        let mut apriori = Apriori::new()
604            .with_min_support(0.25)
605            .with_min_confidence(0.8); // High confidence threshold
606
607        apriori.fit(&transactions);
608        let rules = apriori.get_rules();
609
610        // All rules should meet minimum confidence
611        for rule in &rules {
612            assert!(
613                rule.confidence >= 0.8,
614                "Rule {:?} => {:?} has confidence {:.2} < 0.8",
615                rule.antecedent,
616                rule.consequent,
617                rule.confidence
618            );
619        }
620    }
621
622    #[test]
623    fn test_empty_transactions() {
624        let transactions: Vec<Vec<usize>> = vec![];
625
626        let mut apriori = Apriori::new();
627        apriori.fit(&transactions);
628
629        let itemsets = apriori.get_frequent_itemsets();
630        assert_eq!(itemsets.len(), 0, "Should have no frequent itemsets");
631
632        let rules = apriori.get_rules();
633        assert_eq!(rules.len(), 0, "Should have no rules");
634    }
635
636    #[test]
637    fn test_single_item_transactions() {
638        let transactions = vec![vec![1], vec![2], vec![3], vec![4]];
639
640        let mut apriori = Apriori::new().with_min_support(0.25);
641        apriori.fit(&transactions);
642
643        let itemsets = apriori.get_frequent_itemsets();
644
645        // Each item appears once (25% support)
646        // Should have 4 frequent 1-itemsets
647        assert_eq!(itemsets.len(), 4);
648
649        // No multi-item itemsets possible
650        for (itemset, _) in itemsets {
651            assert_eq!(itemset.len(), 1);
652        }
653
654        let rules = apriori.get_rules();
655        // No rules can be generated from single-item itemsets
656        assert_eq!(rules.len(), 0);
657    }
658
659    #[test]
660    fn test_get_rules_before_fit() {
661        let apriori = Apriori::new();
662        let rules = apriori.get_rules();
663        assert_eq!(rules.len(), 0, "Should have no rules before fit");
664    }
665
666    #[test]
667    fn test_get_itemsets_before_fit() {
668        let apriori = Apriori::new();
669        let itemsets = apriori.get_frequent_itemsets();
670        assert_eq!(itemsets.len(), 0, "Should have no itemsets before fit");
671    }
672}