use std::path::Path;
use tree_sitter::Node;
use crate::ast::parser::ParserPool;
use crate::types::{InheritanceNode, Language};
use crate::TldrResult;
pub fn extract_classes(
source: &str,
file_path: &Path,
parser_pool: &ParserPool,
) -> TldrResult<Vec<InheritanceNode>> {
let tree = parser_pool.parse(source, Language::Swift)?;
let mut classes = Vec::new();
let root = tree.root_node();
visit_node(&root, source, file_path, &mut classes);
Ok(classes)
}
fn visit_node(node: &Node, source: &str, file_path: &Path, classes: &mut Vec<InheritanceNode>) {
match node.kind() {
"class_declaration" => {
if let Some(class) = extract_class_declaration(node, source, file_path) {
classes.push(class);
}
}
"protocol_declaration" => {
if let Some(proto) = extract_protocol_declaration(node, source, file_path) {
classes.push(proto);
}
}
_ => {}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
visit_node(&child, source, file_path, classes);
}
}
fn extract_class_declaration(
node: &Node,
source: &str,
file_path: &Path,
) -> Option<InheritanceNode> {
let kind_node = node.child_by_field_name("declaration_kind")?;
let kind = kind_node.utf8_text(source.as_bytes()).ok()?;
if kind == "extension" {
return None;
}
let name_node = node.child_by_field_name("name")?;
let name = extract_type_name(&name_node, source)?;
let line = node.start_position().row as u32 + 1;
let mut class_node = InheritanceNode::new(name, file_path.to_path_buf(), line, Language::Swift);
class_node.bases = extract_inheritance_specifiers(node, source);
Some(class_node)
}
fn extract_protocol_declaration(
node: &Node,
source: &str,
file_path: &Path,
) -> Option<InheritanceNode> {
let name_node = node.child_by_field_name("name")?;
let name = name_node.utf8_text(source.as_bytes()).ok()?.to_string();
let line = node.start_position().row as u32 + 1;
let mut proto_node = InheritanceNode::new(name, file_path.to_path_buf(), line, Language::Swift);
proto_node.interface = Some(true);
proto_node.protocol = Some(true);
proto_node.bases = extract_inheritance_specifiers(node, source);
Some(proto_node)
}
fn extract_type_name(node: &Node, source: &str) -> Option<String> {
match node.kind() {
"type_identifier" => node
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string()),
"user_type" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
return child
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string());
}
}
node.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string())
}
_ => node
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string()),
}
}
fn extract_inheritance_specifiers(node: &Node, source: &str) -> Vec<String> {
let mut bases = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "inheritance_specifier" {
if let Some(inherits_from) = child.child_by_field_name("inherits_from") {
if let Some(name) = extract_inherited_type_name(&inherits_from, source) {
let name = name.trim().to_string();
if !name.is_empty() && !name.starts_with('@') {
bases.push(name);
}
}
}
}
}
bases
}
fn extract_inherited_type_name(node: &Node, source: &str) -> Option<String> {
match node.kind() {
"user_type" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier" {
return child
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string());
}
}
node.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string())
}
"type_identifier" => node
.utf8_text(source.as_bytes())
.ok()
.map(|s| s.to_string()),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn parse_and_extract(source: &str) -> Vec<InheritanceNode> {
let pool = ParserPool::new();
extract_classes(source, &PathBuf::from("Test.swift"), &pool).unwrap()
}
#[test]
fn test_simple_class() {
let source = r#"
class Animal {
func speak() -> String {
return "..."
}
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Animal");
assert!(classes[0].bases.is_empty());
assert_eq!(classes[0].language, Language::Swift);
}
#[test]
fn test_class_inherits() {
let source = r#"
class Animal {
func speak() -> String { return "..." }
}
class Dog: Animal {
override func speak() -> String { return "Woof" }
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 2);
let dog = classes.iter().find(|c| c.name == "Dog").unwrap();
assert!(dog.bases.contains(&"Animal".to_string()));
}
#[test]
fn test_protocol_declaration() {
let source = r#"
protocol Serializable {
func serialize() -> String
}
"#;
let classes = parse_and_extract(source);
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "Serializable");
assert_eq!(classes[0].protocol, Some(true));
assert_eq!(classes[0].interface, Some(true));
}
#[test]
fn test_class_with_protocol() {
let source = r#"
protocol ParameterEncoder {
func encode() -> String
}
class JSONParameterEncoder: ParameterEncoder {
func encode() -> String { return "{}" }
}
"#;
let classes = parse_and_extract(source);
let encoder = classes
.iter()
.find(|c| c.name == "JSONParameterEncoder")
.unwrap();
assert!(encoder.bases.contains(&"ParameterEncoder".to_string()));
}
#[test]
fn test_class_inherits_and_conforms() {
let source = r#"
class Request {
var url: String = ""
}
protocol Sendable {}
class DownloadRequest: Request, Sendable {
var destination: String = ""
}
"#;
let classes = parse_and_extract(source);
let download = classes
.iter()
.find(|c| c.name == "DownloadRequest")
.unwrap();
assert!(download.bases.contains(&"Request".to_string()));
assert!(download.bases.contains(&"Sendable".to_string()));
assert_eq!(download.bases.len(), 2);
}
#[test]
fn test_protocol_inherits_protocol() {
let source = r#"
protocol Base {
func id() -> String
}
protocol Extended: Base {
func name() -> String
}
"#;
let classes = parse_and_extract(source);
let extended = classes.iter().find(|c| c.name == "Extended").unwrap();
assert!(extended.bases.contains(&"Base".to_string()));
assert_eq!(extended.protocol, Some(true));
}
#[test]
fn test_struct_with_protocol() {
let source = r#"
protocol Encodable {}
struct Options: Encodable {
var rawValue: Int = 0
}
"#;
let classes = parse_and_extract(source);
let opts = classes.iter().find(|c| c.name == "Options").unwrap();
assert!(opts.bases.contains(&"Encodable".to_string()));
}
#[test]
fn test_enum_declaration() {
let source = r#"
protocol RawRepresentable {}
enum Direction: RawRepresentable {
case north
case south
}
"#;
let classes = parse_and_extract(source);
let direction = classes.iter().find(|c| c.name == "Direction").unwrap();
assert!(direction.bases.contains(&"RawRepresentable".to_string()));
}
}