use crate::error::{Result, TailwindError};
use std::collections::HashSet;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct AstParser {
classes: HashSet<String>,
responsive_classes: std::collections::HashMap<String, HashSet<String>>,
conditional_classes: std::collections::HashMap<String, HashSet<String>>,
parsed_files: HashSet<String>,
}
impl AstParser {
pub fn new() -> Self {
Self {
classes: HashSet::new(),
responsive_classes: std::collections::HashMap::new(),
conditional_classes: std::collections::HashMap::new(),
parsed_files: HashSet::new(),
}
}
pub fn parse_file(&mut self, path: &Path) -> Result<()> {
let content = std::fs::read_to_string(path)
.map_err(|e| TailwindError::build(format!("Failed to read file {:?}: {}", path, e)))?;
self.parse_content(&content)?;
self.parsed_files.insert(path.to_string_lossy().to_string());
Ok(())
}
pub fn parse_content(&mut self, content: &str) -> Result<()> {
if let Ok(syntax_tree) = syn::parse_file(content) {
let mut visitor = ClassVisitor::new();
visitor.visit_file(&syntax_tree);
self.merge_visitor_results(visitor);
return Ok(());
}
if let Ok(expr) = syn::parse_str::<syn::Expr>(content) {
let mut visitor = ClassVisitor::new();
visitor.visit_expr(&expr);
self.merge_visitor_results(visitor);
return Ok(());
}
if let Ok(stmt) = syn::parse_str::<syn::Stmt>(content) {
let mut visitor = ClassVisitor::new();
visitor.visit_stmt(&stmt);
self.merge_visitor_results(visitor);
return Ok(());
}
Err(TailwindError::build(format!(
"Failed to parse Rust code: {}",
content
)))
}
fn merge_visitor_results(&mut self, visitor: ClassVisitor) {
self.classes.extend(visitor.classes);
for (breakpoint, classes) in visitor.responsive_classes {
self.responsive_classes
.entry(breakpoint)
.or_default()
.extend(classes);
}
for (condition, classes) in visitor.conditional_classes {
self.conditional_classes
.entry(condition)
.or_default()
.extend(classes);
}
}
pub fn get_classes(&self) -> &HashSet<String> {
&self.classes
}
pub fn get_responsive_classes(&self, breakpoint: &str) -> Option<&HashSet<String>> {
self.responsive_classes.get(breakpoint)
}
pub fn get_conditional_classes(&self, condition: &str) -> Option<&HashSet<String>> {
self.conditional_classes.get(condition)
}
pub fn get_all_responsive_classes(
&self,
) -> &std::collections::HashMap<String, HashSet<String>> {
&self.responsive_classes
}
pub fn get_all_conditional_classes(
&self,
) -> &std::collections::HashMap<String, HashSet<String>> {
&self.conditional_classes
}
pub fn parsed_file_count(&self) -> usize {
self.parsed_files.len()
}
pub fn class_count(&self) -> usize {
self.classes.len()
}
pub fn has_parsed_file(&self, path: &str) -> bool {
self.parsed_files.contains(path)
}
pub fn clear(&mut self) {
self.classes.clear();
self.responsive_classes.clear();
self.conditional_classes.clear();
self.parsed_files.clear();
}
}
impl Default for AstParser {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
struct ClassVisitor {
classes: HashSet<String>,
responsive_classes: std::collections::HashMap<String, HashSet<String>>,
conditional_classes: std::collections::HashMap<String, HashSet<String>>,
}
impl ClassVisitor {
fn new() -> Self {
Self {
classes: HashSet::new(),
responsive_classes: std::collections::HashMap::new(),
conditional_classes: std::collections::HashMap::new(),
}
}
fn visit_file(&mut self, file: &syn::File) {
for item in &file.items {
self.visit_item(item);
}
}
fn visit_item(&mut self, item: &syn::Item) {
match item {
syn::Item::Fn(func) => {
for stmt in &func.block.stmts {
self.visit_stmt(stmt);
}
}
syn::Item::Impl(impl_item) => {
for item in &impl_item.items {
self.visit_impl_item(item);
}
}
_ => {}
}
}
fn visit_impl_item(&mut self, item: &syn::ImplItem) {
if let syn::ImplItem::Fn(method) = item {
for stmt in &method.block.stmts {
self.visit_stmt(stmt);
}
}
}
fn visit_stmt(&mut self, stmt: &syn::Stmt) {
match stmt {
syn::Stmt::Expr(expr, _) => self.visit_expr(expr),
syn::Stmt::Local(local) => {
if let Some(init) = &local.init {
self.visit_expr(&init.expr);
}
}
_ => {}
}
}
fn visit_expr(&mut self, expr: &syn::Expr) {
match expr {
syn::Expr::MethodCall(method_call) => self.visit_method_call(method_call),
syn::Expr::Call(call) => self.visit_call(call),
syn::Expr::Block(block) => {
for stmt in &block.block.stmts {
self.visit_stmt(stmt);
}
}
syn::Expr::If(if_expr) => {
self.visit_expr(&if_expr.cond);
self.visit_block(&if_expr.then_branch);
if let Some(else_branch) = &if_expr.else_branch {
self.visit_expr(&else_branch.1);
}
}
syn::Expr::Match(match_expr) => {
self.visit_expr(&match_expr.expr);
for arm in &match_expr.arms {
self.visit_expr(&arm.body);
}
}
syn::Expr::Return(return_expr) => {
if let Some(expr) = &return_expr.expr {
self.visit_expr(expr);
}
}
syn::Expr::Assign(assign_expr) => {
self.visit_expr(&assign_expr.right);
}
_ => {}
}
}
fn visit_block(&mut self, block: &syn::Block) {
for stmt in &block.stmts {
self.visit_stmt(stmt);
}
}
fn visit_method_call(&mut self, method_call: &syn::ExprMethodCall) {
let method_name = method_call.method.to_string();
if self.is_class_builder_method(&method_name) {
self.extract_class_from_method_call(method_call, &method_name);
}
self.visit_expr(&method_call.receiver);
for arg in &method_call.args {
self.visit_expr(arg);
}
}
fn visit_call(&mut self, call: &syn::ExprCall) {
if let syn::Expr::Path(path) = &*call.func {
if let Some(ident) = path.path.get_ident() {
if ident == "class" {
self.extract_direct_class_call(call);
}
}
}
for arg in &call.args {
self.visit_expr(arg);
}
}
fn is_class_builder_method(&self, method_name: &str) -> bool {
matches!(
method_name,
"class"
| "padding"
| "margin"
| "background_color"
| "text_color"
| "border_color"
| "ring_color"
| "width"
| "height"
| "display"
| "flex"
| "grid"
| "responsive"
| "conditional"
| "custom"
)
}
fn extract_class_from_method_call(
&mut self,
method_call: &syn::ExprMethodCall,
method_name: &str,
) {
if let Some(arg) = method_call.args.first() {
match method_name {
"class" => {
if let Ok(class_name) = self.extract_string_literal(arg) {
self.classes.insert(class_name);
}
}
"padding" => {
if let Ok(spacing_value) = self.extract_spacing_value(arg) {
self.classes.insert(format!("p-{}", spacing_value));
}
}
"margin" => {
if let Ok(spacing_value) = self.extract_spacing_value(arg) {
self.classes.insert(format!("m-{}", spacing_value));
}
}
"background_color" => {
if let Ok(color_value) = self.extract_color_value(arg) {
self.classes.insert(format!("bg-{}", color_value));
}
}
"text_color" => {
if let Ok(color_value) = self.extract_color_value(arg) {
self.classes.insert(format!("text-{}", color_value));
}
}
"responsive" => {
self.extract_responsive_classes(method_call);
}
"conditional" => {
self.extract_conditional_classes(method_call);
}
_ => {}
}
}
}
fn extract_direct_class_call(&mut self, call: &syn::ExprCall) {
if let Some(arg) = call.args.first() {
if let Ok(class_name) = self.extract_string_literal(arg) {
self.classes.insert(class_name);
}
}
}
fn extract_string_literal(&self, expr: &syn::Expr) -> Result<String> {
match expr {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) => Ok(lit_str.value()),
_ => Err(TailwindError::build("Expected string literal".to_string())),
}
}
fn extract_spacing_value(&self, expr: &syn::Expr) -> Result<String> {
match expr {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(lit_int),
..
}) => Ok(lit_int.to_string()),
syn::Expr::Path(path) => {
if let Some(ident) = path.path.get_ident() {
Ok(ident.to_string().to_lowercase())
} else {
Err(TailwindError::build("Expected identifier".to_string()))
}
}
_ => Err(TailwindError::build("Expected spacing value".to_string())),
}
}
fn extract_color_value(&self, expr: &syn::Expr) -> Result<String> {
match expr {
syn::Expr::Path(path) => {
if let Some(ident) = path.path.get_ident() {
Ok(ident.to_string().to_lowercase())
} else {
Err(TailwindError::build(
"Expected color identifier".to_string(),
))
}
}
_ => Err(TailwindError::build("Expected color value".to_string())),
}
}
fn extract_responsive_classes(&mut self, method_call: &syn::ExprMethodCall) {
if let Some(arg) = method_call.args.first() {
if let Ok(breakpoint) = self.extract_string_literal(arg) {
self.responsive_classes
.entry(breakpoint)
.or_default()
.insert("responsive-class".to_string());
}
}
}
fn extract_conditional_classes(&mut self, method_call: &syn::ExprMethodCall) {
if let Some(arg) = method_call.args.first() {
if let Ok(condition) = self.extract_string_literal(arg) {
self.conditional_classes
.entry(condition)
.or_default()
.insert("conditional-class".to_string());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ast_parser_creation() {
let parser = AstParser::new();
assert_eq!(parser.class_count(), 0);
assert_eq!(parser.parsed_file_count(), 0);
}
#[test]
fn test_parse_content() {
let mut parser = AstParser::new();
let content = r#"
use tailwind_rs_core::ClassBuilder;
fn create_button() -> String {
ClassBuilder::new()
.class("px-4")
.class("py-2")
.class("bg-blue-500")
.build_string()
}
"#;
parser.parse_content(content).unwrap();
println!("Extracted classes: {:?}", parser.get_classes());
assert!(parser.get_classes().contains("px-4"));
assert!(parser.get_classes().contains("py-2"));
assert!(parser.get_classes().contains("bg-blue-500"));
}
#[test]
fn test_parse_file() {
let mut parser = AstParser::new();
let temp_file = std::env::temp_dir().join("test_rust_file.rs");
let content = r#"
use tailwind_rs_core::ClassBuilder;
fn test() -> String {
ClassBuilder::new().class("test-class").build_string()
}
"#;
std::fs::write(&temp_file, content).unwrap();
parser.parse_file(&temp_file).unwrap();
assert!(parser.get_classes().contains("test-class"));
assert_eq!(parser.parsed_file_count(), 1);
std::fs::remove_file(&temp_file).unwrap();
}
#[test]
fn test_clear() {
let mut parser = AstParser::new();
let content = r#"
ClassBuilder::new().class("test-class").to_string()
"#;
parser.parse_content(content).unwrap();
parser.clear();
assert_eq!(parser.class_count(), 0);
assert_eq!(parser.parsed_file_count(), 0);
}
#[test]
fn test_invalid_rust_code() {
let mut parser = AstParser::new();
let content = "invalid rust code {";
let result = parser.parse_content(content);
assert!(result.is_err());
}
}