use std::collections::HashMap;
use super::{SyntheticConfig, SyntheticGenerator};
use crate::error::Result;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LabelVote {
Positive,
Negative,
Class(i32),
Abstain,
}
impl LabelVote {
#[must_use]
pub fn is_abstain(&self) -> bool {
matches!(self, Self::Abstain)
}
#[must_use]
pub fn to_label(&self) -> Option<i32> {
match self {
Self::Positive => Some(1),
Self::Negative => Some(0),
Self::Class(c) => Some(*c),
Self::Abstain => None,
}
}
}
pub trait LabelingFunction<T>: Send + Sync {
fn name(&self) -> &str;
fn apply(&self, sample: &T) -> LabelVote;
fn weight(&self) -> f32 {
1.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum AggregationStrategy {
#[default]
MajorityVote,
WeightedVote,
Unanimous,
Any,
}
#[derive(Debug, Clone)]
pub struct LabeledSample<T> {
pub sample: T,
pub label: i32,
pub confidence: f32,
pub num_votes: usize,
pub votes: Vec<(String, LabelVote)>,
}
impl<T> LabeledSample<T> {
pub fn new(sample: T, label: i32, confidence: f32) -> Self {
Self {
sample,
label,
confidence,
num_votes: 0,
votes: Vec::new(),
}
}
pub fn with_votes(mut self, num_votes: usize, votes: Vec<(String, LabelVote)>) -> Self {
self.num_votes = num_votes;
self.votes = votes;
self
}
}
#[derive(Debug, Clone)]
pub struct WeakSupervisionConfig {
pub aggregation: AggregationStrategy,
pub min_confidence: f32,
pub min_votes: usize,
pub include_abstained: bool,
pub default_label: i32,
}
impl Default for WeakSupervisionConfig {
fn default() -> Self {
Self {
aggregation: AggregationStrategy::MajorityVote,
min_confidence: 0.5,
min_votes: 1,
include_abstained: false,
default_label: 0,
}
}
}
impl WeakSupervisionConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_aggregation(mut self, strategy: AggregationStrategy) -> Self {
self.aggregation = strategy;
self
}
#[must_use]
pub fn with_min_confidence(mut self, confidence: f32) -> Self {
self.min_confidence = confidence.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_min_votes(mut self, votes: usize) -> Self {
self.min_votes = votes.max(1);
self
}
#[must_use]
pub fn with_include_abstained(mut self, include: bool, default_label: i32) -> Self {
self.include_abstained = include;
self.default_label = default_label;
self
}
}
pub struct WeakSupervisionGenerator<T> {
labeling_functions: Vec<Box<dyn LabelingFunction<T>>>,
config: WeakSupervisionConfig,
}
impl<T> std::fmt::Debug for WeakSupervisionGenerator<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WeakSupervisionGenerator")
.field("num_lfs", &self.labeling_functions.len())
.field("config", &self.config)
.finish()
}
}
impl<T> WeakSupervisionGenerator<T> {
#[must_use]
pub fn new() -> Self {
Self {
labeling_functions: Vec::new(),
config: WeakSupervisionConfig::default(),
}
}
#[must_use]
pub fn with_config(mut self, config: WeakSupervisionConfig) -> Self {
self.config = config;
self
}
pub fn add_lf(&mut self, lf: Box<dyn LabelingFunction<T>>) {
self.labeling_functions.push(lf);
}
#[must_use]
pub fn num_lfs(&self) -> usize {
self.labeling_functions.len()
}
fn collect_votes(&self, sample: &T) -> Vec<(String, LabelVote, f32)> {
self.labeling_functions
.iter()
.map(|lf| (lf.name().to_string(), lf.apply(sample), lf.weight()))
.collect()
}
fn aggregate_votes(&self, votes: &[(String, LabelVote, f32)]) -> Option<(i32, f32)> {
let non_abstain: Vec<_> = votes.iter().filter(|(_, v, _)| !v.is_abstain()).collect();
if non_abstain.len() < self.config.min_votes {
return if self.config.include_abstained {
Some((self.config.default_label, 0.0))
} else {
None
};
}
match self.config.aggregation {
AggregationStrategy::MajorityVote => self.majority_vote(&non_abstain),
AggregationStrategy::WeightedVote => self.weighted_vote(&non_abstain),
AggregationStrategy::Unanimous => Self::unanimous_vote(&non_abstain),
AggregationStrategy::Any => Self::any_vote(&non_abstain),
}
}
fn majority_vote(&self, votes: &[&(String, LabelVote, f32)]) -> Option<(i32, f32)> {
if votes.is_empty() {
return None;
}
let mut counts: HashMap<i32, usize> = HashMap::new();
for (_, vote, _) in votes {
if let Some(label) = vote.to_label() {
*counts.entry(label).or_insert(0) += 1;
}
}
let total = votes.len();
let (label, count) = counts.into_iter().max_by_key(|(_, c)| *c)?;
let confidence = count as f32 / total as f32;
if confidence >= self.config.min_confidence {
Some((label, confidence))
} else {
None
}
}
fn weighted_vote(&self, votes: &[&(String, LabelVote, f32)]) -> Option<(i32, f32)> {
if votes.is_empty() {
return None;
}
let mut weights: HashMap<i32, f32> = HashMap::new();
let mut total_weight = 0.0;
for (_, vote, weight) in votes {
if let Some(label) = vote.to_label() {
*weights.entry(label).or_insert(0.0) += weight;
total_weight += weight;
}
}
if total_weight < f32::EPSILON {
return None;
}
let (label, weight) = weights
.into_iter()
.max_by(|(_, w1), (_, w2)| w1.partial_cmp(w2).unwrap_or(std::cmp::Ordering::Equal))?;
let confidence = weight / total_weight;
if confidence >= self.config.min_confidence {
Some((label, confidence))
} else {
None
}
}
fn unanimous_vote(votes: &[&(String, LabelVote, f32)]) -> Option<(i32, f32)> {
if votes.is_empty() {
return None;
}
let first_label = votes[0].1.to_label()?;
let unanimous = votes
.iter()
.all(|(_, v, _)| v.to_label() == Some(first_label));
if unanimous {
Some((first_label, 1.0))
} else {
None
}
}
fn any_vote(votes: &[&(String, LabelVote, f32)]) -> Option<(i32, f32)> {
for (_, vote, _) in votes {
if let Some(label) = vote.to_label() {
return Some((label, 1.0 / votes.len() as f32));
}
}
None
}
}
impl<T> Default for WeakSupervisionGenerator<T> {
fn default() -> Self {
Self::new()
}
}
include!("labeling.rs");
include!("weak_supervision_tests.rs");