Skip to main content

reflex/semantic/
agentic.rs

1//! Agentic loop orchestrator for multi-step query generation
2//!
3//! This module implements the main agentic workflow:
4//! 1. Phase 1: Assess if more context is needed
5//! 2. Phase 2: Gather context using tools
6//! 3. Phase 3: Generate final queries
7//! 4. Phase 4: Execute queries
8//! 5. Phase 5: Evaluate results
9//! 6. Phase 6: Refine if needed
10
11use anyhow::{Context as AnyhowContext, Result};
12use crate::cache::CacheManager;
13
14use super::providers::{LlmProvider, create_provider};
15use super::config;
16use super::schema::{QueryResponse, AgenticQueryResponse};
17use super::schema_agentic::{AgenticResponse, Phase, ToolCall};
18use super::tools::{execute_tool, format_tool_results, ToolResult};
19use super::evaluator::{evaluate_results, EvaluationConfig};
20use super::reporter::AgenticReporter;
21
22/// Configuration for agentic loop
23#[derive(Debug, Clone)]
24pub struct AgenticConfig {
25    /// Maximum iterations for refinement (default: 2)
26    pub max_iterations: usize,
27
28    /// Maximum tool calls per gathering phase (default: 5)
29    pub max_tools_per_phase: usize,
30
31    /// Enable result evaluation phase
32    pub enable_evaluation: bool,
33
34    /// Evaluation configuration
35    pub eval_config: EvaluationConfig,
36
37    /// Provider name override
38    pub provider_override: Option<String>,
39
40    /// Model override
41    pub model_override: Option<String>,
42
43    /// Show LLM reasoning blocks (default: false)
44    pub show_reasoning: bool,
45
46    /// Verbose output (show tool results, etc.) (default: false)
47    pub verbose: bool,
48
49    /// Debug mode: output full LLM prompts (default: false)
50    pub debug: bool,
51}
52
53impl Default for AgenticConfig {
54    fn default() -> Self {
55        Self {
56            max_iterations: 2,
57            max_tools_per_phase: 5,
58            enable_evaluation: true,
59            eval_config: EvaluationConfig::default(),
60            provider_override: None,
61            model_override: None,
62            show_reasoning: false,
63            verbose: false,
64            debug: false,
65        }
66    }
67}
68
69/// Run the full agentic loop
70pub async fn run_agentic_loop(
71    question: &str,
72    cache: &CacheManager,
73    config: AgenticConfig,
74    reporter: &dyn AgenticReporter,
75) -> Result<AgenticQueryResponse> {
76    log::info!("Starting agentic loop for question: {}", question);
77
78    // Validate cache before starting - auto-reindex if schema mismatch detected
79    if let Err(e) = cache.validate() {
80        let error_msg = e.to_string();
81
82        // Check if this is a schema mismatch error
83        if error_msg.contains("Cache schema version mismatch") {
84            log::warn!("Cache schema mismatch detected, auto-reindexing...");
85
86            // Create progress callback that reports to the reporter
87            use std::sync::Arc;
88            let progress_callback: crate::indexer::ProgressCallback = Arc::new({
89                // Clone reporter reference for the callback closure
90                // Note: We can't capture `reporter` directly since it's a trait object,
91                // so we'll just log progress and rely on the indexer's built-in progress bar
92                move |current: usize, total: usize, message: String| {
93                    log::debug!("Reindex progress: [{}/{}] {}", current, total, message);
94                }
95            });
96
97            // Trigger reindexing
98            let workspace_root = cache.workspace_root();
99            let index_config = crate::IndexConfig::default();
100            let indexer = crate::indexer::Indexer::new(cache.clone(), index_config);
101
102            log::info!("Auto-reindexing cache at {:?}", workspace_root);
103            indexer.index_with_callback(&workspace_root, false, Some(progress_callback))?;
104
105            log::info!("Cache reindexing completed successfully");
106        } else {
107            // Other validation errors should propagate up
108            return Err(e);
109        }
110    }
111
112    // Initialize provider
113    let provider = initialize_provider(&config, cache)?;
114
115    // Phase 1: Initial assessment - does the LLM need more context?
116    let (needs_context, initial_response) = phase_1_assess(
117        question,
118        cache,
119        &*provider,
120        reporter,
121        config.debug,
122    ).await?;
123
124    // Phase 2: Context gathering (if needed)
125    let (gathered_context, tools_executed) = if needs_context {
126        phase_2_gather(
127            question,
128            initial_response,
129            cache,
130            &*provider,
131            &config,
132            reporter,
133        ).await?
134    } else {
135        (String::new(), Vec::new())
136    };
137
138    // Phase 3: Generate final queries
139    let (query_response, query_confidence) = phase_3_generate(
140        question,
141        &gathered_context,
142        cache,
143        &*provider,
144        reporter,
145        config.debug,
146    ).await?;
147
148    // Phase 4: Execute queries
149    let (results, total_count, count_only) = super::executor::execute_queries(
150        query_response.queries.clone(),
151        cache,
152    ).await?;
153
154    log::info!("Executed queries: {} file groups, {} total matches", results.len(), total_count);
155
156    // Phase 5: Evaluate results (if enabled and not count-only)
157    if config.enable_evaluation && !count_only {
158        let evaluation = evaluate_results(
159            &results,
160            total_count,
161            question,
162            &config.eval_config,
163            if !gathered_context.is_empty() { Some(gathered_context.as_str()) } else { None },
164            query_response.queries.len(),
165            Some(query_confidence),
166        );
167
168        log::info!("Evaluation: success={}, score={:.2}", evaluation.success, evaluation.score);
169
170        // Report evaluation
171        reporter.report_evaluation(&evaluation);
172
173        // Phase 6: Refinement (if needed and iterations remaining)
174        if !evaluation.success && config.max_iterations > 1 {
175            log::info!("Results unsatisfactory, attempting refinement");
176
177            return phase_6_refine(
178                question,
179                &gathered_context,
180                &query_response,
181                &evaluation,
182                cache,
183                &*provider,
184                &config,
185                reporter,
186                config.debug,
187            ).await;
188        }
189    }
190
191    // Return enhanced response with both queries and results
192    Ok(AgenticQueryResponse {
193        queries: query_response.queries,
194        results,
195        total_count: if count_only { None } else { Some(total_count) },
196        gathered_context: if !gathered_context.is_empty() {
197            Some(gathered_context)
198        } else {
199            None
200        },
201        tools_executed: if !tools_executed.is_empty() {
202            Some(tools_executed)
203        } else {
204            None
205        },
206        answer: None,  // No answer generation in agentic mode (handled in CLI)
207    })
208}
209
210/// Phase 1: Assess if more context is needed
211async fn phase_1_assess(
212    question: &str,
213    cache: &CacheManager,
214    provider: &dyn LlmProvider,
215    reporter: &dyn AgenticReporter,
216    debug: bool,
217) -> Result<(bool, AgenticResponse)> {
218    log::info!("Phase 1: Assessing context needs");
219
220    // Build assessment prompt
221    let prompt = super::prompt_agentic::build_assessment_prompt(question, cache)?;
222
223    // Debug mode: output full prompt
224    if debug {
225        eprintln!("\n{}", "=".repeat(80));
226        eprintln!("DEBUG: Full LLM Prompt (Phase 1: Assessment)");
227        eprintln!("{}", "=".repeat(80));
228        eprintln!("{}", prompt);
229        eprintln!("{}\n", "=".repeat(80));
230    }
231
232    // Call LLM — validate against AgenticResponse (requires phase + reasoning)
233    let json_response = call_with_retry(
234        provider, &prompt, 2, super::validate_agentic_response,
235    ).await?;
236
237    // Parse response
238    let response: AgenticResponse = serde_json::from_str(&json_response)
239        .context("Failed to parse LLM assessment response")?;
240
241    // Validate phase
242    if response.phase != Phase::Assessment && response.phase != Phase::Final {
243        anyhow::bail!("Expected 'assessment' or 'final' phase, got {:?}", response.phase);
244    }
245
246    let needs_context = response.needs_context && !response.tool_calls.is_empty();
247
248    log::info!(
249        "Assessment complete: needs_context={}, tool_calls={}",
250        needs_context,
251        response.tool_calls.len()
252    );
253
254    // Report assessment
255    reporter.report_assessment(&response.reasoning, needs_context, &response.tool_calls);
256
257    Ok((needs_context, response))
258}
259
260/// Phase 2: Gather context using tools
261async fn phase_2_gather(
262    _question: &str,
263    initial_response: AgenticResponse,
264    cache: &CacheManager,
265    _provider: &dyn LlmProvider,
266    config: &AgenticConfig,
267    reporter: &dyn AgenticReporter,
268) -> Result<(String, Vec<String>)> {
269    log::info!("Phase 2: Gathering context via tools");
270
271    let mut all_tool_results = Vec::new();
272    let mut tool_descriptions = Vec::new();
273
274    // Limit tool calls to prevent excessive execution
275    let tool_calls: Vec<ToolCall> = initial_response.tool_calls
276        .into_iter()
277        .take(config.max_tools_per_phase)
278        .collect();
279
280    log::info!("Executing {} tool calls", tool_calls.len());
281
282    // Execute all tool calls
283    for (idx, tool) in tool_calls.iter().enumerate() {
284        log::debug!("Executing tool {}/{}: {:?}", idx + 1, tool_calls.len(), tool);
285
286        // Get tool description for UI display
287        let tool_desc = describe_tool_for_ui(tool);
288        tool_descriptions.push(tool_desc);
289
290        // Report tool start
291        reporter.report_tool_start(idx + 1, tool);
292
293        match execute_tool(tool, cache).await {
294            Ok(result) => {
295                log::info!("Tool {} succeeded: {}", idx + 1, result.description);
296                reporter.report_tool_complete(idx + 1, &result);
297                all_tool_results.push(result);
298            }
299            Err(e) => {
300                log::warn!("Tool {} failed: {}", idx + 1, e);
301                // Continue with other tools even if one fails
302                let failed_result = ToolResult {
303                    description: format!("Tool {} (failed)", idx + 1),
304                    output: format!("Error: {}", e),
305                    success: false,
306                };
307                reporter.report_tool_complete(idx + 1, &failed_result);
308                all_tool_results.push(failed_result);
309            }
310        }
311    }
312
313    // Format all tool results into context string
314    let gathered_context = format_tool_results(&all_tool_results);
315
316    log::info!("Context gathering complete: {} chars", gathered_context.len());
317
318    Ok((gathered_context, tool_descriptions))
319}
320
321/// Generate a user-friendly description of a tool call
322fn describe_tool_for_ui(tool: &ToolCall) -> String {
323    match tool {
324        ToolCall::GatherContext { params } => {
325            let mut parts = Vec::new();
326            if params.structure { parts.push("structure"); }
327            if params.file_types { parts.push("file types"); }
328            if params.project_type { parts.push("project type"); }
329            if params.framework { parts.push("frameworks"); }
330            if params.entry_points { parts.push("entry points"); }
331            if params.test_layout { parts.push("test layout"); }
332            if params.config_files { parts.push("config files"); }
333
334            if parts.is_empty() {
335                "gather_context: General codebase context".to_string()
336            } else {
337                format!("gather_context: {}", parts.join(", "))
338            }
339        }
340        ToolCall::ExploreCodebase { description, .. } => {
341            format!("explore_codebase: {}", description)
342        }
343        ToolCall::AnalyzeStructure { analysis_type } => {
344            format!("analyze_structure: {:?}", analysis_type)
345        }
346        ToolCall::SearchDocumentation { query, files } => {
347            if let Some(file_list) = files {
348                format!("search_documentation: '{}' in files {:?}", query, file_list)
349            } else {
350                format!("search_documentation: '{}'", query)
351            }
352        }
353        ToolCall::GetStatistics => {
354            "get_statistics: Retrieved file counts and language stats".to_string()
355        }
356        ToolCall::GetDependencies { file_path, reverse } => {
357            if *reverse {
358                format!("get_dependencies: What depends on '{}'", file_path)
359            } else {
360                format!("get_dependencies: Dependencies of '{}'", file_path)
361            }
362        }
363        ToolCall::GetAnalysisSummary { .. } => {
364            "get_analysis_summary: Dependency health overview".to_string()
365        }
366        ToolCall::FindIslands { .. } => {
367            "find_islands: Disconnected component analysis".to_string()
368        }
369    }
370}
371
372/// Phase 3: Generate final queries
373///
374/// Returns (QueryResponse, confidence_score)
375async fn phase_3_generate(
376    question: &str,
377    gathered_context: &str,
378    cache: &CacheManager,
379    provider: &dyn LlmProvider,
380    reporter: &dyn AgenticReporter,
381    debug: bool,
382) -> Result<(QueryResponse, f32)> {
383    log::info!("Phase 3: Generating final queries");
384
385    // Build generation prompt with gathered context
386    let prompt = super::prompt_agentic::build_generation_prompt(
387        question,
388        gathered_context,
389        cache,
390    )?;
391
392    // Debug mode: output full prompt
393    if debug {
394        eprintln!("\n{}", "=".repeat(80));
395        eprintln!("DEBUG: Full LLM Prompt (Phase 3: Query Generation)");
396        eprintln!("{}", "=".repeat(80));
397        eprintln!("{}", prompt);
398        eprintln!("{}\n", "=".repeat(80));
399    }
400
401    // Call LLM — accepts either AgenticResponse or QueryResponse (fallback path)
402    let json_response = call_with_retry(
403        provider, &prompt, 2, super::validate_agentic_or_query_response,
404    ).await?;
405
406    // Parse response - could be AgenticResponse or QueryResponse
407    // Try AgenticResponse first (for agentic mode)
408    if let Ok(agentic_response) = serde_json::from_str::<AgenticResponse>(&json_response) {
409        if agentic_response.phase == Phase::Final {
410            let confidence = agentic_response.confidence;
411
412            // Report generation with reasoning
413            reporter.report_generation(
414                Some(&agentic_response.reasoning),
415                agentic_response.queries.len(),
416                confidence,
417            );
418
419            // Convert to QueryResponse and return with confidence
420            return Ok((
421                QueryResponse {
422                    queries: agentic_response.queries,
423                },
424                confidence,
425            ));
426        }
427    }
428
429    // Fallback: try direct QueryResponse
430    let query_response: QueryResponse = serde_json::from_str(&json_response)
431        .context("Failed to parse LLM query generation response")?;
432
433    log::info!("Generated {} queries", query_response.queries.len());
434
435    // Report generation without reasoning (fallback mode)
436    reporter.report_generation(None, query_response.queries.len(), 1.0);
437
438    // Default confidence of 1.0 for fallback mode
439    Ok((query_response, 1.0))
440}
441
442/// Phase 6: Refine queries based on evaluation
443async fn phase_6_refine(
444    question: &str,
445    gathered_context: &str,
446    previous_response: &QueryResponse,
447    evaluation: &super::schema_agentic::EvaluationReport,
448    cache: &CacheManager,
449    provider: &dyn LlmProvider,
450    config: &AgenticConfig,
451    reporter: &dyn AgenticReporter,
452    debug: bool,
453) -> Result<AgenticQueryResponse> {
454    log::info!("Phase 6: Refining queries based on evaluation");
455
456    // Report refinement start
457    reporter.report_refinement_start();
458
459    // Build refinement prompt with evaluation feedback
460    let prompt = super::prompt_agentic::build_refinement_prompt(
461        question,
462        gathered_context,
463        previous_response,
464        evaluation,
465        cache,
466    )?;
467
468    // Debug mode: output full prompt
469    if debug {
470        eprintln!("\n{}", "=".repeat(80));
471        eprintln!("DEBUG: Full LLM Prompt (Phase 6: Refinement)");
472        eprintln!("{}", "=".repeat(80));
473        eprintln!("{}", prompt);
474        eprintln!("{}\n", "=".repeat(80));
475    }
476
477    // Call LLM for refinement — expects QueryResponse format
478    let json_response = call_with_retry(
479        provider, &prompt, 2, super::validate_query_response,
480    ).await?;
481
482    // Parse refined response
483    let refined_response: QueryResponse = serde_json::from_str(&json_response)
484        .context("Failed to parse LLM refinement response")?;
485
486    log::info!("Refinement complete: {} refined queries", refined_response.queries.len());
487
488    // Execute refined queries
489    let (results, total_count, count_only) = super::executor::execute_queries(
490        refined_response.queries.clone(),
491        cache,
492    ).await?;
493
494    // Evaluate refined results (one final time)
495    let refined_evaluation = evaluate_results(
496        &results,
497        total_count,
498        question,
499        &config.eval_config,
500        if !gathered_context.is_empty() { Some(gathered_context) } else { None },
501        refined_response.queries.len(),
502        None,  // No confidence available in refinement
503    );
504
505    log::info!(
506        "Refined evaluation: success={}, score={:.2}",
507        refined_evaluation.success,
508        refined_evaluation.score
509    );
510
511    // Return enhanced response with both queries and results
512    Ok(AgenticQueryResponse {
513        queries: refined_response.queries,
514        results,
515        total_count: if count_only { None } else { Some(total_count) },
516        gathered_context: if !gathered_context.is_empty() {
517            Some(gathered_context.to_string())
518        } else {
519            None
520        },
521        tools_executed: None,  // No new tools executed during refinement
522        answer: None,  // No answer generation in agentic mode (handled in CLI)
523    })
524}
525
526/// Initialize LLM provider based on configuration
527fn initialize_provider(
528    config: &AgenticConfig,
529    cache: &CacheManager,
530) -> Result<Box<dyn LlmProvider>> {
531    // Load semantic config
532    let mut semantic_config = config::load_config(cache.path())?;
533
534    // Apply overrides
535    if let Some(provider) = &config.provider_override {
536        semantic_config.provider = provider.clone();
537    }
538
539    // Get API key
540    let api_key = config::get_api_key(&semantic_config.provider)?;
541
542    // Determine model
543    let model = if let Some(model_override) = &config.model_override {
544        Some(model_override.clone())
545    } else if semantic_config.model.is_some() {
546        semantic_config.model.clone()
547    } else {
548        config::get_user_model(&semantic_config.provider)
549    };
550
551    // Create provider
552    create_provider(&semantic_config.provider, api_key, model, config::get_provider_options(&semantic_config.provider))
553}
554
555/// Call LLM provider with retry logic (from semantic/mod.rs)
556async fn call_with_retry(
557    provider: &dyn LlmProvider,
558    prompt: &str,
559    max_retries: usize,
560    validator: impl Fn(&str) -> Result<(), String>,
561) -> Result<String> {
562    super::call_with_retry(provider, prompt, max_retries, validator).await
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[test]
570    fn test_agentic_config_defaults() {
571        let config = AgenticConfig::default();
572        assert_eq!(config.max_iterations, 2);
573        assert_eq!(config.max_tools_per_phase, 5);
574        assert!(config.enable_evaluation);
575    }
576
577    #[test]
578    fn test_agentic_config_custom() {
579        let config = AgenticConfig {
580            max_iterations: 3,
581            max_tools_per_phase: 10,
582            enable_evaluation: false,
583            ..Default::default()
584        };
585
586        assert_eq!(config.max_iterations, 3);
587        assert_eq!(config.max_tools_per_phase, 10);
588        assert!(!config.enable_evaluation);
589    }
590}