alith_client/workflows/classify/
subject_of_text.rs

1use crate::components::cascade::CascadeFlow;
2use crate::{components::cascade::step::StepConfig, primitives::*};
3
4use alith_interface::requests::completion::CompletionRequest;
5
6#[derive(Clone)]
7pub struct ClassifySubjectOfText {
8    pub base_req: CompletionRequest,
9    pub content: String,
10    pub flow: CascadeFlow,
11    pub content_strings: Vec<String>,
12    pub subject: Option<String>,
13    subject_strings: Vec<String>,
14    default_grammar: SentencesPrimitive,
15    default_step_config: StepConfig,
16}
17
18impl ClassifySubjectOfText {
19    pub fn new<T: AsRef<str>>(base_req: CompletionRequest, content: T) -> Self {
20        let mut grammar: SentencesPrimitive = SentencesPrimitive::default();
21        grammar
22            .min_count(1)
23            .max_count(3)
24            .disallowed_char('\'')
25            .disallowed_char('(')
26            .disallowed_char(')')
27            .capitalize_first(false);
28        let mut step_config: StepConfig = StepConfig::default();
29        step_config.stop_word_done("\n").grammar(grammar.grammar());
30
31        Self {
32            base_req,
33            content: content.as_ref().to_owned(),
34            flow: CascadeFlow::new("ClassifySubjectOfText"),
35            content_strings: Vec::new(),
36            subject: None,
37            subject_strings: Vec::new(),
38            default_grammar: grammar,
39            default_step_config: step_config,
40        }
41    }
42
43    pub async fn run(mut self) -> crate::Result<Self> {
44        let mut count = 1;
45        while count <= self.base_req.config.retry_after_fail_n_times {
46            match self.run_cascade().await {
47                Ok(_) => break,
48                Err(e) => {
49                    crate::error!("Failed to classify entity: {}", e);
50                    count += 1;
51                    if count == self.base_req.config.retry_after_fail_n_times {
52                        crate::bail!("Failed to classify entity after {} attempts: {}", count, e);
53                    }
54                    self.base_req.reset_completion_request();
55                    self.flow = CascadeFlow::new("ClassifySubjectOfText");
56                }
57            }
58        }
59        Ok(self)
60    }
61
62    async fn run_cascade(&mut self) -> crate::Result<()> {
63        self.flow.open_cascade();
64        let task = indoc::formatdoc! {"
65        Explain like I'm five; what is the subject of the text:
66        '{}'",
67        self.content
68        };
69        self.flow.new_round(task).step_separator('\n');
70        self.flow.last_round()?.open_round(&mut self.base_req)?;
71
72        self.default_step_config
73            .step_prefix("In the text, the main thing a five-year-old would see is: \"")
74            .grammar(self.default_grammar.max_count(2).grammar());
75        let result = self.run_it().await?;
76        self.subject_strings.push(result);
77
78        self.default_step_config
79            .step_prefix(
80                "An english teacher would clarify that the person or thing that is being discussed, described, or dealt with, is: \"",
81            )
82            .grammar(self.default_grammar.max_count(2).capitalize_first(false).grammar());
83        let result = self.run_it().await?;
84        self.subject_strings.push(result);
85
86        self.ensure_options().await?;
87
88        self.default_step_config
89            .step_prefix(format!(
90                "So, the primary subject of the text '{}' is: \"",
91                self.content
92            ))
93            .grammar(self.default_grammar.max_count(1).grammar());
94        let result = self.run_it().await?;
95        self.subject_strings.push(result);
96
97        let possible_subjects = self.extract_quoted_text();
98        if possible_subjects.len() == 1 {
99            self.subject = Some(possible_subjects[0].clone());
100        } else {
101            self.default_step_config
102                .step_prefix(
103                    "To restate so a five-year-old could understand, the primary subject is: ",
104                )
105                .grammar(
106                    ExactStringPrimitive::default()
107                        .add_strings_to_allowed(&possible_subjects)
108                        .grammar(),
109                );
110
111            self.run_it().await?;
112            self.subject = self.flow.last_round()?.last_step()?.primitive_result();
113        }
114
115        self.flow.last_round()?.close_round(&mut self.base_req)?;
116        self.flow.close_cascade()?;
117
118        Ok(())
119    }
120
121    async fn run_it(&mut self) -> crate::Result<String> {
122        self.flow
123            .last_round()?
124            .add_inference_step(&self.default_step_config);
125        self.flow
126            .last_round()?
127            .run_next_step(&mut self.base_req)
128            .await?;
129        let result = self
130            .flow
131            .last_round()?
132            .last_step()?
133            .display_step_outcome()?;
134        Ok(result)
135    }
136
137    async fn ensure_options(&mut self) -> crate::Result<()> {
138        let mut possible_subjects = self.extract_quoted_text();
139
140        if !possible_subjects.is_empty() {
141            return Ok(());
142        };
143        self.default_step_config
144            .step_prefix("The nouns in the text are: \"")
145            .grammar(self.default_grammar.max_count(1).grammar());
146        let result = self.run_it().await?;
147        self.subject_strings.push(result);
148
149        self.default_step_config
150            .step_prefix("The proper nouns in the text are: \"")
151            .grammar(self.default_grammar.max_count(1).grammar());
152        let result = self.run_it().await?;
153        self.subject_strings.push(result);
154
155        self.default_step_config
156            .step_prefix("The common nouns in the text are: \"")
157            .grammar(self.default_grammar.max_count(1).grammar());
158        let result = self.run_it().await?;
159        self.subject_strings.push(result);
160
161        possible_subjects = self.extract_quoted_text();
162        if possible_subjects.is_empty() {
163            crate::bail!("Failed to classify subject: no qouted subject returned");
164        }
165        Ok(())
166    }
167
168    fn extract_quoted_text(&self) -> Vec<String> {
169        let mut result = Vec::new();
170
171        for input in &self.subject_strings {
172            let mut current_quote = None;
173            let mut start_index = 0;
174            for (i, c) in input.char_indices() {
175                match (current_quote, c) {
176                    (None, '"') => {
177                        current_quote = Some(c);
178                        start_index = i + 1;
179                    }
180                    (Some(quote), c) if c == quote => {
181                        result.push(input[start_index..i].to_string());
182                        current_quote = None;
183                    }
184                    _ => {}
185                }
186            }
187        }
188        let mut cleaned = vec![];
189        for res in result.iter_mut() {
190            let new_res = res
191                .trim_start_matches(|c: char| !c.is_alphanumeric())
192                .trim_end_matches(|c: char| !c.is_alphanumeric())
193                .to_lowercase()
194                .split_whitespace()
195                .filter(|word| word.len() > 1 && *word != "the")
196                .collect::<Vec<_>>()
197                .join(" ")
198                .to_owned();
199            cleaned.push(new_res);
200        }
201        cleaned.sort();
202        cleaned.dedup();
203        let lower_content = self.content.to_lowercase();
204        cleaned.retain(|x| lower_content.contains(x));
205
206        cleaned
207    }
208}
209
210impl std::fmt::Display for ClassifySubjectOfText {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        writeln!(f)?;
213        writeln!(f, "ClassifySubjectOfText:")?;
214        crate::i_nln(f, format_args!("content: \"{}\"", self.content))?;
215        crate::i_nln(f, format_args!("subject: {:?}", self.subject))?;
216        crate::i_nln(f, format_args!("duration: {:?}", self.flow.duration))?;
217        Ok(())
218    }
219}