use super::{Symbol, SymbolSearchResult};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SymbolContext {
pub before: Vec<String>,
pub symbol_line: String,
pub after: Vec<String>,
pub documentation: Vec<String>,
pub annotations: Vec<String>,
}
impl SymbolContext {
pub fn extract(file_path: &Path, line: usize, context_lines: usize) -> Result<Self, Box<dyn std::error::Error>> {
let content = fs::read_to_string(file_path)?;
let lines: Vec<&str> = content.lines().collect();
let line_idx = line.saturating_sub(1).min(lines.len().saturating_sub(1));
let start = line_idx.saturating_sub(context_lines);
let end = (line_idx + context_lines + 1).min(lines.len());
let before: Vec<String> = lines[start..line_idx]
.iter()
.map(|s| s.to_string())
.collect();
let symbol_line = lines
.get(line_idx)
.unwrap_or(&"")
.to_string();
let after: Vec<String> = if line_idx + 1 < lines.len() {
lines[line_idx + 1..end]
.iter()
.map(|s| s.to_string())
.collect()
} else {
Vec::new()
};
let documentation = before
.iter()
.rev()
.take_while(|line| {
line.trim().starts_with("//") ||
line.trim().starts_with("#") ||
line.trim().starts_with("/*") ||
line.trim().starts_with("*")
})
.cloned()
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let annotations = before
.iter()
.rev()
.take_while(|line| line.trim().starts_with('@') || line.trim().starts_with('#'))
.cloned()
.collect();
Ok(Self {
before,
symbol_line,
after,
documentation,
annotations,
})
}
pub fn extract_function_body(file_path: &Path, start_line: usize) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let content = fs::read_to_string(file_path)?;
let lines: Vec<&str> = content.lines().collect();
if start_line >= lines.len() {
return Ok(Vec::new());
}
let mut body = Vec::new();
let mut brace_count = 0;
let mut in_function = false;
for line in lines.iter().skip(start_line) {
let open_braces = line.matches('{').count();
let close_braces = line.matches('}').count();
if open_braces > 0 {
in_function = true;
brace_count += open_braces;
}
if in_function {
body.push(line.to_string());
}
if close_braces > 0 {
brace_count -= close_braces;
if brace_count == 0 && in_function {
break;
}
}
}
Ok(body)
}
pub fn extract_imports(file_path: &Path) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let content = fs::read_to_string(file_path)?;
let mut imports = Vec::new();
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("import ") ||
trimmed.starts_with("use ") ||
trimmed.starts_with("from ") ||
trimmed.starts_with("require(") ||
trimmed.starts_with("#include ") {
imports.push(trimmed.to_string());
}
}
Ok(imports)
}
pub fn extract_module_info(file_path: &Path) -> Result<ModuleInfo, Box<dyn std::error::Error>> {
let content = fs::read_to_string(file_path)?;
let mut module_name = None;
let mut package_name = None;
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("module ") {
module_name = trimmed
.strip_prefix("module ")
.and_then(|s| s.split(';').next())
.map(|s| s.to_string());
}
if trimmed.starts_with("package ") {
package_name = trimmed
.strip_prefix("package ")
.and_then(|s| s.split(';').next())
.map(|s| s.to_string());
}
}
Ok(ModuleInfo {
path: file_path.to_string_lossy().to_string(),
module_name,
package_name,
})
}
}
#[derive(Debug, Clone)]
pub struct ModuleInfo {
pub path: String,
pub module_name: Option<String>,
pub package_name: Option<String>,
}
pub fn create_search_result(
symbol: &Symbol,
score: f64,
context_lines: usize,
) -> SymbolSearchResult {
let file_path = Path::new(&symbol.file_path);
let context = SymbolContext::extract(file_path, symbol.line, context_lines)
.unwrap_or_else(|_| SymbolContext {
before: Vec::new(),
symbol_line: String::new(),
after: Vec::new(),
documentation: Vec::new(),
annotations: Vec::new(),
});
SymbolSearchResult {
symbol: symbol.clone(),
score,
context_before: context.before,
context_after: context.after,
related_symbols: Vec::new(), }
}
pub fn get_function_signature(file_path: &Path, line: usize) -> Result<String, Box<dyn std::error::Error>> {
let content = fs::read_to_string(file_path)?;
let lines: Vec<&str> = content.lines().collect();
if line >= lines.len() {
return Ok(String::new());
}
let func_line = lines.get(line).unwrap_or(&"");
let mut signature = func_line.to_string();
let mut i = line + 1;
while i < lines.len() && !signature.contains('{') && !signature.contains(';') {
let next_line = lines.get(i).unwrap_or(&"");
signature.push_str(next_line);
i += 1;
if i > line + 10 {
break;
}
}
if let Some(open_brace) = signature.find('{') {
signature.truncate(open_brace);
}
if let Some(semicolon) = signature.find(';') {
signature.truncate(semicolon);
}
Ok(signature.trim().to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_extraction() {
use std::io::Write;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.rs");
let mut file = std::fs::File::create(&file_path).unwrap();
writeln!(file, "// This is documentation").unwrap();
writeln!(file, "fn test_function() {{").unwrap();
writeln!(file, " let x = 42;").unwrap();
writeln!(file, "}}").unwrap();
let context = SymbolContext::extract(&file_path, 2, 2).unwrap();
assert!(!context.before.is_empty());
assert_eq!(context.symbol_line, "fn test_function() {");
assert!(!context.after.is_empty());
assert_eq!(context.documentation.len(), 1);
}
#[test]
fn test_function_body_extraction() {
use std::io::Write;
use tempfile::tempdir;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.rs");
let mut file = std::fs::File::create(&file_path).unwrap();
writeln!(file, "fn test() {{").unwrap();
writeln!(file, " let x = 1;").unwrap();
writeln!(file, " let y = 2;").unwrap();
writeln!(file, " x + y").unwrap();
writeln!(file, "}}").unwrap();
let body = SymbolContext::extract_function_body(&file_path, 0).unwrap();
assert_eq!(body.len(), 5);
assert!(body[0].contains("fn test()"));
assert!(body[4].contains("}"));
}
}