llm_chain/agents/
self_ask_with_search.rs

1use crate::{
2    options::Options,
3    parameters,
4    prompt::{PromptTemplate, StringTemplateError},
5    tools::{Tool, ToolError},
6    traits::{Executor, ExecutorError},
7    Parameters,
8};
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12/// TODO: This prompt has some issues:
13///
14/// - models do not always format their output correctly, e.x. respond with "So the final answer could be: ..." instead of "So the final answer is: ..."
15/// - some models have safety measures against asking about events which are in the future (from the point of view of the model); they will not even attempt to use the search tool
16/// - models sometimes finish on "Intermediate answer: ..." if it contains the final answer to the question
17/// - models sometimes immediately answer with "Yes, ..." or "No, ..."; they should always structure their final answer with "So the final answer is: ..." (or equivalent)
18const PROMPT: &str = "Question: Who lived longer, Muhammad Ali or Alan Turing?
19Are follow up questions needed here: Yes.
20Follow up: How old was Muhammad Ali when he died?
21Intermediate answer: Muhammad Ali was 74 years old when he died.
22Follow up: How old was Alan Turing when he died?
23Intermediate answer: Alan Turing was 41 years old when he died.
24So the final answer is: Muhammad Ali
25
26Question: When was the founder of craigslist born?
27Are follow up questions needed here: Yes.
28Follow up: Who was the founder of craigslist?
29Intermediate answer: Craigslist was founded by Craig Newmark.
30Follow up: When was Craig Newmark born?
31Intermediate answer: Craig Newmark was born on December 6, 1952.
32So the final answer is: December 6, 1952
33
34Question: Who was the maternal grandfather of George Washington?
35Are follow up questions needed here: Yes.
36Follow up: Who was the mother of George Washington?
37Intermediate answer: The mother of George Washington was Mary Ball Washington.
38Follow up: Who was the father of Mary Ball Washington?
39Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
40So the final answer is: Joseph Ball
41
42Question: Are both the directors of Jaws and Casino Royale from the same country?
43Are follow up questions needed here: Yes.
44Follow up: Who is the director of Jaws?
45Intermediate answer: The director of Jaws is Steven Spielberg.
46Follow up: Where is Steven Spielberg from?
47Intermediate answer: The United States.
48Follow up: Who is the director of Casino Royale?
49Intermediate answer: The director of Casino Royale is Martin Campbell.
50Follow up: Where is Martin Campbell from?
51Intermediate answer: New Zealand.
52So the final answer is: No
53
54Question: {{input}}
55Are followup questions needed here:{{agent_scratchpad}}";
56
57#[derive(Debug, PartialEq, Eq)]
58pub struct AgentAction {
59    pub tool: String,
60    pub tool_input: serde_yaml::Value,
61    pub log: String,
62}
63#[derive(Debug, PartialEq)]
64pub struct AgentFinish {
65    pub return_values: Parameters,
66    pub log: String,
67}
68
69#[derive(Debug)]
70pub struct AgentIntermediateStep {
71    pub action: AgentAction,
72    pub observation: serde_yaml::Value,
73}
74
75pub enum AgentIntermediateStepOutput {
76    Step(AgentIntermediateStep),
77    Finish(AgentFinish),
78}
79
80#[derive(Debug, PartialEq)]
81pub enum AgentDecision {
82    Action(AgentAction),
83    Finish(AgentFinish),
84}
85pub trait AgentOutputParser {
86    type Error;
87    fn parse(&self, text: String) -> Result<AgentDecision, Self::Error>;
88}
89
90#[derive(Debug, Error)]
91pub enum SelfAskWithSearchAgentError<T>
92where
93    T: std::fmt::Debug + std::error::Error + ToolError,
94{
95    #[error("Search tool input yaml was not of type string: {0:?}")]
96    ToolInputNotString(serde_yaml::Value),
97    #[error(transparent)]
98    SearchToolError(T),
99    #[error(transparent)]
100    ExecutorError(ExecutorError),
101    #[error(transparent)]
102    ParserError(#[from] ParserError),
103    #[error(transparent)]
104    YamlError(#[from] serde_yaml::Error),
105    #[error(transparent)]
106    StringTemplateError(#[from] StringTemplateError),
107    #[error("Model response was empty or contained no choices")]
108    NoChoicesReturned,
109    #[error("Max number of iterations or timeout exceeded. Elapsed: {time_elapsed_seconds}s, {iterations_elapsed} iterations")]
110    RuntimeExceeded {
111        time_elapsed_seconds: f64,
112        iterations_elapsed: u32,
113    },
114}
115
116pub struct SelfAskWithSearchAgentOutputParser {
117    followup_prefix: String,
118    intermediate_answer_prefix: String,
119    acceptable_finish_prefixes: Vec<String>,
120}
121
122impl SelfAskWithSearchAgentOutputParser {
123    pub fn new(
124        followup_prefix: &str,
125        intermediate_answer_prefix: &str,
126        acceptable_finish_prefixes: &[&str],
127    ) -> Self {
128        Self {
129            followup_prefix: followup_prefix.into(),
130            intermediate_answer_prefix: intermediate_answer_prefix.into(),
131            acceptable_finish_prefixes: acceptable_finish_prefixes
132                .iter()
133                .map(|s| s.to_string())
134                .collect(),
135        }
136    }
137}
138
139impl Default for SelfAskWithSearchAgentOutputParser {
140    fn default() -> Self {
141        Self::new(
142            "Follow up:",
143            "Intermediate Answer:",
144            &[
145                "Final answer:",
146                "So the final answer is:",
147                "So the final answer could be:",
148            ],
149        )
150    }
151}
152
153#[derive(Debug, Error)]
154#[error("No finish line or follow up question was returned by the model: {0}")]
155pub struct ParserError(String);
156
157impl AgentOutputParser for SelfAskWithSearchAgentOutputParser {
158    type Error = ParserError;
159    fn parse(&self, text: String) -> Result<AgentDecision, Self::Error> {
160        if let Some(followup_idx) = text.find(&self.followup_prefix) {
161            let (followup_question, log) = if let Some(intermediate_answer_idx) =
162                text.find(&self.intermediate_answer_prefix)
163            {
164                let followup_question = text
165                    .chars()
166                    .skip(followup_idx + self.followup_prefix.len())
167                    .take(intermediate_answer_idx - (followup_idx + self.followup_prefix.len()))
168                    .collect::<String>()
169                    .trim()
170                    .to_owned();
171
172                let log = text.chars().take(intermediate_answer_idx).collect();
173                (followup_question, log)
174            } else {
175                let followup_question = text
176                    .chars()
177                    .skip(followup_idx + self.followup_prefix.len())
178                    .take_while(|&c| c != '\n')
179                    .collect::<String>()
180                    .trim()
181                    .to_owned();
182
183                let log = text
184                    .char_indices()
185                    .map_while(|(idx, c)| {
186                        if c != '\n' || idx < followup_idx {
187                            Some(c)
188                        } else {
189                            None
190                        }
191                    })
192                    .collect();
193                (followup_question, log)
194            };
195            Ok(AgentDecision::Action(AgentAction {
196                tool: "Intermediate Answer".into(),
197                tool_input: followup_question.into(),
198                log,
199            }))
200        } else if let Some((idx, prefix)) = self
201            .acceptable_finish_prefixes
202            .iter()
203            .find_map(|prefix| text.find(prefix).map(|idx| (idx, prefix)))
204        {
205            let final_answer = text.chars().skip(idx + prefix.len()).collect::<String>();
206            Ok(AgentDecision::Finish(AgentFinish {
207                return_values: parameters!("output" => final_answer.trim()),
208                log: text,
209            }))
210        } else {
211            Err(ParserError(text))
212        }
213    }
214}
215
216#[derive(Default)]
217pub struct EarlyStoppingConfig {
218    pub max_iterations: Option<u32>,
219    pub max_time_elapsed_seconds: Option<f64>,
220}
221
222pub struct Agent<E, T>
223where
224    E: Executor,
225    T: Tool,
226    T::Input: From<String>,
227    T::Output: Into<String>,
228{
229    executor: E,
230    search_tool: T,
231    early_stopping_config: EarlyStoppingConfig,
232    observation_prefix: String,
233    llm_prefix: String,
234    output_parser: SelfAskWithSearchAgentOutputParser,
235}
236
237impl<E, T> Agent<E, T>
238where
239    E: Executor,
240    T: Tool,
241    T::Input: From<String>,
242    T::Output: Into<String>,
243{
244    pub fn new(executor: E, search_tool: T, early_stopping_config: EarlyStoppingConfig) -> Self {
245        Self {
246            executor,
247            search_tool,
248            early_stopping_config,
249            observation_prefix: "Intermediate answer: ".to_string(),
250            llm_prefix: "".to_string(),
251            output_parser: SelfAskWithSearchAgentOutputParser::default(),
252        }
253    }
254
255    fn should_continue(&self, iterations_elapsed: u32, time_elapsed_seconds: f64) -> bool {
256        match (
257            self.early_stopping_config.max_iterations,
258            self.early_stopping_config.max_time_elapsed_seconds,
259        ) {
260            (None, None) => true,
261            (None, Some(max_time_elapsed_seconds)) => {
262                max_time_elapsed_seconds >= time_elapsed_seconds
263            }
264            (Some(max_iterations), None) => max_iterations >= iterations_elapsed,
265            (Some(max_iterations), Some(max_time_elapsed_seconds)) => {
266                max_iterations >= iterations_elapsed
267                    && max_time_elapsed_seconds >= time_elapsed_seconds
268            }
269        }
270    }
271
272    /// Ask a model for a decision on what to do next, e.x. which tool to use
273    ///
274    /// Perform the action
275    async fn take_next_step(
276        &self,
277        intermediate_steps: &Vec<AgentIntermediateStep>,
278        query: &str,
279    ) -> Result<AgentIntermediateStepOutput, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
280        let output = self.plan(intermediate_steps, query).await?;
281
282        let decision = self.output_parser.parse(output)?;
283        match decision {
284            AgentDecision::Action(action) => {
285                let observation = self
286                    .search_tool
287                    .invoke_typed(
288                        &action
289                            .tool_input
290                            .as_str()
291                            .ok_or(SelfAskWithSearchAgentError::ToolInputNotString(
292                                action.tool_input.clone(),
293                            ))?
294                            .to_string()
295                            .into(),
296                    )
297                    .await
298                    .map_err(SelfAskWithSearchAgentError::SearchToolError)?;
299
300                Ok(AgentIntermediateStepOutput::Step(AgentIntermediateStep {
301                    action,
302                    observation: serde_yaml::to_value(Into::<String>::into(observation))?,
303                }))
304            }
305            AgentDecision::Finish(finish) => Ok(AgentIntermediateStepOutput::Finish(finish)),
306        }
307    }
308
309    /// Convert the intermediate steps into a single text to pass to the agent so he can continue his thought process
310    pub fn build_agent_scratchpad(
311        &self,
312        intermediate_steps: &Vec<AgentIntermediateStep>,
313    ) -> String {
314        let mut scratchpad = "".to_string();
315        for intermediate_step in intermediate_steps {
316            scratchpad += &intermediate_step.action.log;
317            scratchpad += &format!(
318                "\n{}{}\n{}",
319                self.observation_prefix,
320                intermediate_step.observation.as_str().unwrap_or_default(),
321                self.llm_prefix
322            );
323        }
324        scratchpad
325    }
326
327    /// Ask a model for a decision on what to do next, e.x. which tool to use
328    ///
329    /// Fills in the prompt template then calls the model to complete it
330    async fn plan(
331        &self,
332        intermediate_steps: &Vec<AgentIntermediateStep>,
333        query: &str,
334    ) -> Result<String, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
335        let scratchpad = self.build_agent_scratchpad(intermediate_steps);
336        let template_parameters = parameters!("input" => query, "agent_scratchpad" => scratchpad);
337        let prompt = PromptTemplate::Text(PROMPT.into()).format(&template_parameters)?;
338        let plan = self
339            .executor
340            .execute(Options::empty(), &prompt)
341            .await
342            .map_err(SelfAskWithSearchAgentError::ExecutorError)?;
343        plan.to_immediate()
344            .await
345            .map_err(SelfAskWithSearchAgentError::ExecutorError)?
346            .as_content()
347            .extract_last_body()
348            .cloned()
349            .ok_or(SelfAskWithSearchAgentError::NoChoicesReturned)
350    }
351
352    pub async fn run(
353        &self,
354        query: &str,
355    ) -> Result<
356        (AgentFinish, Vec<AgentIntermediateStep>),
357        SelfAskWithSearchAgentError<<T as Tool>::Error>,
358    > {
359        let mut intermediate_steps = vec![];
360
361        let mut iterations = 0;
362        let start = Instant::now();
363        let mut full_duration = Duration::from_nanos(0);
364        while self.should_continue(iterations, full_duration.as_secs_f64()) {
365            let decision = self.take_next_step(&intermediate_steps, query).await?;
366            full_duration = start.elapsed();
367            iterations += 1;
368            match decision {
369                AgentIntermediateStepOutput::Step(step) => intermediate_steps.push(step),
370                AgentIntermediateStepOutput::Finish(finish) => {
371                    return Ok((finish, intermediate_steps))
372                }
373            }
374        }
375        Err(SelfAskWithSearchAgentError::RuntimeExceeded {
376            time_elapsed_seconds: full_duration.as_secs_f64(),
377            iterations_elapsed: iterations,
378        })
379    }
380}
381
382#[cfg(test)]
383mod tests {
384
385    use async_trait::async_trait;
386
387    use thiserror::Error;
388
389    use crate::{
390        agents::self_ask_with_search::{AgentIntermediateStep, EarlyStoppingConfig},
391        options::Options,
392        output::Output,
393        parameters,
394        prompt::Prompt,
395        tokens::{TokenCollection, Tokenizer},
396        tools::{Tool, ToolError},
397        traits::{Executor, ExecutorError},
398    };
399
400    use super::{
401        Agent, AgentAction, AgentDecision, AgentFinish, AgentOutputParser,
402        SelfAskWithSearchAgentOutputParser,
403    };
404
405    #[test]
406    fn test_parses_followup() {
407        let parser = SelfAskWithSearchAgentOutputParser::default();
408        let text = "
409        Whatever
410        Whatever
411        Follow up: my follow up question abc?";
412        let decision = parser.parse(text.into()).unwrap();
413        assert_eq!(
414            decision,
415            AgentDecision::Action(AgentAction {
416                tool: "Intermediate Answer".into(),
417                tool_input: "my follow up question abc?".into(),
418                log: text.into()
419            })
420        );
421    }
422
423    #[test]
424    fn test_parses_follow_up_trims_trailing_whitespace() {
425        let parser = SelfAskWithSearchAgentOutputParser::default();
426        let text = "
427        Whatever
428        Whatever
429        Follow up: my follow up question abc?
430        ";
431        let decision = parser.parse(text.into()).unwrap();
432        assert_eq!(
433            decision,
434            AgentDecision::Action(AgentAction {
435                tool: "Intermediate Answer".into(),
436                tool_input: "my follow up question abc?".into(),
437                log: text.trim_end().into()
438            })
439        );
440    }
441
442    #[test]
443    fn test_parses_final_answer() {
444        let parser = SelfAskWithSearchAgentOutputParser::default();
445        let text = "
446        Whatever
447        Whatever
448        So the final answer is: yes abc!";
449        let decision = parser.parse(text.into()).unwrap();
450        assert_eq!(
451            decision,
452            AgentDecision::Finish(AgentFinish {
453                return_values: parameters!("output" => "yes abc!"),
454                log: text.into()
455            })
456        );
457    }
458
459    #[test]
460    fn test_parses_final_answer_ignores_trailing_whitespace() {
461        let parser = SelfAskWithSearchAgentOutputParser::default();
462        let text = "
463        Whatever
464        Whatever
465        So the final answer is: yes abc!
466        ";
467        let decision = parser.parse(text.into()).unwrap();
468        assert_eq!(
469            decision,
470            AgentDecision::Finish(AgentFinish {
471                return_values: parameters!("output" => "yes abc!"),
472                log: text.into()
473            })
474        );
475    }
476
477    #[test]
478    fn test_parses_final_answer_with_colons() {
479        let parser = SelfAskWithSearchAgentOutputParser::default();
480        let text = "
481        Whatever
482        Whatever
483        So the final answer is: Mad Max: Fury road";
484        let decision = parser.parse(text.into()).unwrap();
485        assert_eq!(
486            decision,
487            AgentDecision::Finish(AgentFinish {
488                return_values: parameters!("output" => "Mad Max: Fury road"),
489                log: text.into()
490            })
491        );
492    }
493
494    #[test]
495    fn test_builds_agent_sratchpad() {
496        #[derive(Clone)]
497        struct MockOutput;
498
499        #[derive(Debug, Error)]
500        #[error("Mocked executor error")]
501        struct MockError;
502
503        impl ToolError for MockError {}
504
505        impl From<serde_yaml::Error> for MockError {
506            fn from(_: serde_yaml::Error) -> Self {
507                Self
508            }
509        }
510
511        struct MockTokenizer;
512
513        impl Tokenizer for MockTokenizer {
514            fn tokenize_str(
515                &self,
516                _: &str,
517            ) -> Result<TokenCollection, crate::tokens::TokenizerError> {
518                todo!()
519            }
520
521            fn to_string(
522                &self,
523                _: TokenCollection,
524            ) -> Result<String, crate::tokens::TokenizerError> {
525                todo!()
526            }
527        }
528
529        struct MockExecutor;
530
531        #[async_trait]
532        impl Executor for MockExecutor {
533            type StepTokenizer<'a> = MockTokenizer;
534
535            fn new_with_options(_: Options) -> Result<Self, crate::traits::ExecutorCreationError> {
536                todo!()
537            }
538
539            async fn execute(
540                &self,
541                _: &Options,
542                _: &crate::prompt::Prompt,
543            ) -> Result<Output, ExecutorError> {
544                todo!()
545            }
546
547            fn tokens_used(
548                &self,
549                _: &Options,
550                _: &crate::prompt::Prompt,
551            ) -> Result<crate::tokens::TokenCount, crate::tokens::PromptTokensError> {
552                todo!()
553            }
554
555            fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
556                todo!()
557            }
558
559            fn max_tokens_allowed(&self, _: &Options) -> i32 {
560                todo!()
561            }
562
563            fn get_tokenizer(
564                &self,
565                _: &Options,
566            ) -> Result<MockTokenizer, crate::tokens::TokenizerError> {
567                todo!()
568            }
569        }
570        struct MockSearch;
571
572        #[async_trait]
573        impl Tool for MockSearch {
574            type Input = String;
575
576            type Output = String;
577
578            type Error = MockError;
579
580            async fn invoke_typed(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> {
581                todo!()
582            }
583
584            fn description(&self) -> crate::tools::ToolDescription {
585                todo!()
586            }
587        }
588        let mock_executor = MockExecutor;
589        let mock_search = MockSearch;
590        let agent = Agent::new(
591            mock_executor,
592            mock_search,
593            EarlyStoppingConfig {
594                max_iterations: None,
595                max_time_elapsed_seconds: None,
596            },
597        );
598        let intermediate_steps = vec![
599            AgentIntermediateStep {
600                action: AgentAction {
601                    tool: "Intermediate Answer".into(),
602                    tool_input: "How old was Muhammad Ali when he died?".into(),
603                    log: "Yes.
604Follow up: How old was Muhammad Ali when he died?"
605                        .into(),
606                },
607                observation: "Muhammad Ali was 74 years old when he died.".into(),
608            },
609            AgentIntermediateStep {
610                action: AgentAction {
611                    tool: "Intermediate Answer".into(),
612                    tool_input: "How old was Alan Turing when he died?".into(),
613                    log: "Follow up: How old was Alan Turing when he died?".into(),
614                },
615                observation: "Alan Turing was 41 years old when he died.".into(),
616            },
617        ];
618
619        let expected_scratchpad = "Yes.
620Follow up: How old was Muhammad Ali when he died?
621Intermediate answer: Muhammad Ali was 74 years old when he died.
622Follow up: How old was Alan Turing when he died?
623Intermediate answer: Alan Turing was 41 years old when he died.\n";
624
625        let scratchpad = agent.build_agent_scratchpad(&intermediate_steps);
626
627        assert_eq!(scratchpad, expected_scratchpad);
628    }
629}