use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct ContextOptions {
pub lines_before: usize,
pub lines_after: usize,
pub max_context_chars: usize,
pub include_scope: bool,
pub include_imports: bool,
}
impl Default for ContextOptions {
fn default() -> Self {
Self {
lines_before: 10,
lines_after: 10,
max_context_chars: 4000,
include_scope: true,
include_imports: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeContext {
pub file_path: String,
pub language: String,
pub issue_lines: String,
pub line_number: u32,
pub before: String,
pub after: String,
pub scope: Option<String>,
pub imports: Option<String>,
pub full_snippet: String,
}
impl CodeContext {
pub fn new(file_path: &str, language: &str, line_number: u32) -> Self {
Self {
file_path: file_path.to_string(),
language: language.to_string(),
issue_lines: String::new(),
line_number,
before: String::new(),
after: String::new(),
scope: None,
imports: None,
full_snippet: String::new(),
}
}
pub fn char_count(&self) -> usize {
self.full_snippet.len()
}
}
pub fn extract_context(
path: &Path,
line_number: u32,
options: &ContextOptions,
) -> Result<CodeContext, String> {
let content = fs::read_to_string(path).map_err(|e| format!("Failed to read file: {}", e))?;
let lines: Vec<&str> = content.lines().collect();
let total_lines = lines.len();
if line_number == 0 || line_number as usize > total_lines {
return Err(format!(
"Line number {} out of range (file has {} lines)",
line_number, total_lines
));
}
let idx = line_number as usize - 1;
let language = detect_language(path);
let mut context = CodeContext::new(&path.to_string_lossy(), &language, line_number);
context.issue_lines = lines[idx].to_string();
let start = idx.saturating_sub(options.lines_before);
if start < idx {
context.before = lines[start..idx].join("\n");
}
let end = (idx + 1 + options.lines_after).min(total_lines);
if idx + 1 < end {
context.after = lines[idx + 1..end].join("\n");
}
context.full_snippet = format!(
"{}\n>>> {} <<<\n{}",
context.before, context.issue_lines, context.after
);
if options.include_imports {
context.imports = Some(extract_imports(&lines, &language));
}
if options.include_scope {
context.scope = extract_scope(&lines, idx, &language);
}
if context.char_count() > options.max_context_chars {
context.full_snippet = truncate_context(&context.full_snippet, options.max_context_chars);
}
Ok(context)
}
fn detect_language(path: &Path) -> String {
path.extension()
.and_then(|e| e.to_str())
.map(|ext| match ext {
"rs" => "rust",
"py" | "pyw" => "python",
"js" | "jsx" | "mjs" => "javascript",
"ts" | "tsx" | "mts" => "typescript",
"go" => "go",
"java" => "java",
"cpp" | "cc" | "cxx" | "c" | "h" | "hpp" => "cpp",
"swift" => "swift",
"kt" | "kts" => "kotlin",
"rb" => "ruby",
"php" => "php",
"sh" | "bash" => "shell",
_ => ext,
})
.unwrap_or("unknown")
.to_string()
}
fn extract_imports(lines: &[&str], language: &str) -> String {
let import_patterns: Vec<&str> = match language {
"rust" => vec!["use ", "extern crate ", "mod "],
"python" => vec!["import ", "from "],
"javascript" | "typescript" => vec!["import ", "require(", "export "],
"go" => vec!["import "],
"java" => vec!["import ", "package "],
"cpp" => vec!["#include ", "#pragma ", "using "],
"swift" => vec!["import "],
"kotlin" => vec!["import ", "package "],
"ruby" => vec!["require ", "require_relative ", "load "],
"php" => vec!["use ", "require ", "include ", "namespace "],
_ => vec![],
};
let mut imports = Vec::new();
for line in lines {
let trimmed = line.trim();
if import_patterns.iter().any(|p| trimmed.starts_with(p)) {
imports.push(*line);
} else if !trimmed.is_empty()
&& !trimmed.starts_with("//")
&& !trimmed.starts_with("#")
&& !trimmed.starts_with("/*")
&& !trimmed.starts_with("*")
{
break;
}
}
imports.join("\n")
}
fn extract_scope(lines: &[&str], issue_idx: usize, language: &str) -> Option<String> {
let start = find_scope_start(lines, issue_idx, language)?;
let end = find_scope_end(lines, start, language);
if start < lines.len() && end <= lines.len() && start < end {
Some(lines[start..end].join("\n"))
} else {
None
}
}
fn find_scope_start(lines: &[&str], idx: usize, language: &str) -> Option<usize> {
let function_patterns: Vec<&str> = match language {
"rust" => vec!["fn ", "pub fn ", "async fn ", "pub async fn "],
"python" => vec!["def ", "async def ", "class "],
"javascript" | "typescript" => vec!["function ", "const ", "let ", "class ", "async "],
"go" => vec!["func "],
"java" => vec!["public ", "private ", "protected ", "void ", "class "],
"cpp" => vec!["void ", "int ", "bool ", "class ", "struct "],
"swift" => vec!["func ", "class ", "struct "],
"kotlin" => vec!["fun ", "class ", "object "],
_ => vec![],
};
for i in (0..=idx).rev() {
let trimmed = lines[i].trim();
if function_patterns
.iter()
.any(|p| trimmed.starts_with(p) || trimmed.contains(p))
{
return Some(i);
}
}
None
}
fn find_scope_end(lines: &[&str], start: usize, language: &str) -> usize {
if language == "python" {
let base_indent = lines[start].len() - lines[start].trim_start().len();
for (i, line) in lines.iter().enumerate().skip(start + 1) {
if line.trim().is_empty() {
continue;
}
let indent = line.len() - line.trim_start().len();
if indent <= base_indent {
return i;
}
}
return lines.len();
}
let mut brace_count = 0;
let mut found_open = false;
for (i, line) in lines.iter().enumerate().skip(start) {
for ch in line.chars() {
if ch == '{' {
brace_count += 1;
found_open = true;
} else if ch == '}' {
brace_count -= 1;
if found_open && brace_count == 0 {
return i + 1;
}
}
}
}
(start + 50).min(lines.len())
}
fn truncate_context(content: &str, max_chars: usize) -> String {
if content.len() <= max_chars {
return content.to_string();
}
if let Some(marker_pos) = content.find(">>>") {
let start = marker_pos.saturating_sub(max_chars / 2);
let end = (marker_pos + max_chars / 2).min(content.len());
let mut truncated = String::new();
if start > 0 {
truncated.push_str("...\n");
}
truncated.push_str(&content[start..end]);
if end < content.len() {
truncated.push_str("\n...");
}
truncated
} else {
content[..max_chars].to_string() + "..."
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_context_options_default() {
let options = ContextOptions::default();
assert_eq!(options.lines_before, 10);
assert_eq!(options.lines_after, 10);
}
#[test]
fn test_detect_language() {
assert_eq!(detect_language(Path::new("test.rs")), "rust");
assert_eq!(detect_language(Path::new("test.py")), "python");
assert_eq!(detect_language(Path::new("test.ts")), "typescript");
assert_eq!(detect_language(Path::new("test.go")), "go");
}
#[test]
fn test_extract_context() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "fn main() {{").unwrap();
writeln!(file, " let x = 1;").unwrap();
writeln!(file, " let y = 2;").unwrap();
writeln!(file, " println!(\"{{}}\", x + y);").unwrap();
writeln!(file, "}}").unwrap();
let options = ContextOptions::default();
let context = extract_context(file.path(), 3, &options).unwrap();
assert_eq!(context.line_number, 3);
assert!(context.issue_lines.contains("let y = 2"));
}
#[test]
fn test_extract_imports() {
let lines = vec!["use std::io;", "use std::path::Path;", "", "fn main() {"];
let imports = extract_imports(&lines, "rust");
assert!(imports.contains("use std::io"));
assert!(imports.contains("use std::path::Path"));
}
}