langchain_rust/agent/open_ai_tools/
agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use crate::{
7    agent::{Agent, AgentError},
8    chain::Chain,
9    fmt_message, fmt_placeholder, fmt_template, message_formatter,
10    prompt::{HumanMessagePromptTemplate, MessageFormatterStruct, PromptArgs},
11    schemas::{
12        agent::{AgentAction, AgentEvent, AgentFinish, LogTools},
13        messages::Message,
14        FunctionCallResponse,
15    },
16    template_jinja2,
17    tools::Tool,
18};
19
20pub struct OpenAiToolAgent {
21    pub(crate) chain: Box<dyn Chain>,
22    pub(crate) tools: Vec<Arc<dyn Tool>>,
23}
24
25impl OpenAiToolAgent {
26    pub fn create_prompt(prefix: &str) -> Result<MessageFormatterStruct, AgentError> {
27        let prompt = message_formatter![
28            fmt_message!(Message::new_system_message(prefix)),
29            fmt_placeholder!("chat_history"),
30            fmt_template!(HumanMessagePromptTemplate::new(template_jinja2!(
31                "{{input}}",
32                "input"
33            ))),
34            fmt_placeholder!("agent_scratchpad")
35        ];
36
37        Ok(prompt)
38    }
39
40    fn construct_scratchpad(
41        &self,
42        intermediate_steps: &[(AgentAction, String)],
43    ) -> Result<Vec<Message>, AgentError> {
44        let mut thoughts: Vec<Message> = Vec::new();
45
46        for (action, observation) in intermediate_steps {
47            // Deserialize directly and embed in method calls to streamline code.
48            // Extract the tool ID and tool calls from the log.
49            let LogTools { tool_id, tools } = serde_json::from_str(&action.log)?;
50            let tools: Vec<FunctionCallResponse> = serde_json::from_str(&tools)?;
51
52            // For the first action, add an AI message with all tools called in this session.
53            if thoughts.is_empty() {
54                thoughts.push(Message::new_ai_message("").with_tool_calls(json!(tools)));
55            }
56
57            // Add a tool message for each observation. Observation is the ouput of the tool call.
58            // tool_id is the id of the tool.
59            thoughts.push(Message::new_tool_message(observation, tool_id));
60        }
61
62        Ok(thoughts)
63    }
64}
65
66#[async_trait]
67impl Agent for OpenAiToolAgent {
68    async fn plan(
69        &self,
70        intermediate_steps: &[(AgentAction, String)],
71        inputs: PromptArgs,
72    ) -> Result<AgentEvent, AgentError> {
73        let mut inputs = inputs.clone();
74        let scratchpad = self.construct_scratchpad(intermediate_steps)?;
75        inputs.insert("agent_scratchpad".to_string(), json!(scratchpad));
76        let output = self.chain.call(inputs).await?.generation;
77        match serde_json::from_str::<Vec<FunctionCallResponse>>(&output) {
78            Ok(tools) => {
79                let mut actions: Vec<AgentAction> = Vec::new();
80                for tool in tools {
81                    //Log tools will be send as log
82                    let log: LogTools = LogTools {
83                        tool_id: tool.id.clone(),
84                        tools: output.clone(), //We send the complete tools ouput, we will need it in
85                                               //the open ai call
86                    };
87                    actions.push(AgentAction {
88                        tool: tool.function.name.clone(),
89                        tool_input: tool.function.arguments.clone(),
90                        log: serde_json::to_string(&log)?, //We send this as string to minimise changes
91                    });
92                }
93                return Ok(AgentEvent::Action(actions));
94            }
95            Err(_) => return Ok(AgentEvent::Finish(AgentFinish { output })),
96        }
97    }
98
99    fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
100        self.tools.clone()
101    }
102}