use crate::common::time_compat::{SystemTime, UNIX_EPOCH};
use rustc_hash::{FxHashMap, FxHasher};
use std::hash::{Hash, Hasher};
use std::sync::RwLock;
use crate::parser::ast::Expression;
pub const DEFAULT_DECAY_FACTOR: f64 = 0.3;
pub const MIN_SAMPLE_COUNT: u64 = 2;
pub const MAX_CORRECTION_FACTOR: f64 = 100.0;
pub const MIN_CORRECTION_FACTOR: f64 = 0.01;
#[derive(Debug, Clone)]
pub struct CardinalityFeedback {
pub predicate_hash: u64,
pub table_name: String,
pub column_name: Option<String>,
pub estimated_rows: u64,
pub actual_rows: u64,
pub correction_factor: f64,
pub sample_count: u64,
pub last_updated: i64,
}
impl CardinalityFeedback {
pub fn new(
predicate_hash: u64,
table_name: impl Into<String>,
column_name: Option<String>,
estimated_rows: u64,
actual_rows: u64,
) -> Self {
let correction = if estimated_rows > 0 {
(actual_rows as f64 / estimated_rows as f64)
.clamp(MIN_CORRECTION_FACTOR, MAX_CORRECTION_FACTOR)
} else {
1.0
};
Self {
predicate_hash,
table_name: table_name.into(),
column_name,
estimated_rows,
actual_rows,
correction_factor: correction,
sample_count: 1,
last_updated: get_current_timestamp(),
}
}
pub fn update(&mut self, estimated_rows: u64, actual_rows: u64, decay_factor: f64) {
let new_correction = if estimated_rows > 0 {
(actual_rows as f64 / estimated_rows as f64)
.clamp(MIN_CORRECTION_FACTOR, MAX_CORRECTION_FACTOR)
} else {
1.0
};
self.correction_factor =
decay_factor * new_correction + (1.0 - decay_factor) * self.correction_factor;
self.correction_factor = self
.correction_factor
.clamp(MIN_CORRECTION_FACTOR, MAX_CORRECTION_FACTOR);
self.estimated_rows = estimated_rows;
self.actual_rows = actual_rows;
self.sample_count += 1;
self.last_updated = get_current_timestamp();
}
pub fn is_reliable(&self) -> bool {
self.sample_count >= MIN_SAMPLE_COUNT
}
pub fn apply_correction(&self, base_estimate: u64) -> u64 {
if !self.is_reliable() {
return base_estimate;
}
((base_estimate as f64 * self.correction_factor).round() as u64).max(1)
}
}
#[derive(Debug)]
pub struct FeedbackCache {
entries: RwLock<FxHashMap<(String, u64), CardinalityFeedback>>,
decay_factor: f64,
max_entries: usize,
}
impl Default for FeedbackCache {
fn default() -> Self {
Self::new()
}
}
impl FeedbackCache {
pub fn new() -> Self {
Self {
entries: RwLock::new(FxHashMap::default()),
decay_factor: DEFAULT_DECAY_FACTOR,
max_entries: 10000,
}
}
pub fn with_settings(decay_factor: f64, max_entries: usize) -> Self {
Self {
entries: RwLock::new(FxHashMap::default()),
decay_factor,
max_entries,
}
}
pub fn record_feedback(
&self,
table_name: &str,
predicate_hash: u64,
column_name: Option<String>,
estimated_rows: u64,
actual_rows: u64,
) {
let key = (table_name.to_string(), predicate_hash);
let mut entries = self.entries.write().unwrap();
if let Some(existing) = entries.get_mut(&key) {
existing.update(estimated_rows, actual_rows, self.decay_factor);
} else {
if entries.len() >= self.max_entries {
self.evict_oldest(&mut entries);
}
let feedback = CardinalityFeedback::new(
predicate_hash,
table_name,
column_name,
estimated_rows,
actual_rows,
);
entries.insert(key, feedback);
}
}
pub fn lookup(&self, table_name: &str, predicate_hash: u64) -> Option<CardinalityFeedback> {
let key = (table_name.to_string(), predicate_hash);
let entries = self.entries.read().unwrap();
entries.get(&key).cloned()
}
pub fn get_correction(&self, table_name: &str, predicate_hash: u64) -> f64 {
match self.lookup(table_name, predicate_hash) {
Some(feedback) if feedback.is_reliable() => feedback.correction_factor,
_ => 1.0,
}
}
pub fn apply_correction(&self, table_name: &str, predicate_hash: u64, estimate: u64) -> u64 {
let correction = self.get_correction(table_name, predicate_hash);
((estimate as f64 * correction).round() as u64).max(1)
}
pub fn clear(&self) {
self.entries.write().unwrap().clear();
}
pub fn len(&self) -> usize {
self.entries.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.entries.read().unwrap().is_empty()
}
fn evict_oldest(&self, entries: &mut FxHashMap<(String, u64), CardinalityFeedback>) {
let evict_count = self.max_entries / 10;
let mut timestamps: Vec<_> = entries
.iter()
.map(|(k, v)| (k.clone(), v.last_updated))
.collect();
timestamps.sort_by_key(|(_, ts)| *ts);
for (key, _) in timestamps.into_iter().take(evict_count) {
entries.remove(&key);
}
}
pub fn get_table_feedback(&self, table_name: &str) -> Vec<CardinalityFeedback> {
let entries = self.entries.read().unwrap();
entries
.iter()
.filter(|((name, _), _)| name == table_name)
.map(|(_, fb)| fb.clone())
.collect()
}
}
pub fn fingerprint_predicate(table_name: &str, expr: &Expression) -> u64 {
let mut hasher = FxHasher::default();
table_name.hash(&mut hasher);
hash_expression_structure(expr, &mut hasher);
hasher.finish()
}
fn hash_expression_structure(expr: &Expression, hasher: &mut FxHasher) {
std::mem::discriminant(expr).hash(hasher);
match expr {
Expression::Identifier(id) => {
id.value.hash(hasher);
}
Expression::QualifiedIdentifier(qid) => {
qid.qualifier.value.hash(hasher);
qid.name.value.hash(hasher);
}
Expression::Infix(infix) => {
infix.op_type.hash(hasher);
hash_expression_structure(&infix.left, hasher);
hash_expression_structure(&infix.right, hasher);
}
Expression::Prefix(prefix) => {
prefix.op_type.hash(hasher);
hash_expression_structure(&prefix.right, hasher);
}
Expression::Between(between) => {
"BETWEEN".hash(hasher);
between.not.hash(hasher);
hash_expression_structure(&between.expr, hasher);
hash_expression_structure(&between.lower, hasher);
hash_expression_structure(&between.upper, hasher);
}
Expression::In(in_expr) => {
"IN".hash(hasher);
in_expr.not.hash(hasher);
hash_expression_structure(&in_expr.left, hasher);
hash_expression_structure(&in_expr.right, hasher);
}
Expression::Like(like) => {
"LIKE".hash(hasher);
like.operator.hash(hasher);
hash_expression_structure(&like.left, hasher);
}
Expression::FunctionCall(func) => {
"FUNCTION".hash(hasher);
func.function.hash(hasher);
func.arguments.len().hash(hasher);
for arg in &func.arguments {
hash_expression_structure(arg, hasher);
}
}
Expression::Case(case) => {
"CASE".hash(hasher);
case.when_clauses.len().hash(hasher);
case.else_value.is_some().hash(hasher);
}
Expression::Cast(cast) => {
"CAST".hash(hasher);
cast.type_name.hash(hasher);
hash_expression_structure(&cast.expr, hasher);
}
Expression::ScalarSubquery(_) => {
"SUBQUERY".hash(hasher);
}
Expression::Exists(_) => {
"EXISTS".hash(hasher);
}
Expression::IntegerLiteral(_) => {
"INTEGER_LITERAL".hash(hasher);
}
Expression::FloatLiteral(_) => {
"FLOAT_LITERAL".hash(hasher);
}
Expression::StringLiteral(_) => {
"STRING_LITERAL".hash(hasher);
}
Expression::BooleanLiteral(_) => {
"BOOLEAN_LITERAL".hash(hasher);
}
Expression::NullLiteral(_) => {
"NULL_LITERAL".hash(hasher);
}
Expression::List(list) => {
"LIST".hash(hasher);
list.elements.len().hash(hasher);
}
Expression::Star(_) => {
"STAR".hash(hasher);
}
_ => {
"OTHER".hash(hasher);
}
}
}
pub fn extract_column_from_predicate(expr: &Expression) -> Option<String> {
match expr {
Expression::Infix(infix) => {
if let Expression::Identifier(id) = &*infix.left {
return Some(id.value.to_string());
}
if let Expression::QualifiedIdentifier(qid) = &*infix.left {
return Some(qid.name.value.to_string());
}
if let Expression::Identifier(id) = &*infix.right {
return Some(id.value.to_string());
}
if let Expression::QualifiedIdentifier(qid) = &*infix.right {
return Some(qid.name.value.to_string());
}
None
}
Expression::Between(between) => {
if let Expression::Identifier(id) = &*between.expr {
return Some(id.value.to_string());
}
if let Expression::QualifiedIdentifier(qid) = &*between.expr {
return Some(qid.name.value.to_string());
}
None
}
Expression::In(in_expr) => {
if let Expression::Identifier(id) = &*in_expr.left {
return Some(id.value.to_string());
}
if let Expression::QualifiedIdentifier(qid) = &*in_expr.left {
return Some(qid.name.value.to_string());
}
None
}
Expression::Like(like) => {
if let Expression::Identifier(id) = &*like.left {
return Some(id.value.to_string());
}
if let Expression::QualifiedIdentifier(qid) = &*like.left {
return Some(qid.name.value.to_string());
}
None
}
_ => None,
}
}
fn get_current_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as i64)
.unwrap_or(0)
}
static FEEDBACK_CACHE: std::sync::OnceLock<FeedbackCache> = std::sync::OnceLock::new();
pub fn global_feedback_cache() -> &'static FeedbackCache {
FEEDBACK_CACHE.get_or_init(FeedbackCache::new)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::ast::{InfixExpression, InfixOperator};
use crate::parser::{Identifier, IntegerLiteral, Position, Token, TokenType};
fn make_token(literal: &str) -> Token {
Token::new(TokenType::Identifier, literal, Position::new(0, 1, 1))
}
fn make_identifier(name: &str) -> Expression {
Expression::Identifier(Identifier::new(make_token(name), name.to_string()))
}
fn make_literal_int(val: i64) -> Expression {
Expression::IntegerLiteral(IntegerLiteral {
token: Token::new(TokenType::Integer, val.to_string(), Position::new(0, 1, 1)),
value: val,
})
}
fn make_equality(col: &str, val: i64) -> Expression {
Expression::Infix(InfixExpression {
token: Token::new(TokenType::Operator, "=", Position::new(0, 1, 1)),
left: Box::new(make_identifier(col)),
operator: "=".into(),
op_type: InfixOperator::Equal,
right: Box::new(make_literal_int(val)),
})
}
#[test]
fn test_feedback_entry_creation() {
let fb = CardinalityFeedback::new(12345, "users", None, 100, 1000);
assert_eq!(fb.correction_factor, 10.0);
assert_eq!(fb.sample_count, 1);
assert!(!fb.is_reliable()); }
#[test]
fn test_feedback_update_ema() {
let mut fb = CardinalityFeedback::new(12345, "users", None, 100, 1000);
fb.update(100, 100, DEFAULT_DECAY_FACTOR);
assert!((fb.correction_factor - 7.3).abs() < 0.001);
assert_eq!(fb.sample_count, 2);
assert!(fb.is_reliable());
}
#[test]
fn test_feedback_cache() {
let cache = FeedbackCache::new();
cache.record_feedback("users", 12345, Some("status".to_string()), 100, 1000);
assert_eq!(cache.get_correction("users", 12345), 1.0);
cache.record_feedback("users", 12345, Some("status".to_string()), 100, 1000);
let correction = cache.get_correction("users", 12345);
assert!(correction > 1.0);
}
#[test]
fn test_fingerprint_same_structure() {
let pred1 = make_equality("status", 1);
let pred2 = make_equality("status", 2);
let hash1 = fingerprint_predicate("users", &pred1);
let hash2 = fingerprint_predicate("users", &pred2);
assert_eq!(hash1, hash2);
}
#[test]
fn test_fingerprint_different_columns() {
let pred1 = make_equality("status", 1);
let pred2 = make_equality("role", 1);
let hash1 = fingerprint_predicate("users", &pred1);
let hash2 = fingerprint_predicate("users", &pred2);
assert_ne!(hash1, hash2);
}
#[test]
fn test_fingerprint_different_tables() {
let pred = make_equality("status", 1);
let hash1 = fingerprint_predicate("users", &pred);
let hash2 = fingerprint_predicate("orders", &pred);
assert_ne!(hash1, hash2);
}
#[test]
fn test_extract_column() {
let pred = make_equality("status", 1);
let col = extract_column_from_predicate(&pred);
assert_eq!(col, Some("status".to_string()));
}
#[test]
fn test_apply_correction() {
let cache = FeedbackCache::new();
cache.record_feedback("users", 12345, None, 100, 500);
cache.record_feedback("users", 12345, None, 100, 500);
let corrected = cache.apply_correction("users", 12345, 200);
assert!(corrected > 200);
}
}