use std::collections::HashSet;
use syn::{Item, UseTree, parse_file};
pub fn extract_existing_routing_imports(content: &str) -> HashSet<String> {
let mut methods = HashSet::new();
if let Ok(file) = parse_file(content) {
for item in &file.items {
if let Item::Use(use_item) = item {
extract_from_use_tree(&use_item.tree, &[], &mut methods);
}
}
}
methods
}
fn extract_from_use_tree(tree: &UseTree, path: &[String], methods: &mut HashSet<String>) {
match tree {
UseTree::Path(use_path) => {
let mut new_path = path.to_vec();
new_path.push(use_path.ident.to_string());
extract_from_use_tree(&use_path.tree, &new_path, methods);
}
UseTree::Name(use_name) => {
if is_routing_path(path) {
methods.insert(use_name.ident.to_string());
}
}
UseTree::Group(use_group) => {
for item in &use_group.items {
extract_from_use_tree(item, path, methods);
}
}
UseTree::Rename(use_rename) => {
if is_routing_path(path) {
methods.insert(use_rename.ident.to_string());
}
}
_ => {}
}
}
fn is_routing_path(path: &[String]) -> bool {
path.len() >= 2 && path[0] == "axum" && path[1] == "routing"
}
pub fn generate_routing_import(methods: &HashSet<String>) -> Option<String> {
if methods.is_empty() {
return None;
}
let mut sorted_methods: Vec<_> = methods.iter().collect();
sorted_methods.sort();
if sorted_methods.len() == 1 {
Some(format!("use axum::routing::{{{}}};", sorted_methods[0]))
} else {
Some(format!(
"use axum::routing::{{{}}};",
sorted_methods
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
))
}
}
pub fn remove_routing_imports(content: &str) -> String {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
for line in lines {
let trimmed = line.trim();
if trimmed.starts_with("use axum::routing")
|| trimmed.starts_with("pub use axum::routing")
{
continue;
} else if trimmed.starts_with("use axum::") || trimmed.starts_with("use axum::{") {
if trimmed.contains("routing::") || trimmed.contains("routing::{") {
let cleaned = remove_routing_patterns(line);
if !cleaned.trim().is_empty()
&& !cleaned.trim().ends_with("use axum::{};")
&& !cleaned.trim().ends_with("use axum::{}")
&& cleaned.trim() != "use axum::{};"
{
result.push(cleaned);
}
} else {
result.push(line.to_string());
}
} else {
result.push(line.to_string());
}
}
result.join("\n")
}
fn remove_routing_patterns(line: &str) -> String {
let mut result = line.to_string();
let methods = ["get", "post", "put", "delete", "patch", "options", "head"];
for method in &methods {
let pattern1 = format!("routing::{}", method);
result = result.replace(&format!(", {}", pattern1), "");
result = result.replace(&format!("{}, ", pattern1), "");
result = result.replace(&pattern1, "");
let pattern2 = format!("routing::{{{}}}", method);
result = result.replace(&format!(", {}", pattern2), "");
result = result.replace(&format!("{}, ", pattern2), "");
result = result.replace(&pattern2, "");
let pattern3_start = format!("routing::{{{}", method);
let pattern3_end = format!("{}}}", method);
result = result.replace(&format!(", {}", pattern3_start), "");
result = result.replace(&format!("{}, ", pattern3_start), "");
result = result.replace(&format!(", {}", pattern3_end), "");
result = result.replace(&format!("{}, ", pattern3_end), "");
}
result = result.replace(",,", ",");
result = result.replace(", ,", ",");
result = result.replace("{,", "{");
result = result.replace(",}", "}");
result = result.replace(", }", "}");
result = result.replace("{ ,", "{");
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_simple_import() {
let code = r#"use axum::routing::get;"#;
let methods = extract_existing_routing_imports(code);
assert!(methods.contains("get"));
assert_eq!(methods.len(), 1);
}
#[test]
fn test_extract_grouped_import() {
let code = r#"use axum::routing::{get, post};"#;
let methods = extract_existing_routing_imports(code);
assert!(methods.contains("get"));
assert!(methods.contains("post"));
assert_eq!(methods.len(), 2);
}
#[test]
fn test_extract_nested_import() {
let code = r#"use axum::{Router, routing::{get, post}};"#;
let methods = extract_existing_routing_imports(code);
assert!(methods.contains("get"));
assert!(methods.contains("post"));
assert_eq!(methods.len(), 2);
}
#[test]
fn test_extract_deeply_nested() {
let code = r#"use axum::{something, routing::get};"#;
let methods = extract_existing_routing_imports(code);
assert!(methods.contains("get"));
assert_eq!(methods.len(), 1);
}
#[test]
fn test_generate_single_import() {
let mut methods = HashSet::new();
methods.insert("get".to_string());
let import = generate_routing_import(&methods).unwrap();
assert_eq!(import, "use axum::routing::{get};");
}
#[test]
fn test_generate_multiple_imports_sorted() {
let mut methods = HashSet::new();
methods.insert("post".to_string());
methods.insert("get".to_string());
methods.insert("delete".to_string());
let import = generate_routing_import(&methods).unwrap();
assert_eq!(import, "use axum::routing::{delete, get, post};");
}
#[test]
fn test_remove_simple_routing_import() {
let code = r#"use axum::Router;
use axum::routing::{get, post};
pub fn router() -> Router {}"#;
let result = remove_routing_imports(code);
assert!(!result.contains("routing"));
assert!(result.contains("use axum::Router"));
}
#[test]
fn test_remove_routing_keep_others() {
let code = r#"use axum::{http::StatusCode, response::IntoResponse, routing::get, Router};"#;
let result = remove_routing_imports(code);
assert!(!result.contains("routing"));
assert!(result.contains("http::StatusCode"));
assert!(result.contains("response::IntoResponse"));
assert!(result.contains("Router"));
}
#[test]
fn test_no_imports() {
let code = r#"use axum::Router;"#;
let methods = extract_existing_routing_imports(code);
assert_eq!(methods.len(), 0);
}
}