1use std::collections::HashSet;
39
40#[derive(Debug, Clone, PartialEq)]
42pub struct AssociationRule {
43 pub antecedent: Vec<usize>,
45 pub consequent: Vec<usize>,
47 pub support: f64,
49 pub confidence: f64,
51 pub lift: f64,
53}
54
55#[derive(Debug, Clone)]
94pub struct Apriori {
95 min_support: f64,
96 min_confidence: f64,
97 frequent_itemsets: Vec<(HashSet<usize>, f64)>, rules: Vec<AssociationRule>,
99}
100
101impl Apriori {
102 pub fn new() -> Self {
109 Self {
110 min_support: 0.1,
111 min_confidence: 0.5,
112 frequent_itemsets: Vec::new(),
113 rules: Vec::new(),
114 }
115 }
116
117 pub fn with_min_support(mut self, min_support: f64) -> Self {
123 self.min_support = min_support;
124 self
125 }
126
127 pub fn with_min_confidence(mut self, min_confidence: f64) -> Self {
133 self.min_confidence = min_confidence;
134 self
135 }
136
137 fn find_frequent_1_itemsets(&self, transactions: &[Vec<usize>]) -> Vec<(HashSet<usize>, f64)> {
139 use std::collections::HashMap;
140 let mut item_counts: HashMap<usize, usize> = HashMap::new();
141
142 for transaction in transactions {
144 for &item in transaction {
145 *item_counts.entry(item).or_insert(0) += 1;
146 }
147 }
148
149 let n_transactions = transactions.len() as f64;
151 let mut frequent_1_itemsets = Vec::new();
152
153 for (item, count) in item_counts {
154 let support = count as f64 / n_transactions;
155 if support >= self.min_support {
156 let mut itemset = HashSet::new();
157 itemset.insert(item);
158 frequent_1_itemsets.push((itemset, support));
159 }
160 }
161
162 frequent_1_itemsets
163 }
164
165 fn generate_candidates(&self, prev_itemsets: &[(HashSet<usize>, f64)]) -> Vec<HashSet<usize>> {
167 let mut candidates = Vec::new();
168
169 for i in 0..prev_itemsets.len() {
171 for j in (i + 1)..prev_itemsets.len() {
172 let set1 = &prev_itemsets[i].0;
173 let set2 = &prev_itemsets[j].0;
174
175 let union: HashSet<usize> = set1.union(set2).copied().collect();
177
178 if union.len() == set1.len() + 1 {
180 if self.has_infrequent_subset(&union, prev_itemsets) {
182 continue;
183 }
184
185 if !candidates.contains(&union) {
187 candidates.push(union);
188 }
189 }
190 }
191 }
192
193 candidates
194 }
195
196 #[allow(clippy::unused_self)]
198 fn has_infrequent_subset(
199 &self,
200 itemset: &HashSet<usize>,
201 prev_itemsets: &[(HashSet<usize>, f64)],
202 ) -> bool {
203 for &item in itemset {
205 let mut subset = itemset.clone();
206 subset.remove(&item);
207
208 let is_frequent = prev_itemsets
210 .iter()
211 .any(|(freq_set, _)| freq_set == &subset);
212
213 if !is_frequent {
214 return true; }
216 }
217
218 false }
220
221 fn prune_candidates(
223 &self,
224 candidates: Vec<HashSet<usize>>,
225 transactions: &[Vec<usize>],
226 ) -> Vec<(HashSet<usize>, f64)> {
227 let mut frequent = Vec::new();
228
229 for candidate in candidates {
230 let support = Self::calculate_support(&candidate, transactions);
231 if support >= self.min_support {
232 frequent.push((candidate, support));
233 }
234 }
235
236 frequent
237 }
238
239 fn generate_rules(&mut self, transactions: &[Vec<usize>]) {
241 let mut rules = Vec::new();
242
243 for (itemset, itemset_support) in &self.frequent_itemsets {
245 if itemset.len() < 2 {
246 continue;
247 }
248
249 let items: Vec<usize> = itemset.iter().copied().collect();
251 let subsets = self.generate_subsets(&items);
252
253 for antecedent_items in subsets {
254 if antecedent_items.is_empty() || antecedent_items.len() == items.len() {
255 continue; }
257
258 let antecedent_set: HashSet<usize> = antecedent_items.iter().copied().collect();
260 let consequent_set: HashSet<usize> =
261 itemset.difference(&antecedent_set).copied().collect();
262
263 let antecedent_support = Self::calculate_support(&antecedent_set, transactions);
265 let confidence = itemset_support / antecedent_support;
266
267 if confidence >= self.min_confidence {
268 let consequent_support = Self::calculate_support(&consequent_set, transactions);
270 let lift = confidence / consequent_support;
271
272 let rule = AssociationRule {
273 antecedent: antecedent_items,
274 consequent: consequent_set.into_iter().collect(),
275 support: *itemset_support,
276 confidence,
277 lift,
278 };
279
280 rules.push(rule);
281 }
282 }
283 }
284
285 self.rules = rules;
286 }
287
288 #[allow(clippy::unused_self)]
290 fn generate_subsets(&self, items: &[usize]) -> Vec<Vec<usize>> {
291 let mut subsets = Vec::new();
292 let n = items.len();
293
294 for mask in 1..(1 << n) {
296 let mut subset = Vec::new();
297 for (i, &item) in items.iter().enumerate() {
298 if (mask & (1 << i)) != 0 {
299 subset.push(item);
300 }
301 }
302 subsets.push(subset);
303 }
304
305 subsets
306 }
307
308 pub fn fit(&mut self, transactions: &[Vec<usize>]) {
314 if transactions.is_empty() {
315 self.frequent_itemsets = Vec::new();
316 self.rules = Vec::new();
317 return;
318 }
319
320 self.frequent_itemsets = Vec::new();
321
322 let mut current_itemsets = self.find_frequent_1_itemsets(transactions);
324
325 loop {
327 if current_itemsets.is_empty() {
328 break;
329 }
330
331 self.frequent_itemsets.extend(current_itemsets.clone());
333
334 let candidates = self.generate_candidates(¤t_itemsets);
336 if candidates.is_empty() {
337 break;
338 }
339
340 current_itemsets = self.prune_candidates(candidates, transactions);
342 }
343
344 self.generate_rules(transactions);
346
347 self.frequent_itemsets.sort_by(|a, b| {
349 b.1.partial_cmp(&a.1)
350 .expect("Support values must be valid f64 (not NaN)")
351 });
352
353 self.rules.sort_by(|a, b| {
355 b.confidence
356 .partial_cmp(&a.confidence)
357 .expect("Confidence values must be valid f64 (not NaN)")
358 });
359 }
360
361 pub fn get_frequent_itemsets(&self) -> Vec<(Vec<usize>, f64)> {
365 self.frequent_itemsets
366 .iter()
367 .map(|(itemset, support)| (itemset.iter().copied().collect(), *support))
368 .collect()
369 }
370
371 pub fn get_rules(&self) -> Vec<AssociationRule> {
375 self.rules.clone()
376 }
377
378 pub fn calculate_support(itemset: &HashSet<usize>, transactions: &[Vec<usize>]) -> f64 {
389 if transactions.is_empty() {
390 return 0.0;
391 }
392
393 let mut count = 0;
394
395 for transaction in transactions {
396 if itemset.iter().all(|item| transaction.contains(item)) {
398 count += 1;
399 }
400 }
401
402 f64::from(count) / transactions.len() as f64
403 }
404}
405
406impl Default for Apriori {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_apriori_new() {
418 let apriori = Apriori::new();
419 assert_eq!(apriori.min_support, 0.1);
420 assert_eq!(apriori.min_confidence, 0.5);
421 assert_eq!(apriori.frequent_itemsets.len(), 0);
422 assert_eq!(apriori.rules.len(), 0);
423 }
424
425 #[test]
426 fn test_apriori_with_min_support() {
427 let apriori = Apriori::new().with_min_support(0.3);
428 assert_eq!(apriori.min_support, 0.3);
429 }
430
431 #[test]
432 fn test_apriori_with_min_confidence() {
433 let apriori = Apriori::new().with_min_confidence(0.7);
434 assert_eq!(apriori.min_confidence, 0.7);
435 }
436
437 #[test]
438 fn test_apriori_fit_basic() {
439 let transactions = vec![
441 vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3], vec![1, 2, 3, 4], ];
447
448 let mut apriori = Apriori::new()
449 .with_min_support(0.4) .with_min_confidence(0.6); apriori.fit(&transactions);
453
454 assert!(!apriori.frequent_itemsets.is_empty());
456 }
457
458 #[test]
459 fn test_frequent_itemsets() {
460 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
461
462 let mut apriori = Apriori::new().with_min_support(0.5); apriori.fit(&transactions);
464
465 let itemsets = apriori.get_frequent_itemsets();
466
467 assert!(itemsets.len() >= 6);
477
478 for i in 1..itemsets.len() {
480 assert!(itemsets[i - 1].1 >= itemsets[i].1);
481 }
482 }
483
484 #[test]
485 fn test_association_rules() {
486 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
487
488 let mut apriori = Apriori::new()
489 .with_min_support(0.5)
490 .with_min_confidence(0.6);
491
492 apriori.fit(&transactions);
493 let rules = apriori.get_rules();
494
495 assert!(!rules.is_empty());
497
498 for rule in &rules {
500 assert!(rule.confidence >= 0.6);
501 }
502
503 for i in 1..rules.len() {
505 assert!(rules[i - 1].confidence >= rules[i].confidence);
506 }
507 }
508
509 #[test]
510 fn test_support_calculation() {
511 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
512
513 let itemset: HashSet<usize> = vec![1, 2].into_iter().collect();
515 let support = Apriori::calculate_support(&itemset, &transactions);
516
517 assert!((support - 0.5).abs() < 1e-10);
519
520 let itemset: HashSet<usize> = vec![1].into_iter().collect();
522 let support = Apriori::calculate_support(&itemset, &transactions);
523
524 assert!((support - 0.75).abs() < 1e-10);
526 }
527
528 #[test]
529 fn test_confidence_calculation() {
530 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
531
532 let mut apriori = Apriori::new()
533 .with_min_support(0.5)
534 .with_min_confidence(0.0); apriori.fit(&transactions);
537 let rules = apriori.get_rules();
538
539 let rule = rules
541 .iter()
542 .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
543 .expect("Should have rule {1} => {2}");
544
545 assert!((rule.confidence - 0.6666666).abs() < 1e-5);
547 }
548
549 #[test]
550 fn test_lift_calculation() {
551 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![2, 3]];
552
553 let mut apriori = Apriori::new()
554 .with_min_support(0.5)
555 .with_min_confidence(0.0);
556
557 apriori.fit(&transactions);
558 let rules = apriori.get_rules();
559
560 let rule = rules
562 .iter()
563 .find(|r| r.antecedent == vec![1] && r.consequent == vec![2])
564 .expect("Should have rule {1} => {2}");
565
566 assert!((rule.lift - 0.8888888).abs() < 1e-5);
568
569 }
573
574 #[test]
575 fn test_min_support_filter() {
576 let transactions = vec![
577 vec![1, 2],
578 vec![1, 2],
579 vec![1, 2],
580 vec![3, 4], ];
582
583 let mut apriori = Apriori::new().with_min_support(0.5);
584 apriori.fit(&transactions);
585
586 let itemsets = apriori.get_frequent_itemsets();
587
588 for (itemset, support) in itemsets {
591 assert!(support >= 0.5, "All itemsets should meet min_support");
592 assert!(
593 !itemset.contains(&3) && !itemset.contains(&4),
594 "Infrequent items should be pruned"
595 );
596 }
597 }
598
599 #[test]
600 fn test_min_confidence_filter() {
601 let transactions = vec![vec![1, 2, 3], vec![1, 2], vec![1, 3], vec![1]];
602
603 let mut apriori = Apriori::new()
604 .with_min_support(0.25)
605 .with_min_confidence(0.8); apriori.fit(&transactions);
608 let rules = apriori.get_rules();
609
610 for rule in &rules {
612 assert!(
613 rule.confidence >= 0.8,
614 "Rule {:?} => {:?} has confidence {:.2} < 0.8",
615 rule.antecedent,
616 rule.consequent,
617 rule.confidence
618 );
619 }
620 }
621
622 #[test]
623 fn test_empty_transactions() {
624 let transactions: Vec<Vec<usize>> = vec![];
625
626 let mut apriori = Apriori::new();
627 apriori.fit(&transactions);
628
629 let itemsets = apriori.get_frequent_itemsets();
630 assert_eq!(itemsets.len(), 0, "Should have no frequent itemsets");
631
632 let rules = apriori.get_rules();
633 assert_eq!(rules.len(), 0, "Should have no rules");
634 }
635
636 #[test]
637 fn test_single_item_transactions() {
638 let transactions = vec![vec![1], vec![2], vec![3], vec![4]];
639
640 let mut apriori = Apriori::new().with_min_support(0.25);
641 apriori.fit(&transactions);
642
643 let itemsets = apriori.get_frequent_itemsets();
644
645 assert_eq!(itemsets.len(), 4);
648
649 for (itemset, _) in itemsets {
651 assert_eq!(itemset.len(), 1);
652 }
653
654 let rules = apriori.get_rules();
655 assert_eq!(rules.len(), 0);
657 }
658
659 #[test]
660 fn test_get_rules_before_fit() {
661 let apriori = Apriori::new();
662 let rules = apriori.get_rules();
663 assert_eq!(rules.len(), 0, "Should have no rules before fit");
664 }
665
666 #[test]
667 fn test_get_itemsets_before_fit() {
668 let apriori = Apriori::new();
669 let itemsets = apriori.get_frequent_itemsets();
670 assert_eq!(itemsets.len(), 0, "Should have no itemsets before fit");
671 }
672}