1use crate::types::Message;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum TaskClass {
9 Classification,
10 Extraction,
11 Code,
12 Agent,
13 Chat,
14}
15
16pub fn classify(messages: &[Message]) -> TaskClass {
17 let last_user = messages
18 .iter()
19 .rev()
20 .find(|m| m.role == "user")
21 .and_then(|m| m.content.as_str())
22 .map(|s| s.to_lowercase())
23 .unwrap_or_default();
24
25 if last_user.contains("classify")
26 || last_user.contains("category")
27 || last_user.contains("label as")
28 || last_user.contains("is this")
29 || last_user.contains("yes or no")
30 {
31 return TaskClass::Classification;
32 }
33 if last_user.contains("extract")
34 || last_user.contains("parse")
35 || last_user.contains("structured output")
36 || last_user.contains("json schema")
37 || last_user.contains("entity")
38 || last_user.contains("pull out")
39 {
40 return TaskClass::Extraction;
41 }
42 if last_user.contains("function")
43 || last_user.contains("code")
44 || last_user.contains("```")
45 || last_user.contains("refactor")
46 || last_user.contains("implement")
47 {
48 return TaskClass::Code;
49 }
50 if messages.len() > 4 && messages.iter().any(|m| m.role == "tool") {
51 return TaskClass::Agent;
52 }
53 TaskClass::Chat
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use serde_json::json;
60
61 fn u(t: &str) -> Message {
62 Message {
63 role: "user".into(),
64 content: json!(t),
65 }
66 }
67
68 #[test]
69 fn classify_classification() {
70 assert_eq!(
71 classify(&[u("Is this email spam? yes or no")]),
72 TaskClass::Classification
73 );
74 }
75
76 #[test]
77 fn classify_extraction() {
78 assert_eq!(
79 classify(&[u("Extract the names from this text")]),
80 TaskClass::Extraction
81 );
82 }
83
84 #[test]
85 fn classify_code() {
86 assert_eq!(
87 classify(&[u("write a function that adds two numbers")]),
88 TaskClass::Code
89 );
90 }
91
92 #[test]
93 fn classify_chat_default() {
94 assert_eq!(classify(&[u("Hi how are you")]), TaskClass::Chat);
95 }
96}