use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum QueryComplexity {
VeryShort,
Short,
Medium,
Long,
Unlimited,
}
impl QueryComplexity {
#[must_use]
pub const fn max_tokens(self, model_max: u32) -> u32 {
match self {
Self::VeryShort => 64,
Self::Short => 256,
Self::Medium => 512,
Self::Long => 1024,
Self::Unlimited => model_max,
}
}
}
#[must_use]
pub fn classify_query(message: &str) -> QueryComplexity {
let lower = message.to_lowercase();
let trimmed = lower.trim();
if is_unlimited(trimmed) {
return QueryComplexity::Unlimited;
}
if is_long(trimmed) {
return QueryComplexity::Long;
}
if is_very_short(trimmed) {
return QueryComplexity::VeryShort;
}
if is_short(trimmed) {
return QueryComplexity::Short;
}
if is_medium(trimmed) {
return QueryComplexity::Medium;
}
QueryComplexity::Medium
}
fn is_unlimited(query: &str) -> bool {
let code_patterns = [
"write code",
"write a code",
"write the code",
"implement",
"create a function",
"write a function",
"write a program",
"write a script",
"write a class",
"write a module",
"code example",
"code snippet",
"refactor",
"debug this",
"fix this code",
];
let longform_patterns = [
"write an essay",
"write a story",
"write an article",
"write a report",
"write a tutorial",
"write a guide",
"write a blog",
"generate a document",
"full implementation",
];
code_patterns.iter().any(|p| query.contains(p))
|| longform_patterns.iter().any(|p| query.contains(p))
}
fn is_long(query: &str) -> bool {
let question_marks = query.chars().filter(|&c| c == '?').count();
if question_marks >= 2 {
return true;
}
let patterns = [
"and also",
"in detail",
"in depth",
"step by step",
"analyze",
"analyse",
"compare and contrast",
"pros and cons",
"advantages and disadvantages",
"write a",
"draft a",
"compose",
"create a plan",
"design a",
];
patterns.iter().any(|p| query.contains(p))
}
fn is_very_short(query: &str) -> bool {
let is_question = query.ends_with('?');
if !is_question {
return false;
}
let yn_starters = [
"is ", "are ", "was ", "were ", "do ", "does ", "did ", "can ", "could ", "will ",
"would ", "should ", "has ", "have ", "had ",
];
if yn_starters.iter().any(|s| query.starts_with(s)) {
return true;
}
let single_fact = [
"what time",
"what day",
"what date",
"how old",
"how many",
"how much",
"what year",
"what color",
"what colour",
"how tall",
"how far",
"how long is",
"how long does",
"what is the capital",
"what is the population",
"who won",
"true or false",
];
single_fact.iter().any(|p| query.contains(p))
}
fn is_short(query: &str) -> bool {
let patterns = [
"who is",
"who was",
"what is",
"what are",
"what was",
"define ",
"definition of",
"meaning of",
"when did",
"when was",
"when is",
"where is",
"where was",
"where are",
"translate",
"convert ",
"calculate ",
"what does",
];
patterns.iter().any(|p| query.contains(p))
}
fn is_medium(query: &str) -> bool {
let patterns = [
"explain",
"how does",
"how do",
"how can",
"how to",
"why ",
"describe",
"summarize",
"summarise",
"list ",
"name ",
"what happens",
"tell me about",
"give me",
"overview",
"difference between",
];
patterns.iter().any(|p| query.contains(p))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn yes_no_question_is_very_short() {
assert_eq!(
classify_query("Is the sky blue?"),
QueryComplexity::VeryShort
);
}
#[test]
fn can_question_is_very_short() {
assert_eq!(classify_query("Can dogs swim?"), QueryComplexity::VeryShort);
}
#[test]
fn what_time_is_very_short() {
assert_eq!(
classify_query("What time is it in Tokyo?"),
QueryComplexity::VeryShort
);
}
#[test]
fn how_many_is_very_short() {
assert_eq!(
classify_query("How many planets are in the solar system?"),
QueryComplexity::VeryShort
);
}
#[test]
fn capital_question_is_very_short() {
assert_eq!(
classify_query("What is the capital of France?"),
QueryComplexity::VeryShort
);
}
#[test]
fn who_is_query_is_short() {
assert_eq!(
classify_query("Who is Albert Einstein?"),
QueryComplexity::Short
);
}
#[test]
fn definition_query_is_short() {
assert_eq!(
classify_query("Define photosynthesis"),
QueryComplexity::Short
);
}
#[test]
fn when_did_is_short() {
assert_eq!(
classify_query("When did World War 2 end?"),
QueryComplexity::Short
);
}
#[test]
fn translate_is_short() {
assert_eq!(
classify_query("Translate 'hello' to German"),
QueryComplexity::Short
);
}
#[test]
fn where_is_query_is_short() {
assert_eq!(
classify_query("Where is the Eiffel Tower?"),
QueryComplexity::Short
);
}
#[test]
fn explain_query_is_medium() {
assert_eq!(
classify_query("Explain how gravity works"),
QueryComplexity::Medium
);
}
#[test]
fn how_does_query_is_medium() {
assert_eq!(
classify_query("How does a combustion engine work?"),
QueryComplexity::Medium
);
}
#[test]
fn list_query_is_medium() {
assert_eq!(
classify_query("List the top 5 programming languages"),
QueryComplexity::Medium
);
}
#[test]
fn why_query_is_medium() {
assert_eq!(
classify_query("Why is the ocean salty?"),
QueryComplexity::Medium
);
}
#[test]
fn summarize_query_is_medium() {
assert_eq!(
classify_query("Summarize the plot of Hamlet"),
QueryComplexity::Medium
);
}
#[test]
fn multi_question_is_long() {
assert_eq!(
classify_query("What is Rust? Why is it popular?"),
QueryComplexity::Long
);
}
#[test]
fn in_detail_is_long() {
assert_eq!(
classify_query("Explain quantum physics in detail"),
QueryComplexity::Long
);
}
#[test]
fn step_by_step_is_long() {
assert_eq!(
classify_query("Show me step by step how to bake bread"),
QueryComplexity::Long
);
}
#[test]
fn analyze_is_long() {
assert_eq!(
classify_query("Analyze the themes in Macbeth"),
QueryComplexity::Long
);
}
#[test]
fn pros_and_cons_is_long() {
assert_eq!(
classify_query("What are the pros and cons of electric cars?"),
QueryComplexity::Long
);
}
#[test]
fn write_code_is_unlimited() {
assert_eq!(
classify_query("Write code to sort a list in Python"),
QueryComplexity::Unlimited
);
}
#[test]
fn implement_is_unlimited() {
assert_eq!(
classify_query("Implement a binary search tree"),
QueryComplexity::Unlimited
);
}
#[test]
fn write_essay_is_unlimited() {
assert_eq!(
classify_query("Write an essay about climate change"),
QueryComplexity::Unlimited
);
}
#[test]
fn create_function_is_unlimited() {
assert_eq!(
classify_query("Create a function that validates emails"),
QueryComplexity::Unlimited
);
}
#[test]
fn refactor_is_unlimited() {
assert_eq!(
classify_query("Refactor this code to use async/await"),
QueryComplexity::Unlimited
);
}
#[test]
fn max_tokens_very_short() {
assert_eq!(QueryComplexity::VeryShort.max_tokens(2048), 64);
}
#[test]
fn max_tokens_short() {
assert_eq!(QueryComplexity::Short.max_tokens(2048), 256);
}
#[test]
fn max_tokens_medium() {
assert_eq!(QueryComplexity::Medium.max_tokens(2048), 512);
}
#[test]
fn max_tokens_long() {
assert_eq!(QueryComplexity::Long.max_tokens(2048), 1024);
}
#[test]
fn max_tokens_unlimited_uses_model_max() {
assert_eq!(QueryComplexity::Unlimited.max_tokens(2048), 2048);
assert_eq!(QueryComplexity::Unlimited.max_tokens(4096), 4096);
}
#[test]
fn ambiguous_query_defaults_to_medium() {
assert_eq!(classify_query("Tell me something"), QueryComplexity::Medium);
}
#[test]
fn empty_query_defaults_to_medium() {
assert_eq!(classify_query(""), QueryComplexity::Medium);
}
#[test]
fn case_insensitive_classification() {
assert_eq!(
classify_query("EXPLAIN how DNS works"),
QueryComplexity::Medium
);
assert_eq!(
classify_query("WRITE CODE for a web server"),
QueryComplexity::Unlimited
);
}
}