use anyhow::{Context, Result};
use tree_sitter::{Node, Parser};
use crate::scanner::Finding;
pub fn detect(file: &str, source: &[u8], parser: &mut Parser) -> Result<Vec<Finding>> {
crate::ast::python::detect_django_fbv(file, source, parser)
}
pub fn apply(path: &str, parser: &mut Parser) -> Result<ApplyResult> {
let source = std::fs::read(path).with_context(|| format!("cannot read {}", path))?;
let tree = parser
.parse(&source, None)
.ok_or_else(|| anyhow::anyhow!("tree-sitter failed to parse {}", path))?;
let mut rewrites: Vec<Rewrite> = Vec::new();
let mut warnings: Vec<String> = Vec::new();
let root = tree.root_node();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
match child.kind() {
"function_definition" => {
if let Some(rw) = plan_rewrite(child, &source, &mut warnings) {
rewrites.push(rw);
}
}
"decorated_definition" => {
let decorators = collect_decorators(child, &source);
let mut ic = child.walk();
for grandchild in child.children(&mut ic) {
if grandchild.kind() == "function_definition" {
if let Some(mut rw) = plan_rewrite(grandchild, &source, &mut warnings) {
rw.byte_start = child.start_byte();
rw.decorators = decorators.clone();
rewrites.push(rw);
}
break;
}
}
}
_ => {}
}
}
if rewrites.is_empty() {
return Ok(ApplyResult {
rewrites_applied: Vec::new(),
warnings,
backup_path: None,
});
}
let backup_path = format!("{}.bak", path);
std::fs::write(&backup_path, &source)
.with_context(|| format!("cannot write backup {}", backup_path))?;
rewrites.sort_by_key(|r| std::cmp::Reverse(r.byte_start));
let mut result = source.clone();
let mut applied = Vec::new();
for rw in &rewrites {
let new_bytes = rw.replacement.as_bytes();
result.splice(rw.byte_start..rw.byte_end, new_bytes.iter().copied());
applied.push(rw.description.clone());
}
std::fs::write(path, &result).with_context(|| format!("cannot write {}", path))?;
Ok(ApplyResult {
rewrites_applied: applied,
warnings,
backup_path: Some(backup_path),
})
}
struct Rewrite {
byte_start: usize,
byte_end: usize,
replacement: String,
description: String,
decorators: Vec<String>,
}
pub struct ApplyResult {
pub rewrites_applied: Vec<String>,
pub warnings: Vec<String>,
pub backup_path: Option<String>,
}
fn plan_rewrite(func: Node, source: &[u8], warnings: &mut Vec<String>) -> Option<Rewrite> {
if !is_fbv(func, source) {
return None;
}
let name_node = func.child_by_field_name("name")?;
let func_name = node_text(name_node, source);
let params_node = func.child_by_field_name("parameters")?;
let body_node = func.child_by_field_name("body")?;
let body_text = node_text(body_node, source);
if body_text.contains("request.method") {
warnings.push(format!(
"`{}`: multi-method dispatch detected — manual CBV conversion required",
func_name
));
return None;
}
let class_name = to_pascal_case(func_name) + "View";
let extra_params = collect_extra_params(params_node, source);
let indent = leading_spaces(func, source);
let body_lines = reindent_body(body_text, &indent);
let get_params = if extra_params.is_empty() {
"self, request".to_string()
} else {
format!("self, request, {}", extra_params.join(", "))
};
let replacement = format!(
"{indent}class {class_name}(View):\n{indent} def get({get_params}):\n{body_lines}",
indent = indent,
class_name = class_name,
get_params = get_params,
body_lines = body_lines,
);
Some(Rewrite {
byte_start: func.start_byte(),
byte_end: func.end_byte(),
replacement,
description: format!(
"`def {}()` → `class {}(View)` at byte {}",
func_name,
class_name,
func.start_byte()
),
decorators: Vec::new(),
})
}
fn is_fbv(func: Node, source: &[u8]) -> bool {
let params = match func.child_by_field_name("parameters") {
Some(p) => p,
None => return false,
};
let mut cursor = params.walk();
let first_positional = params.children(&mut cursor).find(|n| {
matches!(
n.kind(),
"identifier" | "typed_parameter" | "list_splat_pattern" | "dictionary_splat_pattern"
)
});
match first_positional {
Some(n) if n.kind() == "identifier" => node_text(n, source) == "request",
_ => false,
}
}
fn collect_extra_params(params: Node, source: &[u8]) -> Vec<String> {
let mut result = Vec::new();
let mut cursor = params.walk();
let mut skip_first = true;
for child in params.children(&mut cursor) {
if matches!(
child.kind(),
"identifier" | "typed_parameter" | "default_parameter"
) {
if skip_first {
skip_first = false;
continue; }
result.push(node_text(child, source).to_string());
}
}
result
}
fn collect_decorators(decorated: Node, source: &[u8]) -> Vec<String> {
let mut result = Vec::new();
let mut cursor = decorated.walk();
for child in decorated.children(&mut cursor) {
if child.kind() == "decorator" {
result.push(node_text(child, source).to_string());
}
}
result
}
fn to_pascal_case(s: &str) -> String {
s.split('_')
.filter(|part| !part.is_empty())
.map(|part| {
let mut chars = part.chars();
match chars.next() {
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
})
.collect()
}
fn leading_spaces(node: Node, source: &[u8]) -> String {
let text = std::str::from_utf8(source).unwrap_or("");
let line_start = text[..node.start_byte()]
.rfind('\n')
.map(|p| p + 1)
.unwrap_or(0);
let line = &text[line_start..];
let spaces: String = line.chars().take_while(|c| c.is_whitespace()).collect();
spaces
}
fn reindent_body(body_text: &str, base_indent: &str) -> String {
let target_indent = format!("{} ", base_indent); body_text
.lines()
.map(|line| {
if line.trim().is_empty() {
String::new()
} else {
format!("{}{}", target_indent, line.trim_start())
}
})
.collect::<Vec<_>>()
.join("\n")
}
fn node_text<'a>(node: Node, source: &'a [u8]) -> &'a str {
std::str::from_utf8(&source[node.byte_range()]).unwrap_or("")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::python::make_parser;
use std::io::Write;
fn parser() -> Parser {
make_parser().unwrap()
}
#[test]
fn detect_simple_fbv() {
let src = b"def home(request):\n return HttpResponse('hello')\n";
let findings = detect("views.py", src, &mut parser()).unwrap();
assert_eq!(findings.len(), 1);
assert_eq!(findings[0].rule_id, "django-fbv");
assert_eq!(findings[0].line, 1);
}
#[test]
fn detect_fbv_with_url_params() {
let src = b"def detail(request, pk):\n return HttpResponse(pk)\n";
let findings = detect("views.py", src, &mut parser()).unwrap();
assert_eq!(findings.len(), 1);
}
#[test]
fn detect_skips_cbv_method() {
let src = b"class MyView(View):\n def get(self, request):\n pass\n";
let findings = detect("views.py", src, &mut parser()).unwrap();
assert!(findings.is_empty(), "CBV method must not be detected");
}
#[test]
fn detect_skips_non_view_function() {
let src = b"def helper(x, y):\n return x + y\n";
let findings = detect("views.py", src, &mut parser()).unwrap();
assert!(findings.is_empty());
}
#[test]
fn detect_skips_comment_lookalike() {
let src = b"# def home(request):\ndef real(x: int):\n return x\n";
let findings = detect("views.py", src, &mut parser()).unwrap();
assert!(findings.is_empty(), "comment must not be detected as FBV");
}
fn write_tmp(content: &str) -> tempfile::NamedTempFile {
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(content.as_bytes()).unwrap();
f
}
#[test]
fn apply_simple_fbv_produces_cbv() {
let tmp = write_tmp("def home(request):\n return HttpResponse('hello')\n");
let path = tmp.path().to_str().unwrap();
let result = apply(path, &mut parser()).unwrap();
assert_eq!(result.rewrites_applied.len(), 1, "one rewrite expected");
assert!(result.backup_path.is_some());
let written = std::fs::read_to_string(path).unwrap();
assert!(
written.contains("class HomeView(View):"),
"class not found:\n{}",
written
);
assert!(
written.contains("def get(self, request):"),
"get method not found:\n{}",
written
);
assert!(
written.contains("HttpResponse"),
"body not preserved:\n{}",
written
);
}
#[test]
fn apply_creates_bak_file() {
let tmp = write_tmp("def list_items(request):\n return HttpResponse('ok')\n");
let path = tmp.path().to_str().unwrap();
apply(path, &mut parser()).unwrap();
let bak = format!("{}.bak", path);
assert!(std::path::Path::new(&bak).exists(), ".bak file must exist");
let bak_content = std::fs::read_to_string(&bak).unwrap();
assert!(
bak_content.contains("def list_items"),
"bak must have original"
);
let _ = std::fs::remove_file(&bak);
}
#[test]
fn apply_no_fbv_returns_empty() {
let tmp = write_tmp("def helper(x):\n return x\n");
let path = tmp.path().to_str().unwrap();
let result = apply(path, &mut parser()).unwrap();
assert!(result.rewrites_applied.is_empty());
assert!(result.backup_path.is_none());
}
#[test]
fn apply_multi_method_emits_warning() {
let src = "def view(request):\n if request.method == 'POST':\n pass\n return HttpResponse('ok')\n";
let tmp = write_tmp(src);
let path = tmp.path().to_str().unwrap();
let result = apply(path, &mut parser()).unwrap();
assert!(
!result.warnings.is_empty(),
"multi-method must emit warning"
);
assert!(
result.rewrites_applied.is_empty(),
"multi-method must not be auto-converted"
);
}
}