use crate::types::*;
use crate::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ComplexityLevel {
Trivial,
Simple,
Moderate,
Complex,
VeryComplex,
}
impl std::fmt::Display for ComplexityLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ComplexityLevel::Trivial => write!(f, "TRIVIAL"),
ComplexityLevel::Simple => write!(f, "SIMPLE"),
ComplexityLevel::Moderate => write!(f, "MODERATE"),
ComplexityLevel::Complex => write!(f, "COMPLEX"),
ComplexityLevel::VeryComplex => write!(f, "VERY_COMPLEX"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TokenEstimate {
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ComplexityAssessment {
pub level: ComplexityLevel,
pub reasoning: String,
pub estimated_steps: usize,
pub estimated_tokens: TokenEstimate,
pub confidence: f32,
pub factors: ComplexityFactors,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ComplexityFactors {
pub length_score: f32,
pub question_score: f32,
pub domain_score: f32,
pub context_score: f32,
pub reasoning_score: f32,
}
impl ComplexityFactors {
pub fn average(&self) -> f32 {
(self.length_score
+ self.question_score
+ self.domain_score
+ self.context_score
+ self.reasoning_score)
/ 5.0
}
}
pub struct ComplexityAnalyzer;
impl ComplexityAnalyzer {
pub fn new() -> Self {
Self
}
pub async fn assess(&self, message: &Memory, state: &State) -> Result<ComplexityAssessment> {
let text = &message.content.text;
let factors = ComplexityFactors {
length_score: self.analyze_length(text),
question_score: self.analyze_questions(text),
domain_score: self.analyze_domain(text),
context_score: self.analyze_context_needed(text, state),
reasoning_score: self.analyze_reasoning_depth(text),
};
let level = self.determine_level(&factors);
let estimated_steps = self.estimate_steps(&level, &factors);
let estimated_tokens = self.estimate_tokens(&level, &factors, text);
let confidence = self.calculate_confidence(&factors);
let reasoning = self.build_reasoning(&factors, &level);
Ok(ComplexityAssessment {
level,
reasoning,
estimated_steps,
estimated_tokens,
confidence,
factors,
})
}
fn analyze_length(&self, text: &str) -> f32 {
let words = text.split_whitespace().count();
match words {
0..=5 => 0.1, 6..=15 => 0.2, 16..=50 => 0.4, 51..=150 => 0.6, 151..=300 => 0.8, _ => 1.0, }
}
fn analyze_questions(&self, text: &str) -> f32 {
let lower = text.to_lowercase();
let mut score = 0.0;
let question_marks = text.matches('?').count();
score += (question_marks as f32 * 0.2).min(0.4);
let complex_patterns = [
"how do i",
"how can i",
"how would",
"why does",
"why is",
"why would",
"what's the difference",
"what is the best way",
"can you explain",
"could you help me understand",
"multiple",
"several",
"various",
];
for pattern in &complex_patterns {
if lower.contains(pattern) {
score += 0.15;
}
}
if lower.contains(" and ") || lower.contains(" or ") {
score += 0.2;
}
score.min(1.0)
}
fn analyze_domain(&self, text: &str) -> f32 {
let lower = text.to_lowercase();
let mut score: f32 = 0.0;
let technical_keywords = [
"algorithm",
"implement",
"code",
"function",
"system",
"architecture",
"database",
"optimization",
"performance",
"security",
"encryption",
"network",
"protocol",
"api",
"machine learning",
"neural network",
"blockchain",
"distributed",
"concurrent",
"async",
"runtime",
];
let mut hits = 0;
for keyword in &technical_keywords {
if lower.contains(keyword) {
hits += 1;
}
}
let academic_keywords = ["research", "study", "analysis", "theory", "hypothesis"];
score = 0.2 + 0.15 * hits as f32;
if lower.contains("consensus") {
score += 0.15;
}
if lower.contains("raft") {
score += 0.15;
}
if lower.contains("leader election") {
score += 0.1;
}
if lower.contains("log replication") {
score += 0.1;
}
score.min(1.0)
}
fn analyze_context_needed(&self, text: &str, state: &State) -> f32 {
let lower = text.to_lowercase();
let mut score: f32 = 0.0;
let context_patterns = [
"previous",
"earlier",
"before",
"last time",
"you said",
"you mentioned",
"as discussed",
"continue",
"following up",
"regarding",
];
for pattern in &context_patterns {
if lower.contains(pattern) {
score += 0.2;
}
}
let pronouns = ["it", "this", "that", "these", "those", "they"];
for pronoun in &pronouns {
if lower.contains(&format!(" {} ", pronoun)) {
score += 0.1;
}
}
if let Some(recent_messages) = state.data.get("recentMessages") {
if let Some(arr) = recent_messages.as_array() {
if arr.len() > 3 {
score += 0.2;
}
}
}
score.min(1.0)
}
fn analyze_reasoning_depth(&self, text: &str) -> f32 {
let lower = text.to_lowercase();
let mut score: f32 = 0.2;
let reasoning_patterns = [
"step by step",
"first",
"then",
"finally",
"process",
"explain how",
"explain why",
"reasoning",
"logic",
"compare",
"contrast",
"analyze",
"evaluate",
"pros and cons",
"advantages",
"disadvantages",
"consider",
"think about",
"take into account",
];
for pattern in &reasoning_patterns {
if lower.contains(pattern) {
score += 0.15;
}
}
if lower.contains("because") || lower.contains("therefore") || lower.contains("thus") {
score += 0.1;
}
if lower.matches(" if ").count() > 1 {
score += 0.2;
}
score.min(1.0)
}
fn determine_level(&self, factors: &ComplexityFactors) -> ComplexityLevel {
let avg = factors.average();
let max_score = factors
.length_score
.max(factors.question_score)
.max(factors.domain_score)
.max(factors.context_score)
.max(factors.reasoning_score);
let weighted = avg * 0.7 + max_score * 0.3;
match weighted {
x if x < 0.2 => ComplexityLevel::Trivial,
x if x < 0.4 => ComplexityLevel::Simple,
x if x < 0.6 => ComplexityLevel::Moderate,
x if x < 0.8 => ComplexityLevel::Complex,
_ => ComplexityLevel::VeryComplex,
}
}
fn estimate_steps(&self, level: &ComplexityLevel, factors: &ComplexityFactors) -> usize {
let base_steps = match level {
ComplexityLevel::Trivial => 1,
ComplexityLevel::Simple => 2,
ComplexityLevel::Moderate => 4,
ComplexityLevel::Complex => 7,
ComplexityLevel::VeryComplex => 12,
};
let adjustment = (factors.reasoning_score * 3.0) as usize;
base_steps + adjustment
}
fn estimate_tokens(
&self,
level: &ComplexityLevel,
factors: &ComplexityFactors,
text: &str,
) -> TokenEstimate {
let message_tokens = (text.len() / 4).max(1);
let context_tokens = (factors.context_score * 300.0) as usize;
let system_tokens = match level {
ComplexityLevel::Trivial => 50,
ComplexityLevel::Simple => 100,
ComplexityLevel::Moderate => 150,
ComplexityLevel::Complex => 200,
ComplexityLevel::VeryComplex => 300,
};
let input_tokens = message_tokens + context_tokens + system_tokens;
let base_output = match level {
ComplexityLevel::Trivial => 50,
ComplexityLevel::Simple => 150,
ComplexityLevel::Moderate => 300,
ComplexityLevel::Complex => 500,
ComplexityLevel::VeryComplex => 1000,
};
let domain_adjustment = (factors.domain_score * 200.0) as usize;
let output_tokens = base_output + domain_adjustment;
let buffered_output = (output_tokens as f32 * 1.2) as usize;
TokenEstimate {
input_tokens,
output_tokens: buffered_output,
total_tokens: input_tokens + buffered_output,
confidence: self.calculate_confidence(factors),
}
}
fn calculate_confidence(&self, factors: &ComplexityFactors) -> f32 {
let avg = factors.average();
let variance = [
(factors.length_score - avg).powi(2),
(factors.question_score - avg).powi(2),
(factors.domain_score - avg).powi(2),
(factors.context_score - avg).powi(2),
(factors.reasoning_score - avg).powi(2),
]
.iter()
.sum::<f32>()
/ 5.0;
let confidence = 1.0 - variance.sqrt().min(0.5);
confidence.max(0.5).min(0.95)
}
fn build_reasoning(&self, factors: &ComplexityFactors, level: &ComplexityLevel) -> String {
format!(
"Complexity: {} | Factors: length={:.2}, questions={:.2}, domain={:.2}, context={:.2}, reasoning={:.2} | Average: {:.2}",
level,
factors.length_score,
factors.question_score,
factors.domain_score,
factors.context_score,
factors.reasoning_score,
factors.average()
)
}
}
impl Default for ComplexityAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
fn create_test_message(text: &str) -> Memory {
Memory {
id: Uuid::new_v4(),
entity_id: Uuid::new_v4(),
agent_id: Uuid::new_v4(),
room_id: Uuid::new_v4(),
content: Content {
text: text.to_string(),
..Default::default()
},
embedding: None,
metadata: None,
created_at: chrono::Utc::now().timestamp(),
unique: None,
similarity: None,
}
}
#[tokio::test]
async fn test_trivial_complexity() {
let analyzer = ComplexityAnalyzer::new();
let message = create_test_message("Hi");
let state = State::new();
let assessment = analyzer.assess(&message, &state).await.unwrap();
assert!(matches!(
assessment.level,
ComplexityLevel::Trivial | ComplexityLevel::Simple
));
assert!(assessment.estimated_tokens.total_tokens <= 300);
}
#[tokio::test]
async fn test_simple_complexity() {
let analyzer = ComplexityAnalyzer::new();
let message = create_test_message("What's the weather like today?");
let state = State::new();
let assessment = analyzer.assess(&message, &state).await.unwrap();
assert!(matches!(
assessment.level,
ComplexityLevel::Simple | ComplexityLevel::Trivial
));
}
#[tokio::test]
async fn test_complex_technical() {
let analyzer = ComplexityAnalyzer::new();
let message = create_test_message(
"Can you explain how to implement a distributed consensus algorithm \
using Raft protocol, including the leader election process and log replication?",
);
let state = State::new();
let assessment = analyzer.assess(&message, &state).await.unwrap();
assert!(matches!(
assessment.level,
ComplexityLevel::Moderate | ComplexityLevel::Complex | ComplexityLevel::VeryComplex
));
assert!(assessment.factors.domain_score > 0.5);
assert!(assessment.factors.question_score > 0.3);
}
#[tokio::test]
async fn test_token_estimation() {
let analyzer = ComplexityAnalyzer::new();
let short_msg = create_test_message("Hello");
let state = State::new();
let short_assessment = analyzer.assess(&short_msg, &state).await.unwrap();
let long_msg = create_test_message(
"This is a much longer message that contains many words and will require \
more tokens to process and respond to appropriately.",
);
let long_assessment = analyzer.assess(&long_msg, &state).await.unwrap();
assert!(
long_assessment.estimated_tokens.total_tokens
> short_assessment.estimated_tokens.total_tokens
);
}
#[test]
fn test_complexity_level_display() {
assert_eq!(ComplexityLevel::Trivial.to_string(), "TRIVIAL");
assert_eq!(ComplexityLevel::Complex.to_string(), "COMPLEX");
}
}