use std::collections::HashSet;
use super::{LanguageExtractor, SemanticUnit, UnitKind};
use crate::context::FileContext;
#[derive(Debug, Clone, Copy)]
pub struct PythonExtractor;
impl LanguageExtractor for PythonExtractor {
fn language_id(&self) -> &str {
"python"
}
fn extensions(&self) -> &[&str] {
&["py", "pyi"]
}
fn extract_units(
&self,
content: &str,
file_path: &str,
changed_lines: &HashSet<u32>,
) -> Vec<SemanticUnit> {
if changed_lines.is_empty() {
return vec![];
}
let lines: Vec<&str> = content.lines().collect();
let constructs = detect_python_constructs(&lines);
let mut units = Vec::new();
let mut covered_lines: HashSet<u32> = HashSet::new();
for construct in &constructs {
let overlaps = changed_lines
.iter()
.any(|&l| l >= construct.start_line && l <= construct.end_line);
if overlaps {
let changed_in_unit: Vec<u32> = changed_lines
.iter()
.copied()
.filter(|&l| l >= construct.start_line && l <= construct.end_line)
.collect();
for &l in &changed_in_unit {
covered_lines.insert(l);
}
let start_idx = (construct.start_line - 1) as usize;
let end_idx = construct.end_line as usize;
let unit_content = lines
.get(start_idx..end_idx.min(lines.len()))
.unwrap_or(&[])
.join("\n");
units.push(SemanticUnit {
kind: construct.kind,
name: construct.name.clone(),
file: file_path.to_string(),
start_line: construct.start_line,
end_line: construct.end_line,
content: unit_content,
changed_lines: changed_in_unit,
context: FileContext::default(),
});
}
}
let orphans: Vec<u32> = changed_lines
.iter()
.copied()
.filter(|l| !covered_lines.contains(l))
.collect();
if !orphans.is_empty() {
let total = u32::try_from(lines.len()).unwrap_or(u32::MAX);
let min_orphan = *orphans.iter().min().unwrap_or(&1);
let max_orphan = *orphans.iter().max().unwrap_or(&total);
let start = min_orphan.saturating_sub(5).max(1);
let end = (max_orphan + 5).min(total);
let start_idx = (start - 1) as usize;
let end_idx = end as usize;
let unit_content = lines
.get(start_idx..end_idx.min(lines.len()))
.unwrap_or(&[])
.join("\n");
let mut sorted = orphans;
sorted.sort_unstable();
units.push(SemanticUnit {
kind: UnitKind::TopLevel,
name: format!("lines {}-{}", start, end),
file: file_path.to_string(),
start_line: start,
end_line: end,
content: unit_content,
changed_lines: sorted,
context: FileContext::default(),
});
}
units
}
}
struct PythonConstruct {
kind: UnitKind,
name: String,
start_line: u32,
end_line: u32,
}
fn detect_python_constructs(lines: &[&str]) -> Vec<PythonConstruct> {
let mut constructs = Vec::new();
let mut i = 0;
while i < lines.len() {
let trimmed = lines[i].trim();
let indent = indentation(lines[i]);
let line_num = u32::try_from(i + 1).unwrap_or(u32::MAX);
if (trimmed.starts_with("def ") || trimmed.starts_with("async def ")) && indent == 0 {
let name = extract_python_name(trimmed, "def ");
let start = scan_back_decorators(lines, i);
let end = find_block_end(lines, i);
constructs.push(PythonConstruct {
kind: UnitKind::Function,
name,
start_line: u32::try_from(start + 1).unwrap_or(1),
end_line: u32::try_from(end + 1).unwrap_or(line_num),
});
i = end + 1;
continue;
}
if trimmed.starts_with("class ") && indent == 0 {
let name = extract_python_name(trimmed, "class ");
let start = scan_back_decorators(lines, i);
let end = find_block_end(lines, i);
constructs.push(PythonConstruct {
kind: UnitKind::Class,
name,
start_line: u32::try_from(start + 1).unwrap_or(1),
end_line: u32::try_from(end + 1).unwrap_or(line_num),
});
i = end + 1;
continue;
}
if (trimmed.starts_with("import ") || trimmed.starts_with("from ")) && indent == 0 {
let group_end = find_import_group_end(lines, i);
constructs.push(PythonConstruct {
kind: UnitKind::DeclarationGroup,
name: "imports".to_string(),
start_line: line_num,
end_line: u32::try_from(group_end + 1).unwrap_or(line_num),
});
i = group_end + 1;
continue;
}
if indent == 0
&& !trimmed.is_empty()
&& !trimmed.starts_with('#')
&& !trimmed.starts_with("def ")
&& !trimmed.starts_with("async def ")
&& !trimmed.starts_with("class ")
&& !trimmed.starts_with("import ")
&& !trimmed.starts_with("from ")
&& !trimmed.starts_with('@')
&& (trimmed.contains('=') || trimmed.contains(':'))
{
let group_end = find_assignment_group_end(lines, i);
let name = trimmed
.split(['=', ':', ' '])
.next()
.unwrap_or("variable")
.to_string();
constructs.push(PythonConstruct {
kind: UnitKind::DeclarationGroup,
name,
start_line: line_num,
end_line: u32::try_from(group_end + 1).unwrap_or(line_num),
});
i = group_end + 1;
continue;
}
i += 1;
}
constructs
}
fn indentation(line: &str) -> usize {
line.len() - line.trim_start().len()
}
fn find_block_end(lines: &[&str], def_line: usize) -> usize {
let def_indent = indentation(lines[def_line]);
let mut end = def_line;
for (j, line) in lines.iter().enumerate().skip(def_line + 1) {
let trimmed = line.trim();
if trimmed.is_empty() {
continue; }
let ind = indentation(line);
if ind <= def_indent {
break;
}
end = j;
}
end
}
fn find_import_group_end(lines: &[&str], start: usize) -> usize {
let mut end = start;
for j in (start + 1)..lines.len() {
let trimmed = lines[j].trim();
if trimmed.starts_with("import ") || trimmed.starts_with("from ") {
end = j;
} else if trimmed.is_empty() {
if j + 1 < lines.len() {
let next = lines[j + 1].trim();
if next.starts_with("import ") || next.starts_with("from ") {
continue;
}
}
break;
} else {
break;
}
}
end
}
fn find_assignment_group_end(lines: &[&str], start: usize) -> usize {
let mut end = start;
for (j, line) in lines.iter().enumerate().skip(start + 1) {
let trimmed = line.trim();
let indent = indentation(line);
if indent == 0 && !trimmed.is_empty() && !trimmed.starts_with('#') {
if trimmed.contains('=') || trimmed.contains(':') {
if !trimmed.starts_with("def ")
&& !trimmed.starts_with("async def ")
&& !trimmed.starts_with("class ")
&& !trimmed.starts_with("import ")
&& !trimmed.starts_with("from ")
{
end = j;
continue;
}
}
break;
} else if trimmed.is_empty() {
continue;
} else {
break;
}
}
end
}
fn scan_back_decorators(lines: &[&str], idx: usize) -> usize {
let mut start = idx;
while start > 0 {
let prev = lines[start - 1].trim();
if prev.starts_with('@') {
start -= 1;
} else {
break;
}
}
start
}
fn extract_python_name(trimmed: &str, prefix: &str) -> String {
let rest = if let Some(r) = trimmed.strip_prefix("async def ") {
r
} else if let Some(r) = trimmed.strip_prefix(prefix) {
r
} else {
trimmed
};
rest.chars()
.take_while(|c| c.is_alphanumeric() || *c == '_')
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn extract(code: &str, changed: &[u32]) -> Vec<SemanticUnit> {
let ext = PythonExtractor;
let changed_set: HashSet<u32> = changed.iter().copied().collect();
ext.extract_units(code, "test.py", &changed_set)
}
#[test]
fn python_extracts_function() {
let code = "def foo():\n x = 1\n return x\n";
let units = extract(code, &[2]);
assert_eq!(units.len(), 1, "Should extract one function");
assert_eq!(units[0].kind, UnitKind::Function);
assert_eq!(units[0].name, "foo");
}
#[test]
fn python_extracts_class() {
let code = "class Foo:\n def __init__(self):\n self.x = 1\n";
let units = extract(code, &[3]);
assert_eq!(units.len(), 1, "Should extract one class");
assert_eq!(units[0].kind, UnitKind::Class);
assert_eq!(units[0].name, "Foo");
}
#[test]
fn python_extracts_module_level_assignment() {
let code = "MAX_SIZE = 100\nMIN_SIZE = 10\n\ndef foo():\n pass\n";
let units = extract(code, &[1]);
assert_eq!(units.len(), 1, "Should extract one declaration group");
assert_eq!(units[0].kind, UnitKind::DeclarationGroup);
}
#[test]
fn python_extracts_import_group() {
let code = "import os\nimport sys\nfrom pathlib import Path\n\ndef main():\n pass\n";
let units = extract(code, &[2]);
assert_eq!(units.len(), 1, "Should extract one import group");
assert_eq!(units[0].kind, UnitKind::DeclarationGroup);
assert_eq!(units[0].name, "imports");
}
}