use codegraph_parser_api::{
CallRelation, ClassEntity, FunctionEntity, ImplementationRelation, ImportRelation,
InheritanceRelation, Parameter, ParserConfig, TraitEntity,
};
use tree_sitter::Node;
pub struct SwiftVisitor<'a> {
pub source: &'a [u8],
#[allow(dead_code)]
pub config: ParserConfig,
pub functions: Vec<FunctionEntity>,
pub classes: Vec<ClassEntity>,
pub traits: Vec<TraitEntity>,
pub imports: Vec<ImportRelation>,
pub calls: Vec<CallRelation>,
pub inheritance: Vec<InheritanceRelation>,
pub implementations: Vec<ImplementationRelation>,
current_class: Option<String>,
current_function: Option<String>,
}
impl<'a> SwiftVisitor<'a> {
pub fn new(source: &'a [u8], config: ParserConfig) -> Self {
Self {
source,
config,
functions: Vec::new(),
classes: Vec::new(),
traits: Vec::new(),
imports: Vec::new(),
calls: Vec::new(),
inheritance: Vec::new(),
implementations: Vec::new(),
current_class: None,
current_function: None,
}
}
fn node_text(&self, node: Node) -> String {
node.utf8_text(self.source).unwrap_or("").to_string()
}
fn get_declaration_type(&self, node: Node) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"struct" => return "struct".to_string(),
"enum" => return "enum".to_string(),
"class" => return "class".to_string(),
_ => continue,
}
}
"class".to_string() }
pub fn visit_node(&mut self, node: Node) {
let should_recurse = match node.kind() {
"class_declaration" => {
let declaration_type = self.get_declaration_type(node);
match declaration_type.as_str() {
"struct" => self.visit_struct(node),
"enum" => self.visit_enum(node),
_ => self.visit_class(node), }
false
}
"struct_declaration" => {
self.visit_struct(node);
false
}
"protocol_declaration" => {
self.visit_protocol(node);
false
}
"function_declaration" => {
if self.current_class.is_none() {
self.visit_function(node);
}
false
}
"import_declaration" => {
self.visit_import(node);
false
}
"enum_declaration" => {
self.visit_enum(node);
false
}
"extension_declaration" => {
self.visit_extension(node);
false
}
"call_expression" => {
self.visit_call_expression(node);
true
}
_ => true,
};
if should_recurse {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.visit_node(child);
}
}
}
fn visit_class(&mut self, node: Node) {
let name = self.extract_type_name(node);
if name.is_none() {
return;
}
let name = name.unwrap();
let previous_class = self.current_class.take();
self.current_class = Some(name.clone());
let mut base_classes = Vec::new();
let mut implemented_traits = Vec::new();
self.extract_inheritance(node, &name, &mut base_classes, &mut implemented_traits);
let type_params = self.extract_generic_params(node);
let doc_comment = self.extract_doc_comment(node);
let visibility = self.extract_visibility(node);
let class_entity = ClassEntity {
name: name.clone(),
visibility,
line_start: node.start_position().row + 1,
line_end: node.end_position().row + 1,
is_abstract: false,
is_interface: false,
base_classes,
implemented_traits,
methods: Vec::new(),
fields: Vec::new(),
doc_comment,
attributes: Vec::new(),
type_parameters: type_params,
};
self.classes.push(class_entity);
if let Some(body) = node.child_by_field_name("body") {
self.visit_class_body(body);
} else {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "class_body" {
self.visit_class_body(child);
}
}
}
self.current_class = previous_class;
}
fn visit_struct(&mut self, node: Node) {
let name = self.extract_type_name(node);
if name.is_none() {
return;
}
let name = name.unwrap();
let previous_class = self.current_class.take();
self.current_class = Some(name.clone());
let mut base_classes = Vec::new();
let mut implemented_traits = Vec::new();
self.extract_inheritance(node, &name, &mut base_classes, &mut implemented_traits);
let type_params = self.extract_generic_params(node);
let doc_comment = self.extract_doc_comment(node);
let visibility = self.extract_visibility(node);
let class_entity = ClassEntity {
name: name.clone(),
visibility,
line_start: node.start_position().row + 1,
line_end: node.end_position().row + 1,
is_abstract: false,
is_interface: false,
base_classes,
implemented_traits,
methods: Vec::new(),
fields: Vec::new(),
doc_comment,
attributes: vec!["struct".to_string()],
type_parameters: type_params,
};
self.classes.push(class_entity);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "class_body" || child.kind() == "struct_body" {
self.visit_class_body(child);
}
}
self.current_class = previous_class;
}
fn visit_protocol(&mut self, node: Node) {
let name = self.extract_type_name(node);
if name.is_none() {
return;
}
let name = name.unwrap();
let doc_comment = self.extract_doc_comment(node);
let visibility = self.extract_visibility(node);
let mut required_methods = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "protocol_body" {
let mut body_cursor = child.walk();
for body_child in child.children(&mut body_cursor) {
if body_child.kind() == "protocol_function_declaration"
|| body_child.kind() == "function_declaration"
{
if let Some(method) = self.extract_function_signature(body_child) {
required_methods.push(method);
}
}
}
}
}
let trait_entity = TraitEntity {
name,
visibility,
line_start: node.start_position().row + 1,
line_end: node.end_position().row + 1,
doc_comment,
required_methods,
parent_traits: Vec::new(),
attributes: Vec::new(),
};
self.traits.push(trait_entity);
}
fn visit_class_body(&mut self, node: Node) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"function_declaration" => self.visit_method(child),
"subscript_declaration" => self.visit_method(child),
"init_declaration" | "initializer_declaration" => self.visit_init(child),
"deinit_declaration" => self.visit_deinit(child),
"class_declaration" => self.visit_class(child),
"struct_declaration" => self.visit_struct(child),
"enum_declaration" => self.visit_enum(child),
_ => {}
}
}
}
fn visit_function(&mut self, node: Node) {
let name = self.extract_function_name(node);
if name.is_none() {
return;
}
let name = name.unwrap();
let params = self.extract_parameters(node);
let return_type = self.extract_return_type(node);
let visibility = self.extract_visibility(node);
let is_static = self.has_modifier(node, "static");
let is_async = self.has_modifier(node, "async");
let doc_comment = self.extract_doc_comment(node);
let mut func = FunctionEntity::new(
&name,
node.start_position().row + 1,
node.end_position().row + 1,
)
.with_visibility(&visibility)
.with_signature(self.node_text(node).lines().next().unwrap_or(""));
func.parameters = params;
func.return_type = return_type;
func.is_static = is_static;
func.is_async = is_async;
func.doc_comment = doc_comment;
func.parent_class = self.current_class.clone();
self.functions.push(func);
let previous_function = self.current_function.take();
self.current_function = Some(name);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "function_body" || child.kind() == "code_block" {
self.visit_function_body(child);
}
}
self.current_function = previous_function;
}
fn visit_method(&mut self, node: Node) {
let name = self.extract_function_name(node);
if name.is_none() {
return;
}
let name = name.unwrap();
let params = self.extract_parameters(node);
let return_type = self.extract_return_type(node);
let visibility = self.extract_visibility(node);
let is_static = self.has_modifier(node, "static") || self.has_modifier(node, "class");
let is_async = self.has_modifier(node, "async");
let doc_comment = self.extract_doc_comment(node);
let mut func = FunctionEntity::new(
&name,
node.start_position().row + 1,
node.end_position().row + 1,
)
.with_visibility(&visibility)
.with_signature(self.node_text(node).lines().next().unwrap_or(""));
func.parameters = params;
func.return_type = return_type;
func.is_static = is_static;
func.is_async = is_async;
func.doc_comment = doc_comment;
func.parent_class = self.current_class.clone();
if self.has_modifier(node, "override") {
func.attributes.push("override".to_string());
}
if self.has_modifier(node, "mutating") {
func.attributes.push("mutating".to_string());
}
self.functions.push(func);
let previous_function = self.current_function.take();
self.current_function = Some(name);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "function_body" || child.kind() == "code_block" {
self.visit_function_body(child);
}
}
self.current_function = previous_function;
}
fn visit_init(&mut self, node: Node) {
let name = "init".to_string();
let params = self.extract_parameters(node);
let visibility = self.extract_visibility(node);
let doc_comment = self.extract_doc_comment(node);
let mut func = FunctionEntity::new(
&name,
node.start_position().row + 1,
node.end_position().row + 1,
)
.with_visibility(&visibility)
.with_signature(self.node_text(node).lines().next().unwrap_or(""));
func.parameters = params;
func.doc_comment = doc_comment;
func.parent_class = self.current_class.clone();
func.attributes.push("init".to_string());
if self.has_modifier(node, "convenience") {
func.attributes.push("convenience".to_string());
}
if self.has_modifier(node, "required") {
func.attributes.push("required".to_string());
}
self.functions.push(func);
let previous_function = self.current_function.take();
self.current_function = Some(name);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "function_body" || child.kind() == "code_block" {
self.visit_function_body(child);
}
}
self.current_function = previous_function;
}
fn visit_deinit(&mut self, node: Node) {
let name = "deinit".to_string();
let doc_comment = self.extract_doc_comment(node);
let mut func = FunctionEntity::new(
&name,
node.start_position().row + 1,
node.end_position().row + 1,
)
.with_visibility("internal")
.with_signature(self.node_text(node).lines().next().unwrap_or(""));
func.doc_comment = doc_comment;
func.parent_class = self.current_class.clone();
func.attributes.push("deinit".to_string());
self.functions.push(func);
}
fn visit_function_body(&mut self, node: Node) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"call_expression" => {
self.visit_call_expression(child);
self.visit_function_body(child);
}
_ => {
self.visit_function_body(child);
}
}
}
}
fn visit_call_expression(&mut self, node: Node) {
let caller = match &self.current_function {
Some(name) => name.clone(),
None => return,
};
let callee = self.extract_callee_name(node);
if callee.is_empty() {
return;
}
let call = CallRelation {
caller,
callee,
call_site_line: node.start_position().row + 1,
is_direct: true,
};
self.calls.push(call);
}
fn extract_callee_name(&self, node: Node) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"simple_identifier" | "identifier" => {
return self.node_text(child);
}
"navigation_expression" => {
if let Some(suffix) = child.child_by_field_name("suffix") {
return self.node_text(suffix);
}
let mut inner_cursor = child.walk();
let mut last_id = String::new();
for inner_child in child.children(&mut inner_cursor) {
if inner_child.kind() == "simple_identifier"
|| inner_child.kind() == "identifier"
{
last_id = self.node_text(inner_child);
}
}
if !last_id.is_empty() {
return last_id;
}
}
_ => {}
}
}
String::new()
}
fn visit_enum(&mut self, node: Node) {
let name = self.extract_type_name(node);
if name.is_none() {
return;
}
let name = name.unwrap();
let doc_comment = self.extract_doc_comment(node);
let visibility = self.extract_visibility(node);
let type_params = self.extract_generic_params(node);
let class_entity = ClassEntity {
name,
visibility,
line_start: node.start_position().row + 1,
line_end: node.end_position().row + 1,
is_abstract: false,
is_interface: false,
base_classes: Vec::new(),
implemented_traits: Vec::new(),
methods: Vec::new(),
fields: Vec::new(),
doc_comment,
attributes: vec!["enum".to_string()],
type_parameters: type_params,
};
self.classes.push(class_entity);
}
fn visit_extension(&mut self, node: Node) {
let extended_type = self.extract_extension_type(node);
if extended_type.is_none() {
return;
}
let extended_type = extended_type.unwrap();
let previous_class = self.current_class.take();
self.current_class = Some(extended_type.clone());
let mut implemented_traits = Vec::new();
self.extract_extension_protocols(node, &extended_type, &mut implemented_traits);
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "extension_body" || child.kind() == "class_body" {
self.visit_class_body(child);
}
}
self.current_class = previous_class;
}
fn visit_import(&mut self, node: Node) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "identifier" || child.kind() == "simple_identifier" {
let import = ImportRelation {
importer: "file".to_string(),
imported: self.node_text(child),
symbols: Vec::new(),
is_wildcard: true,
alias: None,
};
self.imports.push(import);
return;
}
}
}
fn extract_type_name(&self, node: Node) -> Option<String> {
if let Some(name_node) = node.child_by_field_name("name") {
return Some(self.node_text(name_node));
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier"
|| child.kind() == "simple_identifier"
|| child.kind() == "identifier"
{
return Some(self.node_text(child));
}
}
None
}
fn extract_function_name(&self, node: Node) -> Option<String> {
if let Some(name_node) = node.child_by_field_name("name") {
return Some(self.node_text(name_node));
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "simple_identifier" || child.kind() == "identifier" {
return Some(self.node_text(child));
}
}
None
}
fn extract_function_signature(&self, node: Node) -> Option<FunctionEntity> {
let name = self.extract_function_name(node)?;
let params = self.extract_parameters(node);
let return_type = self.extract_return_type(node);
let mut func = FunctionEntity::new(
&name,
node.start_position().row + 1,
node.end_position().row + 1,
)
.with_signature(self.node_text(node).lines().next().unwrap_or(""));
func.parameters = params;
func.return_type = return_type;
func.is_abstract = true;
Some(func)
}
fn extract_parameters(&self, node: Node) -> Vec<Parameter> {
let mut params = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "parameter"
|| child.kind() == "function_parameter"
|| child.kind() == "parameter_clause"
{
if child.kind() == "parameter_clause" {
let mut param_cursor = child.walk();
for param_child in child.children(&mut param_cursor) {
if param_child.kind() == "parameter" {
if let Some(param) = self.extract_single_parameter(param_child) {
params.push(param);
}
}
}
} else if let Some(param) = self.extract_single_parameter(child) {
params.push(param);
}
}
}
params
}
fn extract_single_parameter(&self, node: Node) -> Option<Parameter> {
let mut name = String::new();
let mut param_type = String::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"simple_identifier" | "identifier" => {
if name.is_empty() {
name = self.node_text(child);
}
}
"type_annotation" => {
let mut type_cursor = child.walk();
for type_child in child.children(&mut type_cursor) {
if type_child.kind() != ":" {
param_type = self.node_text(type_child);
break;
}
}
}
_ => {}
}
}
if name.is_empty() {
return None;
}
let mut param = Parameter::new(name);
if !param_type.is_empty() {
param = param.with_type(param_type);
}
Some(param)
}
fn extract_return_type(&self, node: Node) -> Option<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "function_result" || child.kind() == "type_annotation" {
let mut type_cursor = child.walk();
for type_child in child.children(&mut type_cursor) {
if type_child.kind() != "->" && type_child.kind() != ":" {
let type_str = self.node_text(type_child);
if type_str != "Void" && !type_str.is_empty() {
return Some(type_str);
}
}
}
}
}
None
}
fn extract_inheritance(
&mut self,
node: Node,
class_name: &str,
base_classes: &mut Vec<String>,
implemented_traits: &mut Vec<String>,
) {
let mut cursor = node.walk();
let mut order = 0;
for child in node.children(&mut cursor) {
if child.kind() == "inheritance_specifier" || child.kind() == "type_inheritance_clause"
{
let mut inherit_cursor = child.walk();
for inherit_child in child.children(&mut inherit_cursor) {
if inherit_child.kind() == "type_identifier"
|| inherit_child.kind() == "user_type"
|| inherit_child.kind() == "simple_identifier"
{
let parent_name = self.node_text(inherit_child);
if order == 0 {
base_classes.push(parent_name.clone());
self.inheritance.push(InheritanceRelation {
child: class_name.to_string(),
parent: parent_name,
order,
});
} else {
implemented_traits.push(parent_name.clone());
self.implementations.push(ImplementationRelation {
implementor: class_name.to_string(),
trait_name: parent_name,
});
}
order += 1;
}
}
}
}
}
fn extract_extension_type(&self, node: Node) -> Option<String> {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_identifier"
|| child.kind() == "user_type"
|| child.kind() == "simple_identifier"
{
return Some(self.node_text(child));
}
}
None
}
fn extract_extension_protocols(
&mut self,
node: Node,
extended_type: &str,
implemented_traits: &mut Vec<String>,
) {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_inheritance_clause" {
let mut inherit_cursor = child.walk();
for inherit_child in child.children(&mut inherit_cursor) {
if inherit_child.kind() == "type_identifier"
|| inherit_child.kind() == "user_type"
{
let protocol_name = self.node_text(inherit_child);
implemented_traits.push(protocol_name.clone());
self.implementations.push(ImplementationRelation {
implementor: extended_type.to_string(),
trait_name: protocol_name,
});
}
}
}
}
}
fn extract_generic_params(&self, node: Node) -> Vec<String> {
let mut params = Vec::new();
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_parameters" || child.kind() == "generic_parameter_clause" {
let mut param_cursor = child.walk();
for param_child in child.children(&mut param_cursor) {
if param_child.kind() == "type_parameter"
|| param_child.kind() == "simple_identifier"
{
params.push(self.node_text(param_child));
}
}
}
}
params
}
fn extract_visibility(&self, node: Node) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers" || child.kind() == "modifier" {
let text = self.node_text(child);
if text.contains("public") {
return "public".to_string();
} else if text.contains("private") {
return "private".to_string();
} else if text.contains("fileprivate") {
return "fileprivate".to_string();
} else if text.contains("internal") {
return "internal".to_string();
} else if text.contains("open") {
return "open".to_string();
}
}
if child.kind() == "visibility_modifier" {
return self.node_text(child);
}
}
"internal".to_string() }
fn has_modifier(&self, node: Node, modifier: &str) -> bool {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "modifiers" || child.kind() == "modifier" {
let text = self.node_text(child);
if text.contains(modifier) {
return true;
}
}
}
false
}
fn extract_doc_comment(&self, node: Node) -> Option<String> {
if let Some(prev) = node.prev_sibling() {
if prev.kind() == "comment" || prev.kind() == "multiline_comment" {
let text = self.node_text(prev);
if text.starts_with("///") || text.starts_with("/**") {
return Some(text);
}
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_and_visit(source: &[u8]) -> SwiftVisitor<'_> {
let mut parser = Parser::new();
let language = tree_sitter_swift::language();
parser.set_language(&language).unwrap();
let tree = parser.parse(source, None).unwrap();
let mut visitor = SwiftVisitor::new(source, ParserConfig::default());
visitor.visit_node(tree.root_node());
visitor
}
#[test]
fn test_visitor_basics() {
let visitor = SwiftVisitor::new(b"", ParserConfig::default());
assert_eq!(visitor.functions.len(), 0);
assert_eq!(visitor.classes.len(), 0);
assert_eq!(visitor.traits.len(), 0);
}
#[test]
fn test_visitor_class_extraction() {
let source = b"class Person { var name: String = \"\" }";
let visitor = parse_and_visit(source);
assert_eq!(visitor.classes.len(), 1);
assert_eq!(visitor.classes[0].name, "Person");
}
#[test]
fn test_visitor_struct_extraction() {
let source = b"struct Point { var x: Int; var y: Int }";
let visitor = parse_and_visit(source);
assert_eq!(visitor.classes.len(), 1);
assert_eq!(visitor.classes[0].name, "Point");
assert!(visitor.classes[0]
.attributes
.contains(&"struct".to_string()));
}
#[test]
fn test_visitor_protocol_extraction() {
let source = b"protocol Drawable { func draw() }";
let visitor = parse_and_visit(source);
assert_eq!(visitor.traits.len(), 1);
assert_eq!(visitor.traits[0].name, "Drawable");
}
#[test]
fn test_visitor_function_extraction() {
let source = b"func greet(name: String) -> String { return name }";
let visitor = parse_and_visit(source);
assert!(!visitor.functions.is_empty());
}
#[test]
fn test_visitor_import_extraction() {
let source = b"import Foundation\nimport UIKit";
let visitor = parse_and_visit(source);
assert_eq!(visitor.imports.len(), 2);
}
#[test]
fn test_visitor_inheritance() {
let source = b"class Animal {}\nclass Dog: Animal {}";
let visitor = parse_and_visit(source);
assert_eq!(visitor.classes.len(), 2);
assert_eq!(visitor.inheritance.len(), 1);
assert_eq!(visitor.inheritance[0].child, "Dog");
assert_eq!(visitor.inheritance[0].parent, "Animal");
}
#[test]
fn test_visitor_enum() {
let source = b"enum Color { case red, green, blue }";
let visitor = parse_and_visit(source);
assert_eq!(visitor.classes.len(), 1);
assert_eq!(visitor.classes[0].name, "Color");
assert!(visitor.classes[0].attributes.contains(&"enum".to_string()));
}
}