use std::process::Command;
use car_inference::{GenerateParams, GenerateRequest};
use crate::handle::ReasoningInferenceHandle;
use crate::types::*;
use crate::ReasonError;
const MAX_CODE_CONTEXT_BYTES: usize = 30_000;
const MAX_CONTEXT_SYMBOLS: usize = 50;
#[derive(Debug, Clone)]
pub struct AccumulatedContext {
pub problem: String,
pub problem_class: ProblemClass,
pub memory_context: String,
pub locations: String,
pub patterns: String,
pub diagnosis: String,
pub code: String,
pub explanation: String,
pub source_code: String,
}
impl AccumulatedContext {
pub fn new(problem: &str, memory_context: &str, problem_class: ProblemClass) -> Self {
Self {
problem: problem.to_string(),
problem_class,
memory_context: memory_context.to_string(),
locations: String::new(),
patterns: String::new(),
diagnosis: String::new(),
code: String::new(),
explanation: String::new(),
source_code: String::new(),
}
}
pub fn integrate(&mut self, outcome: &ActionOutcome) {
match outcome.action {
ActionKind::Locate => self.locations = outcome.output.clone(),
ActionKind::RetrievePatterns => self.patterns = outcome.output.clone(),
ActionKind::Diagnose => self.diagnosis = outcome.output.clone(),
ActionKind::GenerateFix => self.code = outcome.output.clone(),
ActionKind::Explain => self.explanation = outcome.output.clone(),
ActionKind::Classify | ActionKind::VerifyFix => {}
}
}
}
pub fn execute_locate_only(problem: &str) -> Result<ActionOutcome, ReasonError> {
let ctx = AccumulatedContext::new(problem, "", crate::types::ProblemClass::BugFix);
execute_locate(&ctx)
}
pub async fn execute_action(
engine: &dyn ReasoningInferenceHandle,
action: ActionKind,
config: &ActionConfig,
ctx: &AccumulatedContext,
) -> Result<ActionOutcome, ReasonError> {
match action {
ActionKind::RetrievePatterns => {
return Ok(ActionOutcome {
action,
model_used: "memgine".into(),
trace_id: String::new(),
latency_ms: 0,
output: ctx.patterns.clone(),
confidence: 1.0,
success: true,
});
}
ActionKind::Locate => {
return execute_locate(ctx);
}
_ => {}
}
let prompt = render_prompt(action, config, ctx);
let full_context = if ctx.source_code.is_empty() {
if ctx.memory_context.is_empty() {
None
} else {
Some(ctx.memory_context.clone())
}
} else {
let mut c = String::new();
if !ctx.memory_context.is_empty() {
c.push_str(&ctx.memory_context);
c.push_str("\n\n");
}
c.push_str("## Source Code\n\n");
c.push_str(&ctx.source_code);
Some(c)
};
let model = match action {
ActionKind::Classify => {
Some("Qwen3-0.6B".to_string())
}
ActionKind::Diagnose | ActionKind::GenerateFix => {
pick_model_for_tier(engine, ModelTier::Frontier).await
}
ActionKind::Explain => {
match ctx.problem_class {
ProblemClass::Architecture | ProblemClass::Explanation => {
pick_model_for_tier(engine, ModelTier::Frontier).await
}
_ => pick_model_for_tier(engine, ModelTier::Fast).await,
}
}
ActionKind::VerifyFix => {
pick_model_for_tier(engine, ModelTier::Fast).await
}
_ => None,
};
let req = GenerateRequest {
prompt,
model,
params: GenerateParams {
temperature: action.temperature(),
max_tokens: action.max_tokens(),
..Default::default()
},
context: full_context,
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let result = engine
.generate_tracked(req)
.await
.map_err(|e| ReasonError::InferenceFailed(e.to_string()))?;
let confidence = assess_confidence(&result.text, action);
Ok(ActionOutcome {
action,
model_used: result.model_used,
trace_id: result.trace_id,
latency_ms: result.latency_ms,
output: result.text,
confidence,
success: confidence > 0.2,
})
}
#[derive(Debug, Clone, Copy)]
enum ModelTier {
Frontier,
Fast,
}
async fn pick_model_for_tier(
engine: &dyn ReasoningInferenceHandle,
tier: ModelTier,
) -> Option<String> {
let candidates: &[(&str, &str)] = match tier {
ModelTier::Frontier => &[
("claude-opus-4-7", "ANTHROPIC_API_KEY"),
("gpt-5.4", "OPENAI_API_KEY"),
("o3", "OPENAI_API_KEY"),
("claude-sonnet-4-6", "ANTHROPIC_API_KEY"),
("gpt-5.3-codex", "OPENAI_API_KEY"),
("gemini-2.5-pro", "GOOGLE_API_KEY"),
],
ModelTier::Fast => &[
("gpt-5.4-mini", "OPENAI_API_KEY"),
("claude-haiku-4-5", "ANTHROPIC_API_KEY"),
("gpt-4.1-mini", "OPENAI_API_KEY"),
("gemini-2.5-flash", "GOOGLE_API_KEY"),
("claude-sonnet-4-6", "ANTHROPIC_API_KEY"),
("o4-mini", "OPENAI_API_KEY"),
],
};
for (model_name, env_var) in candidates {
if std::env::var(env_var).is_ok() {
if engine.find_model_by_name(model_name).await.is_some() {
return Some(model_name.to_string());
}
}
}
None }
fn execute_locate(ctx: &AccumulatedContext) -> Result<ActionOutcome, ReasonError> {
let start = std::time::Instant::now();
let cwd = std::env::current_dir().unwrap_or_default();
let search_terms = extract_search_terms(&ctx.problem);
let mut found_files: Vec<String> = Vec::new();
for term in &search_terms {
if let Ok(output) = Command::new("grep")
.args([
"-rl",
"--include=*.rs",
"--include=*.py",
"--include=*.ts",
"--include=*.js",
"--include=*.go",
"-i",
term,
".",
])
.current_dir(&cwd)
.output()
{
if output.status.success() {
let files = String::from_utf8_lossy(&output.stdout);
for f in files.lines() {
let f = f.trim().to_string();
if !f.is_empty() && !found_files.contains(&f) {
found_files.push(f);
}
}
}
}
}
let problem_lower = ctx.problem.to_lowercase();
if let Ok(output) = Command::new("find")
.args([
".", "-name", "*.rs", "-o", "-name", "*.py", "-o", "-name", "*.ts", "-o", "-name",
"*.go", "-o", "-name", "*.js",
])
.current_dir(&cwd)
.output()
{
if output.status.success() {
let all_files = String::from_utf8_lossy(&output.stdout);
for f in all_files.lines() {
let f = f.trim();
if f.is_empty() {
continue;
}
let basename = f.rsplit('/').next().unwrap_or(f);
if problem_lower.contains(&basename.to_lowercase())
&& !found_files.contains(&f.to_string())
{
found_files.insert(0, f.to_string());
}
}
}
}
found_files.truncate(30);
let mut source_content = String::new();
let mut total_bytes = 0;
let mut files_read = 0;
let mut total_symbols = 0;
for file_path in &found_files {
if total_bytes >= MAX_CODE_CONTEXT_BYTES || total_symbols >= MAX_CONTEXT_SYMBOLS {
break;
}
let full_path = cwd.join(file_path);
let content = match std::fs::read_to_string(&full_path) {
Ok(c) => c,
Err(_) => continue,
};
if let Some(parsed) = car_ast::parse_file(&content, file_path) {
let mut relevant: Vec<&car_ast::Symbol> = Vec::new();
for term in &search_terms {
relevant.extend(parsed.find_symbol_fuzzy(term));
}
relevant.sort_by(|a, b| a.span.start_byte.cmp(&b.span.start_byte));
relevant.dedup_by(|a, b| a.name == b.name && a.kind == b.kind);
if relevant.is_empty() {
for sym in &parsed.symbols {
relevant.push(sym);
}
}
if !relevant.is_empty() {
let remaining = MAX_CONTEXT_SYMBOLS - total_symbols;
if relevant.len() > remaining {
relevant.truncate(remaining);
}
source_content.push_str(&format!(
"### {} (AST: {} symbols)\n",
file_path,
relevant.len()
));
if !parsed.imports.is_empty() {
source_content.push_str("Imports: ");
let import_summary: Vec<_> = parsed
.imports
.iter()
.take(5)
.map(|i| i.path.as_str())
.collect();
source_content.push_str(&import_summary.join(", "));
if parsed.imports.len() > 5 {
source_content.push_str(&format!(" (+{} more)", parsed.imports.len() - 5));
}
source_content.push('\n');
}
for sym in &relevant {
source_content.push_str(&format!("\n[{:?}] {}\n", sym.kind, sym.signature));
if let Some(doc) = &sym.doc_comment {
source_content
.push_str(&format!(" doc: {}\n", doc.lines().next().unwrap_or("")));
}
for child in &sym.children {
source_content
.push_str(&format!(" [{:?}] {}\n", child.kind, child.signature));
}
let sym_source = car_ast::extract_source(sym, &content);
let remaining_bytes = MAX_CODE_CONTEXT_BYTES - total_bytes;
if sym_source.len() <= remaining_bytes {
source_content.push_str("```\n");
source_content.push_str(&sym_source);
source_content.push_str("\n```\n");
total_bytes += sym_source.len();
} else if remaining_bytes > 200 {
let trunc = &sym_source[..sym_source.floor_char_boundary(remaining_bytes)];
source_content.push_str("```\n");
source_content.push_str(trunc);
source_content.push_str("\n// ... (truncated)\n```\n");
total_bytes += trunc.len();
}
total_symbols += 1;
}
source_content.push('\n');
files_read += 1;
}
} else {
let remaining = MAX_CODE_CONTEXT_BYTES - total_bytes;
if remaining < 200 {
continue;
}
let truncated = if content.len() > remaining {
&content[..content.floor_char_boundary(remaining)]
} else {
&content
};
source_content.push_str(&format!(
"### {} (raw)\n```\n{}\n```\n\n",
file_path, truncated
));
total_bytes += truncated.len();
files_read += 1;
}
}
let latency_ms = start.elapsed().as_millis() as u64;
let locations_summary = format!(
"Found {} relevant files ({} read, {} symbols, {} bytes):\n{}",
found_files.len(),
files_read,
total_symbols,
total_bytes,
found_files
.iter()
.enumerate()
.map(|(i, f)| format!(" {}. {}", i + 1, f))
.collect::<Vec<_>>()
.join("\n"),
);
let output = format!(
"{}\n\n---SOURCE_CODE_START---\n{}",
locations_summary, source_content
);
Ok(ActionOutcome {
action: ActionKind::Locate,
model_used: "codebase-search+ast".into(),
trace_id: String::new(),
latency_ms,
output,
confidence: if files_read > 0 { 0.9 } else { 0.3 },
success: files_read > 0,
})
}
fn extract_search_terms(problem: &str) -> Vec<String> {
let mut terms = Vec::new();
for word in problem.split_whitespace() {
let clean: String = word
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect();
if clean.len() < 3 {
continue;
}
let lower = clean.to_lowercase();
let noise = [
"the", "and", "for", "with", "this", "that", "from", "what", "how", "would", "should",
"could", "does", "not", "all", "are", "but", "its", "has", "have", "was", "were",
"will", "can", "use", "using", "used", "add", "fix", "bug", "error",
];
if noise.contains(&lower.as_str()) {
continue;
}
let has_underscore = clean.contains('_');
let has_camel =
clean.chars().any(|c| c.is_uppercase()) && clean.chars().any(|c| c.is_lowercase());
let is_code_term = has_underscore || has_camel || clean.len() > 6;
if is_code_term {
terms.insert(0, clean); } else {
terms.push(clean);
}
}
terms.truncate(10); terms
}
fn render_prompt(action: ActionKind, config: &ActionConfig, ctx: &AccumulatedContext) -> String {
let template = if config.prompt_template.is_empty() {
default_prompt(action)
} else {
config.prompt_template.clone()
};
let diagnosis_section = if ctx.diagnosis.is_empty() {
String::new()
} else {
format!("Diagnosis: {}\n\n", ctx.diagnosis)
};
let fix_section = if ctx.code.is_empty() {
String::new()
} else {
format!("Proposed fix:\n{}\n\n", ctx.code)
};
template
.replace("{problem}", &ctx.problem)
.replace("{problem_class}", &ctx.problem_class.to_string())
.replace("{context}", &ctx.memory_context)
.replace("{locations}", &ctx.locations)
.replace("{patterns}", &ctx.patterns)
.replace("{diagnosis}", &ctx.diagnosis)
.replace("{code}", &ctx.code)
.replace("{diagnosis_section}", &diagnosis_section)
.replace("{fix_section}", &fix_section)
}
fn default_prompt(action: ActionKind) -> String {
match action {
ActionKind::Classify => "Classify this code problem into one category: bug_fix, refactor, architecture, new_feature, performance, test_writing, or explanation.\n\nProblem: {problem}\n\nCategory:".into(),
ActionKind::Locate => String::new(), ActionKind::RetrievePatterns => String::new(),
ActionKind::Diagnose => "You are an expert code analyst. Analyze the root cause of this problem using the source code provided in the context.\n\nProblem: {problem}\n\nLocated files:\n{locations}\n\nProvide a precise root cause analysis referencing specific functions, types, and line numbers from the source code.".into(),
ActionKind::GenerateFix => "You are an expert programmer. Generate a precise code fix for this problem. Reference the actual source code from the context.\n\nProblem: {problem}\n\nDiagnosis: {diagnosis}\n\nGenerate the fix as a code block with the file path. Only change what's necessary.".into(),
ActionKind::VerifyFix => "Review this fix for correctness. Check for: off-by-one errors, null/None handling, type mismatches, missing edge cases.\n\nProposed fix:\n{code}\n\nVerification (PASS or FAIL with reason):".into(),
ActionKind::Explain => "Explain clearly and concisely what was wrong and how the fix works.\n\nProblem: {problem}\n{diagnosis_section}{fix_section}Explanation:".into(),
}
}
fn assess_confidence(output: &str, action: ActionKind) -> f64 {
if output.trim().is_empty() {
return 0.0;
}
let length = output.len();
match action {
ActionKind::Classify => {
if length < 50 {
0.9
} else {
0.5
}
}
ActionKind::Locate => {
let has_paths = output.contains('/') || output.contains('.');
if has_paths {
0.8
} else {
0.4
}
}
ActionKind::Diagnose => {
if length > 200 {
0.8
} else if length > 50 {
0.6
} else {
0.3
}
}
ActionKind::GenerateFix => {
let has_code =
output.contains("```") || output.contains("fn ") || output.contains("def ");
if has_code {
0.7
} else {
0.3
}
}
ActionKind::VerifyFix => {
let upper = output.to_uppercase();
if upper.contains("PASS") {
0.9
} else if upper.contains("FAIL") {
0.3
} else {
0.5
}
}
ActionKind::Explain => {
if length > 100 {
0.8
} else {
0.5
}
}
ActionKind::RetrievePatterns => 1.0,
}
}