1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
use super::bitpacked::BloomFilter;
use smartcore::linalg::basic::matrix::DenseMatrix;
use smartcore::tree::decision_tree_classifier::DecisionTreeClassifier;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
// Type alias for the decision tree model
type Model = DecisionTreeClassifier<f64, u32, DenseMatrix<f64>, Vec<u32>>;
/// Learned Bloom Filter
///
/// Uses a machine learning model to predict set membership, with a backup
/// traditional bloom filter for uncertain cases.
///
/// Space savings: Model + small backup filter vs large traditional filter
pub struct LearnedBloomFilter {
/// Trained decision tree model
model: Option<Model>,
/// Backup bloom filter for uncertain predictions
backup_filter: BloomFilter,
/// Confidence threshold (predictions above this are trusted)
threshold: f64,
/// Number of elements
count: usize,
/// Feature dimension (number of hash features extracted from keys)
feature_dim: usize,
}
impl LearnedBloomFilter {
/// Create a new learned bloom filter
///
/// # Arguments
/// * `expected_elements` - Expected number of elements
/// * `false_positive_rate` - Target false positive rate
/// * `threshold` - Confidence threshold (0.0-1.0, higher = trust model more)
#[must_use]
pub fn new(expected_elements: usize, false_positive_rate: f64, threshold: f64) -> Self {
// Backup filter is much smaller since model handles most queries
// Use higher FPR for backup since it's only for uncertain cases
let backup_fpr = false_positive_rate * 2.0;
let backup_elements = (expected_elements as f64 * 0.3) as usize; // 30% capacity
Self {
model: None,
backup_filter: BloomFilter::new(backup_elements, backup_fpr),
threshold,
count: 0,
feature_dim: 8, // Use 8 hash features per key
}
}
/// Train the model on positive and negative examples
///
/// # Arguments
/// * `positive_examples` - Keys that are in the set
/// * `negative_examples` - Keys that are NOT in the set
pub fn train<T: Hash>(&mut self, positive_examples: &[T], negative_examples: &[T]) {
let n_positive = positive_examples.len();
let n_negative = negative_examples.len();
let n_total = n_positive + n_negative;
if n_total == 0 {
return;
}
// Extract features
let mut features = Vec::with_capacity(n_total);
let mut labels = Vec::with_capacity(n_total);
// Positive examples (label = 1)
for key in positive_examples {
features.push(self.extract_features(key));
labels.push(1);
self.count += 1;
}
// Negative examples (label = 0)
for key in negative_examples {
features.push(self.extract_features(key));
labels.push(0);
}
// Convert to 2D matrix (row-major: each row is one sample)
let x = DenseMatrix::from_2d_vec(&features);
let y: Vec<u32> = labels;
// Train decision tree
let model = DecisionTreeClassifier::fit(&x, &y, Default::default())
.expect("Failed to train decision tree");
self.model = Some(model);
// Only add uncertain/misclassified positives to backup filter for space savings.
// The model handles confident predictions, backup catches uncertain cases.
for key in positive_examples {
let should_add_to_backup =
if let Some((prediction, confidence)) = self.predict_with_confidence(key) {
// Add to backup if model predicts "not in set" or has low confidence
!prediction || confidence < self.threshold
} else {
// No prediction available - add to be safe
true
};
if should_add_to_backup {
self.backup_filter.insert(key);
}
}
}
/// Check if an element might be in the set
#[inline]
pub fn contains<T: Hash>(&self, item: &T) -> bool {
// Check backup filter first for positive cases (guarantees no false negatives)
if self.backup_filter.contains(item) {
return true;
}
// If not in backup filter, use model for negative predictions
if let Some((prediction, confidence)) = self.predict_with_confidence(item) {
if confidence >= self.threshold {
// High confidence negative: trust the model
return prediction;
}
}
// Low confidence: assume not in set (conservative)
false
}
/// Predict if item is in set, with confidence score.
///
/// Returns `(prediction, confidence)` where prediction is true/false and confidence is 0.0-1.0.
///
/// # Confidence Model Limitation
///
/// Decision trees provide binary predictions without probability estimates.
/// We return a fixed confidence of 0.9 for all predictions. This is acceptable because:
///
/// 1. The backup bloom filter catches all false negatives (no correctness impact)
/// 2. Uncertain positives are added to backup during training (see `train()`)
/// 3. The confidence threshold primarily affects space optimization, not correctness
///
/// For true probability estimates, consider using random forests or gradient boosting.
fn predict_with_confidence<T: Hash>(&self, item: &T) -> Option<(bool, f64)> {
let model = self.model.as_ref()?;
let features = self.extract_features(item);
let x = DenseMatrix::from_2d_vec(&vec![features]);
// Predict (1 = in set, 0 = not in set)
let prediction = model.predict(&x).ok()?;
// Decision trees give binary predictions without probability estimates.
// We use a fixed high confidence; the backup filter ensures correctness.
// See doc comment above for rationale.
const FIXED_CONFIDENCE: f64 = 0.9;
if prediction[0] == 1 {
Some((true, FIXED_CONFIDENCE))
} else {
Some((false, FIXED_CONFIDENCE))
}
}
/// Predict confidence that item is in set (deprecated - use `predict_with_confidence`)
#[allow(dead_code)]
fn predict_confidence<T: Hash>(&self, item: &T) -> Option<f64> {
self.predict_with_confidence(item).map(|(_, conf)| conf)
}
/// Extract hash-based features from a key
fn extract_features<T: Hash>(&self, item: &T) -> Vec<f64> {
let mut features = Vec::with_capacity(self.feature_dim);
for i in 0..self.feature_dim {
let mut hasher = DefaultHasher::new();
i.hash(&mut hasher);
item.hash(&mut hasher);
let hash = hasher.finish();
// Normalize to 0.0-1.0 range
features.push((hash % 10000) as f64 / 10000.0);
}
features
}
/// Get the number of elements
#[must_use]
pub const fn len(&self) -> usize {
self.count
}
/// Check if empty
#[must_use]
pub const fn is_empty(&self) -> bool {
self.count == 0
}
/// Get size in bytes (for benchmarking)
#[must_use]
pub fn size_bytes(&self) -> usize {
let backup_size = self.backup_filter.size_bytes();
// Model size estimation (rough approximation)
// Decision tree size depends on depth and number of nodes
// For simplicity, estimate ~1KB for a small tree
let model_size = if self.model.is_some() { 1024 } else { 0 };
backup_size + model_size + std::mem::size_of::<Self>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_learned_bloom_filter() {
let mut lbf = LearnedBloomFilter::new(1000, 0.01, 0.5);
// Training data
let positive: Vec<String> = (0..100).map(|i| format!("key_{}", i)).collect();
let negative: Vec<String> = (1000..1100).map(|i| format!("key_{}", i)).collect();
lbf.train(&positive, &negative);
// Test membership
for key in &positive {
assert!(
lbf.contains(key),
"Positive example should be in set: {}",
key
);
}
// Test non-membership (may have false positives)
let mut false_positives = 0;
for key in &negative {
if lbf.contains(key) {
false_positives += 1;
}
}
println!("False positives: {}/{}", false_positives, negative.len());
assert!(
false_positives < 5,
"Too many false positives: {}",
false_positives
);
}
#[test]
fn test_size_comparison() {
// Traditional bloom filter
let bf = BloomFilter::new(1000, 0.01);
let bf_size = bf.size_bytes();
// Learned bloom filter
let mut lbf = LearnedBloomFilter::new(1000, 0.01, 0.7);
let positive: Vec<i32> = (0..1000).collect();
let negative: Vec<i32> = (10000..11000).collect();
lbf.train(&positive, &negative);
let lbf_size = lbf.size_bytes();
println!("Traditional Bloom Filter: {} bytes", bf_size);
println!("Learned Bloom Filter: {} bytes", lbf_size);
let reduction = (1.0 - lbf_size as f64 / bf_size as f64) * 100.0;
println!("Space reduction: {:.1}%", reduction);
// Learned bloom filter should be smaller (but this is a rough test)
// In practice, savings depend on data distribution and model complexity
}
}