use std::fs;
use std::path::Path;
use crate::error::Result;
use crate::types::TypeEnv;
use crate::utils::find_python_files;
#[allow(dead_code)]
pub struct Fixer {
type_env: TypeEnv,
in_place: bool,
}
impl Fixer {
pub fn new(type_env: TypeEnv, in_place: bool) -> Self {
Self { type_env, in_place }
}
pub fn fix_path<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let path = path.as_ref();
if path.is_file() {
self.fix_file(path)?;
} else if path.is_dir() {
for file in find_python_files(path) {
self.fix_file(file)?;
}
}
Ok(())
}
#[allow(dead_code)]
fn fix_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let path = path.as_ref();
if path.extension().and_then(|e| e.to_str()) != Some("py") {
return Ok(());
}
let original = fs::read_to_string(path)?;
let fixed = Self::fix_source(&original);
if fixed != original && self.in_place {
fs::write(path, fixed)?;
}
Ok(())
}
#[allow(dead_code)]
fn generate_annotation(&self, _node: &tree_sitter::Node, _source: &[u8]) -> Option<String> {
None
}
fn fix_source(source: &str) -> String {
let mut used_any = false;
let mut out = String::with_capacity(source.len() + 64);
for line in source.lines() {
let trimmed = line.trim_start();
if trimmed.starts_with("def ") {
if let Some(paren_start) = trimmed.find('(') {
if let Some(paren_end) = trimmed[paren_start..].find(')') {
let indent_len = line.len() - trimmed.len();
let before = &trimmed[..paren_start + 1];
let params = &trimmed[paren_start + 1..paren_start + paren_end];
let after_params = &trimmed[paren_start + paren_end + 1..];
let fixed_params = Self::annotate_params(params, &mut used_any);
let mut signature_tail = after_params.to_string();
if !after_params.contains("->") {
if let Some(colon_idx) = signature_tail.find(':') {
let mut new_tail = String::new();
new_tail.push_str(" -> Any");
new_tail.push_str(&signature_tail[colon_idx..]);
signature_tail = new_tail;
used_any = true;
}
}
let mut rebuilt = String::new();
rebuilt.push_str(&" ".repeat(indent_len));
rebuilt.push_str(before);
rebuilt.push_str(&fixed_params);
rebuilt.push(')');
rebuilt.push_str(&signature_tail);
out.push_str(&rebuilt);
out.push('\n');
continue;
}
}
}
out.push_str(line);
out.push('\n');
}
if used_any
&& !source.contains("from typing import Any")
&& !out.contains("from typing import Any")
{
let lines: Vec<&str> = out.lines().collect();
let mut insert_at = 0usize;
if !lines.is_empty() && (lines[0].starts_with("#!") || lines[0].contains("coding")) {
insert_at = 1;
}
let mut new_out = String::new();
for (i, l) in lines.iter().enumerate() {
if i == insert_at {
new_out.push_str("from typing import Any\n");
}
new_out.push_str(l);
new_out.push('\n');
}
return new_out;
}
out
}
fn annotate_params(params: &str, used_any: &mut bool) -> String {
let mut parts = Vec::new();
for raw in params.split(',') {
let p = raw.to_string();
let trimmed = p.trim();
if trimmed.is_empty() {
parts.push(p);
continue;
}
let leading_idx = p.find(trimmed).unwrap_or(0);
let leading = &p[..leading_idx];
let trailing = &p[leading_idx + trimmed.len()..];
if trimmed.starts_with('*') || trimmed.contains(':') {
parts.push(p);
continue;
}
let mut name = trimmed;
let mut default = "";
if let Some(eq_idx) = trimmed.find('=') {
name = trimmed[..eq_idx].trim();
default = &trimmed[eq_idx..];
}
let fixed = format!("{}{}: Any{}{}", leading, name, default, trailing);
*used_any = true;
parts.push(fixed);
}
parts.join(",")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fixer_initialization() {
let type_env = TypeEnv::new();
let fixer = Fixer::new(type_env, false);
assert!(!fixer.in_place);
}
}