1use crate::{
12 options::Options,
13 parameters,
14 prompt::{PromptTemplate, StringTemplateError},
15 tools::{Tool, ToolError},
16 traits::{Executor, ExecutorError},
17 Parameters,
18};
19use std::time::{Duration, Instant};
20use thiserror::Error;
21
22const PROMPT: &str = "Question: Who lived longer, Muhammad Ali or Alan Turing?
25Are follow up questions needed here: Yes.
26Follow up: How old was Muhammad Ali when he died?
27Intermediate answer: Muhammad Ali was 74 years old when he died.
28Follow up: How old was Alan Turing when he died?
29Intermediate answer: Alan Turing was 41 years old when he died.
30So the final answer is: Muhammad Ali
31
32Question: When was the founder of craigslist born?
33Are follow up questions needed here: Yes.
34Follow up: Who was the founder of craigslist?
35Intermediate answer: Craigslist was founded by Craig Newmark.
36Follow up: When was Craig Newmark born?
37Intermediate answer: Craig Newmark was born on December 6, 1952.
38So the final answer is: December 6, 1952
39
40Question: Who was the maternal grandfather of George Washington?
41Are follow up questions needed here: Yes.
42Follow up: Who was the mother of George Washington?
43Intermediate answer: The mother of George Washington was Mary Ball Washington.
44Follow up: Who was the father of Mary Ball Washington?
45Intermediate answer: The father of Mary Ball Washington was Joseph Ball.
46So the final answer is: Joseph Ball
47
48Question: Are both the directors of Jaws and Casino Royale from the same country?
49Are follow up questions needed here: Yes.
50Follow up: Who is the director of Jaws?
51Intermediate answer: The director of Jaws is Steven Spielberg.
52Follow up: Where is Steven Spielberg from?
53Intermediate answer: The United States.
54Follow up: Who is the director of Casino Royale?
55Intermediate answer: The director of Casino Royale is Martin Campbell.
56Follow up: Where is Martin Campbell from?
57Intermediate answer: New Zealand.
58So the final answer is: No
59
60Question: {{input}}
61Are followup questions needed here:{{agent_scratchpad}}";
62
63#[derive(Debug, PartialEq, Eq)]
67pub struct AgentAction {
68 pub tool: String,
70 pub tool_input: serde_yaml::Value,
72 pub log: String,
80}
81
82#[derive(Debug, PartialEq)]
86pub struct AgentFinish {
87 pub return_values: Parameters,
88
89 pub log: String,
92}
93
94#[derive(Debug)]
95pub struct AgentIntermediateStep {
96 pub action: AgentAction,
97 pub observation: serde_yaml::Value,
98}
99
100pub enum AgentIntermediateStepOutput {
101 Step(AgentIntermediateStep),
102 Finish(AgentFinish),
103}
104
105#[derive(Debug, PartialEq)]
106pub enum AgentDecision {
107 Action(AgentAction),
108 Finish(AgentFinish),
109}
110pub trait AgentOutputParser {
111 type Error;
112 fn parse(&self, text: String) -> Result<AgentDecision, Self::Error>;
113}
114
115#[derive(Debug, Error)]
116pub enum SelfAskWithSearchAgentError<T>
117where
118 T: std::fmt::Debug + std::error::Error + ToolError,
119{
120 #[error("Search tool input yaml was not of type string: {0:?}")]
121 ToolInputNotString(serde_yaml::Value),
122 #[error(transparent)]
123 SearchToolError(T),
124 #[error(transparent)]
125 ExecutorError(ExecutorError),
126 #[error(transparent)]
127 ParserError(#[from] ParserError),
128 #[error(transparent)]
129 YamlError(#[from] serde_yaml::Error),
130 #[error(transparent)]
131 StringTemplateError(#[from] StringTemplateError),
132 #[error("Model response was empty or contained no choices")]
133 NoChoicesReturned,
134 #[error("Max number of iterations or timeout exceeded. Elapsed: {time_elapsed_seconds}s, {iterations_elapsed} iterations")]
135 RuntimeExceeded {
136 time_elapsed_seconds: f64,
137 iterations_elapsed: u32,
138 },
139}
140
141pub struct SelfAskWithSearchAgentOutputParser {
142 followup_prefix: String,
143 intermediate_answer_prefix: String,
144 acceptable_finish_prefixes: Vec<String>,
145}
146
147impl SelfAskWithSearchAgentOutputParser {
148 pub fn new(
149 followup_prefix: &str,
150 intermediate_answer_prefix: &str,
151 acceptable_finish_prefixes: &[&str],
152 ) -> Self {
153 Self {
154 followup_prefix: followup_prefix.into(),
155 intermediate_answer_prefix: intermediate_answer_prefix.into(),
156 acceptable_finish_prefixes: acceptable_finish_prefixes
157 .iter()
158 .map(|s| s.to_string())
159 .collect(),
160 }
161 }
162}
163
164impl Default for SelfAskWithSearchAgentOutputParser {
165 fn default() -> Self {
166 Self::new(
167 "Follow up:",
168 "Intermediate Answer:",
169 &[
170 "Final answer:",
171 "So the final answer is:",
172 "So the final answer could be:",
173 ],
174 )
175 }
176}
177
178#[derive(Debug, Error)]
179#[error("No finish line or follow up question was returned by the model: {0}")]
180pub struct ParserError(String);
181
182impl AgentOutputParser for SelfAskWithSearchAgentOutputParser {
183 type Error = ParserError;
184 fn parse(&self, text: String) -> Result<AgentDecision, Self::Error> {
185 if let Some(followup_idx) = text.find(&self.followup_prefix) {
187 let (followup_question, log) = if let Some(intermediate_answer_idx) =
189 text.find(&self.intermediate_answer_prefix)
190 {
191 let followup_question = text
192 .chars()
193 .skip(followup_idx + self.followup_prefix.len())
194 .take(intermediate_answer_idx - (followup_idx + self.followup_prefix.len()))
195 .collect::<String>()
196 .trim()
197 .to_owned();
198
199 let log = text.chars().take(intermediate_answer_idx).collect();
200 (followup_question, log)
201 } else {
202 let followup_question = text
204 .chars()
205 .skip(followup_idx + self.followup_prefix.len())
206 .take_while(|&c| c != '\n')
207 .collect::<String>()
208 .trim()
209 .to_owned();
210
211 let log = text
212 .char_indices()
213 .map_while(|(idx, c)| {
214 if c != '\n' || idx < followup_idx {
215 Some(c)
216 } else {
217 None
218 }
219 })
220 .collect();
221 (followup_question, log)
222 };
223 Ok(AgentDecision::Action(AgentAction {
224 tool: "Intermediate Answer".into(),
225 tool_input: followup_question.into(),
226 log,
227 }))
228 } else if let Some((idx, prefix)) = self
229 .acceptable_finish_prefixes
230 .iter()
231 .find_map(|prefix| text.find(prefix).map(|idx| (idx, prefix)))
232 {
233 let final_answer = text.chars().skip(idx + prefix.len()).collect::<String>();
234 Ok(AgentDecision::Finish(AgentFinish {
235 return_values: parameters!("output" => final_answer.trim()),
236 log: text,
237 }))
238 } else {
239 Err(ParserError(text))
240 }
241 }
242}
243
244#[derive(Default)]
245pub struct EarlyStoppingConfig {
246 pub max_iterations: Option<u32>,
247 pub max_time_elapsed_seconds: Option<f64>,
248}
249
250pub struct Agent<E, T>
251where
252 E: Executor,
253 T: Tool,
254 T::Input: From<String>,
255 T::Output: Into<String>,
256{
257 executor: E,
258 search_tool: T,
259 early_stopping_config: EarlyStoppingConfig,
260 observation_prefix: String,
261 llm_prefix: String,
262 output_parser: SelfAskWithSearchAgentOutputParser,
263}
264
265impl<E, T> Agent<E, T>
266where
267 E: Executor,
268 T: Tool,
269 T::Input: From<String>,
270 T::Output: Into<String>,
271{
272 pub fn new(executor: E, search_tool: T, early_stopping_config: EarlyStoppingConfig) -> Self {
273 Self {
274 executor,
275 search_tool,
276 early_stopping_config,
277 observation_prefix: "Intermediate answer: ".to_string(),
278 llm_prefix: "".to_string(),
279 output_parser: SelfAskWithSearchAgentOutputParser::default(),
280 }
281 }
282
283 fn should_continue(&self, iterations_elapsed: u32, time_elapsed_seconds: f64) -> bool {
284 match (
285 self.early_stopping_config.max_iterations,
286 self.early_stopping_config.max_time_elapsed_seconds,
287 ) {
288 (None, None) => true,
289 (None, Some(max_time_elapsed_seconds)) => {
290 max_time_elapsed_seconds >= time_elapsed_seconds
291 }
292 (Some(max_iterations), None) => max_iterations >= iterations_elapsed,
293 (Some(max_iterations), Some(max_time_elapsed_seconds)) => {
294 max_iterations >= iterations_elapsed
295 && max_time_elapsed_seconds >= time_elapsed_seconds
296 }
297 }
298 }
299
300 async fn take_next_step(
304 &self,
305 intermediate_steps: &Vec<AgentIntermediateStep>,
306 query: &str,
307 ) -> Result<AgentIntermediateStepOutput, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
308 let output = self.plan(intermediate_steps, query).await?;
309
310 let decision = self.output_parser.parse(output)?;
311 match decision {
312 AgentDecision::Action(action) => {
313 let observation = self
314 .search_tool
315 .invoke_typed(
316 &action
317 .tool_input
318 .as_str()
319 .ok_or(SelfAskWithSearchAgentError::ToolInputNotString(
320 action.tool_input.clone(),
321 ))?
322 .to_string()
323 .into(),
324 )
325 .await
326 .map_err(SelfAskWithSearchAgentError::SearchToolError)?;
327
328 Ok(AgentIntermediateStepOutput::Step(AgentIntermediateStep {
329 action,
330 observation: serde_yaml::to_value(Into::<String>::into(observation))?,
331 }))
332 }
333 AgentDecision::Finish(finish) => Ok(AgentIntermediateStepOutput::Finish(finish)),
334 }
335 }
336
337 pub fn build_agent_scratchpad(
339 &self,
340 intermediate_steps: &Vec<AgentIntermediateStep>,
341 ) -> String {
342 let mut scratchpad = "".to_string();
343 for intermediate_step in intermediate_steps {
344 scratchpad += &intermediate_step.action.log;
345 scratchpad += &format!(
346 "\n{}{}\n{}",
347 self.observation_prefix,
348 intermediate_step.observation.as_str().unwrap_or_default(),
349 self.llm_prefix
350 );
351 }
352 scratchpad
353 }
354
355 async fn plan(
359 &self,
360 intermediate_steps: &Vec<AgentIntermediateStep>,
361 query: &str,
362 ) -> Result<String, SelfAskWithSearchAgentError<<T as Tool>::Error>> {
363 let scratchpad = self.build_agent_scratchpad(intermediate_steps);
364 let template_parameters = parameters!("input" => query, "agent_scratchpad" => scratchpad);
365 let prompt = PromptTemplate::Text(PROMPT.into()).format(&template_parameters)?;
366 let plan = self
367 .executor
368 .execute(Options::empty(), &prompt)
369 .await
370 .map_err(SelfAskWithSearchAgentError::ExecutorError)?;
371 plan.to_immediate()
372 .await
373 .map_err(SelfAskWithSearchAgentError::ExecutorError)?
374 .as_content()
375 .extract_last_body()
376 .cloned()
377 .ok_or(SelfAskWithSearchAgentError::NoChoicesReturned)
378 }
379
380 pub async fn run(
381 &self,
382 query: &str,
383 ) -> Result<
384 (AgentFinish, Vec<AgentIntermediateStep>),
385 SelfAskWithSearchAgentError<<T as Tool>::Error>,
386 > {
387 let mut intermediate_steps = vec![];
388
389 let mut iterations = 0;
390 let start = Instant::now();
391 let mut full_duration = Duration::from_nanos(0);
392 while self.should_continue(iterations, full_duration.as_secs_f64()) {
393 let decision = self.take_next_step(&intermediate_steps, query).await?;
394 full_duration = start.elapsed();
395 iterations += 1;
396 match decision {
397 AgentIntermediateStepOutput::Step(step) => intermediate_steps.push(step),
398 AgentIntermediateStepOutput::Finish(finish) => {
399 return Ok((finish, intermediate_steps))
400 }
401 }
402 }
403 Err(SelfAskWithSearchAgentError::RuntimeExceeded {
404 time_elapsed_seconds: full_duration.as_secs_f64(),
405 iterations_elapsed: iterations,
406 })
407 }
408}
409
410#[cfg(test)]
411mod tests {
412
413 use async_trait::async_trait;
414
415 use thiserror::Error;
416
417 use crate::{
418 agents::self_ask_with_search::{AgentIntermediateStep, EarlyStoppingConfig},
419 options::Options,
420 output::Output,
421 parameters,
422 prompt::Prompt,
423 tokens::{TokenCollection, Tokenizer},
424 tools::{Tool, ToolError},
425 traits::{Executor, ExecutorError},
426 };
427
428 use super::{
429 Agent, AgentAction, AgentDecision, AgentFinish, AgentOutputParser,
430 SelfAskWithSearchAgentOutputParser,
431 };
432
433 #[test]
434 fn test_parses_followup() {
435 let parser = SelfAskWithSearchAgentOutputParser::default();
436 let text = "
437 Whatever
438 Whatever
439 Follow up: my follow up question abc?";
440 let decision = parser.parse(text.into()).unwrap();
441 assert_eq!(
442 decision,
443 AgentDecision::Action(AgentAction {
444 tool: "Intermediate Answer".into(),
445 tool_input: "my follow up question abc?".into(),
446 log: text.into()
447 })
448 );
449 }
450
451 #[test]
452 fn test_parses_follow_up_trims_trailing_whitespace() {
453 let parser = SelfAskWithSearchAgentOutputParser::default();
454 let text = "
455 Whatever
456 Whatever
457 Follow up: my follow up question abc?
458 ";
459 let decision = parser.parse(text.into()).unwrap();
460 assert_eq!(
461 decision,
462 AgentDecision::Action(AgentAction {
463 tool: "Intermediate Answer".into(),
464 tool_input: "my follow up question abc?".into(),
465 log: text.trim_end().into()
466 })
467 );
468 }
469
470 #[test]
471 fn test_parses_final_answer() {
472 let parser = SelfAskWithSearchAgentOutputParser::default();
473 let text = "
474 Whatever
475 Whatever
476 So the final answer is: yes abc!";
477 let decision = parser.parse(text.into()).unwrap();
478 assert_eq!(
479 decision,
480 AgentDecision::Finish(AgentFinish {
481 return_values: parameters!("output" => "yes abc!"),
482 log: text.into()
483 })
484 );
485 }
486
487 #[test]
488 fn test_parses_final_answer_ignores_trailing_whitespace() {
489 let parser = SelfAskWithSearchAgentOutputParser::default();
490 let text = "
491 Whatever
492 Whatever
493 So the final answer is: yes abc!
494 ";
495 let decision = parser.parse(text.into()).unwrap();
496 assert_eq!(
497 decision,
498 AgentDecision::Finish(AgentFinish {
499 return_values: parameters!("output" => "yes abc!"),
500 log: text.into()
501 })
502 );
503 }
504
505 #[test]
506 fn test_parses_final_answer_with_colons() {
507 let parser = SelfAskWithSearchAgentOutputParser::default();
508 let text = "
509 Whatever
510 Whatever
511 So the final answer is: Mad Max: Fury road";
512 let decision = parser.parse(text.into()).unwrap();
513 assert_eq!(
514 decision,
515 AgentDecision::Finish(AgentFinish {
516 return_values: parameters!("output" => "Mad Max: Fury road"),
517 log: text.into()
518 })
519 );
520 }
521
522 #[test]
523 fn test_builds_agent_sratchpad() {
524 #[derive(Clone)]
525 struct MockOutput;
526
527 #[derive(Debug, Error)]
528 #[error("Mocked executor error")]
529 struct MockError;
530
531 impl ToolError for MockError {}
532
533 impl From<serde_yaml::Error> for MockError {
534 fn from(_: serde_yaml::Error) -> Self {
535 Self
536 }
537 }
538
539 struct MockTokenizer;
540
541 impl Tokenizer for MockTokenizer {
542 fn tokenize_str(
543 &self,
544 _: &str,
545 ) -> Result<TokenCollection, crate::tokens::TokenizerError> {
546 todo!()
547 }
548
549 fn to_string(
550 &self,
551 _: TokenCollection,
552 ) -> Result<String, crate::tokens::TokenizerError> {
553 todo!()
554 }
555 }
556
557 struct MockExecutor;
558
559 #[async_trait]
560 impl Executor for MockExecutor {
561 type StepTokenizer<'a> = MockTokenizer;
562
563 fn new_with_options(_: Options) -> Result<Self, crate::traits::ExecutorCreationError> {
564 todo!()
565 }
566
567 async fn execute(
568 &self,
569 _: &Options,
570 _: &crate::prompt::Prompt,
571 ) -> Result<Output, ExecutorError> {
572 todo!()
573 }
574
575 fn tokens_used(
576 &self,
577 _: &Options,
578 _: &crate::prompt::Prompt,
579 ) -> Result<crate::tokens::TokenCount, crate::tokens::PromptTokensError> {
580 todo!()
581 }
582
583 fn answer_prefix(&self, _prompt: &Prompt) -> Option<String> {
584 todo!()
585 }
586
587 fn max_tokens_allowed(&self, _: &Options) -> i32 {
588 todo!()
589 }
590
591 fn get_tokenizer(
592 &self,
593 _: &Options,
594 ) -> Result<MockTokenizer, crate::tokens::TokenizerError> {
595 todo!()
596 }
597 }
598 struct MockSearch;
599
600 #[async_trait]
601 impl Tool for MockSearch {
602 type Input = String;
603
604 type Output = String;
605
606 type Error = MockError;
607
608 async fn invoke_typed(&self, _: &Self::Input) -> Result<Self::Output, Self::Error> {
609 todo!()
610 }
611
612 fn description(&self) -> crate::tools::ToolDescription {
613 todo!()
614 }
615 }
616 let mock_executor = MockExecutor;
617 let mock_search = MockSearch;
618 let agent = Agent::new(
619 mock_executor,
620 mock_search,
621 EarlyStoppingConfig {
622 max_iterations: None,
623 max_time_elapsed_seconds: None,
624 },
625 );
626 let intermediate_steps = vec![
627 AgentIntermediateStep {
628 action: AgentAction {
629 tool: "Intermediate Answer".into(),
630 tool_input: "How old was Muhammad Ali when he died?".into(),
631 log: "Yes.
632Follow up: How old was Muhammad Ali when he died?"
633 .into(),
634 },
635 observation: "Muhammad Ali was 74 years old when he died.".into(),
636 },
637 AgentIntermediateStep {
638 action: AgentAction {
639 tool: "Intermediate Answer".into(),
640 tool_input: "How old was Alan Turing when he died?".into(),
641 log: "Follow up: How old was Alan Turing when he died?".into(),
642 },
643 observation: "Alan Turing was 41 years old when he died.".into(),
644 },
645 ];
646
647 let expected_scratchpad = "Yes.
648Follow up: How old was Muhammad Ali when he died?
649Intermediate answer: Muhammad Ali was 74 years old when he died.
650Follow up: How old was Alan Turing when he died?
651Intermediate answer: Alan Turing was 41 years old when he died.\n";
652
653 let scratchpad = agent.build_agent_scratchpad(&intermediate_steps);
654
655 assert_eq!(scratchpad, expected_scratchpad);
656 }
657}