1use anyhow::Result;
2use tracing::info;
3
4use crate::chat::{LlmConfig, LlmParams, LlmProvider, get_llm};
5use crate::search_result::SearchResult;
6use crate::searcher::LeannSearcher;
7
8pub struct ReActAgent {
16 searcher: LeannSearcher,
17 llm: Box<dyn LlmProvider>,
18 max_iterations: usize,
19 search_history: Vec<SearchHistoryEntry>,
20}
21
22#[derive(Debug, Clone)]
23#[allow(dead_code)]
24struct SearchHistoryEntry {
25 iteration: usize,
26 thought: String,
27 action: String,
28 results_count: usize,
29}
30
31impl ReActAgent {
32 pub fn new(
33 searcher: LeannSearcher,
34 llm_config: Option<&LlmConfig>,
35 max_iterations: usize,
36 ) -> Result<Self> {
37 let config = llm_config.cloned().unwrap_or_default();
38 let llm = get_llm(&config)?;
39
40 Ok(Self {
41 searcher,
42 llm,
43 max_iterations,
44 search_history: Vec::new(),
45 })
46 }
47
48 pub fn run(&mut self, question: &str, top_k: usize) -> Result<String> {
50 info!("Starting ReAct agent for question: {}", question);
51 self.search_history.clear();
52 let mut previous_observations: Vec<String> = Vec::new();
53 let mut all_context: Vec<String> = Vec::new();
54
55 for iteration in 1..=self.max_iterations {
56 info!("--- Iteration {}/{} ---", iteration, self.max_iterations);
57
58 let prompt = self.create_react_prompt(question, iteration, &previous_observations);
59 let response = self.llm.ask(&prompt, &LlmParams::default())?;
60
61 let (thought, action) = parse_llm_response(&response);
62 info!("Thought: {}", thought);
63
64 if action.is_none() {
65 let final_answer = if response.contains("Final Answer:") {
67 response.split("Final Answer:").nth(1).unwrap_or("").trim()
68 } else {
69 response.trim()
70 };
71 info!("Final answer: {}", final_answer);
72 return Ok(final_answer.to_string());
73 }
74
75 let search_query = action.unwrap();
76 info!("Action: search(\"{}\")", search_query);
77
78 let results = self.searcher.search(&search_query, top_k)?;
79 let observation = format_search_results(&results);
80 previous_observations.push(observation.clone());
81 all_context.push(format!("Search: {}\n{}", search_query, observation));
82
83 self.search_history.push(SearchHistoryEntry {
84 iteration,
85 thought,
86 action: search_query,
87 results_count: results.len(),
88 });
89
90 if results.is_empty() && iteration >= 2 {
92 let final_prompt = format!(
93 "Based on the previous searches, provide your best answer.\n\n\
94 Question: {}\n\n\
95 Previous searches:\n{}\n\n\
96 Provide your final answer.",
97 question,
98 all_context.join("\n")
99 );
100 return self.llm.ask(&final_prompt, &LlmParams::default());
101 }
102 }
103
104 let final_prompt = format!(
106 "Based on all searches, provide your final answer.\n\n\
107 Question: {}\n\n\
108 All results:\n{}\n\n\
109 Provide your final answer now.",
110 question,
111 all_context.join("\n")
112 );
113 self.llm.ask(&final_prompt, &LlmParams::default())
114 }
115
116 fn create_react_prompt(
117 &self,
118 question: &str,
119 iteration: usize,
120 previous_observations: &[String],
121 ) -> String {
122 let mut prompt = format!(
123 "You are a helpful assistant that answers questions by searching through a knowledge base.\n\n\
124 Question: {}\n\n\
125 You can search the knowledge base by using the action: search(\"query\")\n\n\
126 Previous observations:\n",
127 question
128 );
129
130 if previous_observations.is_empty() {
131 prompt.push_str("None yet.\n");
132 } else {
133 for (i, obs) in previous_observations.iter().enumerate() {
134 prompt.push_str(&format!("\nObservation {}:\n{}\n", i + 1, obs));
135 }
136 }
137
138 prompt.push_str(&format!(
139 "\nCurrent iteration: {}/{}\n\n\
140 Think step by step:\n\
141 1. If you need more information, use search(\"your search query\")\n\
142 2. If you have enough information, provide your final answer\n\n\
143 Format your response as:\n\
144 Thought: [your reasoning]\n\
145 Action: search(\"query\") OR Final Answer: [your answer]\n",
146 iteration, self.max_iterations
147 ));
148
149 prompt
150 }
151}
152
153fn format_search_results(results: &[SearchResult]) -> String {
154 if results.is_empty() {
155 return "No results found.".to_string();
156 }
157
158 results
159 .iter()
160 .enumerate()
161 .map(|(i, r)| {
162 let text_preview = if r.text.len() > 500 {
163 format!("{}...", &r.text[..500])
164 } else {
165 r.text.clone()
166 };
167 let mut entry = format!(
168 "[Result {}] (Score: {:.3})\n{}",
169 i + 1,
170 r.score,
171 text_preview
172 );
173 if let Some(source) = r.metadata.get("source").and_then(|v| v.as_str()) {
174 entry.push_str(&format!("\nSource: {}", source));
175 }
176 entry
177 })
178 .collect::<Vec<_>>()
179 .join("\n\n")
180}
181
182fn parse_llm_response(response: &str) -> (String, Option<String>) {
183 let mut thought = String::new();
184 let mut action = None;
185
186 if let Some(idx) = response.find("Thought:") {
188 let after = &response[idx + 8..];
189 if let Some(action_idx) = after.find("Action:") {
190 thought = after[..action_idx].trim().to_string();
191 } else if let Some(final_idx) = after.find("Final Answer:") {
192 thought = after[..final_idx].trim().to_string();
193 } else {
194 thought = after.trim().to_string();
195 }
196 }
197
198 if response.contains("Final Answer:") {
200 return (thought, None);
201 }
202
203 if let Some(idx) = response.find("Action:") {
205 let after = &response[idx + 7..];
206 if let Some(start) = after.find("search(\"") {
207 let query_start = start + 8;
208 if let Some(end) = after[query_start..].find("\")") {
209 action = Some(after[query_start..query_start + end].to_string());
210 }
211 } else if let Some(start) = after.find("search(") {
212 let query_start = start + 7;
213 if let Some(end) = after[query_start..].find(')') {
214 let q = after[query_start..query_start + end]
215 .trim_matches('"')
216 .trim_matches('\'')
217 .to_string();
218 action = Some(q);
219 }
220 }
221 }
222
223 (thought, action)
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn test_parse_llm_response_with_action() {
232 let response =
233 "Thought: I need to find information about HNSW.\nAction: search(\"HNSW algorithm\")";
234 let (thought, action) = parse_llm_response(response);
235 assert_eq!(thought, "I need to find information about HNSW.");
236 assert_eq!(action, Some("HNSW algorithm".to_string()));
237 }
238
239 #[test]
240 fn test_parse_llm_response_final_answer() {
241 let response = "Thought: I have enough info.\nFinal Answer: HNSW is an approximate nearest neighbor algorithm.";
242 let (thought, action) = parse_llm_response(response);
243 assert_eq!(thought, "I have enough info.");
244 assert!(action.is_none());
245 }
246
247 #[test]
248 fn test_parse_llm_response_single_quotes() {
249 let response = "Thought: Looking for data.\nAction: search('vector database')";
250 let (thought, action) = parse_llm_response(response);
251 assert_eq!(thought, "Looking for data.");
252 assert_eq!(action, Some("vector database".to_string()));
253 }
254
255 #[test]
256 fn test_parse_llm_response_no_thought_prefix() {
257 let response = "Action: search(\"embeddings\")";
258 let (thought, action) = parse_llm_response(response);
259 assert!(thought.is_empty());
260 assert_eq!(action, Some("embeddings".to_string()));
261 }
262
263 #[test]
264 fn test_parse_llm_response_final_answer_without_thought() {
265 let response = "Final Answer: The answer is 42.";
266 let (_thought, action) = parse_llm_response(response);
267 assert!(action.is_none());
268 }
269
270 #[test]
271 fn test_parse_llm_response_empty_string() {
272 let (thought, action) = parse_llm_response("");
273 assert!(thought.is_empty());
274 assert!(action.is_none());
275 }
276
277 #[test]
278 fn test_parse_llm_response_no_action_or_final() {
279 let response = "Thought: I'm just thinking out loud.";
280 let (thought, action) = parse_llm_response(response);
281 assert_eq!(thought, "I'm just thinking out loud.");
282 assert!(action.is_none());
283 }
284
285 #[test]
286 fn test_format_search_results_empty() {
287 let results = format_search_results(&[]);
288 assert_eq!(results, "No results found.");
289 }
290
291 #[test]
292 fn test_format_search_results_single() {
293 let results = vec![SearchResult {
294 id: "1".to_string(),
295 score: 0.95,
296 text: "HNSW is a graph-based algorithm.".to_string(),
297 metadata: Default::default(),
298 }];
299 let formatted = format_search_results(&results);
300 assert!(formatted.contains("[Result 1]"));
301 assert!(formatted.contains("0.950"));
302 assert!(formatted.contains("HNSW is a graph-based algorithm."));
303 }
304
305 #[test]
306 fn test_format_search_results_with_source_metadata() {
307 let mut metadata = std::collections::HashMap::new();
308 metadata.insert(
309 "source".to_string(),
310 serde_json::Value::String("docs/readme.md".to_string()),
311 );
312 let results = vec![SearchResult {
313 id: "1".to_string(),
314 score: 0.8,
315 text: "Some text".to_string(),
316 metadata,
317 }];
318 let formatted = format_search_results(&results);
319 assert!(formatted.contains("Source: docs/readme.md"));
320 }
321
322 #[test]
323 fn test_format_search_results_truncates_long_text() {
324 let long_text = "x".repeat(600);
325 let results = vec![SearchResult {
326 id: "1".to_string(),
327 score: 0.5,
328 text: long_text,
329 metadata: Default::default(),
330 }];
331 let formatted = format_search_results(&results);
332 assert!(formatted.contains("..."));
333 assert!(formatted.len() < 700);
334 }
335}