1use crate::extract::FlexibleExtractor;
4use crate::extract::core::ContentExtractor;
5use std::str::FromStr;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
10pub enum IntentError {
11 #[error("Extraction failed: Tag '{tag}' not found in response")]
12 TagNotFound { tag: String },
13
14 #[error("Parsing failed: Could not parse '{value}' into a valid intent")]
15 ParseFailed { value: String },
16
17 #[error(transparent)]
18 Other(#[from] anyhow::Error),
19}
20
21pub trait IntentExtractor<T>
25where
26 T: FromStr,
27{
28 fn extract_intent(&self, text: &str) -> Result<T, IntentError>;
30}
31
32pub struct PromptBasedExtractor {
38 extractor: FlexibleExtractor,
39 tag: String,
40}
41
42impl PromptBasedExtractor {
43 pub fn new(tag: &str) -> Self {
45 Self {
46 extractor: FlexibleExtractor::new(),
47 tag: tag.to_string(),
48 }
49 }
50}
51
52impl<T> IntentExtractor<T> for PromptBasedExtractor
53where
54 T: FromStr,
55{
56 fn extract_intent(&self, text: &str) -> Result<T, IntentError> {
57 let extracted_str = self
59 .extractor
60 .extract_tagged(text, &self.tag)
61 .ok_or_else(|| IntentError::TagNotFound {
62 tag: self.tag.clone(),
63 })?;
64
65 T::from_str(&extracted_str).map_err(|_| IntentError::ParseFailed {
67 value: extracted_str.to_string(),
68 })
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75
76 #[derive(Debug, PartialEq)]
77 enum TestIntent {
78 Login,
79 Logout,
80 }
81
82 impl FromStr for TestIntent {
83 type Err = String;
84
85 fn from_str(s: &str) -> Result<Self, Self::Err> {
86 match s {
87 "Login" => Ok(TestIntent::Login),
88 "Logout" => Ok(TestIntent::Logout),
89 _ => Err(format!("Unknown intent: {}", s)),
90 }
91 }
92 }
93
94 #[test]
95 fn test_extract_intent_success() {
96 let extractor = PromptBasedExtractor::new("intent");
97 let text = "<intent>Login</intent>";
98 let result: Result<TestIntent, _> = IntentExtractor::extract_intent(&extractor, text);
99 assert_eq!(result.unwrap(), TestIntent::Login);
100 }
101
102 #[test]
103 fn test_extract_intent_tag_not_found() {
104 let extractor = PromptBasedExtractor::new("intent");
105 let text = "No intent tag here";
106 let result: Result<TestIntent, _> = IntentExtractor::extract_intent(&extractor, text);
107
108 match result {
109 Err(IntentError::TagNotFound { tag }) => {
110 assert_eq!(tag, "intent");
111 }
112 _ => panic!("Expected TagNotFound error"),
113 }
114 }
115
116 #[test]
117 fn test_extract_intent_parse_failed() {
118 let extractor = PromptBasedExtractor::new("intent");
119 let text = "<intent>Invalid</intent>";
120 let result: Result<TestIntent, _> = IntentExtractor::extract_intent(&extractor, text);
121
122 match result {
123 Err(IntentError::ParseFailed { value }) => {
124 assert_eq!(value, "Invalid");
125 }
126 _ => panic!("Expected ParseFailed error"),
127 }
128 }
129}