Skip to main content

reflex/semantic/
agentic.rs

1//! Agentic loop orchestrator for multi-step query generation
2//!
3//! This module implements the main agentic workflow:
4//! 1. Phase 1: Assess if more context is needed
5//! 2. Phase 2: Gather context using tools
6//! 3. Phase 3: Generate final queries
7//! 4. Phase 4: Execute queries
8//! 5. Phase 5: Evaluate results
9//! 6. Phase 6: Refine if needed
10
11use crate::cache::CacheManager;
12use anyhow::{Context as AnyhowContext, Result};
13
14use super::config;
15use super::evaluator::{EvaluationConfig, evaluate_results};
16use super::providers::{LlmProvider, create_provider};
17use super::reporter::AgenticReporter;
18use super::schema::{AgenticQueryResponse, QueryResponse};
19use super::schema_agentic::{AgenticResponse, Phase, ToolCall};
20use super::tools::{ToolResult, execute_tool, format_tool_results};
21
22/// Configuration for agentic loop
23#[derive(Debug, Clone)]
24pub struct AgenticConfig {
25    /// Maximum iterations for refinement (default: 2)
26    pub max_iterations: usize,
27
28    /// Maximum tool calls per gathering phase (default: 5)
29    pub max_tools_per_phase: usize,
30
31    /// Enable result evaluation phase
32    pub enable_evaluation: bool,
33
34    /// Evaluation configuration
35    pub eval_config: EvaluationConfig,
36
37    /// Provider name override
38    pub provider_override: Option<String>,
39
40    /// Model override
41    pub model_override: Option<String>,
42
43    /// Show LLM reasoning blocks (default: false)
44    pub show_reasoning: bool,
45
46    /// Verbose output (show tool results, etc.) (default: false)
47    pub verbose: bool,
48
49    /// Debug mode: output full LLM prompts (default: false)
50    pub debug: bool,
51}
52
53impl Default for AgenticConfig {
54    fn default() -> Self {
55        Self {
56            max_iterations: 2,
57            max_tools_per_phase: 5,
58            enable_evaluation: true,
59            eval_config: EvaluationConfig::default(),
60            provider_override: None,
61            model_override: None,
62            show_reasoning: false,
63            verbose: false,
64            debug: false,
65        }
66    }
67}
68
69/// Run the full agentic loop
70pub async fn run_agentic_loop(
71    question: &str,
72    cache: &CacheManager,
73    config: AgenticConfig,
74    reporter: &dyn AgenticReporter,
75) -> Result<AgenticQueryResponse> {
76    log::info!("Starting agentic loop for question: {}", question);
77
78    // Validate cache before starting - auto-reindex if schema mismatch detected
79    if let Err(e) = cache.validate() {
80        let error_msg = e.to_string();
81
82        // Check if this is a schema mismatch error
83        if error_msg.contains("Cache schema version mismatch") {
84            log::warn!("Cache schema mismatch detected, auto-reindexing...");
85
86            // Create progress callback that reports to the reporter
87            use std::sync::Arc;
88            let progress_callback: crate::indexer::ProgressCallback = Arc::new({
89                // Clone reporter reference for the callback closure
90                // Note: We can't capture `reporter` directly since it's a trait object,
91                // so we'll just log progress and rely on the indexer's built-in progress bar
92                move |current: usize, total: usize, message: String| {
93                    log::debug!("Reindex progress: [{}/{}] {}", current, total, message);
94                }
95            });
96
97            // Trigger reindexing
98            let workspace_root = cache.workspace_root();
99            let index_config = crate::IndexConfig::default();
100            let indexer = crate::indexer::Indexer::new(cache.clone(), index_config);
101
102            log::info!("Auto-reindexing cache at {:?}", workspace_root);
103            indexer.index_with_callback(&workspace_root, false, Some(progress_callback))?;
104
105            log::info!("Cache reindexing completed successfully");
106        } else {
107            // Other validation errors should propagate up
108            return Err(e);
109        }
110    }
111
112    // Initialize provider
113    let provider = initialize_provider(&config, cache)?;
114
115    // Phase 1: Initial assessment - does the LLM need more context?
116    let (needs_context, initial_response) =
117        phase_1_assess(question, cache, &*provider, reporter, config.debug).await?;
118
119    // Phase 2: Context gathering (if needed)
120    let (gathered_context, tools_executed) = if needs_context {
121        phase_2_gather(
122            question,
123            initial_response,
124            cache,
125            &*provider,
126            &config,
127            reporter,
128        )
129        .await?
130    } else {
131        (String::new(), Vec::new())
132    };
133
134    // Phase 3: Generate final queries
135    let (query_response, query_confidence) = phase_3_generate(
136        question,
137        &gathered_context,
138        cache,
139        &*provider,
140        reporter,
141        config.debug,
142    )
143    .await?;
144
145    // Phase 4: Execute queries
146    let (results, total_count, count_only) =
147        super::executor::execute_queries(query_response.queries.clone(), cache).await?;
148
149    log::info!(
150        "Executed queries: {} file groups, {} total matches",
151        results.len(),
152        total_count
153    );
154
155    // Phase 5: Evaluate results (if enabled and not count-only)
156    if config.enable_evaluation && !count_only {
157        let evaluation = evaluate_results(
158            &results,
159            total_count,
160            question,
161            &config.eval_config,
162            if !gathered_context.is_empty() {
163                Some(gathered_context.as_str())
164            } else {
165                None
166            },
167            query_response.queries.len(),
168            Some(query_confidence),
169        );
170
171        log::info!(
172            "Evaluation: success={}, score={:.2}",
173            evaluation.success,
174            evaluation.score
175        );
176
177        // Report evaluation
178        reporter.report_evaluation(&evaluation);
179
180        // Phase 6: Refinement (if needed and iterations remaining)
181        if !evaluation.success && config.max_iterations > 1 {
182            log::info!("Results unsatisfactory, attempting refinement");
183
184            return phase_6_refine(
185                question,
186                &gathered_context,
187                &query_response,
188                &evaluation,
189                cache,
190                &*provider,
191                &config,
192                reporter,
193                config.debug,
194            )
195            .await;
196        }
197    }
198
199    // Return enhanced response with both queries and results
200    Ok(AgenticQueryResponse {
201        queries: query_response.queries,
202        results,
203        total_count: if count_only { None } else { Some(total_count) },
204        gathered_context: if !gathered_context.is_empty() {
205            Some(gathered_context)
206        } else {
207            None
208        },
209        tools_executed: if !tools_executed.is_empty() {
210            Some(tools_executed)
211        } else {
212            None
213        },
214        answer: None, // No answer generation in agentic mode (handled in CLI)
215    })
216}
217
218/// Phase 1: Assess if more context is needed
219async fn phase_1_assess(
220    question: &str,
221    cache: &CacheManager,
222    provider: &dyn LlmProvider,
223    reporter: &dyn AgenticReporter,
224    debug: bool,
225) -> Result<(bool, AgenticResponse)> {
226    log::info!("Phase 1: Assessing context needs");
227
228    // Build assessment prompt
229    let prompt = super::prompt_agentic::build_assessment_prompt(question, cache)?;
230
231    // Debug mode: output full prompt
232    if debug {
233        eprintln!("\n{}", "=".repeat(80));
234        eprintln!("DEBUG: Full LLM Prompt (Phase 1: Assessment)");
235        eprintln!("{}", "=".repeat(80));
236        eprintln!("{}", prompt);
237        eprintln!("{}\n", "=".repeat(80));
238    }
239
240    // Call LLM — validate against AgenticResponse (requires phase + reasoning)
241    let json_response =
242        call_with_retry(provider, &prompt, 2, super::validate_agentic_response).await?;
243
244    // Parse response
245    let response: AgenticResponse =
246        serde_json::from_str(&json_response).context("Failed to parse LLM assessment response")?;
247
248    // Validate phase
249    if response.phase != Phase::Assessment && response.phase != Phase::Final {
250        anyhow::bail!(
251            "Expected 'assessment' or 'final' phase, got {:?}",
252            response.phase
253        );
254    }
255
256    let needs_context = response.needs_context && !response.tool_calls.is_empty();
257
258    log::info!(
259        "Assessment complete: needs_context={}, tool_calls={}",
260        needs_context,
261        response.tool_calls.len()
262    );
263
264    // Report assessment
265    reporter.report_assessment(&response.reasoning, needs_context, &response.tool_calls);
266
267    Ok((needs_context, response))
268}
269
270/// Phase 2: Gather context using tools
271async fn phase_2_gather(
272    _question: &str,
273    initial_response: AgenticResponse,
274    cache: &CacheManager,
275    _provider: &dyn LlmProvider,
276    config: &AgenticConfig,
277    reporter: &dyn AgenticReporter,
278) -> Result<(String, Vec<String>)> {
279    log::info!("Phase 2: Gathering context via tools");
280
281    let mut all_tool_results = Vec::new();
282    let mut tool_descriptions = Vec::new();
283
284    // Limit tool calls to prevent excessive execution
285    let tool_calls: Vec<ToolCall> = initial_response
286        .tool_calls
287        .into_iter()
288        .take(config.max_tools_per_phase)
289        .collect();
290
291    log::info!("Executing {} tool calls", tool_calls.len());
292
293    // Execute all tool calls
294    for (idx, tool) in tool_calls.iter().enumerate() {
295        log::debug!(
296            "Executing tool {}/{}: {:?}",
297            idx + 1,
298            tool_calls.len(),
299            tool
300        );
301
302        // Get tool description for UI display
303        let tool_desc = describe_tool_for_ui(tool);
304        tool_descriptions.push(tool_desc);
305
306        // Report tool start
307        reporter.report_tool_start(idx + 1, tool);
308
309        match execute_tool(tool, cache).await {
310            Ok(result) => {
311                log::info!("Tool {} succeeded: {}", idx + 1, result.description);
312                reporter.report_tool_complete(idx + 1, &result);
313                all_tool_results.push(result);
314            }
315            Err(e) => {
316                log::warn!("Tool {} failed: {}", idx + 1, e);
317                // Continue with other tools even if one fails
318                let failed_result = ToolResult {
319                    description: format!("Tool {} (failed)", idx + 1),
320                    output: format!("Error: {}", e),
321                    success: false,
322                };
323                reporter.report_tool_complete(idx + 1, &failed_result);
324                all_tool_results.push(failed_result);
325            }
326        }
327    }
328
329    // Format all tool results into context string
330    let gathered_context = format_tool_results(&all_tool_results);
331
332    log::info!(
333        "Context gathering complete: {} chars",
334        gathered_context.len()
335    );
336
337    Ok((gathered_context, tool_descriptions))
338}
339
340/// Generate a user-friendly description of a tool call
341fn describe_tool_for_ui(tool: &ToolCall) -> String {
342    match tool {
343        ToolCall::GatherContext { params } => {
344            let mut parts = Vec::new();
345            if params.structure {
346                parts.push("structure");
347            }
348            if params.file_types {
349                parts.push("file types");
350            }
351            if params.project_type {
352                parts.push("project type");
353            }
354            if params.framework {
355                parts.push("frameworks");
356            }
357            if params.entry_points {
358                parts.push("entry points");
359            }
360            if params.test_layout {
361                parts.push("test layout");
362            }
363            if params.config_files {
364                parts.push("config files");
365            }
366
367            if parts.is_empty() {
368                "gather_context: General codebase context".to_string()
369            } else {
370                format!("gather_context: {}", parts.join(", "))
371            }
372        }
373        ToolCall::ExploreCodebase { description, .. } => {
374            format!("explore_codebase: {}", description)
375        }
376        ToolCall::AnalyzeStructure { analysis_type } => {
377            format!("analyze_structure: {:?}", analysis_type)
378        }
379        ToolCall::SearchDocumentation { query, files } => {
380            if let Some(file_list) = files {
381                format!("search_documentation: '{}' in files {:?}", query, file_list)
382            } else {
383                format!("search_documentation: '{}'", query)
384            }
385        }
386        ToolCall::GetStatistics => {
387            "get_statistics: Retrieved file counts and language stats".to_string()
388        }
389        ToolCall::GetDependencies { file_path, reverse } => {
390            if *reverse {
391                format!("get_dependencies: What depends on '{}'", file_path)
392            } else {
393                format!("get_dependencies: Dependencies of '{}'", file_path)
394            }
395        }
396        ToolCall::GetAnalysisSummary { .. } => {
397            "get_analysis_summary: Dependency health overview".to_string()
398        }
399        ToolCall::FindIslands { .. } => "find_islands: Disconnected component analysis".to_string(),
400    }
401}
402
403/// Phase 3: Generate final queries
404///
405/// Returns (QueryResponse, confidence_score)
406async fn phase_3_generate(
407    question: &str,
408    gathered_context: &str,
409    cache: &CacheManager,
410    provider: &dyn LlmProvider,
411    reporter: &dyn AgenticReporter,
412    debug: bool,
413) -> Result<(QueryResponse, f32)> {
414    log::info!("Phase 3: Generating final queries");
415
416    // Build generation prompt with gathered context
417    let prompt = super::prompt_agentic::build_generation_prompt(question, gathered_context, cache)?;
418
419    // Debug mode: output full prompt
420    if debug {
421        eprintln!("\n{}", "=".repeat(80));
422        eprintln!("DEBUG: Full LLM Prompt (Phase 3: Query Generation)");
423        eprintln!("{}", "=".repeat(80));
424        eprintln!("{}", prompt);
425        eprintln!("{}\n", "=".repeat(80));
426    }
427
428    // Call LLM — accepts either AgenticResponse or QueryResponse (fallback path)
429    let json_response = call_with_retry(
430        provider,
431        &prompt,
432        2,
433        super::validate_agentic_or_query_response,
434    )
435    .await?;
436
437    // Parse response - could be AgenticResponse or QueryResponse
438    // Try AgenticResponse first (for agentic mode)
439    if let Ok(agentic_response) = serde_json::from_str::<AgenticResponse>(&json_response) {
440        if agentic_response.phase == Phase::Final {
441            let confidence = agentic_response.confidence;
442
443            // Report generation with reasoning
444            reporter.report_generation(
445                Some(&agentic_response.reasoning),
446                agentic_response.queries.len(),
447                confidence,
448            );
449
450            // Convert to QueryResponse and return with confidence
451            return Ok((
452                QueryResponse {
453                    queries: agentic_response.queries,
454                    message: None,
455                },
456                confidence,
457            ));
458        }
459    }
460
461    // Fallback: try direct QueryResponse
462    let query_response: QueryResponse = serde_json::from_str(&json_response)
463        .context("Failed to parse LLM query generation response")?;
464
465    log::info!("Generated {} queries", query_response.queries.len());
466
467    // Report generation without reasoning (fallback mode)
468    reporter.report_generation(None, query_response.queries.len(), 1.0);
469
470    // Default confidence of 1.0 for fallback mode
471    Ok((query_response, 1.0))
472}
473
474/// Phase 6: Refine queries based on evaluation
475async fn phase_6_refine(
476    question: &str,
477    gathered_context: &str,
478    previous_response: &QueryResponse,
479    evaluation: &super::schema_agentic::EvaluationReport,
480    cache: &CacheManager,
481    provider: &dyn LlmProvider,
482    config: &AgenticConfig,
483    reporter: &dyn AgenticReporter,
484    debug: bool,
485) -> Result<AgenticQueryResponse> {
486    log::info!("Phase 6: Refining queries based on evaluation");
487
488    // Report refinement start
489    reporter.report_refinement_start();
490
491    // Build refinement prompt with evaluation feedback
492    let prompt = super::prompt_agentic::build_refinement_prompt(
493        question,
494        gathered_context,
495        previous_response,
496        evaluation,
497        cache,
498    )?;
499
500    // Debug mode: output full prompt
501    if debug {
502        eprintln!("\n{}", "=".repeat(80));
503        eprintln!("DEBUG: Full LLM Prompt (Phase 6: Refinement)");
504        eprintln!("{}", "=".repeat(80));
505        eprintln!("{}", prompt);
506        eprintln!("{}\n", "=".repeat(80));
507    }
508
509    // Call LLM for refinement — expects QueryResponse format
510    let json_response =
511        call_with_retry(provider, &prompt, 2, super::validate_query_response).await?;
512
513    // Parse refined response
514    let refined_response: QueryResponse =
515        serde_json::from_str(&json_response).context("Failed to parse LLM refinement response")?;
516
517    log::info!(
518        "Refinement complete: {} refined queries",
519        refined_response.queries.len()
520    );
521
522    // Execute refined queries
523    let (results, total_count, count_only) =
524        super::executor::execute_queries(refined_response.queries.clone(), cache).await?;
525
526    // Evaluate refined results (one final time)
527    let refined_evaluation = evaluate_results(
528        &results,
529        total_count,
530        question,
531        &config.eval_config,
532        if !gathered_context.is_empty() {
533            Some(gathered_context)
534        } else {
535            None
536        },
537        refined_response.queries.len(),
538        None, // No confidence available in refinement
539    );
540
541    log::info!(
542        "Refined evaluation: success={}, score={:.2}",
543        refined_evaluation.success,
544        refined_evaluation.score
545    );
546
547    // Return enhanced response with both queries and results
548    Ok(AgenticQueryResponse {
549        queries: refined_response.queries,
550        results,
551        total_count: if count_only { None } else { Some(total_count) },
552        gathered_context: if !gathered_context.is_empty() {
553            Some(gathered_context.to_string())
554        } else {
555            None
556        },
557        tools_executed: None, // No new tools executed during refinement
558        answer: None,         // No answer generation in agentic mode (handled in CLI)
559    })
560}
561
562/// Initialize LLM provider based on configuration
563fn initialize_provider(
564    config: &AgenticConfig,
565    cache: &CacheManager,
566) -> Result<Box<dyn LlmProvider>> {
567    // Load semantic config
568    let mut semantic_config = config::load_config(cache.path())?;
569
570    // Apply overrides
571    if let Some(provider) = &config.provider_override {
572        semantic_config.provider = provider.clone();
573    }
574
575    // Get API key
576    let api_key = config::get_api_key(&semantic_config.provider)?;
577
578    let model = config::resolve_model(&semantic_config, config.model_override.as_deref());
579
580    // Create provider
581    create_provider(
582        &semantic_config.provider,
583        api_key,
584        model,
585        config::get_provider_options(&semantic_config.provider),
586        semantic_config.timeout_seconds,
587    )
588}
589
590/// Call LLM provider with retry logic (from semantic/mod.rs)
591async fn call_with_retry(
592    provider: &dyn LlmProvider,
593    prompt: &str,
594    max_retries: usize,
595    validator: impl Fn(&str) -> Result<(), String>,
596) -> Result<String> {
597    super::call_with_retry(provider, prompt, max_retries, validator).await
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_agentic_config_defaults() {
606        let config = AgenticConfig::default();
607        assert_eq!(config.max_iterations, 2);
608        assert_eq!(config.max_tools_per_phase, 5);
609        assert!(config.enable_evaluation);
610    }
611
612    #[test]
613    fn test_agentic_config_custom() {
614        let config = AgenticConfig {
615            max_iterations: 3,
616            max_tools_per_phase: 10,
617            enable_evaluation: false,
618            ..Default::default()
619        };
620
621        assert_eq!(config.max_iterations, 3);
622        assert_eq!(config.max_tools_per_phase, 10);
623        assert!(!config.enable_evaluation);
624    }
625}