langchain_rust/agent/chat/
chat_agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::json;
5
6use crate::{
7    agent::{agent::Agent, chat::prompt::FORMAT_INSTRUCTIONS, AgentError},
8    chain::chain_trait::Chain,
9    message_formatter,
10    prompt::{
11        HumanMessagePromptTemplate, MessageFormatterStruct, MessageOrTemplate, PromptArgs,
12        PromptFromatter,
13    },
14    prompt_args,
15    schemas::{
16        agent::{AgentAction, AgentEvent},
17        messages::Message,
18    },
19    template_jinja2,
20    tools::Tool,
21};
22
23use super::{output_parser::ChatOutputParser, prompt::TEMPLATE_TOOL_RESPONSE};
24
25pub struct ConversationalAgent {
26    pub(crate) chain: Box<dyn Chain>,
27    pub(crate) tools: Vec<Arc<dyn Tool>>,
28    pub(crate) output_parser: ChatOutputParser,
29}
30
31impl ConversationalAgent {
32    pub fn create_prompt(
33        tools: &[Arc<dyn Tool>],
34        suffix: &str,
35        prefix: &str,
36    ) -> Result<MessageFormatterStruct, AgentError> {
37        let tool_string = tools
38            .iter()
39            .map(|tool| format!("> {}: {}", tool.name(), tool.description()))
40            .collect::<Vec<_>>()
41            .join("\n");
42        let tool_names = tools
43            .iter()
44            .map(|tool| tool.name())
45            .collect::<Vec<_>>()
46            .join(", ");
47
48        let sufix_prompt = template_jinja2!(suffix, "tools", "format_instructions");
49
50        let input_variables_fstring = prompt_args! {
51            "tools" => tool_string,
52            "format_instructions" => FORMAT_INSTRUCTIONS,
53            "tool_names"=>tool_names
54        };
55
56        let sufix_prompt = sufix_prompt.format(input_variables_fstring)?;
57        let formatter = message_formatter![
58            MessageOrTemplate::Message(Message::new_system_message(prefix)),
59            MessageOrTemplate::MessagesPlaceholder("chat_history".to_string()),
60            MessageOrTemplate::Template(
61                HumanMessagePromptTemplate::new(template_jinja2!(
62                    &sufix_prompt.to_string(),
63                    "input"
64                ))
65                .into()
66            ),
67            MessageOrTemplate::MessagesPlaceholder("agent_scratchpad".to_string()),
68        ];
69        Ok(formatter)
70    }
71
72    fn construct_scratchpad(
73        &self,
74        intermediate_steps: &[(AgentAction, String)],
75    ) -> Result<Vec<Message>, AgentError> {
76        let mut thoughts: Vec<Message> = Vec::new();
77        for (action, observation) in intermediate_steps.iter() {
78            thoughts.push(Message::new_ai_message(&action.log));
79            let tool_response = template_jinja2!(TEMPLATE_TOOL_RESPONSE, "observation")
80                .format(prompt_args!("observation"=>observation))?;
81            thoughts.push(Message::new_human_message(&tool_response));
82        }
83        Ok(thoughts)
84    }
85}
86
87#[async_trait]
88impl Agent for ConversationalAgent {
89    async fn plan(
90        &self,
91        intermediate_steps: &[(AgentAction, String)],
92        inputs: PromptArgs,
93    ) -> Result<AgentEvent, AgentError> {
94        let scratchpad = self.construct_scratchpad(intermediate_steps)?;
95        let mut inputs = inputs.clone();
96        inputs.insert("agent_scratchpad".to_string(), json!(scratchpad));
97        let output = self.chain.call(inputs.clone()).await?.generation;
98        let parsed_output = self.output_parser.parse(&output)?;
99        Ok(parsed_output)
100    }
101
102    fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
103        self.tools.clone()
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use std::{error::Error, sync::Arc};
110
111    use async_trait::async_trait;
112    use serde_json::Value;
113
114    use crate::{
115        agent::{chat::builder::ConversationalAgentBuilder, executor::AgentExecutor},
116        chain::chain_trait::Chain,
117        llm::openai::{OpenAI, OpenAIModel},
118        memory::SimpleMemory,
119        prompt_args,
120        tools::Tool,
121    };
122
123    struct Calc {}
124
125    #[async_trait]
126    impl Tool for Calc {
127        fn name(&self) -> String {
128            "Calculator".to_string()
129        }
130        fn description(&self) -> String {
131            "Usefull to make calculations".to_string()
132        }
133        async fn run(&self, _input: Value) -> Result<String, Box<dyn Error>> {
134            Ok("25".to_string())
135        }
136    }
137
138    #[tokio::test]
139    #[ignore]
140    async fn test_invoke_agent() {
141        let llm = OpenAI::default().with_model(OpenAIModel::Gpt4.to_string());
142        let memory = SimpleMemory::new();
143        let tool_calc = Calc {};
144        let agent = ConversationalAgentBuilder::new()
145            .tools(&[Arc::new(tool_calc)])
146            .build(llm)
147            .unwrap();
148        let input_variables = prompt_args! {
149            "input" => "hola,Me llamo luis, y tengo 10 anos, y estudio Computer scinence",
150        };
151        let executor = AgentExecutor::from_agent(agent).with_memory(memory.into());
152        match executor.invoke(input_variables).await {
153            Ok(result) => {
154                println!("Result: {:?}", result);
155            }
156            Err(e) => panic!("Error invoking LLMChain: {:?}", e),
157        }
158        let input_variables = prompt_args! {
159            "input" => "cuanta es la edad de luis +10 y que estudia",
160        };
161        match executor.invoke(input_variables).await {
162            Ok(result) => {
163                println!("Result: {:?}", result);
164            }
165            Err(e) => panic!("Error invoking LLMChain: {:?}", e),
166        }
167    }
168}