use anyhow::Result;
use std::cell::RefCell;
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
thread_local! {
static PARSER: RefCell<Option<PyProjectParser>> = RefCell::new(None);
}
#[derive(Debug, Clone)]
pub struct PackageInfo {
pub name: String, pub directory: String, }
#[derive(Clone)]
pub struct PyProjectParser {
project_root: PathBuf,
package_info: OnceLock<Vec<PackageInfo>>,
}
fn filter_contained_packages(mut packages: Vec<PackageInfo>) -> Vec<PackageInfo> {
packages.sort_by(|a, b| a.directory.len().cmp(&b.directory.len()));
let mut filtered = Vec::new();
for package in packages {
let is_contained = filtered.iter().any(|existing: &PackageInfo| {
let existing_path = existing.directory.trim_end_matches('/');
let package_path = package.directory.trim_end_matches('/');
package_path.starts_with(&format!("{}/", existing_path))
|| (package_path.len() > existing_path.len()
&& package_path.starts_with(existing_path))
});
if !is_contained {
filtered.push(package);
}
}
filtered
}
impl PyProjectParser {
pub fn new(project_root: &Path) -> Self {
Self {
project_root: project_root.to_path_buf(),
package_info: OnceLock::new(),
}
}
fn load_package_info(&self) -> Result<Vec<PackageInfo>> {
let pyproject_path = self.project_root.join("pyproject.toml");
if !pyproject_path.exists() {
return Ok(Vec::new());
}
let content = std::fs::read_to_string(&pyproject_path)?;
let toml: toml::Value = toml::from_str(&content)?;
let mut packages = Vec::new();
if let Some(packages_array) = toml
.get("tool")
.and_then(|t| t.get("poetry"))
.and_then(|p| p.get("packages"))
.and_then(|p| p.as_array())
{
for package in packages_array {
if let Some(include) = package.get("include").and_then(|i| i.as_str()) {
let directory = package
.get("from")
.and_then(|f| f.as_str())
.unwrap_or(include)
.to_string();
packages.push(PackageInfo {
name: include.to_string(),
directory,
});
}
}
}
Ok(filter_contained_packages(packages))
}
pub fn get_package_info(&self) -> &Vec<PackageInfo> {
self.package_info
.get_or_init(|| self.load_package_info().unwrap_or_default())
}
pub fn is_internal_module(&self, module_name: &str) -> bool {
let packages = self.get_package_info();
let top_level = module_name.split('.').next().unwrap_or(module_name);
packages.iter().any(|pkg| pkg.name == top_level)
}
pub fn normalize_module_name(&self, module_name: &str) -> Result<String> {
let packages = self.get_package_info();
for package in packages {
let from_dotted = package.directory.trim_end_matches('/').replace('/', ".");
if module_name.starts_with(&format!("{}.", from_dotted)) {
if let Some(remainder) = module_name.strip_prefix(&format!("{}.", from_dotted)) {
if remainder.starts_with(&format!("{}.", package.name)) {
return Ok(remainder.to_string());
} else if remainder == package.name {
return Ok(package.name.clone());
} else {
return Ok(format!("{}.{}", package.name, remainder));
}
} else if module_name == from_dotted {
return Ok(package.name.clone());
}
}
}
Ok(module_name.to_string())
}
pub fn get_declared_dependencies(&self) -> Result<Vec<String>> {
let pyproject_path = self.project_root.join("pyproject.toml");
if !pyproject_path.exists() {
return Ok(Vec::new());
}
let content = std::fs::read_to_string(&pyproject_path)?;
let toml: toml::Value = toml::from_str(&content)?;
let mut dependencies = Vec::new();
if let Some(deps) = toml
.get("tool")
.and_then(|t| t.get("poetry"))
.and_then(|p| p.get("dependencies"))
.and_then(|d| d.as_table())
{
for (dep_name, _dep_spec) in deps {
if dep_name != "python" {
dependencies.push(normalize_dependency_name(dep_name));
}
}
}
if let Some(groups) = toml
.get("tool")
.and_then(|t| t.get("poetry"))
.and_then(|p| p.get("group"))
.and_then(|g| g.as_table())
{
for (_group_name, group_config) in groups {
if let Some(group_deps) =
group_config.get("dependencies").and_then(|d| d.as_table())
{
for (dep_name, _dep_spec) in group_deps {
dependencies.push(normalize_dependency_name(dep_name));
}
}
}
}
dependencies.sort();
dependencies.dedup();
Ok(dependencies)
}
pub fn get_used_externals(&self) -> Result<Vec<String>> {
let used_externals_path = self.project_root.join(".used-externals.txt");
if !used_externals_path.exists() {
return Ok(Vec::new());
}
let content = std::fs::read_to_string(&used_externals_path)?;
let mut externals = Vec::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let package_name = if let Some(comment_pos) = line.find('#') {
line[..comment_pos].trim()
} else {
line
};
if !package_name.is_empty() {
externals.push(normalize_dependency_name(package_name));
}
}
externals.sort();
externals.dedup();
Ok(externals)
}
}
fn normalize_dependency_name(dep_name: &str) -> String {
dep_name.to_lowercase().replace('_', "-")
}
pub fn init(project_root: &Path) {
PARSER.with(|parser| {
*parser.borrow_mut() = Some(PyProjectParser::new(project_root));
});
}
#[cfg(test)]
pub fn init_for_test(project_root: &Path) {
init(project_root);
}
#[cfg(test)]
pub fn reset_for_test() {
PARSER.with(|parser| {
*parser.borrow_mut() = None;
});
}
pub fn is_internal_module(module_name: &str) -> bool {
PARSER.with(|parser| {
if let Some(p) = parser.borrow().as_ref() {
p.is_internal_module(module_name)
} else {
false
}
})
}
pub fn normalize_module_name(module_name: &str) -> Result<String> {
PARSER.with(|parser| {
if let Some(p) = parser.borrow().as_ref() {
p.normalize_module_name(module_name)
} else {
Ok(module_name.to_string())
}
})
}
pub fn get_declared_dependencies() -> Result<Vec<String>> {
PARSER.with(|parser| {
if let Some(p) = parser.borrow().as_ref() {
p.get_declared_dependencies()
} else {
Ok(Vec::new())
}
})
}
pub fn get_used_externals() -> Result<Vec<String>> {
PARSER.with(|parser| {
if let Some(p) = parser.borrow().as_ref() {
p.get_used_externals()
} else {
Ok(Vec::new())
}
})
}
pub fn compute_module_name(file_path: &Path, project_root: &Path) -> Result<String> {
let relative_path = file_path.strip_prefix(project_root).map_err(|_| {
anyhow::anyhow!(
"File path '{}' is not within project root '{}'",
file_path.display(),
project_root.display()
)
})?;
let mut parts = Vec::new();
for component in relative_path.components() {
if let std::path::Component::Normal(name) = component {
if let Some(name_str) = name.to_str() {
if name_str.ends_with(".py") {
let file_stem = name_str.strip_suffix(".py").unwrap();
if file_stem != "__init__" {
parts.push(file_stem.to_string());
}
} else {
parts.push(name_str.to_string());
}
}
}
}
if parts.is_empty() {
return Err(anyhow::anyhow!(
"Could not determine module name from file path"
));
}
let full_name = parts.join(".");
normalize_module_name(&full_name)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_get_package_info() {
let temp_dir = TempDir::new().unwrap();
let pyproject_content = r#"
[tool.poetry]
packages = [
{ include = "common", from = "common/" },
{ include = "mymodule", from = "MyModule/" },
]
"#;
fs::write(temp_dir.path().join("pyproject.toml"), pyproject_content).unwrap();
let parser = PyProjectParser::new(temp_dir.path());
let packages = parser.get_package_info();
assert_eq!(packages.len(), 2);
let common = packages.iter().find(|p| p.name == "common").unwrap();
assert_eq!(common.directory, "common/");
let mymodule = packages.iter().find(|p| p.name == "mymodule").unwrap();
assert_eq!(mymodule.directory, "MyModule/");
}
#[test]
fn test_is_internal_module() {
let temp_dir = TempDir::new().unwrap();
let pyproject_content = r#"
[tool.poetry]
packages = [
{ include = "common", from = "common/" },
{ include = "mymodule", from = "MyModule/" },
]
"#;
fs::write(temp_dir.path().join("pyproject.toml"), pyproject_content).unwrap();
let parser = PyProjectParser::new(temp_dir.path());
assert!(parser.is_internal_module("common"));
assert!(parser.is_internal_module("common.utils"));
assert!(!parser.is_internal_module("numpy"));
}
#[test]
fn test_filter_contained_packages() {
let packages = vec![
PackageInfo {
name: "medcat".to_string(),
directory: "ehr_data_formatter/medcat/".to_string(),
},
PackageInfo {
name: "ehr_data_formatter".to_string(),
directory: "ehr_data_formatter/".to_string(),
},
PackageInfo {
name: "other".to_string(),
directory: "other/".to_string(),
},
];
let filtered = filter_contained_packages(packages);
assert_eq!(filtered.len(), 2);
assert!(filtered.iter().any(|p| p.name == "ehr_data_formatter"));
assert!(filtered.iter().any(|p| p.name == "other"));
assert!(!filtered.iter().any(|p| p.name == "medcat"));
}
#[test]
fn test_compute_module_name() {
let temp_dir = TempDir::new().unwrap();
reset_for_test();
init(temp_dir.path());
let project_root = temp_dir.path();
let file_path = project_root.join("main.py");
fs::write(&file_path, "").unwrap();
assert_eq!(
compute_module_name(&file_path, project_root).unwrap(),
"main"
);
fs::create_dir_all(project_root.join("package")).unwrap();
let file_path = project_root.join("package/module.py");
fs::write(&file_path, "").unwrap();
assert_eq!(
compute_module_name(&file_path, project_root).unwrap(),
"package.module"
);
let file_path = project_root.join("package/__init__.py");
fs::write(&file_path, "").unwrap();
assert_eq!(
compute_module_name(&file_path, project_root).unwrap(),
"package"
);
}
#[test]
fn test_get_declared_dependencies() {
let temp_dir = TempDir::new().unwrap();
let pyproject_content = r#"
[tool.poetry.dependencies]
python = ">=3.10,<3.11"
numpy = "^1.24.3"
pandas = "^2.0.3"
torch = { version = "2.3.0"}
[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
jupyter = "^1.0.0"
[tool.poetry.group.optional.dependencies]
matplotlib = "^3.8.2"
"#;
fs::write(temp_dir.path().join("pyproject.toml"), pyproject_content).unwrap();
let parser = PyProjectParser::new(temp_dir.path());
let deps = parser.get_declared_dependencies().unwrap();
assert!(deps.contains(&"numpy".to_string()));
assert!(deps.contains(&"pandas".to_string()));
assert!(deps.contains(&"torch".to_string()));
assert!(deps.contains(&"pytest".to_string()));
assert!(deps.contains(&"jupyter".to_string()));
assert!(deps.contains(&"matplotlib".to_string()));
assert!(!deps.contains(&"python".to_string()));
assert_eq!(deps.len(), 6);
}
#[test]
fn test_get_used_externals_empty_file() {
let temp_dir = TempDir::new().unwrap();
let parser = PyProjectParser::new(temp_dir.path());
let externals = parser.get_used_externals().unwrap();
assert!(externals.is_empty());
}
#[test]
fn test_get_used_externals_with_content() {
let temp_dir = TempDir::new().unwrap();
let used_externals_content = r#"# Build tools
setuptools
wheel
# Database drivers
psycopg2-binary
SQLAlchemy # inline comment
# Testing frameworks
pytest-asyncio
# Empty lines and comments should be ignored
Django_REST_Framework # Should be normalized to django-rest-framework
"#;
fs::write(temp_dir.path().join(".used-externals.txt"), used_externals_content).unwrap();
let parser = PyProjectParser::new(temp_dir.path());
let externals = parser.get_used_externals().unwrap();
assert_eq!(externals.len(), 6);
assert!(externals.contains(&"setuptools".to_string()));
assert!(externals.contains(&"wheel".to_string()));
assert!(externals.contains(&"psycopg2-binary".to_string()));
assert!(externals.contains(&"sqlalchemy".to_string()));
assert!(externals.contains(&"pytest-asyncio".to_string()));
assert!(externals.contains(&"django-rest-framework".to_string()));
assert_eq!(externals[0], "django-rest-framework");
assert_eq!(externals[1], "psycopg2-binary");
}
#[test]
fn test_get_used_externals_comments_and_whitespace() {
let temp_dir = TempDir::new().unwrap();
let used_externals_content = r#"
# This is a comment at the start
numpy # Trailing comment with spaces
pandas
# Another comment
redis
matplotlib # Comment at end
"#;
fs::write(temp_dir.path().join(".used-externals.txt"), used_externals_content).unwrap();
let parser = PyProjectParser::new(temp_dir.path());
let externals = parser.get_used_externals().unwrap();
assert_eq!(externals.len(), 4);
assert!(externals.contains(&"numpy".to_string()));
assert!(externals.contains(&"pandas".to_string()));
assert!(externals.contains(&"redis".to_string()));
assert!(externals.contains(&"matplotlib".to_string()));
}
#[test]
fn test_get_used_externals_deduplication() {
let temp_dir = TempDir::new().unwrap();
let used_externals_content = r#"numpy
NumPy # Should normalize to same as above
NUMPY # Should normalize to same as above
requests
"#;
fs::write(temp_dir.path().join(".used-externals.txt"), used_externals_content).unwrap();
let parser = PyProjectParser::new(temp_dir.path());
let externals = parser.get_used_externals().unwrap();
assert_eq!(externals.len(), 2);
assert!(externals.contains(&"numpy".to_string()));
assert!(externals.contains(&"requests".to_string()));
}
}