langchain_rust/agent/chat/
chat_agent.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use crate::{
7 agent::{agent::Agent, chat::prompt::FORMAT_INSTRUCTIONS, AgentError},
8 chain::chain_trait::Chain,
9 message_formatter,
10 prompt::{
11 HumanMessagePromptTemplate, MessageFormatterStruct, MessageOrTemplate, PromptArgs,
12 PromptFromatter,
13 },
14 prompt_args,
15 schemas::{
16 agent::{AgentAction, AgentEvent},
17 messages::Message,
18 },
19 template_jinja2,
20 tools::Tool,
21};
22
23use super::{output_parser::ChatOutputParser, prompt::TEMPLATE_TOOL_RESPONSE};
24
25pub struct ConversationalAgent {
26 pub(crate) chain: Box<dyn Chain>,
27 pub(crate) tools: Vec<Arc<dyn Tool>>,
28 pub(crate) output_parser: ChatOutputParser,
29}
30
31impl ConversationalAgent {
32 pub fn create_prompt(
33 tools: &[Arc<dyn Tool>],
34 suffix: &str,
35 prefix: &str,
36 ) -> Result<MessageFormatterStruct, AgentError> {
37 let tool_string = tools
38 .iter()
39 .map(|tool| format!("> {}: {}", tool.name(), tool.description()))
40 .collect::<Vec<_>>()
41 .join("\n");
42 let tool_names = tools
43 .iter()
44 .map(|tool| tool.name())
45 .collect::<Vec<_>>()
46 .join(", ");
47
48 let sufix_prompt = template_jinja2!(suffix, "tools", "format_instructions");
49
50 let input_variables_fstring = prompt_args! {
51 "tools" => tool_string,
52 "format_instructions" => FORMAT_INSTRUCTIONS,
53 "tool_names"=>tool_names
54 };
55
56 let sufix_prompt = sufix_prompt.format(input_variables_fstring)?;
57 let formatter = message_formatter![
58 MessageOrTemplate::Message(Message::new_system_message(prefix)),
59 MessageOrTemplate::MessagesPlaceholder("chat_history".to_string()),
60 MessageOrTemplate::Template(
61 HumanMessagePromptTemplate::new(template_jinja2!(
62 &sufix_prompt.to_string(),
63 "input"
64 ))
65 .into()
66 ),
67 MessageOrTemplate::MessagesPlaceholder("agent_scratchpad".to_string()),
68 ];
69 Ok(formatter)
70 }
71
72 fn construct_scratchpad(
73 &self,
74 intermediate_steps: &[(AgentAction, String)],
75 ) -> Result<Vec<Message>, AgentError> {
76 let mut thoughts: Vec<Message> = Vec::new();
77 for (action, observation) in intermediate_steps.iter() {
78 thoughts.push(Message::new_ai_message(&action.log));
79 let tool_response = template_jinja2!(TEMPLATE_TOOL_RESPONSE, "observation")
80 .format(prompt_args!("observation"=>observation))?;
81 thoughts.push(Message::new_human_message(&tool_response));
82 }
83 Ok(thoughts)
84 }
85}
86
87#[async_trait]
88impl Agent for ConversationalAgent {
89 async fn plan(
90 &self,
91 intermediate_steps: &[(AgentAction, String)],
92 inputs: PromptArgs,
93 ) -> Result<AgentEvent, AgentError> {
94 let scratchpad = self.construct_scratchpad(intermediate_steps)?;
95 let mut inputs = inputs.clone();
96 inputs.insert("agent_scratchpad".to_string(), json!(scratchpad));
97 let output = self.chain.call(inputs.clone()).await?.generation;
98 let parsed_output = self.output_parser.parse(&output)?;
99 Ok(parsed_output)
100 }
101
102 fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
103 self.tools.clone()
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use std::{error::Error, sync::Arc};
110
111 use async_trait::async_trait;
112 use serde_json::Value;
113
114 use crate::{
115 agent::{chat::builder::ConversationalAgentBuilder, executor::AgentExecutor},
116 chain::chain_trait::Chain,
117 llm::openai::{OpenAI, OpenAIModel},
118 memory::SimpleMemory,
119 prompt_args,
120 tools::Tool,
121 };
122
123 struct Calc {}
124
125 #[async_trait]
126 impl Tool for Calc {
127 fn name(&self) -> String {
128 "Calculator".to_string()
129 }
130 fn description(&self) -> String {
131 "Usefull to make calculations".to_string()
132 }
133 async fn run(&self, _input: Value) -> Result<String, Box<dyn Error>> {
134 Ok("25".to_string())
135 }
136 }
137
138 #[tokio::test]
139 #[ignore]
140 async fn test_invoke_agent() {
141 let llm = OpenAI::default().with_model(OpenAIModel::Gpt4.to_string());
142 let memory = SimpleMemory::new();
143 let tool_calc = Calc {};
144 let agent = ConversationalAgentBuilder::new()
145 .tools(&[Arc::new(tool_calc)])
146 .build(llm)
147 .unwrap();
148 let input_variables = prompt_args! {
149 "input" => "hola,Me llamo luis, y tengo 10 anos, y estudio Computer scinence",
150 };
151 let executor = AgentExecutor::from_agent(agent).with_memory(memory.into());
152 match executor.invoke(input_variables).await {
153 Ok(result) => {
154 println!("Result: {:?}", result);
155 }
156 Err(e) => panic!("Error invoking LLMChain: {:?}", e),
157 }
158 let input_variables = prompt_args! {
159 "input" => "cuanta es la edad de luis +10 y que estudia",
160 };
161 match executor.invoke(input_variables).await {
162 Ok(result) => {
163 println!("Result: {:?}", result);
164 }
165 Err(e) => panic!("Error invoking LLMChain: {:?}", e),
166 }
167 }
168}