1use crate::{
2 options::Options,
3 parameters,
4 prompt::{PromptTemplate, StringTemplateError},
5 tools::{Tool, ToolError},
6 traits::{Executor, ExecutorError},
7 Parameters,
8};
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12const PROMPT: &str = "Question: Who lived longer, Muhammad Ali or Alan Turing?
19Are follow up questions needed here: Yes.
20Follow up: How old was Muhammad Ali when he died?
21Intermediate answer: Muhammad Ali was 74 years old when he died.
22Follow up: How old was Alan Turing when he died?
23Intermediate answer: Alan Turing was 41 years old when he died.
24So the final answer is: Muhammad Ali
25
26Question: When was the founder of craigslist born?
27Are follow up questions needed here: Yes.
28Follow up: Who was the founder of craigslist?
29Intermediate answer: Craigslist was founded by Craig Newmark.
30Follow up: When was Craig Newmark born?
31Intermediate answer: Craig Newmark was born on December 6, 1952.
32So the final answer is: December 6, 1952
33
34Question: Who was the maternal grandfather of George Washington?
35Are follow up questions needed here: Yes.
36Follow up: Who was the mother of George Washington?
37Intermediate answer: The mother of George Washington was Mary Ball Washington.
38Follow up: Who was the father of Mary Ball Washington?
39Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
40So the final answer is: Joseph Ball
41
42Question: Are both the directors of Jaws and Casino Royale from the same country?
43Are follow up questions needed here: Yes.
44Follow up: Who is the director of Jaws?
45Intermediate answer: The director of Jaws is Steven Spielberg.
46Follow up: Where is Steven Spielberg from?
47Intermediate answer: The United States.
48Follow up: Who is the director of Casino Royale?
49Intermediate answer: The director of Casino Royale is Martin Campbell.
50Follow up: Where is Martin Campbell from?
51Intermediate answer: New Zealand.
52So the final answer is: No
53
54Question: {{input}}
55Are followup questions needed here:{{agent_scratchpad}}";
56
57#[derive(Debug, PartialEq, Eq)]
58pub struct AgentAction {
59 pub tool: String,
60 pub tool_input: serde_yaml::Value,
61 pub log: String,
62}
63#[derive(Debug, PartialEq)]
64pub struct AgentFinish {
65 pub return_values: Parameters,
66 pub log: String,
67}
68
69#[derive(Debug)]
70pub struct AgentIntermediateStep {
71 pub action: AgentAction,
72 pub observation: serde_yaml::Value,
73}
74
75pub enum AgentIntermediateStepOutput {
76 Step(AgentIntermediateStep),
77 Finish(AgentFinish),
78}
79
80#[derive(Debug, PartialEq)]
81pub enum AgentDecision {
82 Action(AgentAction),
83 Finish(AgentFinish),
84}
85pub trait AgentOutputParser {
86 type Error;
87 fn parse(&self, text: String) -> Result<AgentDecision, Self::Error>;
88}
89
90#[derive(Debug, Error)]
91pub enum SelfAskWithSearchAgentError<T>
92where
93 T: std::fmt::Debug + std::error::Error + ToolError,
94{
95 #[error("Search tool input yaml was not of type string: {0:?}")]
96 ToolInputNotString(serde_yaml::Value),
97 #[error(transparent)]
98 SearchToolError(T),
99 #[error(transparent)]
100 ExecutorError(ExecutorError),
101 #[error(transparent)]
102 ParserError(#[from] ParserError),
103 #[error(transparent)]
104 YamlError(#[from] serde_yaml::Error),
105 #[error(transparent)]
106 StringTemplateError(#[from] StringTemplateError),
107 #[error("Model response was empty or contained no choices")]
108 NoChoicesReturned,
109 #[error("Max number of iterations or timeout exceeded. Elapsed: {time_elapsed_seconds}s, {iterations_elapsed} iterations")]
110 RuntimeExceeded {
111 time_elapsed_seconds: f64,
112 iterations_elapsed: u32,
113 },
114}
115
116pub struct SelfAskWithSearchAgentOutputParser {
117 followup_prefix: String,
118 intermediate_answer_prefix: String,
119 acceptable_finish_prefixes: Vec<String>,
120}
121
122impl SelfAskWithSearchAgentOutputParser {
123 pub fn new(
124 followup_prefix: &str,
125 intermediate_answer_prefix: &str,
126 acceptable_finish_prefixes: &[&str],
127 ) -> Self {
128 Self {
129 followup_prefix: followup_prefix.into(),
130 intermediate_answer_prefix: intermediate_answer_prefix.into(),
131 acceptable_finish_prefixes: acceptable_finish_prefixes
132 .iter()
133 .map(|s| s.to_string())
134 .collect(),
135 }
136 }
137}
138
139impl Default for SelfAskWithSearchAgentOutputParser {
140 fn default() -> Self {
141 Self::new(
142 "Follow up:",
143 "Intermediate Answer:",
144 &[
145 "Final answer:",
146 "So the final answer is:",
147 "So the final answer could be:",
148 ],
149 )
150 }
151}
152
153#[derive(Debug, Error)]
154#[error("No finish line or follow up question was returned by the model: {0}")]
155pub struct ParserError(String);
156
157impl AgentOutputParser for SelfAskWithSearchAgentOutputParser {
158 type Error = ParserError;
159 fn parse(&self, text: String) -> Result<AgentDecision, Self::Error> {
160 if let Some(followup_idx) = text.find(&self.followup_prefix) {
161 let (followup_question, log) = if let Some(intermediate_answer_idx) =
162 text.find(&self.intermediate_answer_prefix)
163 {
164 let followup_question = text
165 .chars()
166 .skip(followup_idx + self.followup_prefix.len())
167 .take(intermediate_answer_idx - (followup_idx + self.followup_prefix.len()))
168 .collect::<String>()
169 .trim()
170 .to_owned();
171
172 let log = text.chars().take(intermediate_answer_idx).collect();
173 (followup_question, log)
174 } else {
175 let followup_question = text
176 .chars()
177 .skip(followup_idx + self.followup_prefix.len())
178 .take_while(|&c| c != '\n')
179 .collect::<String>()
180 .trim()
181 .to_owned();
182
183 let log = text
184 .char_indices()
185 .map_while(|(idx, c)| {
186 if c != '\n' || idx < followup_idx {
187 Some(c)
188 } else {
189 None
190 }
191 })
192 .collect();
193 (followup_question, log)
194 };
195 Ok(AgentDecision::Action(AgentAction {
196 tool: "Intermediate Answer".into(),
197 tool_input: followup_question.into(),
198 log,
199 }))
200 } else if let Some((idx, prefix)) = self
201 .acceptable_finish_prefixes
202 .iter()
203 .find_map(|prefix| text.find(prefix).map(|idx| (idx, prefix)))
204 {
205 let final_answer = text.chars().skip(idx + prefix.len()).collect::<String>();
206 Ok(AgentDecision::Finish(AgentFinish {
207 return_values: parameters!("output" => final_answer.trim()),
208 log: text,
209 }))
210 } else {
211 Err(ParserError(text))
212 }
213 }
214}
215
216#[derive(Default)]
217pub struct EarlyStoppingConfig {
218 pub max_iterations: Option<u32>,
219 pub max_time_elapsed_seconds: Option<f64>,
220}
221
222pub struct Agent<E, T>
223where
224 E: Executor,
225 T: Tool,
226 T::Input: From<String>,
227 T::Output: Into<String>,
228{
229 executor: E,
230 search_tool: T,
231 early_stopping_config: EarlyStoppingConfig,
232 observation_prefix: String,
233 llm_prefix: String,
234 output_parser: SelfAskWithSearchAgentOutputParser,
235}
236
237impl<E, T> Agent<E, T>
238where
239 E: Executor,
240 T: Tool,
241 T::Input: From<String>,
242 T::Output: Into<String>,
243{
244 pub fn new(executor: E, search_tool: T, early_stopping_config: EarlyStoppingConfig) -> Self {
245 Self {
246 executor,
247 search_tool,
248 early_stopping_config,
249 observation_prefix: "Intermediate answer: ".to_string(),
250 llm_prefix: "".to_string(),
251 output_parser: SelfAskWithSearchAgentOutputParser::default(),
252 }
253 }
254
255 fn should_continue(&self, iterations_elapsed: u32, time_elapsed_seconds: f64) -> bool {
256 match (
257 self.early_stopping_config.max_iterations,
258 self.early_stopping_config.max_time_elapsed_seconds,
259 ) {
260 (None, None) => true,
261 (None, Some(max_time_elapsed_seconds)) => {
262 max_time_elapsed_seconds >= time_elapsed_seconds
263 }
264 (Some(max_iterations), None) => max_iterations >= iterations_elapsed,
265 (Some(max_iterations), Some(max_time_elapsed_seconds)) => {
266 max_iterations >= iterations_elapsed
267 && max_time_elapsed_seconds >= time_elapsed_seconds
268 }
269 }
270 }
271
272 async fn take_next_step(
276 &self,
277 intermediate_steps: &Vec<AgentIntermediateStep>,
278 query: &str,
279 ) -> Result<AgentIntermediateStepOutput, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
280 let output = self.plan(intermediate_steps, query).await?;
281
282 let decision = self.output_parser.parse(output)?;
283 match decision {
284 AgentDecision::Action(action) => {
285 let observation = self
286 .search_tool
287 .invoke_typed(
288 &action
289 .tool_input
290 .as_str()
291 .ok_or(SelfAskWithSearchAgentError::ToolInputNotString(
292 action.tool_input.clone(),
293 ))?
294 .to_string()
295 .into(),
296 )
297 .await
298 .map_err(SelfAskWithSearchAgentError::SearchToolError)?;
299
300 Ok(AgentIntermediateStepOutput::Step(AgentIntermediateStep {
301 action,
302 observation: serde_yaml::to_value(Into::<String>::into(observation))?,
303 }))
304 }
305 AgentDecision::Finish(finish) => Ok(AgentIntermediateStepOutput::Finish(finish)),
306 }
307 }
308
309 pub fn build_agent_scratchpad(
311 &self,
312 intermediate_steps: &Vec<AgentIntermediateStep>,
313 ) -> String {
314 let mut scratchpad = "".to_string();
315 for intermediate_step in intermediate_steps {
316 scratchpad += &intermediate_step.action.log;
317 scratchpad += &format!(
318 "\n{}{}\n{}",
319 self.observation_prefix,
320 intermediate_step.observation.as_str().unwrap_or_default(),
321 self.llm_prefix
322 );
323 }
324 scratchpad
325 }
326
327 async fn plan(
331 &self,
332 intermediate_steps: &Vec<AgentIntermediateStep>,
333 query: &str,
334 ) -> Result<String, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
335 let scratchpad = self.build_agent_scratchpad(intermediate_steps);
336 let template_parameters = parameters!("input" => query, "agent_scratchpad" => scratchpad);
337 let prompt = PromptTemplate::Text(PROMPT.into()).format(&template_parameters)?;
338 let plan = self
339 .executor
340 .execute(Options::empty(), &prompt)
341 .await
342 .map_err(SelfAskWithSearchAgentError::ExecutorError)?;
343 plan.to_immediate()
344 .await
345 .map_err(SelfAskWithSearchAgentError::ExecutorError)?
346 .as_content()
347 .extract_last_body()
348 .cloned()
349 .ok_or(SelfAskWithSearchAgentError::NoChoicesReturned)
350 }
351
352 pub async fn run(
353 &self,
354 query: &str,
355 ) -> Result<
356 (AgentFinish, Vec<AgentIntermediateStep>),
357 SelfAskWithSearchAgentError<<T as Tool>::Error>,
358 > {
359 let mut intermediate_steps = vec![];
360
361 let mut iterations = 0;
362 let start = Instant::now();
363 let mut full_duration = Duration::from_nanos(0);
364 while self.should_continue(iterations, full_duration.as_secs_f64()) {
365 let decision = self.take_next_step(&intermediate_steps, query).await?;
366 full_duration = start.elapsed();
367 iterations += 1;
368 match decision {
369 AgentIntermediateStepOutput::Step(step) => intermediate_steps.push(step),
370 AgentIntermediateStepOutput::Finish(finish) => {
371 return Ok((finish, intermediate_steps))
372 }
373 }
374 }
375 Err(SelfAskWithSearchAgentError::RuntimeExceeded {
376 time_elapsed_seconds: full_duration.as_secs_f64(),
377 iterations_elapsed: iterations,
378 })
379 }
380}
381
382#[cfg(test)]
383mod tests {
384
385 use async_trait::async_trait;
386
387 use thiserror::Error;
388
389 use crate::{
390 agents::self_ask_with_search::{AgentIntermediateStep, EarlyStoppingConfig},
391 options::Options,
392 output::Output,
393 parameters,
394 prompt::Prompt,
395 tokens::{TokenCollection, Tokenizer},
396 tools::{Tool, ToolError},
397 traits::{Executor, ExecutorError},
398 };
399
400 use super::{
401 Agent, AgentAction, AgentDecision, AgentFinish, AgentOutputParser,
402 SelfAskWithSearchAgentOutputParser,
403 };
404
405 #[test]
406 fn test_parses_followup() {
407 let parser = SelfAskWithSearchAgentOutputParser::default();
408 let text = "
409 Whatever
410 Whatever
411 Follow up: my follow up question abc?";
412 let decision = parser.parse(text.into()).unwrap();
413 assert_eq!(
414 decision,
415 AgentDecision::Action(AgentAction {
416 tool: "Intermediate Answer".into(),
417 tool_input: "my follow up question abc?".into(),
418 log: text.into()
419 })
420 );
421 }
422
423 #[test]
424 fn test_parses_follow_up_trims_trailing_whitespace() {
425 let parser = SelfAskWithSearchAgentOutputParser::default();
426 let text = "
427 Whatever
428 Whatever
429 Follow up: my follow up question abc?
430 ";
431 let decision = parser.parse(text.into()).unwrap();
432 assert_eq!(
433 decision,
434 AgentDecision::Action(AgentAction {
435 tool: "Intermediate Answer".into(),
436 tool_input: "my follow up question abc?".into(),
437 log: text.trim_end().into()
438 })
439 );
440 }
441
442 #[test]
443 fn test_parses_final_answer() {
444 let parser = SelfAskWithSearchAgentOutputParser::default();
445 let text = "
446 Whatever
447 Whatever
448 So the final answer is: yes abc!";
449 let decision = parser.parse(text.into()).unwrap();
450 assert_eq!(
451 decision,
452 AgentDecision::Finish(AgentFinish {
453 return_values: parameters!("output" => "yes abc!"),
454 log: text.into()
455 })
456 );
457 }
458
459 #[test]
460 fn test_parses_final_answer_ignores_trailing_whitespace() {
461 let parser = SelfAskWithSearchAgentOutputParser::default();
462 let text = "
463 Whatever
464 Whatever
465 So the final answer is: yes abc!
466 ";
467 let decision = parser.parse(text.into()).unwrap();
468 assert_eq!(
469 decision,
470 AgentDecision::Finish(AgentFinish {
471 return_values: parameters!("output" => "yes abc!"),
472 log: text.into()
473 })
474 );
475 }
476
477 #[test]
478 fn test_parses_final_answer_with_colons() {
479 let parser = SelfAskWithSearchAgentOutputParser::default();
480 let text = "
481 Whatever
482 Whatever
483 So the final answer is: Mad Max: Fury road";
484 let decision = parser.parse(text.into()).unwrap();
485 assert_eq!(
486 decision,
487 AgentDecision::Finish(AgentFinish {
488 return_values: parameters!("output" => "Mad Max: Fury road"),
489 log: text.into()
490 })
491 );
492 }
493
494 #[test]
495 fn test_builds_agent_sratchpad() {
496 #[derive(Clone)]
497 struct MockOutput;
498
499 #[derive(Debug, Error)]
500 #[error("Mocked executor error")]
501 struct MockError;
502
503 impl ToolError for MockError {}
504
505 impl From<serde_yaml::Error> for MockError {
506 fn from(_: serde_yaml::Error) -> Self {
507 Self
508 }
509 }
510
511 struct MockTokenizer;
512
513 impl Tokenizer for MockTokenizer {
514 fn tokenize_str(
515 &self,
516 _: &str,
517 ) -> Result<TokenCollection, crate::tokens::TokenizerError> {
518 todo!()
519 }
520
521 fn to_string(
522 &self,
523 _: TokenCollection,
524 ) -> Result<String, crate::tokens::TokenizerError> {
525 todo!()
526 }
527 }
528
529 struct MockExecutor;
530
531 #[async_trait]
532 impl Executor for MockExecutor {
533 type StepTokenizer<'a> = MockTokenizer;
534
535 fn new_with_options(_: Options) -> Result<Self, crate::traits::ExecutorCreationError> {
536 todo!()
537 }
538
539 async fn execute(
540 &self,
541 _: &Options,
542 _: &crate::prompt::Prompt,
543 ) -> Result<Output, ExecutorError> {
544 todo!()
545 }
546
547 fn tokens_used(
548 &self,
549 _: &Options,
550 _: &crate::prompt::Prompt,
551 ) -> Result<crate::tokens::TokenCount, crate::tokens::PromptTokensError> {
552 todo!()
553 }
554
555 fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
556 todo!()
557 }
558
559 fn max_tokens_allowed(&self, _: &Options) -> i32 {
560 todo!()
561 }
562
563 fn get_tokenizer(
564 &self,
565 _: &Options,
566 ) -> Result<MockTokenizer, crate::tokens::TokenizerError> {
567 todo!()
568 }
569 }
570 struct MockSearch;
571
572 #[async_trait]
573 impl Tool for MockSearch {
574 type Input = String;
575
576 type Output = String;
577
578 type Error = MockError;
579
580 async fn invoke_typed(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> {
581 todo!()
582 }
583
584 fn description(&self) -> crate::tools::ToolDescription {
585 todo!()
586 }
587 }
588 let mock_executor = MockExecutor;
589 let mock_search = MockSearch;
590 let agent = Agent::new(
591 mock_executor,
592 mock_search,
593 EarlyStoppingConfig {
594 max_iterations: None,
595 max_time_elapsed_seconds: None,
596 },
597 );
598 let intermediate_steps = vec![
599 AgentIntermediateStep {
600 action: AgentAction {
601 tool: "Intermediate Answer".into(),
602 tool_input: "How old was Muhammad Ali when he died?".into(),
603 log: "Yes.
604Follow up: How old was Muhammad Ali when he died?"
605 .into(),
606 },
607 observation: "Muhammad Ali was 74 years old when he died.".into(),
608 },
609 AgentIntermediateStep {
610 action: AgentAction {
611 tool: "Intermediate Answer".into(),
612 tool_input: "How old was Alan Turing when he died?".into(),
613 log: "Follow up: How old was Alan Turing when he died?".into(),
614 },
615 observation: "Alan Turing was 41 years old when he died.".into(),
616 },
617 ];
618
619 let expected_scratchpad = "Yes.
620Follow up: How old was Muhammad Ali when he died?
621Intermediate answer: Muhammad Ali was 74 years old when he died.
622Follow up: How old was Alan Turing when he died?
623Intermediate answer: Alan Turing was 41 years old when he died.\n";
624
625 let scratchpad = agent.build_agent_scratchpad(&intermediate_steps);
626
627 assert_eq!(scratchpad, expected_scratchpad);
628 }
629}