use omk::runtime::classifier::{
cache::cache_key, cache::new_session_cache, classify, heuristic::heuristic_classify,
heuristic::HeuristicOutcome, llm_backend::MockLlmClassifier, llm_backend::RawLlmClassification,
ClassificationSource, ClassifierInput, Intent, LlmClassifierBackend,
};
use std::path::PathBuf;
use std::sync::Arc;
static TEST_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
const PROMPTS_LABELED: &[(&str, Intent)] = &[
("what does build_task_graph do?", Intent::Trivial),
(
"rename build_task_graph to compile_task_graph everywhere",
Intent::Small,
),
(
"add input validation to /signup and write tests",
Intent::Medium,
),
(
"add OAuth login with Google and GitHub plus rate limiting",
Intent::Large,
),
("what is the meaning of life?", Intent::Trivial),
("explain quantum computing", Intent::Trivial),
("show me the main function", Intent::Trivial),
("how does the cache work?", Intent::Trivial),
("where is the config file?", Intent::Trivial),
("does this compile?", Intent::Trivial),
("define monad", Intent::Trivial),
("summarise this module", Intent::Trivial),
("summary of the changes", Intent::Trivial),
("fix the typo in README", Intent::Small),
("refactor utils.rs into smaller functions", Intent::Small),
("add a test for edge case", Intent::Small),
("update the error message", Intent::Small),
("extract helper function", Intent::Small),
("rename variable x to count", Intent::Small),
("move logic into separate module", Intent::Medium),
(
"implement rate limiting middleware with tests",
Intent::Medium,
),
("add database migration for users table", Intent::Medium),
("refactor auth module and add unit tests", Intent::Medium),
("create new CLI command with subcommands", Intent::Medium),
("redesign the API with breaking changes", Intent::Large),
("implement distributed consensus algorithm", Intent::Large),
("add support for multiple cloud providers", Intent::Large),
("rewrite the rendering engine", Intent::Large),
("migrate from REST to GraphQL", Intent::Large),
("introduce plugin architecture", Intent::Large),
("add end-to-end encryption", Intent::Large),
("build a new frontend framework", Intent::Large),
("integrate with external payment gateway", Intent::Medium),
("add caching layer with invalidation", Intent::Medium),
("optimise hot path in query engine", Intent::Small),
("document the public API", Intent::Trivial),
("what is this function doing?", Intent::Trivial),
("how do I run tests?", Intent::Trivial),
("explain the build system", Intent::Trivial),
("show dependencies", Intent::Trivial),
("is this thread-safe?", Intent::Trivial),
("where are the types defined?", Intent::Trivial),
("does this handle null?", Intent::Trivial),
("define the interface", Intent::Trivial),
("summarise errors.rs", Intent::Trivial),
("summary of PR #123", Intent::Trivial),
("fix compilation error in lib.rs", Intent::Small),
("add logging to debug flow", Intent::Small),
("remove unused import", Intent::Small),
("update version to 1.0", Intent::Small),
];
fn intent_to_str(intent: Intent) -> &'static str {
match intent {
Intent::Trivial => "trivial",
Intent::Small => "small",
Intent::Medium => "medium",
Intent::Large => "large",
}
}
#[tokio::test]
async fn test_dataset_agreement_at_least_85_percent() {
let _guard = TEST_LOCK.lock().await;
let mut mock = MockLlmClassifier::new();
for (prompt, intent) in PROMPTS_LABELED {
let raw_json = format!(
r#"{{"intent":"{}","confidence":0.9,"reasoning":"mock","signals":[],"suggested_action":null}}"#,
intent_to_str(*intent)
);
let hash = cache_key(prompt);
mock = mock.with_answer(
hash,
RawLlmClassification {
raw_json,
model: "mock".to_string(),
tokens_in: 10,
tokens_out: 10,
},
);
}
let backend: Arc<dyn LlmClassifierBackend> = Arc::new(mock);
let mut cache = new_session_cache();
let mut correct = 0usize;
for (prompt, expected) in PROMPTS_LABELED {
let input = ClassifierInput {
prompt: prompt.to_string(),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
let output = classify(input, backend.as_ref(), &mut cache).await;
if output.intent == *expected {
correct += 1;
}
}
let ratio = correct as f32 / PROMPTS_LABELED.len() as f32;
assert!(ratio >= 0.85, "agreement {:.2}% < 85%", ratio * 100.0);
}
#[tokio::test]
async fn test_heuristic_catches_trivial_prefix() {
let _guard = TEST_LOCK.lock().await;
let mock = MockLlmClassifier::new();
let backend: Arc<dyn LlmClassifierBackend> = Arc::new(mock);
let mut cache = new_session_cache();
let input = ClassifierInput {
prompt: "what is X?".to_string(),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
let output = classify(input, backend.as_ref(), &mut cache).await;
assert_eq!(output.intent, Intent::Trivial);
assert_eq!(output.source, ClassificationSource::Heuristic);
assert!(output.latency_ms < 5);
}
#[tokio::test]
async fn test_heuristic_rejects_slash_command() {
let _guard = TEST_LOCK.lock().await;
let outcome = heuristic_classify("/classify foo");
assert!(matches!(outcome, HeuristicOutcome::SlashCommand));
}
#[tokio::test]
async fn test_heuristic_rejects_empty() {
let _guard = TEST_LOCK.lock().await;
assert!(matches!(heuristic_classify(""), HeuristicOutcome::Empty));
assert!(matches!(
heuristic_classify(" \n"),
HeuristicOutcome::Empty
));
}
#[tokio::test]
async fn test_cache_hit_returns_cached_result_under_5ms() {
let _guard = TEST_LOCK.lock().await;
let mut mock = MockLlmClassifier::new();
let hash = cache_key("repeat me");
mock = mock.with_answer(
hash,
RawLlmClassification {
raw_json: r#"{"intent":"medium","confidence":0.82,"reasoning":"mock","signals":[],"suggested_action":null}"#.to_string(),
model: "mock".to_string(),
tokens_in: 10,
tokens_out: 10,
},
);
let backend: Arc<dyn LlmClassifierBackend> = Arc::new(mock);
let mut cache = new_session_cache();
let input = ClassifierInput {
prompt: "repeat me".to_string(),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
let first = classify(input.clone(), backend.as_ref(), &mut cache).await;
assert_eq!(first.source, ClassificationSource::Llm);
let second = classify(input, backend.as_ref(), &mut cache).await;
assert_eq!(second.source, ClassificationSource::Cache);
assert!(second.latency_ms <= 5);
}
#[tokio::test]
async fn test_llm_fallback_on_malformed_json_returns_heuristic_not_large() {
let _guard = TEST_LOCK.lock().await;
let mut mock = MockLlmClassifier::new();
let hash = cache_key("fix the auth flow rewrite security module");
mock = mock.with_answer(
hash,
RawLlmClassification {
raw_json: "not even json".to_string(),
model: "mock".to_string(),
tokens_in: 10,
tokens_out: 10,
},
);
let backend: Arc<dyn LlmClassifierBackend> = Arc::new(mock);
let mut cache = new_session_cache();
let input = ClassifierInput {
prompt: "fix the auth flow rewrite security module".to_string(),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
let output = classify(input, backend.as_ref(), &mut cache).await;
assert_ne!(output.intent, Intent::Large);
assert!(output.fallback);
}
#[tokio::test]
async fn test_llm_fallback_on_transport_failure_returns_heuristic_not_large() {
let _guard = TEST_LOCK.lock().await;
let mock = MockLlmClassifier::new();
let backend: Arc<dyn LlmClassifierBackend> = Arc::new(mock);
let mut cache = new_session_cache();
let input = ClassifierInput {
prompt: "add new endpoint with tests".to_string(),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
let output = classify(input, backend.as_ref(), &mut cache).await;
assert_ne!(output.intent, Intent::Large);
assert!(output.fallback);
}
#[tokio::test]
async fn test_telemetry_record_does_not_contain_raw_prompt() {
let _guard = TEST_LOCK.lock().await;
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("XDG_STATE_HOME", tmp.path());
let mut mock = MockLlmClassifier::new();
let hash = cache_key("private API key abc123");
mock = mock.with_answer(
hash,
RawLlmClassification {
raw_json: r#"{"intent":"small","confidence":0.8,"reasoning":"mock","signals":[],"suggested_action":null}"#.to_string(),
model: "mock".to_string(),
tokens_in: 10,
tokens_out: 10,
},
);
let backend: Arc<dyn LlmClassifierBackend> = Arc::new(mock);
let mut cache = new_session_cache();
let input = ClassifierInput {
prompt: "private API key abc123".to_string(),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
let _ = classify(input, backend.as_ref(), &mut cache).await;
let telemetry_path = tmp.path().join("omk").join("telemetry.jsonl");
let contents = tokio::fs::read_to_string(&telemetry_path).await.unwrap();
assert!(!contents.contains("abc123"));
assert!(contents.contains("prompt_hash"));
std::env::remove_var("XDG_STATE_HOME");
}
#[tokio::test]
async fn test_telemetry_compact_drops_stale_records() {
let _guard = TEST_LOCK.lock().await;
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("XDG_STATE_HOME", tmp.path());
let record = omk::runtime::classifier::telemetry::TelemetryRecord {
ts: chrono::Utc::now() - chrono::Duration::days(100),
intent: Intent::Small,
confidence: 0.8,
source: ClassificationSource::Llm,
latency_ms: 100,
prompt_hash: "deadbeef".to_string(),
fallback: false,
};
omk::runtime::classifier::telemetry::append(record)
.await
.unwrap();
omk::runtime::classifier::telemetry::compact_if_stale(30)
.await
.unwrap();
let telemetry_path = tmp.path().join("omk").join("telemetry.jsonl");
let contents = tokio::fs::read_to_string(&telemetry_path).await.unwrap();
assert!(!contents.contains("deadbeef"));
std::env::remove_var("XDG_STATE_HOME");
}
#[tokio::test]
async fn test_concurrent_classify_does_not_corrupt_telemetry() {
let _guard = TEST_LOCK.lock().await;
let tmp = tempfile::tempdir().unwrap();
std::env::set_var("XDG_STATE_HOME", tmp.path());
let mut mock = MockLlmClassifier::new();
for i in 0..10 {
let prompt = format!("prompt number {}", i);
let hash = cache_key(&prompt);
mock = mock.with_answer(
hash,
RawLlmClassification {
raw_json: format!(
r#"{{"intent":"small","confidence":0.8,"reasoning":"mock {}","signals":[],"suggested_action":null}}"#,
i
),
model: "mock".to_string(),
tokens_in: 10,
tokens_out: 10,
},
);
}
let backend = Arc::new(mock);
let mut handles = Vec::new();
for i in 0..10 {
let backend = Arc::clone(&backend);
handles.push(tokio::spawn(async move {
let mut cache = new_session_cache();
let input = ClassifierInput {
prompt: format!("prompt number {}", i),
recent_conversation: vec![],
project_root: PathBuf::from("."),
};
classify(input, backend.as_ref(), &mut cache).await
}));
}
let mut results = Vec::new();
for h in handles {
results.push(h.await.unwrap());
}
assert_eq!(results.len(), 10);
let telemetry_path = tmp.path().join("omk").join("telemetry.jsonl");
let contents = tokio::fs::read_to_string(&telemetry_path).await.unwrap();
let lines: Vec<&str> = contents.lines().collect();
assert_eq!(lines.len(), 10);
for line in &lines {
let _: serde_json::Value = serde_json::from_str(line).expect("valid JSON per line");
}
std::env::remove_var("XDG_STATE_HOME");
}