alith_client/workflows/nlp/extract/
urls.rs

1use crate::{
2    components::{
3        cascade::{CascadeFlow, step::StepConfig},
4        instruct_prompt::{InstructPrompt, InstructPromptTrait},
5    },
6    primitives::*,
7};
8use alith_interface::requests::{
9    completion::CompletionRequest,
10    req_components::{RequestConfig, RequestConfigTrait},
11};
12use anyhow::Result;
13use linkify::{LinkFinder, LinkKind};
14use std::{collections::HashSet, str::FromStr};
15use url::Url;
16
17pub fn extract_urls<T: AsRef<str>>(input: T) -> Vec<Url> {
18    let mut unique_urls = HashSet::new();
19
20    LinkFinder::new()
21        .kinds(&[LinkKind::Url])
22        .links(input.as_ref())
23        .filter_map(|link| Url::from_str(link.as_str()).ok())
24        .filter(|url| unique_urls.insert(url.clone()))
25        .collect()
26}
27
28#[derive(Clone)]
29pub struct ExtractUrls {
30    pub base_req: CompletionRequest,
31    pub instruct_prompt: InstructPrompt,
32    pub criteria: Option<String>,
33    pub results: Vec<String>,
34    pub max_urls: usize,
35}
36
37impl ExtractUrls {
38    pub fn new(base_req: CompletionRequest) -> Self {
39        ExtractUrls {
40            instruct_prompt: InstructPrompt::new(),
41            base_req,
42            criteria: None,
43            results: Vec::new(),
44            max_urls: 5,
45        }
46    }
47
48    pub fn max_urls(mut self, max_urls: usize) -> Self {
49        self.max_urls = max_urls;
50        self
51    }
52
53    pub async fn run_return_urls(&mut self) -> Result<Option<Vec<Url>>> {
54        Ok(self.run_return_result().await?.results)
55    }
56
57    pub async fn run_return_result(&mut self) -> Result<ExtractUrlResult> {
58        let flow = self.run_backend().await?;
59        if self.results.is_empty() {
60            Ok(ExtractUrlResult::new(
61                flow,
62                None,
63                self.criteria.as_ref().unwrap(),
64            ))
65        } else {
66            Ok(ExtractUrlResult::new(
67                flow,
68                Some(
69                    self.results
70                        .iter()
71                        .map(|url| Url::parse(url).unwrap())
72                        .collect(),
73                ),
74                self.criteria.as_ref().unwrap(),
75            ))
76        }
77    }
78
79    async fn run_backend(&mut self) -> Result<CascadeFlow> {
80        let mut primitive = ExactStringPrimitive::default();
81
82        let mut urls_from_instructions: Vec<Url> = Vec::new();
83        if let Some(instructions) = self.instruct_prompt.build_instructions() {
84            urls_from_instructions.extend(extract_urls(instructions));
85        }
86        if let Some(supporting_material) = self.instruct_prompt.build_supporting_material() {
87            urls_from_instructions.extend(extract_urls(supporting_material));
88        }
89        if urls_from_instructions.is_empty() {
90            return Err(anyhow::anyhow!("No URLs found in the instructions"));
91        }
92
93        primitive.add_strings_to_allowed(&urls_from_instructions);
94
95        let mut flow = self.set_criteria().await?;
96
97        self.run_cascade(&mut flow, &mut primitive).await?;
98        Ok(flow)
99    }
100
101    async fn set_criteria(&mut self) -> Result<CascadeFlow> {
102        let mut flow = CascadeFlow::new("ExtractUrls");
103        flow.open_cascade();
104        flow.new_round(
105            "We are extracting URLs from text. Please provide examples of extracting URLs with the instructions: 'Which of these URLs are commonly used in webdev tutorials?'").add_guidance_step(
106            &StepConfig::default(),
107            "`https://www.example.com is commonly used in webdev tutorials: true.` In this example, the URL satisfies the criteria: 'is commonly used in webdev tutorials.' Therefore, the URL should be extracted from the text.\n`https://www.zombo.com is commonly used in webdev tutorials: false.`. In this example, the URL does not satisfy the criteria: 'is commonly used in webdev tutorials.' Therefore, the URL should not be extracted from the text.",
108        );
109        flow.last_round()?.run_all_steps(&mut self.base_req).await?;
110
111        let initial_qualities_task = format!(
112            "We are extracting URLs from text using the instructions:\n{} Briefly describe the criteria of the URLs to be extracted.",
113            self.instruct_prompt.build_instructions().unwrap()
114        );
115        let config = StepConfig {
116            step_prefix: Some("Criteria: ".to_owned()),
117            grammar: TextPrimitive::default().text_token_length(200).grammar(),
118            ..StepConfig::default()
119        };
120        flow.new_round(initial_qualities_task)
121            .add_inference_step(&config);
122        flow.last_round()?.run_all_steps(&mut self.base_req).await?;
123
124        let refine_criteria_task = format!(
125            "Reframe the instructions and criteria into a statment used to evaluate if a URL should be extracted. This statement should have a boolean answer. The answer should represent whether or not the URL satisfies the criteria. This should be a single sentence 'is' statment; as in, 'The URL is <criteria>: true or false'.\nCriteria:\n{}\nInstructions:\n{}",
126            flow.primitive_result().unwrap(),
127            self.instruct_prompt.build_instructions().unwrap()
128        );
129        let config = StepConfig {
130            step_prefix: Some("The URL is ".to_owned()),
131            stop_word_done: ": true or false".to_owned(),
132            grammar: TextPrimitive::default().text_token_length(200).grammar(),
133            ..StepConfig::default()
134        };
135        flow.new_round(refine_criteria_task)
136            .add_inference_step(&config);
137        flow.last_round()?.run_all_steps(&mut self.base_req).await?;
138        self.criteria = Some(flow.primitive_result().unwrap());
139        Ok(flow)
140    }
141
142    async fn extract_step(
143        &mut self,
144        flow: &mut CascadeFlow,
145        primitive: &mut ExactStringPrimitive,
146    ) -> Result<()> {
147        let config = StepConfig {
148            cache_prompt: true,
149            stop_word_no_result: Some("No qualifying URLs.".to_owned()),
150            grammar: primitive.grammar(),
151            ..StepConfig::default()
152        };
153
154        flow.last_round()?.add_inference_step(&config);
155        flow.last_round()?.run_next_step(&mut self.base_req).await
156    }
157
158    async fn validate_step(&mut self, flow: &mut CascadeFlow) -> Result<bool> {
159        let config = StepConfig {
160            cache_prompt: true,
161            step_prefix: Some(format!(" is {}: ", self.criteria.as_ref().unwrap())),
162            grammar: BooleanPrimitive::default().grammar(),
163            ..StepConfig::default()
164        };
165
166        flow.last_round()?.add_inference_step(&config);
167        flow.last_round()?.run_next_step(&mut self.base_req).await?;
168        if flow.primitive_result().unwrap().parse().unwrap() {
169            Ok(true)
170        } else {
171            flow.last_round()?
172                .last_step()?
173                .set_dynamic_suffix(". I apologize. This URL does not meet the criteria and was returned by mistake. In the future, we'll only return URLs that satisfy the criteria.\n".to_owned());
174            Ok(false)
175        }
176    }
177
178    async fn check_for_remaining(
179        &mut self,
180        flow: &mut CascadeFlow,
181        primitive: &mut ExactStringPrimitive,
182    ) -> Result<bool> {
183        let remaining_urls = primitive.allowed_strings.join(", ");
184        let config = StepConfig {
185            cache_prompt: true,
186            step_prefix: Some(format!(
187                "At least one of the remaining URLs, {remaining_urls}, is {}: ",
188                self.criteria.as_ref().unwrap()
189            )),
190            grammar: BooleanPrimitive::default().grammar(),
191            ..StepConfig::default()
192        };
193
194        flow.last_round()?.add_inference_step(&config);
195        flow.last_round()?.run_next_step(&mut self.base_req).await?;
196        flow.last_round()?
197            .last_step()?
198            .set_dynamic_suffix(".\n".to_owned());
199        if flow.primitive_result().unwrap().parse().unwrap() {
200            Ok(true)
201        } else {
202            Ok(false)
203        }
204    }
205
206    async fn run_cascade(
207        &mut self,
208        flow: &mut CascadeFlow,
209        primitive: &mut ExactStringPrimitive,
210    ) -> Result<()> {
211        let task = format!(
212            "Text with URLs to extract:\n{}\nReturn the URL that is most likely relevant to the criteria. If you are certain the text contains no qualifying URLs say 'No qualifying URLs.'.\nCriteria:\n This URL is {}.",
213            self.instruct_prompt.build_supporting_material().unwrap(),
214            self.criteria.as_ref().unwrap()
215        );
216        flow.new_round(task).step_separator = None;
217        flow.last_round()?.open_round(&mut self.base_req)?;
218        for i in 1..=primitive.allowed_strings.len() {
219            if self.results.len() >= self.max_urls {
220                break;
221            }
222            if i > 1 {
223                flow.new_round("Return the next URL that is likely to satisfy the criteria, or if there are no more URLs to extract say 'No qualifying URLs.'.").step_separator = None;
224                flow.last_round()?.open_round(&mut self.base_req)?;
225            }
226            self.extract_step(flow, primitive).await?;
227            match flow.primitive_result() {
228                Some(url_result) => {
229                    primitive.remove_string_from_allowed(&url_result);
230                    if self.validate_step(flow).await? {
231                        self.results.push(url_result);
232                    } else if !self.check_for_remaining(flow, primitive).await? {
233                        flow.last_round()?.close_round(&mut self.base_req)?;
234                        break;
235                    }
236                    flow.last_round()?.close_round(&mut self.base_req)?;
237                }
238                None => {
239                    flow.last_round()?.close_round(&mut self.base_req)?;
240                    break;
241                }
242            };
243        }
244
245        flow.close_cascade()?;
246        Ok(())
247    }
248}
249
250impl RequestConfigTrait for ExtractUrls {
251    fn config(&mut self) -> &mut RequestConfig {
252        &mut self.base_req.config
253    }
254
255    fn reset_request(&mut self) {
256        self.instruct_prompt.reset_instruct_prompt();
257        self.base_req.reset_completion_request();
258    }
259}
260
261impl InstructPromptTrait for ExtractUrls {
262    fn instruct_prompt_mut(&mut self) -> &mut InstructPrompt {
263        &mut self.instruct_prompt
264    }
265}
266
267#[derive(Clone)]
268pub struct ExtractUrlResult {
269    pub results: Option<Vec<Url>>,
270    pub criteria: String,
271    pub duration: std::time::Duration,
272    pub workflow: CascadeFlow,
273}
274
275impl ExtractUrlResult {
276    fn new(flow: CascadeFlow, results: Option<Vec<Url>>, criteria: &str) -> Self {
277        ExtractUrlResult {
278            results,
279            criteria: criteria.to_owned(),
280            duration: flow.duration,
281            workflow: flow,
282        }
283    }
284}
285
286impl std::fmt::Display for ExtractUrlResult {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        writeln!(f)?;
289        writeln!(
290            f,
291            "\x1b[38;5;45m\x1b[1m{}\x1b[0m",
292            self.workflow.cascade_name
293        )?;
294        writeln!(f)?;
295        for (i, round) in self.workflow.rounds.iter().enumerate() {
296            writeln!(f, "\x1b[38;5;44mRound {}\x1b[0m", i + 1)?;
297            writeln!(f, "{round}",)?;
298        }
299        writeln!(f, "\x1b[38;5;42mcriteria\x1b[0m: {:?}", self.criteria)?;
300        writeln!(f, "\x1b[38;5;43mduration\x1b[0m: {:?}", self.duration)?;
301        Ok(())
302    }
303}