use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use parking_lot::RwLock;
use super::registry::{SchemaRegistry, DataTemperature, WorkloadType};
#[derive(Debug)]
pub struct LearningClassifier {
history: DashMap<String, QueryHistory>,
model: Arc<RwLock<ClassificationModel>>,
schema: Arc<SchemaRegistry>,
config: ClassifierConfig,
}
impl LearningClassifier {
pub fn new(schema: Arc<SchemaRegistry>) -> Self {
Self {
history: DashMap::new(),
model: Arc::new(RwLock::new(ClassificationModel::new())),
schema,
config: ClassifierConfig::default(),
}
}
pub fn with_config(schema: Arc<SchemaRegistry>, config: ClassifierConfig) -> Self {
Self {
history: DashMap::new(),
model: Arc::new(RwLock::new(ClassificationModel::new())),
schema,
config,
}
}
pub fn record(&self, table: &str, query_type: QueryType, latency: Duration) {
let mut history = self.history
.entry(table.to_string())
.or_insert_with(QueryHistory::new);
history.record(query_type, latency);
if history.count() % self.config.reclassification_threshold == 0 {
self.reclassify(table);
}
}
pub fn reclassify(&self, table: &str) {
let history = match self.history.get(table) {
Some(h) => h.clone(),
None => return,
};
let model = self.model.read();
let temperature = model.classify_temperature(&history);
let workload = model.classify_workload(&history);
self.schema.update_classification(table, temperature, workload);
}
pub fn get_classification(&self, table: &str) -> Option<TableClassification> {
let history = self.history.get(table)?;
let model = self.model.read();
Some(TableClassification {
table: table.to_string(),
temperature: model.classify_temperature(&history),
workload: model.classify_workload(&history),
confidence: model.classification_confidence(&history),
query_count: history.count(),
last_updated: history.last_updated(),
})
}
pub fn all_classifications(&self) -> Vec<TableClassification> {
self.history
.iter()
.map(|entry| {
let table = entry.key();
let history = entry.value();
let model = self.model.read();
TableClassification {
table: table.clone(),
temperature: model.classify_temperature(history),
workload: model.classify_workload(history),
confidence: model.classification_confidence(history),
query_count: history.count(),
last_updated: history.last_updated(),
}
})
.collect()
}
pub fn update_thresholds(&self, thresholds: ModelThresholds) {
let mut model = self.model.write();
model.thresholds = thresholds;
}
pub fn get_history(&self, table: &str) -> Option<QueryHistory> {
self.history.get(table).map(|h| h.clone())
}
pub fn clear_history(&self, table: &str) {
self.history.remove(table);
}
pub fn clear_all(&self) {
self.history.clear();
}
pub fn query_count(&self) -> u64 {
self.history.iter().map(|h| h.value().count()).sum()
}
pub fn suggest_temperature(&self, table: &str) -> Option<DataTemperature> {
let history = self.history.get(table)?;
let model = self.model.read();
Some(model.classify_temperature(&history))
}
pub fn suggest_workload(&self, table: &str) -> Option<WorkloadType> {
let history = self.history.get(table)?;
let model = self.model.read();
Some(model.classify_workload(&history))
}
pub fn get_confidence(&self, table: &str) -> Option<f64> {
let history = self.history.get(table)?;
let model = self.model.read();
Some(model.classification_confidence(&history))
}
pub fn classify_query(&self, sql: &str) -> Option<WorkloadType> {
let query_type = QueryType::from_sql(sql);
Some(match query_type {
QueryType::VectorSearch => WorkloadType::Vector,
QueryType::AggregateSelect | QueryType::JoinSelect => WorkloadType::OLAP,
QueryType::SimpleSelect => WorkloadType::OLTP,
QueryType::Insert | QueryType::Update | QueryType::Delete => WorkloadType::OLTP,
})
}
}
#[derive(Debug, Clone)]
pub struct ClassifierConfig {
pub reclassification_threshold: u64,
pub rate_window: Duration,
pub min_queries: u64,
}
impl Default for ClassifierConfig {
fn default() -> Self {
Self {
reclassification_threshold: 1000,
rate_window: Duration::from_secs(60),
min_queries: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct QueryHistory {
total_count: u64,
read_count: u64,
write_count: u64,
type_counts: HashMap<QueryType, u64>,
latencies: Vec<Duration>,
qpm_samples: Vec<(Instant, u64)>,
created: Instant,
last_updated: Instant,
}
impl QueryHistory {
pub fn new() -> Self {
let now = Instant::now();
Self {
total_count: 0,
read_count: 0,
write_count: 0,
type_counts: HashMap::new(),
latencies: Vec::new(),
qpm_samples: Vec::new(),
created: now,
last_updated: now,
}
}
pub fn record(&mut self, query_type: QueryType, latency: Duration) {
self.total_count += 1;
self.last_updated = Instant::now();
*self.type_counts.entry(query_type).or_insert(0) += 1;
if query_type.is_read() {
self.read_count += 1;
} else {
self.write_count += 1;
}
if self.latencies.len() >= 1000 {
self.latencies.remove(0);
}
self.latencies.push(latency);
self.update_qpm();
}
fn update_qpm(&mut self) {
let now = Instant::now();
self.qpm_samples.retain(|(t, _)| now.duration_since(*t) < Duration::from_secs(300));
self.qpm_samples.push((now, self.total_count));
}
pub fn count(&self) -> u64 {
self.total_count
}
pub fn qpm(&self) -> f64 {
if self.qpm_samples.len() < 2 {
return 0.0;
}
let first = self.qpm_samples.first().expect("checked len");
let last = self.qpm_samples.last().expect("checked len");
let duration = last.0.duration_since(first.0);
if duration.as_secs() == 0 {
return 0.0;
}
let queries = last.1 - first.1;
(queries as f64 / duration.as_secs_f64()) * 60.0
}
pub fn read_write_ratio(&self) -> f64 {
if self.write_count == 0 {
return f64::INFINITY;
}
self.read_count as f64 / self.write_count as f64
}
pub fn avg_latency(&self) -> Duration {
if self.latencies.is_empty() {
return Duration::ZERO;
}
let sum: Duration = self.latencies.iter().sum();
sum / self.latencies.len() as u32
}
pub fn p95_latency(&self) -> Duration {
if self.latencies.is_empty() {
return Duration::ZERO;
}
let mut sorted = self.latencies.clone();
sorted.sort();
let idx = (sorted.len() as f64 * 0.95) as usize;
sorted.get(idx.min(sorted.len() - 1)).copied().unwrap_or(Duration::ZERO)
}
pub fn last_updated(&self) -> Instant {
self.last_updated
}
pub fn type_count(&self, query_type: QueryType) -> u64 {
self.type_counts.get(&query_type).copied().unwrap_or(0)
}
pub fn type_fraction(&self, query_type: QueryType) -> f64 {
if self.total_count == 0 {
return 0.0;
}
self.type_count(query_type) as f64 / self.total_count as f64
}
}
impl Default for QueryHistory {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QueryType {
SimpleSelect,
AggregateSelect,
JoinSelect,
VectorSearch,
Insert,
Update,
Delete,
}
impl QueryType {
pub fn is_read(&self) -> bool {
matches!(self,
QueryType::SimpleSelect | QueryType::AggregateSelect |
QueryType::JoinSelect | QueryType::VectorSearch)
}
pub fn is_write(&self) -> bool {
!self.is_read()
}
pub fn is_olap(&self) -> bool {
matches!(self, QueryType::AggregateSelect | QueryType::JoinSelect)
}
pub fn from_sql(sql: &str) -> Self {
let upper = sql.to_uppercase();
if upper.starts_with("INSERT") {
QueryType::Insert
} else if upper.starts_with("UPDATE") {
QueryType::Update
} else if upper.starts_with("DELETE") {
QueryType::Delete
} else if upper.contains("<->") || upper.contains("VECTOR") || upper.contains("EMBEDDING") {
QueryType::VectorSearch
} else if upper.contains("COUNT(") || upper.contains("SUM(") || upper.contains("AVG(") {
QueryType::AggregateSelect
} else if upper.contains(" JOIN ") {
QueryType::JoinSelect
} else {
QueryType::SimpleSelect
}
}
}
#[derive(Debug)]
pub struct ClassificationModel {
pub thresholds: ModelThresholds,
}
impl ClassificationModel {
pub fn new() -> Self {
Self {
thresholds: ModelThresholds::default(),
}
}
pub fn classify_temperature(&self, history: &QueryHistory) -> DataTemperature {
let qpm = history.qpm();
if qpm > self.thresholds.hot_qpm {
DataTemperature::Hot
} else if qpm > self.thresholds.warm_qpm {
DataTemperature::Warm
} else if qpm > self.thresholds.cold_qpm {
DataTemperature::Cold
} else {
DataTemperature::Frozen
}
}
pub fn classify_workload(&self, history: &QueryHistory) -> WorkloadType {
if history.type_fraction(QueryType::VectorSearch) > 0.3 {
return WorkloadType::Vector;
}
let rw_ratio = history.read_write_ratio();
if rw_ratio > self.thresholds.olap_ratio {
if history.type_fraction(QueryType::AggregateSelect) > 0.2 {
return WorkloadType::OLAP;
}
}
if rw_ratio < self.thresholds.oltp_ratio {
return WorkloadType::OLTP;
}
if history.qpm() > 100.0 && rw_ratio > 1.0 && rw_ratio < 10.0 {
return WorkloadType::HTAP;
}
WorkloadType::Mixed
}
pub fn classification_confidence(&self, history: &QueryHistory) -> f64 {
let query_factor = (history.count() as f64 / 1000.0).min(1.0);
let rw_ratio = history.read_write_ratio();
let pattern_factor = if rw_ratio > 10.0 || rw_ratio < 2.0 {
0.8
} else {
0.5
};
query_factor * pattern_factor
}
}
impl Default for ClassificationModel {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ModelThresholds {
pub hot_qpm: f64,
pub warm_qpm: f64,
pub cold_qpm: f64,
pub olap_ratio: f64,
pub oltp_ratio: f64,
}
impl Default for ModelThresholds {
fn default() -> Self {
Self {
hot_qpm: 1000.0,
warm_qpm: 100.0,
cold_qpm: 10.0,
olap_ratio: 10.0,
oltp_ratio: 2.0,
}
}
}
#[derive(Debug, Clone)]
pub struct TableClassification {
pub table: String,
pub temperature: DataTemperature,
pub workload: WorkloadType,
pub confidence: f64,
pub query_count: u64,
pub last_updated: Instant,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_history() {
let mut history = QueryHistory::new();
history.record(QueryType::SimpleSelect, Duration::from_millis(10));
history.record(QueryType::SimpleSelect, Duration::from_millis(20));
history.record(QueryType::Insert, Duration::from_millis(30));
assert_eq!(history.count(), 3);
assert_eq!(history.read_count, 2);
assert_eq!(history.write_count, 1);
assert_eq!(history.read_write_ratio(), 2.0);
}
#[test]
fn test_query_type_detection() {
assert_eq!(QueryType::from_sql("INSERT INTO users VALUES (1)"), QueryType::Insert);
assert_eq!(QueryType::from_sql("UPDATE users SET name = 'x'"), QueryType::Update);
assert_eq!(QueryType::from_sql("DELETE FROM users"), QueryType::Delete);
assert_eq!(QueryType::from_sql("SELECT COUNT(*) FROM users"), QueryType::AggregateSelect);
assert_eq!(QueryType::from_sql("SELECT * FROM users"), QueryType::SimpleSelect);
assert_eq!(QueryType::from_sql("SELECT * FROM a JOIN b ON a.id = b.id"), QueryType::JoinSelect);
}
#[test]
fn test_classification_model() {
let model = ClassificationModel::new();
let mut history = QueryHistory::new();
for _ in 0..1000 {
history.record(QueryType::SimpleSelect, Duration::from_millis(5));
}
for _ in 0..50 {
history.record(QueryType::Insert, Duration::from_millis(10));
}
let workload = model.classify_workload(&history);
assert!(workload == WorkloadType::OLAP || workload == WorkloadType::Mixed);
}
#[test]
fn test_learning_classifier() {
let registry = Arc::new(SchemaRegistry::new());
let classifier = LearningClassifier::new(registry);
for _ in 0..100 {
classifier.record("users", QueryType::SimpleSelect, Duration::from_millis(5));
}
let classification = classifier.get_classification("users");
assert!(classification.is_some());
assert_eq!(classification.as_ref().map(|c| c.query_count), Some(100));
}
#[test]
fn test_temperature_classification() {
let model = ClassificationModel::new();
let mut history = QueryHistory::new();
for _ in 0..1000 {
history.record(QueryType::SimpleSelect, Duration::from_millis(1));
}
let temp = model.classify_temperature(&history);
assert!(temp == DataTemperature::Hot || temp == DataTemperature::Warm || temp == DataTemperature::Cold || temp == DataTemperature::Frozen);
}
#[test]
fn test_latency_tracking() {
let mut history = QueryHistory::new();
for i in 0..100 {
history.record(QueryType::SimpleSelect, Duration::from_millis(i));
}
let avg = history.avg_latency();
assert!(avg.as_millis() > 0);
let p95 = history.p95_latency();
assert!(p95 >= avg);
}
}