use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum QueryComplexity {
SimpleLookup,
ConstraintCheck,
Repair,
Temporal,
Aggregation,
}
#[derive(Debug, Clone)]
pub struct QueryClassification {
pub complexity: QueryComplexity,
pub matched_patterns: Vec<String>,
pub layer_weights: HashMap<u8, f64>,
}
const REPAIR_PATTERNS: &[&str] = &[
"now that",
"given the change",
"after the update",
"recalculate",
"revised",
"updated",
"what is the new",
"what's the new",
"how does this affect",
"impact of the change",
"corrected",
"adjusted",
"changed from",
"cascade",
"propagate",
"downstream",
];
const CONSTRAINT_PATTERNS: &[&str] = &[
"can we",
"should we",
"is it allowed",
"is this allowed",
"do we have approval",
"does this comply",
"within budget",
"approved",
"permitted",
"authorize",
"feasible",
"proceed with",
"go ahead",
"move forward",
"violate",
"breach",
"exceed",
"comply",
"eligible",
"qualified",
"meets the requirement",
"policy allows",
"allowed to",
];
const TEMPORAL_PATTERNS: &[&str] = &[
"when is",
"when does",
"when will",
"deadline",
"due date",
"schedule",
"how long",
"timeline",
"by when",
"expired",
"still valid",
"current status",
"latest",
"most recent",
"what changed",
];
const AGGREGATION_PATTERNS: &[&str] = &[
"how many",
"how much",
"total",
"summarize",
"summary",
"list all",
"overview",
"across all",
"combined",
"count",
"aggregate",
"everything about",
];
fn default_layer_weights(complexity: QueryComplexity) -> HashMap<u8, f64> {
let (l1, l2, l3, l4) = match complexity {
QueryComplexity::SimpleLookup => (0.05, 0.60, 0.20, 0.15),
QueryComplexity::ConstraintCheck => (0.05, 0.65, 0.20, 0.10),
QueryComplexity::Repair => (0.05, 0.60, 0.25, 0.10),
QueryComplexity::Temporal => (0.05, 0.35, 0.25, 0.35),
QueryComplexity::Aggregation => (0.05, 0.55, 0.25, 0.15),
};
HashMap::from([(1, l1), (2, l2), (3, l3), (4, l4)])
}
pub fn classify_query(query: &str) -> QueryClassification {
let q = query.to_lowercase();
for (patterns, complexity) in [
(REPAIR_PATTERNS, QueryComplexity::Repair),
(CONSTRAINT_PATTERNS, QueryComplexity::ConstraintCheck),
(TEMPORAL_PATTERNS, QueryComplexity::Temporal),
(AGGREGATION_PATTERNS, QueryComplexity::Aggregation),
] {
let matched: Vec<String> = patterns
.iter()
.filter(|p| q.contains(*p))
.map(|p| p.to_string())
.collect();
if !matched.is_empty() {
return QueryClassification {
complexity,
matched_patterns: matched,
layer_weights: default_layer_weights(complexity),
};
}
}
QueryClassification {
complexity: QueryComplexity::SimpleLookup,
matched_patterns: vec![],
layer_weights: default_layer_weights(QueryComplexity::SimpleLookup),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_lookup() {
let c = classify_query("What is the project budget?");
assert_eq!(c.complexity, QueryComplexity::SimpleLookup);
}
#[test]
fn constraint_check() {
let c = classify_query("Can we proceed with this purchase?");
assert_eq!(c.complexity, QueryComplexity::ConstraintCheck);
assert!(!c.matched_patterns.is_empty());
}
#[test]
fn repair_query() {
let c = classify_query("Now that the budget changed, what is the new headcount?");
assert_eq!(c.complexity, QueryComplexity::Repair);
}
#[test]
fn temporal_query() {
let c = classify_query("When is the project deadline?");
assert_eq!(c.complexity, QueryComplexity::Temporal);
}
#[test]
fn aggregation_query() {
let c = classify_query("How many engineers are on the team total?");
assert_eq!(c.complexity, QueryComplexity::Aggregation);
}
#[test]
fn repair_takes_priority_over_constraint() {
let c = classify_query("Now that the policy changed, can we proceed?");
assert_eq!(c.complexity, QueryComplexity::Repair);
}
#[test]
fn layer_weights_sum_to_one() {
for complexity in [
QueryComplexity::SimpleLookup,
QueryComplexity::ConstraintCheck,
QueryComplexity::Repair,
QueryComplexity::Temporal,
QueryComplexity::Aggregation,
] {
let weights = default_layer_weights(complexity);
let sum: f64 = weights.values().sum();
assert!(
(sum - 1.0).abs() < 0.01,
"{:?} weights sum to {}",
complexity,
sum
);
}
}
#[test]
fn temporal_gives_more_to_environment() {
let temporal = default_layer_weights(QueryComplexity::Temporal);
let simple = default_layer_weights(QueryComplexity::SimpleLookup);
assert!(
temporal[&4] > simple[&4],
"temporal should give more budget to environment layer"
);
}
}