use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq)]
pub struct AssociationRule {
pub antecedent: Vec<usize>,
pub consequent: Vec<usize>,
pub support: f64,
pub confidence: f64,
pub lift: f64,
}
#[derive(Debug, Clone)]
pub struct Apriori {
min_support: f64,
min_confidence: f64,
frequent_itemsets: Vec<(HashSet<usize>, f64)>, rules: Vec<AssociationRule>,
}
impl Apriori {
#[must_use]
pub fn new() -> Self {
Self {
min_support: 0.1,
min_confidence: 0.5,
frequent_itemsets: Vec::new(),
rules: Vec::new(),
}
}
#[must_use]
pub fn with_min_support(mut self, min_support: f64) -> Self {
self.min_support = min_support;
self
}
#[must_use]
pub fn with_min_confidence(mut self, min_confidence: f64) -> Self {
self.min_confidence = min_confidence;
self
}
fn find_frequent_1_itemsets(&self, transactions: &[Vec<usize>]) -> Vec<(HashSet<usize>, f64)> {
use std::collections::HashMap;
let mut item_counts: HashMap<usize, usize> = HashMap::new();
for transaction in transactions {
for &item in transaction {
*item_counts.entry(item).or_insert(0) += 1;
}
}
let n_transactions = transactions.len() as f64;
let mut frequent_1_itemsets = Vec::new();
for (item, count) in item_counts {
let support = count as f64 / n_transactions;
if support >= self.min_support {
let mut itemset = HashSet::new();
itemset.insert(item);
frequent_1_itemsets.push((itemset, support));
}
}
frequent_1_itemsets
}
fn generate_candidates(&self, prev_itemsets: &[(HashSet<usize>, f64)]) -> Vec<HashSet<usize>> {
let mut candidates = Vec::new();
for i in 0..prev_itemsets.len() {
for j in (i + 1)..prev_itemsets.len() {
let set1 = &prev_itemsets[i].0;
let set2 = &prev_itemsets[j].0;
let union: HashSet<usize> = set1.union(set2).copied().collect();
if union.len() == set1.len() + 1 {
if self.has_infrequent_subset(&union, prev_itemsets) {
continue;
}
if !candidates.contains(&union) {
candidates.push(union);
}
}
}
}
candidates
}
#[allow(clippy::unused_self)]
fn has_infrequent_subset(
&self,
itemset: &HashSet<usize>,
prev_itemsets: &[(HashSet<usize>, f64)],
) -> bool {
for &item in itemset {
let mut subset = itemset.clone();
subset.remove(&item);
let is_frequent = prev_itemsets
.iter()
.any(|(freq_set, _)| freq_set == &subset);
if !is_frequent {
return true; }
}
false }
fn prune_candidates(
&self,
candidates: Vec<HashSet<usize>>,
transactions: &[Vec<usize>],
) -> Vec<(HashSet<usize>, f64)> {
let mut frequent = Vec::new();
for candidate in candidates {
let support = Self::calculate_support(&candidate, transactions);
if support >= self.min_support {
frequent.push((candidate, support));
}
}
frequent
}
fn generate_rules(&mut self, transactions: &[Vec<usize>]) {
let mut rules = Vec::new();
for (itemset, itemset_support) in &self.frequent_itemsets {
if itemset.len() < 2 {
continue;
}
let items: Vec<usize> = itemset.iter().copied().collect();
let subsets = self.generate_subsets(&items);
for antecedent_items in subsets {
if antecedent_items.is_empty() || antecedent_items.len() == items.len() {
continue; }
let antecedent_set: HashSet<usize> = antecedent_items.iter().copied().collect();
let consequent_set: HashSet<usize> =
itemset.difference(&antecedent_set).copied().collect();
let antecedent_support = Self::calculate_support(&antecedent_set, transactions);
let confidence = itemset_support / antecedent_support;
if confidence >= self.min_confidence {
let consequent_support = Self::calculate_support(&consequent_set, transactions);
let lift = confidence / consequent_support;
let rule = AssociationRule {
antecedent: antecedent_items,
consequent: consequent_set.into_iter().collect(),
support: *itemset_support,
confidence,
lift,
};
rules.push(rule);
}
}
}
self.rules = rules;
}
#[allow(clippy::unused_self)]
fn generate_subsets(&self, items: &[usize]) -> Vec<Vec<usize>> {
let mut subsets = Vec::new();
let n = items.len();
for mask in 1..(1 << n) {
let mut subset = Vec::new();
for (i, &item) in items.iter().enumerate() {
if (mask & (1 << i)) != 0 {
subset.push(item);
}
}
subsets.push(subset);
}
subsets
}
pub fn fit(&mut self, transactions: &[Vec<usize>]) {
if transactions.is_empty() {
self.frequent_itemsets = Vec::new();
self.rules = Vec::new();
return;
}
self.frequent_itemsets = Vec::new();
let mut current_itemsets = self.find_frequent_1_itemsets(transactions);
loop {
if current_itemsets.is_empty() {
break;
}
self.frequent_itemsets.extend(current_itemsets.clone());
let candidates = self.generate_candidates(¤t_itemsets);
if candidates.is_empty() {
break;
}
current_itemsets = self.prune_candidates(candidates, transactions);
}
self.generate_rules(transactions);
self.frequent_itemsets.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.expect("Support values must be valid f64 (not NaN)")
});
self.rules.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.expect("Confidence values must be valid f64 (not NaN)")
});
}
#[must_use]
pub fn get_frequent_itemsets(&self) -> Vec<(Vec<usize>, f64)> {
self.frequent_itemsets
.iter()
.map(|(itemset, support)| (itemset.iter().copied().collect(), *support))
.collect()
}
#[must_use]
pub fn get_rules(&self) -> Vec<AssociationRule> {
self.rules.clone()
}
#[must_use]
pub fn calculate_support(itemset: &HashSet<usize>, transactions: &[Vec<usize>]) -> f64 {
if transactions.is_empty() {
return 0.0;
}
let mut count = 0;
for transaction in transactions {
if itemset.iter().all(|item| transaction.contains(item)) {
count += 1;
}
}
f64::from(count) / transactions.len() as f64
}
}
impl Default for Apriori {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[path = "mining_tests.rs"]
mod tests;