reflex/semantic/
agentic.rs

1//! Agentic loop orchestrator for multi-step query generation
2//!
3//! This module implements the main agentic workflow:
4//! 1. Phase 1: Assess if more context is needed
5//! 2. Phase 2: Gather context using tools
6//! 3. Phase 3: Generate final queries
7//! 4. Phase 4: Execute queries
8//! 5. Phase 5: Evaluate results
9//! 6. Phase 6: Refine if needed
10
11use anyhow::{Context as AnyhowContext, Result};
12use crate::cache::CacheManager;
13
14use super::providers::{LlmProvider, create_provider};
15use super::config;
16use super::schema::{QueryResponse, AgenticQueryResponse};
17use super::schema_agentic::{AgenticResponse, Phase, ToolCall};
18use super::tools::{execute_tool, format_tool_results, ToolResult};
19use super::evaluator::{evaluate_results, EvaluationConfig};
20use super::reporter::AgenticReporter;
21
22/// Configuration for agentic loop
23#[derive(Debug, Clone)]
24pub struct AgenticConfig {
25    /// Maximum iterations for refinement (default: 2)
26    pub max_iterations: usize,
27
28    /// Maximum tool calls per gathering phase (default: 5)
29    pub max_tools_per_phase: usize,
30
31    /// Enable result evaluation phase
32    pub enable_evaluation: bool,
33
34    /// Evaluation configuration
35    pub eval_config: EvaluationConfig,
36
37    /// Provider name override
38    pub provider_override: Option<String>,
39
40    /// Model override
41    pub model_override: Option<String>,
42
43    /// Show LLM reasoning blocks (default: false)
44    pub show_reasoning: bool,
45
46    /// Verbose output (show tool results, etc.) (default: false)
47    pub verbose: bool,
48
49    /// Debug mode: output full LLM prompts (default: false)
50    pub debug: bool,
51}
52
53impl Default for AgenticConfig {
54    fn default() -> Self {
55        Self {
56            max_iterations: 2,
57            max_tools_per_phase: 5,
58            enable_evaluation: true,
59            eval_config: EvaluationConfig::default(),
60            provider_override: None,
61            model_override: None,
62            show_reasoning: false,
63            verbose: false,
64            debug: false,
65        }
66    }
67}
68
69/// Run the full agentic loop
70pub async fn run_agentic_loop(
71    question: &str,
72    cache: &CacheManager,
73    config: AgenticConfig,
74    reporter: &dyn AgenticReporter,
75) -> Result<AgenticQueryResponse> {
76    log::info!("Starting agentic loop for question: {}", question);
77
78    // Validate cache before starting - auto-reindex if schema mismatch detected
79    if let Err(e) = cache.validate() {
80        let error_msg = e.to_string();
81
82        // Check if this is a schema mismatch error
83        if error_msg.contains("Cache schema version mismatch") {
84            log::warn!("Cache schema mismatch detected, auto-reindexing...");
85
86            // Create progress callback that reports to the reporter
87            use std::sync::Arc;
88            let progress_callback: crate::indexer::ProgressCallback = Arc::new({
89                // Clone reporter reference for the callback closure
90                // Note: We can't capture `reporter` directly since it's a trait object,
91                // so we'll just log progress and rely on the indexer's built-in progress bar
92                move |current: usize, total: usize, message: String| {
93                    log::debug!("Reindex progress: [{}/{}] {}", current, total, message);
94                }
95            });
96
97            // Trigger reindexing
98            let workspace_root = cache.workspace_root();
99            let index_config = crate::IndexConfig::default();
100            let indexer = crate::indexer::Indexer::new(cache.clone(), index_config);
101
102            log::info!("Auto-reindexing cache at {:?}", workspace_root);
103            indexer.index_with_callback(&workspace_root, false, Some(progress_callback))?;
104
105            log::info!("Cache reindexing completed successfully");
106        } else {
107            // Other validation errors should propagate up
108            return Err(e);
109        }
110    }
111
112    // Initialize provider
113    let provider = initialize_provider(&config, cache)?;
114
115    // Phase 1: Initial assessment - does the LLM need more context?
116    let (needs_context, initial_response) = phase_1_assess(
117        question,
118        cache,
119        &*provider,
120        reporter,
121        config.debug,
122    ).await?;
123
124    // Phase 2: Context gathering (if needed)
125    let (gathered_context, tools_executed) = if needs_context {
126        phase_2_gather(
127            question,
128            initial_response,
129            cache,
130            &*provider,
131            &config,
132            reporter,
133        ).await?
134    } else {
135        (String::new(), Vec::new())
136    };
137
138    // Phase 3: Generate final queries
139    let (query_response, query_confidence) = phase_3_generate(
140        question,
141        &gathered_context,
142        cache,
143        &*provider,
144        reporter,
145        config.debug,
146    ).await?;
147
148    // Phase 4: Execute queries
149    let (results, total_count, count_only) = super::executor::execute_queries(
150        query_response.queries.clone(),
151        cache,
152    ).await?;
153
154    log::info!("Executed queries: {} file groups, {} total matches", results.len(), total_count);
155
156    // Phase 5: Evaluate results (if enabled and not count-only)
157    if config.enable_evaluation && !count_only {
158        let evaluation = evaluate_results(
159            &results,
160            total_count,
161            question,
162            &config.eval_config,
163            if !gathered_context.is_empty() { Some(gathered_context.as_str()) } else { None },
164            query_response.queries.len(),
165            Some(query_confidence),
166        );
167
168        log::info!("Evaluation: success={}, score={:.2}", evaluation.success, evaluation.score);
169
170        // Report evaluation
171        reporter.report_evaluation(&evaluation);
172
173        // Phase 6: Refinement (if needed and iterations remaining)
174        if !evaluation.success && config.max_iterations > 1 {
175            log::info!("Results unsatisfactory, attempting refinement");
176
177            return phase_6_refine(
178                question,
179                &gathered_context,
180                &query_response,
181                &evaluation,
182                cache,
183                &*provider,
184                &config,
185                reporter,
186                config.debug,
187            ).await;
188        }
189    }
190
191    // Return enhanced response with both queries and results
192    Ok(AgenticQueryResponse {
193        queries: query_response.queries,
194        results,
195        total_count: if count_only { None } else { Some(total_count) },
196        gathered_context: if !gathered_context.is_empty() {
197            Some(gathered_context)
198        } else {
199            None
200        },
201        tools_executed: if !tools_executed.is_empty() {
202            Some(tools_executed)
203        } else {
204            None
205        },
206        answer: None,  // No answer generation in agentic mode (handled in CLI)
207    })
208}
209
210/// Phase 1: Assess if more context is needed
211async fn phase_1_assess(
212    question: &str,
213    cache: &CacheManager,
214    provider: &dyn LlmProvider,
215    reporter: &dyn AgenticReporter,
216    debug: bool,
217) -> Result<(bool, AgenticResponse)> {
218    log::info!("Phase 1: Assessing context needs");
219
220    // Build assessment prompt
221    let prompt = super::prompt_agentic::build_assessment_prompt(question, cache)?;
222
223    // Debug mode: output full prompt
224    if debug {
225        eprintln!("\n{}", "=".repeat(80));
226        eprintln!("DEBUG: Full LLM Prompt (Phase 1: Assessment)");
227        eprintln!("{}", "=".repeat(80));
228        eprintln!("{}", prompt);
229        eprintln!("{}\n", "=".repeat(80));
230    }
231
232    // Call LLM
233    let json_response = call_with_retry(provider, &prompt, 2).await?;
234
235    // Parse response
236    let response: AgenticResponse = serde_json::from_str(&json_response)
237        .context("Failed to parse LLM assessment response")?;
238
239    // Validate phase
240    if response.phase != Phase::Assessment && response.phase != Phase::Final {
241        anyhow::bail!("Expected 'assessment' or 'final' phase, got {:?}", response.phase);
242    }
243
244    let needs_context = response.needs_context && !response.tool_calls.is_empty();
245
246    log::info!(
247        "Assessment complete: needs_context={}, tool_calls={}",
248        needs_context,
249        response.tool_calls.len()
250    );
251
252    // Report assessment
253    reporter.report_assessment(&response.reasoning, needs_context, &response.tool_calls);
254
255    Ok((needs_context, response))
256}
257
258/// Phase 2: Gather context using tools
259async fn phase_2_gather(
260    _question: &str,
261    initial_response: AgenticResponse,
262    cache: &CacheManager,
263    _provider: &dyn LlmProvider,
264    config: &AgenticConfig,
265    reporter: &dyn AgenticReporter,
266) -> Result<(String, Vec<String>)> {
267    log::info!("Phase 2: Gathering context via tools");
268
269    let mut all_tool_results = Vec::new();
270    let mut tool_descriptions = Vec::new();
271
272    // Limit tool calls to prevent excessive execution
273    let tool_calls: Vec<ToolCall> = initial_response.tool_calls
274        .into_iter()
275        .take(config.max_tools_per_phase)
276        .collect();
277
278    log::info!("Executing {} tool calls", tool_calls.len());
279
280    // Execute all tool calls
281    for (idx, tool) in tool_calls.iter().enumerate() {
282        log::debug!("Executing tool {}/{}: {:?}", idx + 1, tool_calls.len(), tool);
283
284        // Get tool description for UI display
285        let tool_desc = describe_tool_for_ui(tool);
286        tool_descriptions.push(tool_desc);
287
288        // Report tool start
289        reporter.report_tool_start(idx + 1, tool);
290
291        match execute_tool(tool, cache).await {
292            Ok(result) => {
293                log::info!("Tool {} succeeded: {}", idx + 1, result.description);
294                reporter.report_tool_complete(idx + 1, &result);
295                all_tool_results.push(result);
296            }
297            Err(e) => {
298                log::warn!("Tool {} failed: {}", idx + 1, e);
299                // Continue with other tools even if one fails
300                let failed_result = ToolResult {
301                    description: format!("Tool {} (failed)", idx + 1),
302                    output: format!("Error: {}", e),
303                    success: false,
304                };
305                reporter.report_tool_complete(idx + 1, &failed_result);
306                all_tool_results.push(failed_result);
307            }
308        }
309    }
310
311    // Format all tool results into context string
312    let gathered_context = format_tool_results(&all_tool_results);
313
314    log::info!("Context gathering complete: {} chars", gathered_context.len());
315
316    Ok((gathered_context, tool_descriptions))
317}
318
319/// Generate a user-friendly description of a tool call
320fn describe_tool_for_ui(tool: &ToolCall) -> String {
321    match tool {
322        ToolCall::GatherContext { params } => {
323            let mut parts = Vec::new();
324            if params.structure { parts.push("structure"); }
325            if params.file_types { parts.push("file types"); }
326            if params.project_type { parts.push("project type"); }
327            if params.framework { parts.push("frameworks"); }
328            if params.entry_points { parts.push("entry points"); }
329            if params.test_layout { parts.push("test layout"); }
330            if params.config_files { parts.push("config files"); }
331
332            if parts.is_empty() {
333                "gather_context: General codebase context".to_string()
334            } else {
335                format!("gather_context: {}", parts.join(", "))
336            }
337        }
338        ToolCall::ExploreCodebase { description, .. } => {
339            format!("explore_codebase: {}", description)
340        }
341        ToolCall::AnalyzeStructure { analysis_type } => {
342            format!("analyze_structure: {:?}", analysis_type)
343        }
344        ToolCall::SearchDocumentation { query, files } => {
345            if let Some(file_list) = files {
346                format!("search_documentation: '{}' in files {:?}", query, file_list)
347            } else {
348                format!("search_documentation: '{}'", query)
349            }
350        }
351        ToolCall::GetStatistics => {
352            "get_statistics: Retrieved file counts and language stats".to_string()
353        }
354        ToolCall::GetDependencies { file_path, reverse } => {
355            if *reverse {
356                format!("get_dependencies: What depends on '{}'", file_path)
357            } else {
358                format!("get_dependencies: Dependencies of '{}'", file_path)
359            }
360        }
361        ToolCall::GetAnalysisSummary { .. } => {
362            "get_analysis_summary: Dependency health overview".to_string()
363        }
364        ToolCall::FindIslands { .. } => {
365            "find_islands: Disconnected component analysis".to_string()
366        }
367    }
368}
369
370/// Phase 3: Generate final queries
371///
372/// Returns (QueryResponse, confidence_score)
373async fn phase_3_generate(
374    question: &str,
375    gathered_context: &str,
376    cache: &CacheManager,
377    provider: &dyn LlmProvider,
378    reporter: &dyn AgenticReporter,
379    debug: bool,
380) -> Result<(QueryResponse, f32)> {
381    log::info!("Phase 3: Generating final queries");
382
383    // Build generation prompt with gathered context
384    let prompt = super::prompt_agentic::build_generation_prompt(
385        question,
386        gathered_context,
387        cache,
388    )?;
389
390    // Debug mode: output full prompt
391    if debug {
392        eprintln!("\n{}", "=".repeat(80));
393        eprintln!("DEBUG: Full LLM Prompt (Phase 3: Query Generation)");
394        eprintln!("{}", "=".repeat(80));
395        eprintln!("{}", prompt);
396        eprintln!("{}\n", "=".repeat(80));
397    }
398
399    // Call LLM
400    let json_response = call_with_retry(provider, &prompt, 2).await?;
401
402    // Parse response - could be AgenticResponse or QueryResponse
403    // Try AgenticResponse first (for agentic mode)
404    if let Ok(agentic_response) = serde_json::from_str::<AgenticResponse>(&json_response) {
405        if agentic_response.phase == Phase::Final {
406            let confidence = agentic_response.confidence;
407
408            // Report generation with reasoning
409            reporter.report_generation(
410                Some(&agentic_response.reasoning),
411                agentic_response.queries.len(),
412                confidence,
413            );
414
415            // Convert to QueryResponse and return with confidence
416            return Ok((
417                QueryResponse {
418                    queries: agentic_response.queries,
419                },
420                confidence,
421            ));
422        }
423    }
424
425    // Fallback: try direct QueryResponse
426    let query_response: QueryResponse = serde_json::from_str(&json_response)
427        .context("Failed to parse LLM query generation response")?;
428
429    log::info!("Generated {} queries", query_response.queries.len());
430
431    // Report generation without reasoning (fallback mode)
432    reporter.report_generation(None, query_response.queries.len(), 1.0);
433
434    // Default confidence of 1.0 for fallback mode
435    Ok((query_response, 1.0))
436}
437
438/// Phase 6: Refine queries based on evaluation
439async fn phase_6_refine(
440    question: &str,
441    gathered_context: &str,
442    previous_response: &QueryResponse,
443    evaluation: &super::schema_agentic::EvaluationReport,
444    cache: &CacheManager,
445    provider: &dyn LlmProvider,
446    config: &AgenticConfig,
447    reporter: &dyn AgenticReporter,
448    debug: bool,
449) -> Result<AgenticQueryResponse> {
450    log::info!("Phase 6: Refining queries based on evaluation");
451
452    // Report refinement start
453    reporter.report_refinement_start();
454
455    // Build refinement prompt with evaluation feedback
456    let prompt = super::prompt_agentic::build_refinement_prompt(
457        question,
458        gathered_context,
459        previous_response,
460        evaluation,
461        cache,
462    )?;
463
464    // Debug mode: output full prompt
465    if debug {
466        eprintln!("\n{}", "=".repeat(80));
467        eprintln!("DEBUG: Full LLM Prompt (Phase 6: Refinement)");
468        eprintln!("{}", "=".repeat(80));
469        eprintln!("{}", prompt);
470        eprintln!("{}\n", "=".repeat(80));
471    }
472
473    // Call LLM for refinement
474    let json_response = call_with_retry(provider, &prompt, 2).await?;
475
476    // Parse refined response
477    let refined_response: QueryResponse = serde_json::from_str(&json_response)
478        .context("Failed to parse LLM refinement response")?;
479
480    log::info!("Refinement complete: {} refined queries", refined_response.queries.len());
481
482    // Execute refined queries
483    let (results, total_count, count_only) = super::executor::execute_queries(
484        refined_response.queries.clone(),
485        cache,
486    ).await?;
487
488    // Evaluate refined results (one final time)
489    let refined_evaluation = evaluate_results(
490        &results,
491        total_count,
492        question,
493        &config.eval_config,
494        if !gathered_context.is_empty() { Some(gathered_context) } else { None },
495        refined_response.queries.len(),
496        None,  // No confidence available in refinement
497    );
498
499    log::info!(
500        "Refined evaluation: success={}, score={:.2}",
501        refined_evaluation.success,
502        refined_evaluation.score
503    );
504
505    // Return enhanced response with both queries and results
506    Ok(AgenticQueryResponse {
507        queries: refined_response.queries,
508        results,
509        total_count: if count_only { None } else { Some(total_count) },
510        gathered_context: if !gathered_context.is_empty() {
511            Some(gathered_context.to_string())
512        } else {
513            None
514        },
515        tools_executed: None,  // No new tools executed during refinement
516        answer: None,  // No answer generation in agentic mode (handled in CLI)
517    })
518}
519
520/// Initialize LLM provider based on configuration
521fn initialize_provider(
522    config: &AgenticConfig,
523    cache: &CacheManager,
524) -> Result<Box<dyn LlmProvider>> {
525    // Load semantic config
526    let mut semantic_config = config::load_config(cache.path())?;
527
528    // Apply overrides
529    if let Some(provider) = &config.provider_override {
530        semantic_config.provider = provider.clone();
531    }
532
533    // Get API key
534    let api_key = config::get_api_key(&semantic_config.provider)?;
535
536    // Determine model
537    let model = if let Some(model_override) = &config.model_override {
538        Some(model_override.clone())
539    } else if semantic_config.model.is_some() {
540        semantic_config.model.clone()
541    } else {
542        config::get_user_model(&semantic_config.provider)
543    };
544
545    // Create provider
546    create_provider(&semantic_config.provider, api_key, model)
547}
548
549/// Call LLM provider with retry logic (from semantic/mod.rs)
550async fn call_with_retry(
551    provider: &dyn LlmProvider,
552    prompt: &str,
553    max_retries: usize,
554) -> Result<String> {
555    super::call_with_retry(provider, prompt, max_retries).await
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    #[test]
563    fn test_agentic_config_defaults() {
564        let config = AgenticConfig::default();
565        assert_eq!(config.max_iterations, 2);
566        assert_eq!(config.max_tools_per_phase, 5);
567        assert!(config.enable_evaluation);
568    }
569
570    #[test]
571    fn test_agentic_config_custom() {
572        let config = AgenticConfig {
573            max_iterations: 3,
574            max_tools_per_phase: 10,
575            enable_evaluation: false,
576            ..Default::default()
577        };
578
579        assert_eq!(config.max_iterations, 3);
580        assert_eq!(config.max_tools_per_phase, 10);
581        assert!(!config.enable_evaluation);
582    }
583}