Skip to main content

leann_core/
react_agent.rs

1use anyhow::Result;
2use tracing::info;
3
4use crate::chat::{LlmConfig, LlmParams, LlmProvider, get_llm};
5use crate::search_result::SearchResult;
6use crate::searcher::LeannSearcher;
7
8/// Simple ReAct agent for multi-turn retrieval.
9///
10/// Follows the pattern:
11/// 1. Thought: LLM reasons about what information is needed
12/// 2. Action: Performs a search query
13/// 3. Observation: Gets search results
14/// 4. Repeat until LLM decides it has enough information to answer
15pub struct ReActAgent {
16    searcher: LeannSearcher,
17    llm: Box<dyn LlmProvider>,
18    max_iterations: usize,
19    search_history: Vec<SearchHistoryEntry>,
20}
21
22#[derive(Debug, Clone)]
23#[allow(dead_code)]
24struct SearchHistoryEntry {
25    iteration: usize,
26    thought: String,
27    action: String,
28    results_count: usize,
29}
30
31impl ReActAgent {
32    pub fn new(
33        searcher: LeannSearcher,
34        llm_config: Option<&LlmConfig>,
35        max_iterations: usize,
36    ) -> Result<Self> {
37        let config = llm_config.cloned().unwrap_or_default();
38        let llm = get_llm(&config)?;
39
40        Ok(Self {
41            searcher,
42            llm,
43            max_iterations,
44            search_history: Vec::new(),
45        })
46    }
47
48    /// Run the ReAct agent to answer a question.
49    pub fn run(&mut self, question: &str, top_k: usize) -> Result<String> {
50        info!("Starting ReAct agent for question: {}", question);
51        self.search_history.clear();
52        let mut previous_observations: Vec<String> = Vec::new();
53        let mut all_context: Vec<String> = Vec::new();
54
55        for iteration in 1..=self.max_iterations {
56            info!("--- Iteration {}/{} ---", iteration, self.max_iterations);
57
58            let prompt = self.create_react_prompt(question, iteration, &previous_observations);
59            let response = self.llm.ask(&prompt, &LlmParams::default())?;
60
61            let (thought, action) = parse_llm_response(&response);
62            info!("Thought: {}", thought);
63
64            if action.is_none() {
65                // Extract final answer
66                let final_answer = if response.contains("Final Answer:") {
67                    response.split("Final Answer:").nth(1).unwrap_or("").trim()
68                } else {
69                    response.trim()
70                };
71                info!("Final answer: {}", final_answer);
72                return Ok(final_answer.to_string());
73            }
74
75            let search_query = action.unwrap();
76            info!("Action: search(\"{}\")", search_query);
77
78            let results = self.searcher.search(&search_query, top_k)?;
79            let observation = format_search_results(&results);
80            previous_observations.push(observation.clone());
81            all_context.push(format!("Search: {}\n{}", search_query, observation));
82
83            self.search_history.push(SearchHistoryEntry {
84                iteration,
85                thought,
86                action: search_query,
87                results_count: results.len(),
88            });
89
90            // Early stop if no results after iteration 2
91            if results.is_empty() && iteration >= 2 {
92                let final_prompt = format!(
93                    "Based on the previous searches, provide your best answer.\n\n\
94                     Question: {}\n\n\
95                     Previous searches:\n{}\n\n\
96                     Provide your final answer.",
97                    question,
98                    all_context.join("\n")
99                );
100                return self.llm.ask(&final_prompt, &LlmParams::default());
101            }
102        }
103
104        // Max iterations reached
105        let final_prompt = format!(
106            "Based on all searches, provide your final answer.\n\n\
107             Question: {}\n\n\
108             All results:\n{}\n\n\
109             Provide your final answer now.",
110            question,
111            all_context.join("\n")
112        );
113        self.llm.ask(&final_prompt, &LlmParams::default())
114    }
115
116    fn create_react_prompt(
117        &self,
118        question: &str,
119        iteration: usize,
120        previous_observations: &[String],
121    ) -> String {
122        let mut prompt = format!(
123            "You are a helpful assistant that answers questions by searching through a knowledge base.\n\n\
124             Question: {}\n\n\
125             You can search the knowledge base by using the action: search(\"query\")\n\n\
126             Previous observations:\n",
127            question
128        );
129
130        if previous_observations.is_empty() {
131            prompt.push_str("None yet.\n");
132        } else {
133            for (i, obs) in previous_observations.iter().enumerate() {
134                prompt.push_str(&format!("\nObservation {}:\n{}\n", i + 1, obs));
135            }
136        }
137
138        prompt.push_str(&format!(
139            "\nCurrent iteration: {}/{}\n\n\
140             Think step by step:\n\
141             1. If you need more information, use search(\"your search query\")\n\
142             2. If you have enough information, provide your final answer\n\n\
143             Format your response as:\n\
144             Thought: [your reasoning]\n\
145             Action: search(\"query\") OR Final Answer: [your answer]\n",
146            iteration, self.max_iterations
147        ));
148
149        prompt
150    }
151}
152
153fn format_search_results(results: &[SearchResult]) -> String {
154    if results.is_empty() {
155        return "No results found.".to_string();
156    }
157
158    results
159        .iter()
160        .enumerate()
161        .map(|(i, r)| {
162            let text_preview = if r.text.len() > 500 {
163                format!("{}...", &r.text[..500])
164            } else {
165                r.text.clone()
166            };
167            let mut entry = format!(
168                "[Result {}] (Score: {:.3})\n{}",
169                i + 1,
170                r.score,
171                text_preview
172            );
173            if let Some(source) = r.metadata.get("source").and_then(|v| v.as_str()) {
174                entry.push_str(&format!("\nSource: {}", source));
175            }
176            entry
177        })
178        .collect::<Vec<_>>()
179        .join("\n\n")
180}
181
182fn parse_llm_response(response: &str) -> (String, Option<String>) {
183    let mut thought = String::new();
184    let mut action = None;
185
186    // Extract thought
187    if let Some(idx) = response.find("Thought:") {
188        let after = &response[idx + 8..];
189        if let Some(action_idx) = after.find("Action:") {
190            thought = after[..action_idx].trim().to_string();
191        } else if let Some(final_idx) = after.find("Final Answer:") {
192            thought = after[..final_idx].trim().to_string();
193        } else {
194            thought = after.trim().to_string();
195        }
196    }
197
198    // Check for final answer
199    if response.contains("Final Answer:") {
200        return (thought, None);
201    }
202
203    // Extract search action
204    if let Some(idx) = response.find("Action:") {
205        let after = &response[idx + 7..];
206        if let Some(start) = after.find("search(\"") {
207            let query_start = start + 8;
208            if let Some(end) = after[query_start..].find("\")") {
209                action = Some(after[query_start..query_start + end].to_string());
210            }
211        } else if let Some(start) = after.find("search(") {
212            let query_start = start + 7;
213            if let Some(end) = after[query_start..].find(')') {
214                let q = after[query_start..query_start + end]
215                    .trim_matches('"')
216                    .trim_matches('\'')
217                    .to_string();
218                action = Some(q);
219            }
220        }
221    }
222
223    (thought, action)
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn test_parse_llm_response_with_action() {
232        let response =
233            "Thought: I need to find information about HNSW.\nAction: search(\"HNSW algorithm\")";
234        let (thought, action) = parse_llm_response(response);
235        assert_eq!(thought, "I need to find information about HNSW.");
236        assert_eq!(action, Some("HNSW algorithm".to_string()));
237    }
238
239    #[test]
240    fn test_parse_llm_response_final_answer() {
241        let response = "Thought: I have enough info.\nFinal Answer: HNSW is an approximate nearest neighbor algorithm.";
242        let (thought, action) = parse_llm_response(response);
243        assert_eq!(thought, "I have enough info.");
244        assert!(action.is_none());
245    }
246
247    #[test]
248    fn test_parse_llm_response_single_quotes() {
249        let response = "Thought: Looking for data.\nAction: search('vector database')";
250        let (thought, action) = parse_llm_response(response);
251        assert_eq!(thought, "Looking for data.");
252        assert_eq!(action, Some("vector database".to_string()));
253    }
254
255    #[test]
256    fn test_parse_llm_response_no_thought_prefix() {
257        let response = "Action: search(\"embeddings\")";
258        let (thought, action) = parse_llm_response(response);
259        assert!(thought.is_empty());
260        assert_eq!(action, Some("embeddings".to_string()));
261    }
262
263    #[test]
264    fn test_parse_llm_response_final_answer_without_thought() {
265        let response = "Final Answer: The answer is 42.";
266        let (_thought, action) = parse_llm_response(response);
267        assert!(action.is_none());
268    }
269
270    #[test]
271    fn test_parse_llm_response_empty_string() {
272        let (thought, action) = parse_llm_response("");
273        assert!(thought.is_empty());
274        assert!(action.is_none());
275    }
276
277    #[test]
278    fn test_parse_llm_response_no_action_or_final() {
279        let response = "Thought: I'm just thinking out loud.";
280        let (thought, action) = parse_llm_response(response);
281        assert_eq!(thought, "I'm just thinking out loud.");
282        assert!(action.is_none());
283    }
284
285    #[test]
286    fn test_format_search_results_empty() {
287        let results = format_search_results(&[]);
288        assert_eq!(results, "No results found.");
289    }
290
291    #[test]
292    fn test_format_search_results_single() {
293        let results = vec![SearchResult {
294            id: "1".to_string(),
295            score: 0.95,
296            text: "HNSW is a graph-based algorithm.".to_string(),
297            metadata: Default::default(),
298        }];
299        let formatted = format_search_results(&results);
300        assert!(formatted.contains("[Result 1]"));
301        assert!(formatted.contains("0.950"));
302        assert!(formatted.contains("HNSW is a graph-based algorithm."));
303    }
304
305    #[test]
306    fn test_format_search_results_with_source_metadata() {
307        let mut metadata = std::collections::HashMap::new();
308        metadata.insert(
309            "source".to_string(),
310            serde_json::Value::String("docs/readme.md".to_string()),
311        );
312        let results = vec![SearchResult {
313            id: "1".to_string(),
314            score: 0.8,
315            text: "Some text".to_string(),
316            metadata,
317        }];
318        let formatted = format_search_results(&results);
319        assert!(formatted.contains("Source: docs/readme.md"));
320    }
321
322    #[test]
323    fn test_format_search_results_truncates_long_text() {
324        let long_text = "x".repeat(600);
325        let results = vec![SearchResult {
326            id: "1".to_string(),
327            score: 0.5,
328            text: long_text,
329            metadata: Default::default(),
330        }];
331        let formatted = format_search_results(&results);
332        assert!(formatted.contains("..."));
333        assert!(formatted.len() < 700);
334    }
335}