1use std::collections::HashSet;
39
40#[derive(Debug, Clone, PartialEq)]
42pub struct AssociationRule {
43 pub antecedent: Vec<usize>,
45 pub consequent: Vec<usize>,
47 pub support: f64,
49 pub confidence: f64,
51 pub lift: f64,
53}
54
55#[derive(Debug, Clone)]
94pub struct Apriori {
95 min_support: f64,
96 min_confidence: f64,
97 frequent_itemsets: Vec<(HashSet<usize>, f64)>, rules: Vec<AssociationRule>,
99}
100
101impl Apriori {
102 #[must_use]
109 pub fn new() -> Self {
110 Self {
111 min_support: 0.1,
112 min_confidence: 0.5,
113 frequent_itemsets: Vec::new(),
114 rules: Vec::new(),
115 }
116 }
117
118 #[must_use]
124 pub fn with_min_support(mut self, min_support: f64) -> Self {
125 self.min_support = min_support;
126 self
127 }
128
129 #[must_use]
135 pub fn with_min_confidence(mut self, min_confidence: f64) -> Self {
136 self.min_confidence = min_confidence;
137 self
138 }
139
140 fn find_frequent_1_itemsets(&self, transactions: &[Vec<usize>]) -> Vec<(HashSet<usize>, f64)> {
142 use std::collections::HashMap;
143 let mut item_counts: HashMap<usize, usize> = HashMap::new();
144
145 for transaction in transactions {
147 for &item in transaction {
148 *item_counts.entry(item).or_insert(0) += 1;
149 }
150 }
151
152 let n_transactions = transactions.len() as f64;
154 let mut frequent_1_itemsets = Vec::new();
155
156 for (item, count) in item_counts {
157 let support = count as f64 / n_transactions;
158 if support >= self.min_support {
159 let mut itemset = HashSet::new();
160 itemset.insert(item);
161 frequent_1_itemsets.push((itemset, support));
162 }
163 }
164
165 frequent_1_itemsets
166 }
167
168 fn generate_candidates(&self, prev_itemsets: &[(HashSet<usize>, f64)]) -> Vec<HashSet<usize>> {
170 let mut candidates = Vec::new();
171
172 for i in 0..prev_itemsets.len() {
174 for j in (i + 1)..prev_itemsets.len() {
175 let set1 = &prev_itemsets[i].0;
176 let set2 = &prev_itemsets[j].0;
177
178 let union: HashSet<usize> = set1.union(set2).copied().collect();
180
181 if union.len() == set1.len() + 1 {
183 if self.has_infrequent_subset(&union, prev_itemsets) {
185 continue;
186 }
187
188 if !candidates.contains(&union) {
190 candidates.push(union);
191 }
192 }
193 }
194 }
195
196 candidates
197 }
198
199 #[allow(clippy::unused_self)]
201 fn has_infrequent_subset(
202 &self,
203 itemset: &HashSet<usize>,
204 prev_itemsets: &[(HashSet<usize>, f64)],
205 ) -> bool {
206 for &item in itemset {
208 let mut subset = itemset.clone();
209 subset.remove(&item);
210
211 let is_frequent = prev_itemsets
213 .iter()
214 .any(|(freq_set, _)| freq_set == &subset);
215
216 if !is_frequent {
217 return true; }
219 }
220
221 false }
223
224 fn prune_candidates(
226 &self,
227 candidates: Vec<HashSet<usize>>,
228 transactions: &[Vec<usize>],
229 ) -> Vec<(HashSet<usize>, f64)> {
230 let mut frequent = Vec::new();
231
232 for candidate in candidates {
233 let support = Self::calculate_support(&candidate, transactions);
234 if support >= self.min_support {
235 frequent.push((candidate, support));
236 }
237 }
238
239 frequent
240 }
241
242 fn generate_rules(&mut self, transactions: &[Vec<usize>]) {
244 let mut rules = Vec::new();
245
246 for (itemset, itemset_support) in &self.frequent_itemsets {
248 if itemset.len() < 2 {
249 continue;
250 }
251
252 let items: Vec<usize> = itemset.iter().copied().collect();
254 let subsets = self.generate_subsets(&items);
255
256 for antecedent_items in subsets {
257 if antecedent_items.is_empty() || antecedent_items.len() == items.len() {
258 continue; }
260
261 let antecedent_set: HashSet<usize> = antecedent_items.iter().copied().collect();
263 let consequent_set: HashSet<usize> =
264 itemset.difference(&antecedent_set).copied().collect();
265
266 let antecedent_support = Self::calculate_support(&antecedent_set, transactions);
268 let confidence = itemset_support / antecedent_support;
269
270 if confidence >= self.min_confidence {
271 let consequent_support = Self::calculate_support(&consequent_set, transactions);
273 let lift = confidence / consequent_support;
274
275 let rule = AssociationRule {
276 antecedent: antecedent_items,
277 consequent: consequent_set.into_iter().collect(),
278 support: *itemset_support,
279 confidence,
280 lift,
281 };
282
283 rules.push(rule);
284 }
285 }
286 }
287
288 self.rules = rules;
289 }
290
291 #[allow(clippy::unused_self)]
293 fn generate_subsets(&self, items: &[usize]) -> Vec<Vec<usize>> {
294 let mut subsets = Vec::new();
295 let n = items.len();
296
297 for mask in 1..(1 << n) {
299 let mut subset = Vec::new();
300 for (i, &item) in items.iter().enumerate() {
301 if (mask & (1 << i)) != 0 {
302 subset.push(item);
303 }
304 }
305 subsets.push(subset);
306 }
307
308 subsets
309 }
310
311 pub fn fit(&mut self, transactions: &[Vec<usize>]) {
317 if transactions.is_empty() {
318 self.frequent_itemsets = Vec::new();
319 self.rules = Vec::new();
320 return;
321 }
322
323 self.frequent_itemsets = Vec::new();
324
325 let mut current_itemsets = self.find_frequent_1_itemsets(transactions);
327
328 loop {
330 if current_itemsets.is_empty() {
331 break;
332 }
333
334 self.frequent_itemsets.extend(current_itemsets.clone());
336
337 let candidates = self.generate_candidates(¤t_itemsets);
339 if candidates.is_empty() {
340 break;
341 }
342
343 current_itemsets = self.prune_candidates(candidates, transactions);
345 }
346
347 self.generate_rules(transactions);
349
350 self.frequent_itemsets.sort_by(|a, b| {
352 b.1.partial_cmp(&a.1)
353 .expect("Support values must be valid f64 (not NaN)")
354 });
355
356 self.rules.sort_by(|a, b| {
358 b.confidence
359 .partial_cmp(&a.confidence)
360 .expect("Confidence values must be valid f64 (not NaN)")
361 });
362 }
363
364 #[must_use]
368 pub fn get_frequent_itemsets(&self) -> Vec<(Vec<usize>, f64)> {
369 self.frequent_itemsets
370 .iter()
371 .map(|(itemset, support)| (itemset.iter().copied().collect(), *support))
372 .collect()
373 }
374
375 #[must_use]
379 pub fn get_rules(&self) -> Vec<AssociationRule> {
380 self.rules.clone()
381 }
382
383 #[must_use]
394 pub fn calculate_support(itemset: &HashSet<usize>, transactions: &[Vec<usize>]) -> f64 {
395 if transactions.is_empty() {
396 return 0.0;
397 }
398
399 let mut count = 0;
400
401 for transaction in transactions {
402 if itemset.iter().all(|item| transaction.contains(item)) {
404 count += 1;
405 }
406 }
407
408 f64::from(count) / transactions.len() as f64
409 }
410}
411
412impl Default for Apriori {
413 fn default() -> Self {
414 Self::new()
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_apriori_new() {
424 let apriori = Apriori::new();
425 assert_eq!(apriori.min_support, 0.1);
426 assert_eq!(apriori.min_confidence, 0.5);
427 assert_eq!(apriori.frequent_itemsets.len(), 0);
428 assert_eq!(apriori.rules.len(), 0);
429 }
430
431 #[test]
432 fn test_apriori_with_min_support() {
433 let apriori = Apriori::new().with_min_support(0.3);
434 assert_eq!(apriori.min_support, 0.3);
435 }
436
437 #[test]
438 fn test_apriori_with_min_confidence() {
439 let apriori = Apriori::new().with_min_confidence(0.7);
440 assert_eq!(apriori.min_confidence, 0.7);
441 }
442
443 #[test]
444 fn test_apriori_fit_basic() {
445 let transactions = vec![
447 vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3], vec![1, 2, 3, 4], ];
453
454 let mut apriori = Apriori::new()
455 .with_min_support(0.4) .with_min_confidence(0.6); apriori.fit(&transactions);
459
460 assert!(!apriori.frequent_itemsets.is_empty());
462 }
463
464 #[test]
465 fn test_frequent_itemsets() {
466 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
467
468 let mut apriori = Apriori::new().with_min_support(0.5); apriori.fit(&transactions);
470
471 let itemsets = apriori.get_frequent_itemsets();
472
473 assert!(itemsets.len() >= 6);
483
484 for i in 1..itemsets.len() {
486 assert!(itemsets[i - 1].1 >= itemsets[i].1);
487 }
488 }
489
490 #[test]
491 fn test_association_rules() {
492 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
493
494 let mut apriori = Apriori::new()
495 .with_min_support(0.5)
496 .with_min_confidence(0.6);
497
498 apriori.fit(&transactions);
499 let rules = apriori.get_rules();
500
501 assert!(!rules.is_empty());
503
504 for rule in &rules {
506 assert!(rule.confidence >= 0.6);
507 }
508
509 for i in 1..rules.len() {
511 assert!(rules[i - 1].confidence >= rules[i].confidence);
512 }
513 }
514
515 #[test]
516 fn test_support_calculation() {
517 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
518
519 let itemset: HashSet<usize> = vec![1, 2].into_iter().collect();
521 let support = Apriori::calculate_support(&itemset, &transactions);
522
523 assert!((support - 0.5).abs() < 1e-10);
525
526 let itemset: HashSet<usize> = vec![1].into_iter().collect();
528 let support = Apriori::calculate_support(&itemset, &transactions);
529
530 assert!((support - 0.75).abs() < 1e-10);
532 }
533
534 #[test]
535 fn test_confidence_calculation() {
536 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
537
538 let mut apriori = Apriori::new()
539 .with_min_support(0.5)
540 .with_min_confidence(0.0); apriori.fit(&transactions);
543 let rules = apriori.get_rules();
544
545 let rule = rules
547 .iter()
548 .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
549 .expect("Should have rule {1} => {2}");
550
551 assert!((rule.confidence - 0.6666666).abs() < 1e-5);
553 }
554
555 #[test]
556 fn test_lift_calculation() {
557 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
558
559 let mut apriori = Apriori::new()
560 .with_min_support(0.5)
561 .with_min_confidence(0.0);
562
563 apriori.fit(&transactions);
564 let rules = apriori.get_rules();
565
566 let rule = rules
568 .iter()
569 .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
570 .expect("Should have rule {1} => {2}");
571
572 assert!((rule.lift - 0.8888888).abs() < 1e-5);
574
575 }
579
580 #[test]
581 fn test_min_support_filter() {
582 let transactions = vec![
583 vec![1, 2],
584 vec![1, 2],
585 vec![1, 2],
586 vec![3, 4], ];
588
589 let mut apriori = Apriori::new().with_min_support(0.5);
590 apriori.fit(&transactions);
591
592 let itemsets = apriori.get_frequent_itemsets();
593
594 for (itemset, support) in itemsets {
597 assert!(support >= 0.5, "All itemsets should meet min_support");
598 assert!(
599 !itemset.contains(&3) && !itemset.contains(&4),
600 "Infrequent items should be pruned"
601 );
602 }
603 }
604
605 #[test]
606 fn test_min_confidence_filter() {
607 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![1]];
608
609 let mut apriori = Apriori::new()
610 .with_min_support(0.25)
611 .with_min_confidence(0.8); apriori.fit(&transactions);
614 let rules = apriori.get_rules();
615
616 for rule in &rules {
618 assert!(
619 rule.confidence >= 0.8,
620 "Rule {:?} => {:?} has confidence {:.2} < 0.8",
621 rule.antecedent,
622 rule.consequent,
623 rule.confidence
624 );
625 }
626 }
627
628 #[test]
629 fn test_empty_transactions() {
630 let transactions: Vec<Vec<usize>> = vec![];
631
632 let mut apriori = Apriori::new();
633 apriori.fit(&transactions);
634
635 let itemsets = apriori.get_frequent_itemsets();
636 assert_eq!(itemsets.len(), 0, "Should have no frequent itemsets");
637
638 let rules = apriori.get_rules();
639 assert_eq!(rules.len(), 0, "Should have no rules");
640 }
641
642 #[test]
643 fn test_single_item_transactions() {
644 let transactions = vec![vec![1], vec![2], vec![3], vec![4]];
645
646 let mut apriori = Apriori::new().with_min_support(0.25);
647 apriori.fit(&transactions);
648
649 let itemsets = apriori.get_frequent_itemsets();
650
651 assert_eq!(itemsets.len(), 4);
654
655 for (itemset, _) in itemsets {
657 assert_eq!(itemset.len(), 1);
658 }
659
660 let rules = apriori.get_rules();
661 assert_eq!(rules.len(), 0);
663 }
664
665 #[test]
666 fn test_get_rules_before_fit() {
667 let apriori = Apriori::new();
668 let rules = apriori.get_rules();
669 assert_eq!(rules.len(), 0, "Should have no rules before fit");
670 }
671
672 #[test]
673 fn test_get_itemsets_before_fit() {
674 let apriori = Apriori::new();
675 let itemsets = apriori.get_frequent_itemsets();
676 assert_eq!(itemsets.len(), 0, "Should have no itemsets before fit");
677 }
678}