1use crate::extract::FlexibleExtractor;
4use crate::extract::core::ContentExtractor;
5use std::str::FromStr;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
10pub enum IntentError {
11    #[error("Extraction failed: Tag '{tag}' not found in response")]
12    TagNotFound { tag: String },
13
14    #[error("Parsing failed: Could not parse '{value}' into a valid intent")]
15    ParseFailed { value: String },
16
17    #[error(transparent)]
18    Other(#[from] anyhow::Error),
19}
20
21pub trait IntentExtractor<T>
25where
26    T: FromStr,
27{
28    fn extract_intent(&self, text: &str) -> Result<T, IntentError>;
30}
31
32pub struct PromptBasedExtractor {
38    extractor: FlexibleExtractor,
39    tag: String,
40}
41
42impl PromptBasedExtractor {
43    pub fn new(tag: &str) -> Self {
45        Self {
46            extractor: FlexibleExtractor::new(),
47            tag: tag.to_string(),
48        }
49    }
50}
51
52impl<T> IntentExtractor<T> for PromptBasedExtractor
53where
54    T: FromStr,
55{
56    fn extract_intent(&self, text: &str) -> Result<T, IntentError> {
57        let extracted_str = self
59            .extractor
60            .extract_tagged(text, &self.tag)
61            .ok_or_else(|| IntentError::TagNotFound {
62                tag: self.tag.clone(),
63            })?;
64
65        T::from_str(&extracted_str).map_err(|_| IntentError::ParseFailed {
67            value: extracted_str.to_string(),
68        })
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    #[derive(Debug, PartialEq)]
77    enum TestIntent {
78        Login,
79        Logout,
80    }
81
82    impl FromStr for TestIntent {
83        type Err = String;
84
85        fn from_str(s: &str) -> Result<Self, Self::Err> {
86            match s {
87                "Login" => Ok(TestIntent::Login),
88                "Logout" => Ok(TestIntent::Logout),
89                _ => Err(format!("Unknown intent: {}", s)),
90            }
91        }
92    }
93
94    #[test]
95    fn test_extract_intent_success() {
96        let extractor = PromptBasedExtractor::new("intent");
97        let text = "<intent>Login</intent>";
98        let result: Result<TestIntent, _> = IntentExtractor::extract_intent(&extractor, text);
99        assert_eq!(result.unwrap(), TestIntent::Login);
100    }
101
102    #[test]
103    fn test_extract_intent_tag_not_found() {
104        let extractor = PromptBasedExtractor::new("intent");
105        let text = "No intent tag here";
106        let result: Result<TestIntent, _> = IntentExtractor::extract_intent(&extractor, text);
107
108        match result {
109            Err(IntentError::TagNotFound { tag }) => {
110                assert_eq!(tag, "intent");
111            }
112            _ => panic!("Expected TagNotFound error"),
113        }
114    }
115
116    #[test]
117    fn test_extract_intent_parse_failed() {
118        let extractor = PromptBasedExtractor::new("intent");
119        let text = "<intent>Invalid</intent>";
120        let result: Result<TestIntent, _> = IntentExtractor::extract_intent(&extractor, text);
121
122        match result {
123            Err(IntentError::ParseFailed { value }) => {
124                assert_eq!(value, "Invalid");
125            }
126            _ => panic!("Expected ParseFailed error"),
127        }
128    }
129}