1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
use std::sync::Arc;

use crate::{
    agent::AgentError,
    chain::{options::ChainCallOptions, LLMChainBuilder},
    language_models::{llm::LLM, options::CallOptions},
    schemas::FunctionDefinition,
    tools::Tool,
};

use super::{prompt::PREFIX, OpenAiToolAgent};

pub struct OpenAiToolAgentBuilder {
    tools: Option<Vec<Arc<dyn Tool>>>,
    prefix: Option<String>,
    options: Option<ChainCallOptions>,
}

impl OpenAiToolAgentBuilder {
    pub fn new() -> Self {
        Self {
            tools: None,
            prefix: None,
            options: None,
        }
    }

    pub fn tools(mut self, tools: &[Arc<dyn Tool>]) -> Self {
        self.tools = Some(tools.to_vec());
        self
    }

    pub fn prefix<S: Into<String>>(mut self, prefix: S) -> Self {
        self.prefix = Some(prefix.into());
        self
    }

    pub fn options(mut self, options: ChainCallOptions) -> Self {
        self.options = Some(options);
        self
    }

    pub fn build<L: LLM + 'static>(self, llm: L) -> Result<OpenAiToolAgent, AgentError> {
        let tools = self.tools.unwrap_or_else(Vec::new);
        let prefix = self.prefix.unwrap_or_else(|| PREFIX.to_string());
        let mut llm = llm;

        let prompt = OpenAiToolAgent::create_prompt(&prefix)?;
        let default_options = ChainCallOptions::default().with_max_tokens(1000);
        let functions = tools
            .iter()
            .map(|tool| FunctionDefinition::from_langchain_tool(tool))
            .collect::<Vec<FunctionDefinition>>();
        llm.add_options(CallOptions::new().with_functions(functions));
        let chain = Box::new(
            LLMChainBuilder::new()
                .prompt(prompt)
                .llm(llm)
                .options(self.options.unwrap_or(default_options))
                .build()?,
        );

        Ok(OpenAiToolAgent { chain, tools })
    }
}