use tiktoken_rs::{num_tokens_from_messages, ChatCompletionRequestMessage};
#[derive(Debug, Clone, PartialEq)]
pub enum TaskComplexity {
Simple, Medium, Complex, }
impl TaskComplexity {
pub fn from_prompt(prompt: &str) -> Self {
let token_count = count_tokens(prompt);
match token_count {
0..=200 => TaskComplexity::Simple,
201..=1000 => TaskComplexity::Medium,
_ => TaskComplexity::Complex,
}
}
pub fn from_content_analysis(prompt: &str) -> Self {
let mut complexity_score = 0.0;
let complex_keywords = [
"architecture",
"design",
"system",
"refactor",
"debug",
"performance",
"optimization",
"security",
"scaling",
"distributed",
"patterns",
"algorithm",
];
let simple_keywords = [
"what is", "how to", "define", "explain", "syntax", "fix", "bug",
];
let prompt_lower = prompt.to_lowercase();
for keyword in &complex_keywords {
if prompt_lower.contains(keyword) {
complexity_score += 2.0;
}
}
for keyword in &simple_keywords {
if prompt_lower.contains(keyword) {
complexity_score -= 1.0;
}
}
let length_factor = (prompt.len() as f64 / 1000.0).min(5.0);
complexity_score += length_factor;
if prompt.contains("fn ") || prompt.contains("function") || prompt.contains("class") {
complexity_score += 0.5; }
if prompt.contains('{') && prompt.contains('}') {
complexity_score += 0.3;
}
match complexity_score {
score if score < 1.0 => TaskComplexity::Simple,
score if score < 3.0 => TaskComplexity::Medium,
_ => TaskComplexity::Complex,
}
}
pub fn analyze(prompt: &str) -> Self {
let token_based = Self::from_prompt(prompt);
let content_based = Self::from_content_analysis(prompt);
match (token_based, content_based) {
(TaskComplexity::Complex, _) | (_, TaskComplexity::Complex) => TaskComplexity::Complex,
(TaskComplexity::Medium, _) | (_, TaskComplexity::Medium) => TaskComplexity::Medium,
_ => TaskComplexity::Simple,
}
}
}
fn count_tokens(text: &str) -> usize {
match num_tokens_from_messages(
"gpt-3.5-turbo",
&[ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(text.to_string()),
name: None,
function_call: None,
}],
) {
Ok(count) => count,
Err(_) => {
text.chars().count() / 4 }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_task() {
let prompt = "What is 2 + 2?";
assert_eq!(TaskComplexity::analyze(prompt), TaskComplexity::Simple);
}
#[test]
fn test_medium_task() {
let prompt = "Write a Rust function that takes a vector of integers and returns the sum.";
assert_eq!(TaskComplexity::analyze(prompt), TaskComplexity::Medium);
}
#[test]
fn test_complex_task() {
let prompt = r#"
Design a distributed caching system that handles high throughput requests,
provides cache invalidation, and supports multiple cache eviction strategies.
Consider CAP theorem tradeoffs and provide implementation details for each component.
"#;
assert_eq!(TaskComplexity::analyze(prompt), TaskComplexity::Complex);
}
}