ai_chain/agents/
self_ask_with_search.rs

1/// Agent inspired by [self-ask](https://github.com/ofirpress/self-ask)
2///
3/// The prompt implemented from the paper is designed for GPT-3, therefore it might not work well
4/// with other models.
5///
6/// These are the limitations and inconsistencies of the prompt:
7/// - models do not always format their output correctly, e.x. respond with "So the final answer could be: ..." instead of "So the final answer is: ..."
8/// - some models have safety measures against asking about events which are in the future (from the point of view of the model); they will not even attempt to use the search tool
9/// - models sometimes finish on "Intermediate answer: ..." if it contains the final answer to the question
10/// - models sometimes immediately answer with "Yes, ..." or "No, ..."; they should always structure their final answer with "So the final answer is: ..." (or equivalent)
11use crate::{
12    options::Options,
13    parameters,
14    prompt::{PromptTemplate, StringTemplateError},
15    tools::{Tool, ToolError},
16    traits::{Executor, ExecutorError},
17    Parameters,
18};
19use std::time::{Duration, Instant};
20use thiserror::Error;
21
22/// This prompt is from the paper and is designed for GPT-3.
23/// See limitations above.
24const PROMPT: &str = "Question: Who lived longer, Muhammad Ali or Alan Turing?
25Are follow up questions needed here: Yes.
26Follow up: How old was Muhammad Ali when he died?
27Intermediate answer: Muhammad Ali was 74 years old when he died.
28Follow up: How old was Alan Turing when he died?
29Intermediate answer: Alan Turing was 41 years old when he died.
30So the final answer is: Muhammad Ali
31
32Question: When was the founder of craigslist born?
33Are follow up questions needed here: Yes.
34Follow up: Who was the founder of craigslist?
35Intermediate answer: Craigslist was founded by Craig Newmark.
36Follow up: When was Craig Newmark born?
37Intermediate answer: Craig Newmark was born on December 6, 1952.
38So the final answer is: December 6, 1952
39
40Question: Who was the maternal grandfather of George Washington?
41Are follow up questions needed here: Yes.
42Follow up: Who was the mother of George Washington?
43Intermediate answer: The mother of George Washington was Mary Ball Washington.
44Follow up: Who was the father of Mary Ball Washington?
45Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
46So the final answer is: Joseph Ball
47
48Question: Are both the directors of Jaws and Casino Royale from the same country?
49Are follow up questions needed here: Yes.
50Follow up: Who is the director of Jaws?
51Intermediate answer: The director of Jaws is Steven Spielberg.
52Follow up: Where is Steven Spielberg from?
53Intermediate answer: The United States.
54Follow up: Who is the director of Casino Royale?
55Intermediate answer: The director of Casino Royale is Martin Campbell.
56Follow up: Where is Martin Campbell from?
57Intermediate answer: New Zealand.
58So the final answer is: No
59
60Question: {{input}}
61Are followup questions needed here:{{agent_scratchpad}}";
62
63/// A struct representing the action the agent should take
64///
65/// This structure is heavily inspired from LangChain.
66#[derive(Debug, PartialEq, Eq)]
67pub struct AgentAction {
68    /// name of tool
69    pub tool: String,
70    /// input to pass to tool
71    pub tool_input: serde_yaml::Value,
72    /// Additional information to log about the action.
73    /// This log can be used in a few ways. First, it can be used to audit
74    /// what exactly the LLM predicted to lead to this (tool, tool_input).
75    /// Second, it can be used in future iterations to show the LLMs prior
76    /// thoughts. This is useful when (tool, tool_input) does not contain
77    /// full information about the LLM prediction (for example, any 'thought'
78    /// before the tool/tool_input).
79    pub log: String,
80}
81
82/// Final output of the agent
83///
84/// This structure is heavily inspired from LangChain.
85#[derive(Debug, PartialEq)]
86pub struct AgentFinish {
87    pub return_values: Parameters,
88
89    /// additional information for observability
90    /// This is used to pass along the full LLM prediction, not just the parsed out return value.
91    pub log: String,
92}
93
94#[derive(Debug)]
95pub struct AgentIntermediateStep {
96    pub action: AgentAction,
97    pub observation: serde_yaml::Value,
98}
99
100pub enum AgentIntermediateStepOutput {
101    Step(AgentIntermediateStep),
102    Finish(AgentFinish),
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AgentDecision {
107    Action(AgentAction),
108    Finish(AgentFinish),
109}
110pub trait AgentOutputParser {
111    type Error;
112    fn parse(&self, text: String) -> Result<AgentDecision, Self::Error>;
113}
114
115#[derive(Debug, Error)]
116pub enum SelfAskWithSearchAgentError<T>
117where
118    T: std::fmt::Debug + std::error::Error + ToolError,
119{
120    #[error("Search tool input yaml was not of type string: {0:?}")]
121    ToolInputNotString(serde_yaml::Value),
122    #[error(transparent)]
123    SearchToolError(T),
124    #[error(transparent)]
125    ExecutorError(ExecutorError),
126    #[error(transparent)]
127    ParserError(#[from] ParserError),
128    #[error(transparent)]
129    YamlError(#[from] serde_yaml::Error),
130    #[error(transparent)]
131    StringTemplateError(#[from] StringTemplateError),
132    #[error("Model response was empty or contained no choices")]
133    NoChoicesReturned,
134    #[error("Max number of iterations or timeout exceeded. Elapsed: {time_elapsed_seconds}s, {iterations_elapsed} iterations")]
135    RuntimeExceeded {
136        time_elapsed_seconds: f64,
137        iterations_elapsed: u32,
138    },
139}
140
141pub struct SelfAskWithSearchAgentOutputParser {
142    followup_prefix: String,
143    intermediate_answer_prefix: String,
144    acceptable_finish_prefixes: Vec<String>,
145}
146
147impl SelfAskWithSearchAgentOutputParser {
148    pub fn new(
149        followup_prefix: &str,
150        intermediate_answer_prefix: &str,
151        acceptable_finish_prefixes: &[&str],
152    ) -> Self {
153        Self {
154            followup_prefix: followup_prefix.into(),
155            intermediate_answer_prefix: intermediate_answer_prefix.into(),
156            acceptable_finish_prefixes: acceptable_finish_prefixes
157                .iter()
158                .map(|s| s.to_string())
159                .collect(),
160        }
161    }
162}
163
164impl Default for SelfAskWithSearchAgentOutputParser {
165    fn default() -> Self {
166        Self::new(
167            "Follow up:",
168            "Intermediate Answer:",
169            &[
170                "Final answer:",
171                "So the final answer is:",
172                "So the final answer could be:",
173            ],
174        )
175    }
176}
177
178#[derive(Debug, Error)]
179#[error("No finish line or follow up question was returned by the model: {0}")]
180pub struct ParserError(String);
181
182impl AgentOutputParser for SelfAskWithSearchAgentOutputParser {
183    type Error = ParserError;
184    fn parse(&self, text: String) -> Result<AgentDecision, Self::Error> {
185        // If there is a followup question, we need to extract it
186        if let Some(followup_idx) = text.find(&self.followup_prefix) {
187            // If there is an intermediate answer, extract it
188            let (followup_question, log) = if let Some(intermediate_answer_idx) =
189                text.find(&self.intermediate_answer_prefix)
190            {
191                let followup_question = text
192                    .chars()
193                    .skip(followup_idx + self.followup_prefix.len())
194                    .take(intermediate_answer_idx - (followup_idx + self.followup_prefix.len()))
195                    .collect::<String>()
196                    .trim()
197                    .to_owned();
198
199                let log = text.chars().take(intermediate_answer_idx).collect();
200                (followup_question, log)
201            } else {
202                // If there is no intermediate answer, extract the followup question
203                let followup_question = text
204                    .chars()
205                    .skip(followup_idx + self.followup_prefix.len())
206                    .take_while(|&c| c != '\n')
207                    .collect::<String>()
208                    .trim()
209                    .to_owned();
210
211                let log = text
212                    .char_indices()
213                    .map_while(|(idx, c)| {
214                        if c != '\n' || idx < followup_idx {
215                            Some(c)
216                        } else {
217                            None
218                        }
219                    })
220                    .collect();
221                (followup_question, log)
222            };
223            Ok(AgentDecision::Action(AgentAction {
224                tool: "Intermediate Answer".into(),
225                tool_input: followup_question.into(),
226                log,
227            }))
228        } else if let Some((idx, prefix)) = self
229            .acceptable_finish_prefixes
230            .iter()
231            .find_map(|prefix| text.find(prefix).map(|idx| (idx, prefix)))
232        {
233            let final_answer = text.chars().skip(idx + prefix.len()).collect::<String>();
234            Ok(AgentDecision::Finish(AgentFinish {
235                return_values: parameters!("output" => final_answer.trim()),
236                log: text,
237            }))
238        } else {
239            Err(ParserError(text))
240        }
241    }
242}
243
244#[derive(Default)]
245pub struct EarlyStoppingConfig {
246    pub max_iterations: Option<u32>,
247    pub max_time_elapsed_seconds: Option<f64>,
248}
249
250pub struct Agent<E, T>
251where
252    E: Executor,
253    T: Tool,
254    T::Input: From<String>,
255    T::Output: Into<String>,
256{
257    executor: E,
258    search_tool: T,
259    early_stopping_config: EarlyStoppingConfig,
260    observation_prefix: String,
261    llm_prefix: String,
262    output_parser: SelfAskWithSearchAgentOutputParser,
263}
264
265impl<E, T> Agent<E, T>
266where
267    E: Executor,
268    T: Tool,
269    T::Input: From<String>,
270    T::Output: Into<String>,
271{
272    pub fn new(executor: E, search_tool: T, early_stopping_config: EarlyStoppingConfig) -> Self {
273        Self {
274            executor,
275            search_tool,
276            early_stopping_config,
277            observation_prefix: "Intermediate answer: ".to_string(),
278            llm_prefix: "".to_string(),
279            output_parser: SelfAskWithSearchAgentOutputParser::default(),
280        }
281    }
282
283    fn should_continue(&self, iterations_elapsed: u32, time_elapsed_seconds: f64) -> bool {
284        match (
285            self.early_stopping_config.max_iterations,
286            self.early_stopping_config.max_time_elapsed_seconds,
287        ) {
288            (None, None) => true,
289            (None, Some(max_time_elapsed_seconds)) => {
290                max_time_elapsed_seconds >= time_elapsed_seconds
291            }
292            (Some(max_iterations), None) => max_iterations >= iterations_elapsed,
293            (Some(max_iterations), Some(max_time_elapsed_seconds)) => {
294                max_iterations >= iterations_elapsed
295                    && max_time_elapsed_seconds >= time_elapsed_seconds
296            }
297        }
298    }
299
300    /// Ask a model for a decision on what to do next, e.x. which tool to use
301    ///
302    /// Perform the action
303    async fn take_next_step(
304        &self,
305        intermediate_steps: &Vec<AgentIntermediateStep>,
306        query: &str,
307    ) -> Result<AgentIntermediateStepOutput, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
308        let output = self.plan(intermediate_steps, query).await?;
309
310        let decision = self.output_parser.parse(output)?;
311        match decision {
312            AgentDecision::Action(action) => {
313                let observation = self
314                    .search_tool
315                    .invoke_typed(
316                        &action
317                            .tool_input
318                            .as_str()
319                            .ok_or(SelfAskWithSearchAgentError::ToolInputNotString(
320                                action.tool_input.clone(),
321                            ))?
322                            .to_string()
323                            .into(),
324                    )
325                    .await
326                    .map_err(SelfAskWithSearchAgentError::SearchToolError)?;
327
328                Ok(AgentIntermediateStepOutput::Step(AgentIntermediateStep {
329                    action,
330                    observation: serde_yaml::to_value(Into::<String>::into(observation))?,
331                }))
332            }
333            AgentDecision::Finish(finish) => Ok(AgentIntermediateStepOutput::Finish(finish)),
334        }
335    }
336
337    /// Convert the intermediate steps into a single text to pass to the agent so he can continue his thought process
338    pub fn build_agent_scratchpad(
339        &self,
340        intermediate_steps: &Vec<AgentIntermediateStep>,
341    ) -> String {
342        let mut scratchpad = "".to_string();
343        for intermediate_step in intermediate_steps {
344            scratchpad += &intermediate_step.action.log;
345            scratchpad += &format!(
346                "\n{}{}\n{}",
347                self.observation_prefix,
348                intermediate_step.observation.as_str().unwrap_or_default(),
349                self.llm_prefix
350            );
351        }
352        scratchpad
353    }
354
355    /// Ask a model for a decision on what to do next, e.x. which tool to use
356    ///
357    /// Fills in the prompt template then calls the model to complete it
358    async fn plan(
359        &self,
360        intermediate_steps: &Vec<AgentIntermediateStep>,
361        query: &str,
362    ) -> Result<String, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
363        let scratchpad = self.build_agent_scratchpad(intermediate_steps);
364        let template_parameters = parameters!("input" => query, "agent_scratchpad" => scratchpad);
365        let prompt = PromptTemplate::Text(PROMPT.into()).format(&template_parameters)?;
366        let plan = self
367            .executor
368            .execute(Options::empty(), &prompt)
369            .await
370            .map_err(SelfAskWithSearchAgentError::ExecutorError)?;
371        plan.to_immediate()
372            .await
373            .map_err(SelfAskWithSearchAgentError::ExecutorError)?
374            .as_content()
375            .extract_last_body()
376            .cloned()
377            .ok_or(SelfAskWithSearchAgentError::NoChoicesReturned)
378    }
379
380    pub async fn run(
381        &self,
382        query: &str,
383    ) -> Result<
384        (AgentFinish, Vec<AgentIntermediateStep>),
385        SelfAskWithSearchAgentError<<T as Tool>::Error>,
386    > {
387        let mut intermediate_steps = vec![];
388
389        let mut iterations = 0;
390        let start = Instant::now();
391        let mut full_duration = Duration::from_nanos(0);
392        while self.should_continue(iterations, full_duration.as_secs_f64()) {
393            let decision = self.take_next_step(&intermediate_steps, query).await?;
394            full_duration = start.elapsed();
395            iterations += 1;
396            match decision {
397                AgentIntermediateStepOutput::Step(step) => intermediate_steps.push(step),
398                AgentIntermediateStepOutput::Finish(finish) => {
399                    return Ok((finish, intermediate_steps))
400                }
401            }
402        }
403        Err(SelfAskWithSearchAgentError::RuntimeExceeded {
404            time_elapsed_seconds: full_duration.as_secs_f64(),
405            iterations_elapsed: iterations,
406        })
407    }
408}
409
410#[cfg(test)]
411mod tests {
412
413    use async_trait::async_trait;
414
415    use thiserror::Error;
416
417    use crate::{
418        agents::self_ask_with_search::{AgentIntermediateStep, EarlyStoppingConfig},
419        options::Options,
420        output::Output,
421        parameters,
422        prompt::Prompt,
423        tokens::{TokenCollection, Tokenizer},
424        tools::{Tool, ToolError},
425        traits::{Executor, ExecutorError},
426    };
427
428    use super::{
429        Agent, AgentAction, AgentDecision, AgentFinish, AgentOutputParser,
430        SelfAskWithSearchAgentOutputParser,
431    };
432
433    #[test]
434    fn test_parses_followup() {
435        let parser = SelfAskWithSearchAgentOutputParser::default();
436        let text = "
437        Whatever
438        Whatever
439        Follow up: my follow up question abc?";
440        let decision = parser.parse(text.into()).unwrap();
441        assert_eq!(
442            decision,
443            AgentDecision::Action(AgentAction {
444                tool: "Intermediate Answer".into(),
445                tool_input: "my follow up question abc?".into(),
446                log: text.into()
447            })
448        );
449    }
450
451    #[test]
452    fn test_parses_follow_up_trims_trailing_whitespace() {
453        let parser = SelfAskWithSearchAgentOutputParser::default();
454        let text = "
455        Whatever
456        Whatever
457        Follow up: my follow up question abc?
458        ";
459        let decision = parser.parse(text.into()).unwrap();
460        assert_eq!(
461            decision,
462            AgentDecision::Action(AgentAction {
463                tool: "Intermediate Answer".into(),
464                tool_input: "my follow up question abc?".into(),
465                log: text.trim_end().into()
466            })
467        );
468    }
469
470    #[test]
471    fn test_parses_final_answer() {
472        let parser = SelfAskWithSearchAgentOutputParser::default();
473        let text = "
474        Whatever
475        Whatever
476        So the final answer is: yes abc!";
477        let decision = parser.parse(text.into()).unwrap();
478        assert_eq!(
479            decision,
480            AgentDecision::Finish(AgentFinish {
481                return_values: parameters!("output" => "yes abc!"),
482                log: text.into()
483            })
484        );
485    }
486
487    #[test]
488    fn test_parses_final_answer_ignores_trailing_whitespace() {
489        let parser = SelfAskWithSearchAgentOutputParser::default();
490        let text = "
491        Whatever
492        Whatever
493        So the final answer is: yes abc!
494        ";
495        let decision = parser.parse(text.into()).unwrap();
496        assert_eq!(
497            decision,
498            AgentDecision::Finish(AgentFinish {
499                return_values: parameters!("output" => "yes abc!"),
500                log: text.into()
501            })
502        );
503    }
504
505    #[test]
506    fn test_parses_final_answer_with_colons() {
507        let parser = SelfAskWithSearchAgentOutputParser::default();
508        let text = "
509        Whatever
510        Whatever
511        So the final answer is: Mad Max: Fury road";
512        let decision = parser.parse(text.into()).unwrap();
513        assert_eq!(
514            decision,
515            AgentDecision::Finish(AgentFinish {
516                return_values: parameters!("output" => "Mad Max: Fury road"),
517                log: text.into()
518            })
519        );
520    }
521
522    #[test]
523    fn test_builds_agent_sratchpad() {
524        #[derive(Clone)]
525        struct MockOutput;
526
527        #[derive(Debug, Error)]
528        #[error("Mocked executor error")]
529        struct MockError;
530
531        impl ToolError for MockError {}
532
533        impl From<serde_yaml::Error> for MockError {
534            fn from(_: serde_yaml::Error) -> Self {
535                Self
536            }
537        }
538
539        struct MockTokenizer;
540
541        impl Tokenizer for MockTokenizer {
542            fn tokenize_str(
543                &self,
544                _: &str,
545            ) -> Result<TokenCollection, crate::tokens::TokenizerError> {
546                todo!()
547            }
548
549            fn to_string(
550                &self,
551                _: TokenCollection,
552            ) -> Result<String, crate::tokens::TokenizerError> {
553                todo!()
554            }
555        }
556
557        struct MockExecutor;
558
559        #[async_trait]
560        impl Executor for MockExecutor {
561            type StepTokenizer<'a> = MockTokenizer;
562
563            fn new_with_options(_: Options) -> Result<Self, crate::traits::ExecutorCreationError> {
564                todo!()
565            }
566
567            async fn execute(
568                &self,
569                _: &Options,
570                _: &crate::prompt::Prompt,
571            ) -> Result<Output, ExecutorError> {
572                todo!()
573            }
574
575            fn tokens_used(
576                &self,
577                _: &Options,
578                _: &crate::prompt::Prompt,
579            ) -> Result<crate::tokens::TokenCount, crate::tokens::PromptTokensError> {
580                todo!()
581            }
582
583            fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
584                todo!()
585            }
586
587            fn max_tokens_allowed(&self, _: &Options) -> i32 {
588                todo!()
589            }
590
591            fn get_tokenizer(
592                &self,
593                _: &Options,
594            ) -> Result<MockTokenizer, crate::tokens::TokenizerError> {
595                todo!()
596            }
597        }
598        struct MockSearch;
599
600        #[async_trait]
601        impl Tool for MockSearch {
602            type Input = String;
603
604            type Output = String;
605
606            type Error = MockError;
607
608            async fn invoke_typed(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> {
609                todo!()
610            }
611
612            fn description(&self) -> crate::tools::ToolDescription {
613                todo!()
614            }
615        }
616        let mock_executor = MockExecutor;
617        let mock_search = MockSearch;
618        let agent = Agent::new(
619            mock_executor,
620            mock_search,
621            EarlyStoppingConfig {
622                max_iterations: None,
623                max_time_elapsed_seconds: None,
624            },
625        );
626        let intermediate_steps = vec![
627            AgentIntermediateStep {
628                action: AgentAction {
629                    tool: "Intermediate Answer".into(),
630                    tool_input: "How old was Muhammad Ali when he died?".into(),
631                    log: "Yes.
632Follow up: How old was Muhammad Ali when he died?"
633                        .into(),
634                },
635                observation: "Muhammad Ali was 74 years old when he died.".into(),
636            },
637            AgentIntermediateStep {
638                action: AgentAction {
639                    tool: "Intermediate Answer".into(),
640                    tool_input: "How old was Alan Turing when he died?".into(),
641                    log: "Follow up: How old was Alan Turing when he died?".into(),
642                },
643                observation: "Alan Turing was 41 years old when he died.".into(),
644            },
645        ];
646
647        let expected_scratchpad = "Yes.
648Follow up: How old was Muhammad Ali when he died?
649Intermediate answer: Muhammad Ali was 74 years old when he died.
650Follow up: How old was Alan Turing when he died?
651Intermediate answer: Alan Turing was 41 years old when he died.\n";
652
653        let scratchpad = agent.build_agent_scratchpad(&intermediate_steps);
654
655        assert_eq!(scratchpad, expected_scratchpad);
656    }
657}