use std::path::Path;
use tree_sitter::Node;
use crate::ast::parser::ParserPool;
use crate::types::{InheritanceNode, Language};
use crate::TldrResult;
pub fn extract_classes(
source: &str,
file_path: &Path,
parser_pool: &ParserPool,
) -> TldrResult<Vec<InheritanceNode>> {
let tree = parser_pool.parse(source, Language::Scala)?;
let mut classes = Vec::new();
let root = tree.root_node();
visit_node(&root, source, file_path, &mut classes);
Ok(classes)
}
fn visit_node(node: &Node, source: &str, file_path: &Path, classes: &mut Vec<InheritanceNode>) {
match node.kind() {
"class_definition" => {
if let Some(class) = extract_class_definition(node, source, file_path) {
classes.push(class);
}
}
"trait_definition" => {
if let Some(trait_node) = extract_trait_definition(node, source, file_path) {
classes.push(trait_node);
}
}
"object_definition" => {
if let Some(obj) = extract_object_definition(node, source, file_path) {
classes.push(obj);
}
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
visit_node(&child, source, file_path, classes);
}
}
fn extract_class_definition(
node: &Node,
source: &str,
file_path: &Path,
) -> Option<InheritanceNode> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source.as_bytes()).ok()?.to_string();
let line = node.start_position().row as u32 + 1;
let mut class_node = InheritanceNode::new(name, file_path.to_path_buf(), line, Language::Scala);
if let Some(extends) = node.child_by_field_name("extend") {
class_node.bases = extract_types_from_extends_clause(&extends, source);
}
if has_modifier(node, source, "abstract") {
class_node.is_abstract = Some(true);
}
Some(class_node)
}
fn extract_trait_definition(
node: &Node,
source: &str,
file_path: &Path,
) -> Option<InheritanceNode> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source.as_bytes()).ok()?.to_string();
let line = node.start_position().row as u32 + 1;
let mut trait_node = InheritanceNode::new(name, file_path.to_path_buf(), line, Language::Scala);
trait_node.interface = Some(true);
if let Some(extends) = node.child_by_field_name("extend") {
trait_node.bases = extract_types_from_extends_clause(&extends, source);
}
Some(trait_node)
}
fn extract_object_definition(
node: &Node,
source: &str,
file_path: &Path,
) -> Option<InheritanceNode> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source.as_bytes()).ok()?.to_string();
let line = node.start_position().row as u32 + 1;
let mut obj_node = InheritanceNode::new(name, file_path.to_path_buf(), line, Language::Scala);
if let Some(extends) = node.child_by_field_name("extend") {
obj_node.bases = extract_types_from_extends_clause(&extends, source);
}
Some(obj_node)
}
fn extract_types_from_extends_clause(node: &Node, source: &str) -> Vec<String> {
let mut bases = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"type_identifier" => {
if let Ok(text) = child.utf8_text(source.as_bytes()) {
bases.push(text.to_string());
}
}
"generic_type" => {
if let Some(name) = extract_type_name_from_generic(&child, source) {
bases.push(name);
}
}
"stable_type_identifier" => {
if let Some(name) = extract_last_identifier(&child, source) {
bases.push(name);
}
}
_ => {}
}
}
bases
}
fn extract_type_name_from_generic(node: &Node, source: &str) -> Option<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
return child
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string());
}
}
None
}
fn extract_last_identifier(node: &Node, source: &str) -> Option<String> {
let mut last_id = None;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "type_identifier" {
last_id = child
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string());
}
}
last_id
}
fn has_modifier(node: &Node, source: &str, modifier: &str) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers" {
return check_modifier_recursive(&child, source, modifier);
}
if child.kind() == modifier {
return true;
}
}
false
}
fn check_modifier_recursive(node: &Node, source: &str, modifier: &str) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Ok(text) = child.utf8_text(source.as_bytes()) {
if text == modifier {
return true;
}
}
if check_modifier_recursive(&child, source, modifier) {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn parse_and_extract(source: &str) -> Vec<InheritanceNode> {
let pool = ParserPool::new();
extract_classes(source, &PathBuf::from("Test.scala"), &pool).unwrap()
}
#[test]
fn test_simple_class() {
let source = r#"
class Animal {
def speak(): String = "..."
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Animal");
assert!(classes[0].bases.is_empty());
assert_eq!(classes[0].language, Language::Scala);
}
#[test]
fn test_class_extends() {
let source = r#"
class Animal(val name: String)
class Dog(name: String) extends Animal(name) {
def bark(): String = "Woof"
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 2);
let dog = classes.iter().find(|c| c.name == "Dog").unwrap();
assert!(dog.bases.contains(&"Animal".to_string()));
}
#[test]
fn test_trait_definition() {
let source = r#"
trait Serializable {
def serialize(): String
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Serializable");
assert_eq!(classes[0].interface, Some(true));
}
#[test]
fn test_trait_extends_trait() {
let source = r#"
trait Base {
def id: String
}
trait Extended extends Base {
def name: String
}
"#;
let classes = parse_and_extract(source);
let extended = classes.iter().find(|c| c.name == "Extended").unwrap();
assert!(extended.bases.contains(&"Base".to_string()));
assert_eq!(extended.interface, Some(true));
}
#[test]
fn test_object_extends() {
let source = r#"
trait Compress {
def headers: Seq[String]
}
object Gzip extends Compress {
def headers = Seq("gzip")
}
"#;
let classes = parse_and_extract(source);
let gzip = classes.iter().find(|c| c.name == "Gzip").unwrap();
assert!(gzip.bases.contains(&"Compress".to_string()));
}
#[test]
fn test_class_with_mixin() {
let source = r#"
class Animal(val name: String)
trait Serializable {
def serialize(): String
}
class Dog(name: String) extends Animal(name) with Serializable {
def serialize() = s"Dog($name)"
}
"#;
let classes = parse_and_extract(source);
let dog = classes.iter().find(|c| c.name == "Dog").unwrap();
assert!(dog.bases.contains(&"Animal".to_string()));
assert!(dog.bases.contains(&"Serializable".to_string()));
assert_eq!(dog.bases.len(), 2);
}
#[test]
fn test_case_class() {
let source = r#"
case class Request(url: String, method: String)
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Request");
}
#[test]
fn test_abstract_class() {
let source = r#"
abstract class Shape {
def area(): Double
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Shape");
assert_eq!(classes[0].is_abstract, Some(true));
}
}