1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use std::{error::Error, sync::Arc};

use crate::{
    agent::agent::AgentOutputParser, chain::llm_chain::LLMChainBuilder, language_models::llm::LLM,
    tools::Tool,
};

use super::{
    prompt::{PREFIX, SUFFIX},
    ConversationalAgent,
};

pub struct ConversationalAgentBuilder {
    tools: Option<Vec<Arc<dyn Tool>>>,
    output_parser: Option<Box<dyn AgentOutputParser>>,
    prefix: Option<String>,
    suffix: Option<String>,
}

impl ConversationalAgentBuilder {
    pub fn new() -> Self {
        Self {
            tools: None,
            output_parser: None,
            prefix: None,
            suffix: None,
        }
    }

    pub fn tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
        self.tools = Some(tools);
        self
    }

    pub fn output_parser(mut self, output_parser: Box<dyn AgentOutputParser>) -> Self {
        self.output_parser = Some(output_parser);
        self
    }

    pub fn prefix(mut self, prefix: String) -> Self {
        self.prefix = Some(prefix);
        self
    }

    pub fn suffix(mut self, suffix: String) -> Self {
        self.suffix = Some(suffix);
        self
    }

    pub fn build<L: LLM + 'static>(self, llm: L) -> Result<ConversationalAgent, Box<dyn Error>> {
        let tools = self.tools.unwrap_or_else(Vec::new);
        let output_parser = self.output_parser.ok_or("Output parser must be set")?;
        let prefix = self.prefix.unwrap_or_else(|| PREFIX.to_string());
        let suffix = self.suffix.unwrap_or_else(|| SUFFIX.to_string());

        let prompt = ConversationalAgent::create_prompt(&tools, &suffix, &prefix)?;
        let chain = Box::new(LLMChainBuilder::new().prompt(prompt).llm(llm).build()?);

        Ok(ConversationalAgent {
            chain,
            tools,
            output_parser,
        })
    }
}