use std::{
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
use clap::Parser;
use rayon::prelude::{ParallelBridge, ParallelIterator};
use walkdir::DirEntry;
const PARAMETERS_KIND: u16 = 147;
const _TYPED_PARAMETER: u16 = 206;
const _TYPED_DEFAULT_PARAMETER: u16 = 183;
const IDENTIFIER: u16 = 1;
const DEFAULT_PARAMETER: u16 = 182;
#[derive(clap::Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
path: String,
#[arg(alias = "ih", long, default_value_t = false)]
ignore_hidden: bool,
#[arg(alias = "it", long, default_value_t = false)]
ignore_tests: bool,
#[arg(alias = "ir", long, default_value_t = false)]
ignore_return: bool,
}
pub fn get_tree_from_file(
parser: &mut tree_sitter::Parser,
path: &str,
) -> (tree_sitter::Tree, Vec<u8>) {
let contents =
std::fs::read_to_string(path).unwrap_or_else(|_| panic!("File in {path} should exist."));
let contents_to_return = contents.as_bytes().to_vec();
(parser.parse(contents, None).unwrap(), contents_to_return)
}
pub fn create_python_parser() -> tree_sitter::Parser {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(tree_sitter_python::language())
.expect("Error loading Python grammar");
parser
}
#[derive(Debug)]
struct Position {
start: tree_sitter::Point,
_end: tree_sitter::Point,
missing_type: MissingType,
}
#[derive(Debug)]
enum MissingType {
Return(String),
Parameter(String),
}
fn find_missing_types_positions(
source_code: &[u8],
tree: tree_sitter::Tree,
ignore_return: bool,
) -> Vec<Position> {
let walk = tree.walk();
let mut results = Vec::new();
for node in tree_sitter_traversal::traverse(walk, tree_sitter_traversal::Order::Pre) {
if node.kind() == "function_definition" {
let mut cursor = node.walk();
let mut has_return_type = false;
for child in node.children(&mut cursor) {
if child.kind() == "type" {
has_return_type = true;
}
if child.kind_id() == PARAMETERS_KIND {
let mut cursor = child.walk();
for inner_child in child.children(&mut cursor) {
if matches!(inner_child.kind_id(), IDENTIFIER | DEFAULT_PARAMETER) {
let utf8_text = inner_child.utf8_text(source_code);
if let Ok("self") = utf8_text {
continue;
}
let start = inner_child.start_position();
let end = inner_child.end_position();
results.push(Position {
start,
_end: end,
missing_type: MissingType::Parameter(
utf8_text.expect("Parameter should have name").to_string(),
),
});
}
}
}
}
if !has_return_type & !ignore_return {
let identifier = node.child(1).expect("Function should have name.");
let mut function_name = identifier
.utf8_text(source_code)
.expect("Function should have name.")
.to_string();
if function_name == "def" {
let identifier = node.child(2).expect("Function should have name.");
function_name = identifier
.utf8_text(source_code)
.expect("Function should have name.")
.to_string();
}
if function_name == "main" {
continue;
}
results.push(Position {
start: node.start_position(),
_end: node.end_position(),
missing_type: MissingType::Return(function_name),
});
}
}
}
results
}
fn get_message_from_positions(positions: &[Position]) -> String {
let mut message = String::new();
for position in positions {
match &position.missing_type {
MissingType::Return(name) => {
message += &format!(
"Function '{name}' in line {} and column {} is missing a return type.\n",
position.start.row + 1,
position.start.column + 1
)
}
MissingType::Parameter(name) => {
message += &format!(
"Parameter '{name}' in line {} and column {} is missing a type hint.\n",
position.start.row + 1,
position.start.column + 1
)
}
}
}
message
}
fn main() {
let args = Args::parse();
let path = args.path;
let ignore_hidden = args.ignore_hidden;
let ignore_tests = args.ignore_tests;
let ignore_return = args.ignore_return;
let path = PathBuf::from(&path);
if path.is_dir() {
let message = Arc::new(Mutex::from(String::new()));
let walkdir = walkdir::WalkDir::new(path);
let mut filters: Vec<Box<dyn Filter + Sync>> = Vec::new();
if ignore_hidden {
filters.push(Box::new(NotHidden));
}
if ignore_tests {
filters.push(Box::new(NotTest));
}
walkdir
.into_iter()
.filter_entry(|x| filters.iter().all(|filter| filter.should_be_processed(x)))
.flatten()
.par_bridge()
.for_each(|entry| add_to_message_from_file(entry, Arc::clone(&message), ignore_return));
let message = message
.as_ref()
.lock()
.expect("Should be able to retrieve message at end of parallelization.");
if message.is_empty() {
println!("✨ All good!");
} else {
print!("{}", message);
}
} else {
let message = get_message_from_file(path.as_path(), ignore_return);
if message.is_empty() {
println!("✨ All good!");
} else {
print!("{}", message);
}
}
}
trait Filter {
fn should_be_processed(&self, entry: &DirEntry) -> bool;
}
struct NotHidden;
impl Filter for NotHidden {
fn should_be_processed(&self, entry: &DirEntry) -> bool {
entry
.file_name()
.to_str()
.map(|s| !s.starts_with('.') || s == ".")
.unwrap_or(false)
}
}
struct NotTest;
impl Filter for NotTest {
fn should_be_processed(&self, entry: &DirEntry) -> bool {
entry
.file_name()
.to_str()
.map(|s| !s.starts_with("test_") && s != "tests")
.unwrap_or(false)
}
}
fn add_to_message_from_file(
entry: walkdir::DirEntry,
message: Arc<Mutex<String>>,
ignore_return: bool,
) {
if !entry.metadata().expect("Should have metadata.").is_dir()
&& entry
.file_name()
.to_str()
.expect("Should be valid path name.")
.ends_with(".py")
{
let messages_from_file = get_message_from_file(entry.path(), ignore_return);
if messages_from_file.is_empty() {
return;
}
let mut message = message
.lock()
.expect("Should be able to get a lock on the message.");
*message += format!(
"File: {}\n",
entry.path().to_str().expect("Should be valid path name.")
)
.as_str();
let messages_from_file = messages_from_file.split('\n');
for line in messages_from_file {
*message += &(" ".to_string() + line + "\n")
}
}
}
fn get_message_from_file(file: &Path, ignore_return: bool) -> String {
let mut parser = create_python_parser();
let (tree, source_code) = get_tree_from_file(
&mut parser,
file.to_str().expect("Should be valid path name."),
);
let positions = find_missing_types_positions(&source_code, tree, ignore_return);
get_message_from_positions(&positions)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tree_from_test_file() {
let mut parser = create_python_parser();
get_tree_from_file(&mut parser, "test_file.py");
}
#[test]
fn find_args_test() {
let mut parser = create_python_parser();
let (tree, source_code) = get_tree_from_file(&mut parser, "test_file.py");
println!(
"{:?}",
find_missing_types_positions(&source_code, tree, false)
);
}
}