langchain_rust/agent/open_ai_tools/
agent.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use crate::{
7 agent::{Agent, AgentError},
8 chain::Chain,
9 fmt_message, fmt_placeholder, fmt_template, message_formatter,
10 prompt::{HumanMessagePromptTemplate, MessageFormatterStruct, PromptArgs},
11 schemas::{
12 agent::{AgentAction, AgentEvent, AgentFinish, LogTools},
13 messages::Message,
14 FunctionCallResponse,
15 },
16 template_jinja2,
17 tools::Tool,
18};
19
20pub struct OpenAiToolAgent {
21 pub(crate) chain: Box<dyn Chain>,
22 pub(crate) tools: Vec<Arc<dyn Tool>>,
23}
24
25impl OpenAiToolAgent {
26 pub fn create_prompt(prefix: &str) -> Result<MessageFormatterStruct, AgentError> {
27 let prompt = message_formatter![
28 fmt_message!(Message::new_system_message(prefix)),
29 fmt_placeholder!("chat_history"),
30 fmt_template!(HumanMessagePromptTemplate::new(template_jinja2!(
31 "{{input}}",
32 "input"
33 ))),
34 fmt_placeholder!("agent_scratchpad")
35 ];
36
37 Ok(prompt)
38 }
39
40 fn construct_scratchpad(
41 &self,
42 intermediate_steps: &[(AgentAction, String)],
43 ) -> Result<Vec<Message>, AgentError> {
44 let mut thoughts: Vec<Message> = Vec::new();
45
46 for (action, observation) in intermediate_steps {
47 let LogTools { tool_id, tools } = serde_json::from_str(&action.log)?;
50 let tools: Vec<FunctionCallResponse> = serde_json::from_str(&tools)?;
51
52 if thoughts.is_empty() {
54 thoughts.push(Message::new_ai_message("").with_tool_calls(json!(tools)));
55 }
56
57 thoughts.push(Message::new_tool_message(observation, tool_id));
60 }
61
62 Ok(thoughts)
63 }
64}
65
66#[async_trait]
67impl Agent for OpenAiToolAgent {
68 async fn plan(
69 &self,
70 intermediate_steps: &[(AgentAction, String)],
71 inputs: PromptArgs,
72 ) -> Result<AgentEvent, AgentError> {
73 let mut inputs = inputs.clone();
74 let scratchpad = self.construct_scratchpad(intermediate_steps)?;
75 inputs.insert("agent_scratchpad".to_string(), json!(scratchpad));
76 let output = self.chain.call(inputs).await?.generation;
77 match serde_json::from_str::<Vec<FunctionCallResponse>>(&output) {
78 Ok(tools) => {
79 let mut actions: Vec<AgentAction> = Vec::new();
80 for tool in tools {
81 let log: LogTools = LogTools {
83 tool_id: tool.id.clone(),
84 tools: output.clone(), };
87 actions.push(AgentAction {
88 tool: tool.function.name.clone(),
89 tool_input: tool.function.arguments.clone(),
90 log: serde_json::to_string(&log)?, });
92 }
93 return Ok(AgentEvent::Action(actions));
94 }
95 Err(_) => return Ok(AgentEvent::Finish(AgentFinish { output })),
96 }
97 }
98
99 fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
100 self.tools.clone()
101 }
102}