use llm_toolkit::{IntentError, IntentExtractor, IntentFrame, define_intent};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[define_intent]
#[intent(
prompt = r#"
You are a helpful AI assistant. Analyze the user's query and determine their intent.
User Query: {{ user_query }}
Based on the query above, classify the user's intent into one of the following categories:
{{ intents_doc }}
Respond with your classification wrapped in the appropriate tags.
"#,
extractor_tag = "intent"
)]
enum UserIntent {
SearchQuery,
CreateContent,
RequestHelp,
AnalyzeData,
Greeting,
}
impl FromStr for UserIntent {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"SearchQuery" => Ok(UserIntent::SearchQuery),
"CreateContent" => Ok(UserIntent::CreateContent),
"RequestHelp" => Ok(UserIntent::RequestHelp),
"AnalyzeData" => Ok(UserIntent::AnalyzeData),
"Greeting" => Ok(UserIntent::Greeting),
_ => Err(format!("Unknown UserIntent variant: {}", s)),
}
}
}
#[test]
fn test_generated_prompt_function() {
let prompt = build_user_intent_prompt("How can I learn Rust programming?");
assert!(prompt.contains("How can I learn Rust programming?"));
assert!(prompt.contains("The user wants to search for information"));
assert!(prompt.contains("SearchQuery"));
assert!(prompt.contains("The user wants to create or generate new content"));
assert!(prompt.contains("CreateContent"));
assert!(prompt.contains("The user needs help with a problem"));
assert!(prompt.contains("RequestHelp"));
assert!(prompt.contains("Greeting"));
assert!(prompt.contains("User Query:"));
assert!(prompt.contains("classify the user's intent"));
}
#[test]
fn test_generated_extractor_success() {
let extractor = UserIntentExtractor;
let mock_response = r#"
Based on the user's query about learning Rust programming, I can see they are looking for information and resources.
<intent>SearchQuery</intent>
This is clearly a search for educational information.
"#;
let result = extractor.extract_intent(mock_response);
assert!(result.is_ok());
assert_eq!(result.unwrap(), UserIntent::SearchQuery);
let mock_response_2 = r#"
The user wants to generate something new.
<intent>CreateContent</intent>
"#;
let result_2 = extractor.extract_intent(mock_response_2);
assert!(result_2.is_ok());
assert_eq!(result_2.unwrap(), UserIntent::CreateContent);
let mock_response_3 = "<intent>RequestHelp</intent>";
let result_3 = extractor.extract_intent(mock_response_3);
assert!(result_3.is_ok());
assert_eq!(result_3.unwrap(), UserIntent::RequestHelp);
let mock_response_4 =
"Let me analyze this... <intent>Greeting</intent> Yes, this is just a greeting.";
let result_4 = extractor.extract_intent(mock_response_4);
assert!(result_4.is_ok());
assert_eq!(result_4.unwrap(), UserIntent::Greeting);
}
#[test]
fn test_generated_extractor_failure() {
let extractor = UserIntentExtractor;
let mock_response = "This response doesn't contain the required tag at all.";
let result = extractor.extract_intent(mock_response);
assert!(result.is_err());
if let Err(err) = result {
match err {
IntentError::TagNotFound { tag } => {
assert_eq!(tag, "intent");
}
_ => panic!("Expected TagNotFound error, got: {:?}", err),
}
}
let mock_response_2 = "<wrong_tag>SearchQuery</wrong_tag>";
let result_2 = extractor.extract_intent(mock_response_2);
assert!(result_2.is_err());
let mock_response_3 = "<intent>InvalidVariant</intent>";
let result_3 = extractor.extract_intent(mock_response_3);
assert!(result_3.is_err());
if let Err(err) = result_3 {
match err {
IntentError::ParseFailed { .. } => {
}
_ => panic!("Expected ParseFailed error, got: {:?}", err),
}
}
let mock_response_4 = "<intent></intent>";
let result_4 = extractor.extract_intent(mock_response_4);
assert!(result_4.is_err());
}
#[test]
fn test_intent_frame_integration() {
let frame = IntentFrame::new("input", "intent");
let mock_response =
"Analysis complete. <intent>AnalyzeData</intent> The user wants to analyze data.";
let result: Result<UserIntent, _> = frame.extract_intent(mock_response);
assert!(result.is_ok());
assert_eq!(result.unwrap(), UserIntent::AnalyzeData);
}
#[test]
fn test_enum_variant_serialization() {
let intent = UserIntent::SearchQuery;
let json = serde_json::to_string(&intent).unwrap();
assert_eq!(json, "\"SearchQuery\"");
let deserialized: UserIntent = serde_json::from_str("\"CreateContent\"").unwrap();
assert_eq!(deserialized, UserIntent::CreateContent);
}
#[test]
fn test_enum_variant_from_str() {
assert_eq!(
UserIntent::from_str("SearchQuery").unwrap(),
UserIntent::SearchQuery
);
assert_eq!(
UserIntent::from_str("CreateContent").unwrap(),
UserIntent::CreateContent
);
assert_eq!(
UserIntent::from_str("RequestHelp").unwrap(),
UserIntent::RequestHelp
);
assert_eq!(
UserIntent::from_str("AnalyzeData").unwrap(),
UserIntent::AnalyzeData
);
assert_eq!(
UserIntent::from_str("Greeting").unwrap(),
UserIntent::Greeting
);
assert!(UserIntent::from_str("InvalidVariant").is_err());
}
#[test]
fn test_multiple_intents_in_same_module() {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[define_intent]
#[intent(
prompt = "Classify the sentiment: {{ text }}\n\n{{ intents_doc }}",
extractor_tag = "sentiment"
)]
enum SentimentIntent {
Positive,
Negative,
Neutral,
}
impl FromStr for SentimentIntent {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Positive" => Ok(SentimentIntent::Positive),
"Negative" => Ok(SentimentIntent::Negative),
"Neutral" => Ok(SentimentIntent::Neutral),
_ => Err(format!("Unknown SentimentIntent variant: {}", s)),
}
}
}
let prompt = build_sentiment_intent_prompt("This is amazing!");
assert!(prompt.contains("This is amazing!"));
assert!(prompt.contains("Positive sentiment"));
let extractor = SentimentIntentExtractor;
let response = "<sentiment>Positive</sentiment>";
assert_eq!(
extractor.extract_intent(response).unwrap(),
SentimentIntent::Positive
);
}