alith_client/workflows/classify/
subject_of_text.rs1use crate::components::cascade::CascadeFlow;
2use crate::{components::cascade::step::StepConfig, primitives::*};
3
4use alith_interface::requests::completion::CompletionRequest;
5
6#[derive(Clone)]
7pub struct ClassifySubjectOfText {
8 pub base_req: CompletionRequest,
9 pub content: String,
10 pub flow: CascadeFlow,
11 pub content_strings: Vec<String>,
12 pub subject: Option<String>,
13 subject_strings: Vec<String>,
14 default_grammar: SentencesPrimitive,
15 default_step_config: StepConfig,
16}
17
18impl ClassifySubjectOfText {
19 pub fn new<T: AsRef<str>>(base_req: CompletionRequest, content: T) -> Self {
20 let mut grammar: SentencesPrimitive = SentencesPrimitive::default();
21 grammar
22 .min_count(1)
23 .max_count(3)
24 .disallowed_char('\'')
25 .disallowed_char('(')
26 .disallowed_char(')')
27 .capitalize_first(false);
28 let mut step_config: StepConfig = StepConfig::default();
29 step_config.stop_word_done("\n").grammar(grammar.grammar());
30
31 Self {
32 base_req,
33 content: content.as_ref().to_owned(),
34 flow: CascadeFlow::new("ClassifySubjectOfText"),
35 content_strings: Vec::new(),
36 subject: None,
37 subject_strings: Vec::new(),
38 default_grammar: grammar,
39 default_step_config: step_config,
40 }
41 }
42
43 pub async fn run(mut self) -> crate::Result<Self> {
44 let mut count = 1;
45 while count <= self.base_req.config.retry_after_fail_n_times {
46 match self.run_cascade().await {
47 Ok(_) => break,
48 Err(e) => {
49 crate::error!("Failed to classify entity: {}", e);
50 count += 1;
51 if count == self.base_req.config.retry_after_fail_n_times {
52 crate::bail!("Failed to classify entity after {} attempts: {}", count, e);
53 }
54 self.base_req.reset_completion_request();
55 self.flow = CascadeFlow::new("ClassifySubjectOfText");
56 }
57 }
58 }
59 Ok(self)
60 }
61
62 async fn run_cascade(&mut self) -> crate::Result<()> {
63 self.flow.open_cascade();
64 let task = indoc::formatdoc! {"
65 Explain like I'm five; what is the subject of the text:
66 '{}'",
67 self.content
68 };
69 self.flow.new_round(task).step_separator('\n');
70 self.flow.last_round()?.open_round(&mut self.base_req)?;
71
72 self.default_step_config
73 .step_prefix("In the text, the main thing a five-year-old would see is: \"")
74 .grammar(self.default_grammar.max_count(2).grammar());
75 let result = self.run_it().await?;
76 self.subject_strings.push(result);
77
78 self.default_step_config
79 .step_prefix(
80 "An english teacher would clarify that the person or thing that is being discussed, described, or dealt with, is: \"",
81 )
82 .grammar(self.default_grammar.max_count(2).capitalize_first(false).grammar());
83 let result = self.run_it().await?;
84 self.subject_strings.push(result);
85
86 self.ensure_options().await?;
87
88 self.default_step_config
89 .step_prefix(format!(
90 "So, the primary subject of the text '{}' is: \"",
91 self.content
92 ))
93 .grammar(self.default_grammar.max_count(1).grammar());
94 let result = self.run_it().await?;
95 self.subject_strings.push(result);
96
97 let possible_subjects = self.extract_quoted_text();
98 if possible_subjects.len() == 1 {
99 self.subject = Some(possible_subjects[0].clone());
100 } else {
101 self.default_step_config
102 .step_prefix(
103 "To restate so a five-year-old could understand, the primary subject is: ",
104 )
105 .grammar(
106 ExactStringPrimitive::default()
107 .add_strings_to_allowed(&possible_subjects)
108 .grammar(),
109 );
110
111 self.run_it().await?;
112 self.subject = self.flow.last_round()?.last_step()?.primitive_result();
113 }
114
115 self.flow.last_round()?.close_round(&mut self.base_req)?;
116 self.flow.close_cascade()?;
117
118 Ok(())
119 }
120
121 async fn run_it(&mut self) -> crate::Result<String> {
122 self.flow
123 .last_round()?
124 .add_inference_step(&self.default_step_config);
125 self.flow
126 .last_round()?
127 .run_next_step(&mut self.base_req)
128 .await?;
129 let result = self
130 .flow
131 .last_round()?
132 .last_step()?
133 .display_step_outcome()?;
134 Ok(result)
135 }
136
137 async fn ensure_options(&mut self) -> crate::Result<()> {
138 let mut possible_subjects = self.extract_quoted_text();
139
140 if !possible_subjects.is_empty() {
141 return Ok(());
142 };
143 self.default_step_config
144 .step_prefix("The nouns in the text are: \"")
145 .grammar(self.default_grammar.max_count(1).grammar());
146 let result = self.run_it().await?;
147 self.subject_strings.push(result);
148
149 self.default_step_config
150 .step_prefix("The proper nouns in the text are: \"")
151 .grammar(self.default_grammar.max_count(1).grammar());
152 let result = self.run_it().await?;
153 self.subject_strings.push(result);
154
155 self.default_step_config
156 .step_prefix("The common nouns in the text are: \"")
157 .grammar(self.default_grammar.max_count(1).grammar());
158 let result = self.run_it().await?;
159 self.subject_strings.push(result);
160
161 possible_subjects = self.extract_quoted_text();
162 if possible_subjects.is_empty() {
163 crate::bail!("Failed to classify subject: no qouted subject returned");
164 }
165 Ok(())
166 }
167
168 fn extract_quoted_text(&self) -> Vec<String> {
169 let mut result = Vec::new();
170
171 for input in &self.subject_strings {
172 let mut current_quote = None;
173 let mut start_index = 0;
174 for (i, c) in input.char_indices() {
175 match (current_quote, c) {
176 (None, '"') => {
177 current_quote = Some(c);
178 start_index = i + 1;
179 }
180 (Some(quote), c) if c == quote => {
181 result.push(input[start_index..i].to_string());
182 current_quote = None;
183 }
184 _ => {}
185 }
186 }
187 }
188 let mut cleaned = vec![];
189 for res in result.iter_mut() {
190 let new_res = res
191 .trim_start_matches(|c: char| !c.is_alphanumeric())
192 .trim_end_matches(|c: char| !c.is_alphanumeric())
193 .to_lowercase()
194 .split_whitespace()
195 .filter(|word| word.len() > 1 && *word != "the")
196 .collect::<Vec<_>>()
197 .join(" ")
198 .to_owned();
199 cleaned.push(new_res);
200 }
201 cleaned.sort();
202 cleaned.dedup();
203 let lower_content = self.content.to_lowercase();
204 cleaned.retain(|x| lower_content.contains(x));
205
206 cleaned
207 }
208}
209
210impl std::fmt::Display for ClassifySubjectOfText {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 writeln!(f)?;
213 writeln!(f, "ClassifySubjectOfText:")?;
214 crate::i_nln(f, format_args!("content: \"{}\"", self.content))?;
215 crate::i_nln(f, format_args!("subject: {:?}", self.subject))?;
216 crate::i_nln(f, format_args!("duration: {:?}", self.flow.duration))?;
217 Ok(())
218 }
219}