use codegraph::extraction::ScalaExtractor;
use codegraph::types::{EdgeKind, NodeKind};
fn extract(source: &str) -> codegraph::types::ExtractionResult {
ScalaExtractor::extract_scala("test.scala", source)
}
#[test]
fn test_scala_file_node_is_root() {
let result = extract("object Main");
let file_nodes: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::File).collect();
assert_eq!(file_nodes.len(), 1);
assert_eq!(file_nodes[0].name, "test.scala");
}
#[test]
fn test_scala_extract_package() {
let result = extract("package com.example.app\n\nobject Main");
let pkgs: Vec<_> = result
.nodes
.iter()
.filter(|n| n.kind == NodeKind::ScalaPackage)
.collect();
assert_eq!(pkgs.len(), 1);
assert_eq!(pkgs[0].name, "com.example.app");
}
#[test]
fn test_scala_extract_import() {
let result = extract("import scala.collection.mutable.ListBuffer\nimport java.io._");
let imports: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::Use).collect();
assert_eq!(imports.len(), 2);
assert!(imports.iter().any(|n| n.name.contains("ListBuffer")));
}
#[test]
fn test_scala_extract_class() {
let result = extract("class MyClass(val x: Int) {\n def hello(): String = \"hi\"\n}");
let classes: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "MyClass");
}
#[test]
fn test_scala_extract_case_class() {
let result = extract("case class Person(name: String, age: Int)");
let case_classes: Vec<_> = result
.nodes
.iter()
.filter(|n| n.kind == NodeKind::CaseClass)
.collect();
assert_eq!(case_classes.len(), 1);
assert_eq!(case_classes[0].name, "Person");
}
#[test]
fn test_scala_extract_trait() {
let result = extract("trait Greeter {\n def greet(name: String): String\n}");
let traits: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::Trait).collect();
assert_eq!(traits.len(), 1);
assert_eq!(traits[0].name, "Greeter");
}
#[test]
fn test_scala_extract_abstract_method_in_trait() {
let result = extract("trait Greeter {\n def greet(name: String): String\n}");
let abstract_methods: Vec<_> = result
.nodes
.iter()
.filter(|n| n.kind == NodeKind::AbstractMethod)
.collect();
assert_eq!(abstract_methods.len(), 1);
assert_eq!(abstract_methods[0].name, "greet");
}
#[test]
fn test_scala_extract_object() {
let result = extract("object Main {\n def main(args: Array[String]): Unit = println(\"hi\")\n}");
let objects: Vec<_> = result
.nodes
.iter()
.filter(|n| n.kind == NodeKind::ScalaObject)
.collect();
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].name, "Main");
}
#[test]
fn test_scala_extract_method() {
let result = extract("object Main {\n def hello(name: String): String = s\"Hello $name\"\n}");
let methods: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
assert_eq!(methods.len(), 1);
assert_eq!(methods[0].name, "hello");
}
#[test]
fn test_scala_extract_function() {
let result = extract("def topLevel(x: Int): Int = x + 1");
let fns: Vec<_> = result
.nodes
.iter()
.filter(|n| n.kind == NodeKind::Function)
.collect();
assert_eq!(fns.len(), 1);
assert_eq!(fns[0].name, "topLevel");
}
#[test]
fn test_scala_extract_val() {
let result = extract("object Config {\n val name: String = \"app\"\n}");
let vals: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::ValField).collect();
assert_eq!(vals.len(), 1);
assert_eq!(vals[0].name, "name");
}
#[test]
fn test_scala_extract_var() {
let result = extract("object State {\n var count: Int = 0\n}");
let vars: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::VarField).collect();
assert_eq!(vars.len(), 1);
assert_eq!(vars[0].name, "count");
}
#[test]
fn test_scala_extract_type_alias() {
let result = extract("object Types {\n type StringMap = Map[String, String]\n}");
let types: Vec<_> = result
.nodes
.iter()
.filter(|n| n.kind == NodeKind::TypeAlias)
.collect();
assert_eq!(types.len(), 1);
assert_eq!(types[0].name, "StringMap");
}
#[test]
fn test_scala_extract_class_params_as_fields() {
let result = extract("class Point(val x: Int, val y: Int, z: Int)");
let vals: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::ValField).collect();
assert!(vals.len() >= 2);
assert!(vals.iter().any(|n| n.name == "x"));
assert!(vals.iter().any(|n| n.name == "y"));
}
#[test]
fn test_scala_contains_edges() {
let result = extract("object Main {\n def hello(): Unit = ()\n}");
let contains_edges: Vec<_> = result
.edges
.iter()
.filter(|e| e.kind == EdgeKind::Contains)
.collect();
assert!(contains_edges.len() >= 2);
}
#[test]
fn test_scala_extract_call_sites() {
let result = extract(
"object Main {\n def run(): Unit = {\n println(\"hello\")\n foo()\n }\n}",
);
let calls: Vec<_> = result
.unresolved_refs
.iter()
.filter(|r| r.reference_kind == EdgeKind::Calls)
.collect();
assert!(calls.len() >= 2);
assert!(calls.iter().any(|c| c.reference_name == "println"));
assert!(calls.iter().any(|c| c.reference_name == "foo"));
}
#[test]
fn test_scala_visibility_private() {
let result = extract("class Foo {\n private def secret(): Unit = ()\n}");
let methods: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
assert_eq!(methods.len(), 1);
assert_eq!(methods[0].visibility, codegraph::types::Visibility::Private);
}
#[test]
fn test_scala_visibility_default_is_public() {
let result = extract("class Foo {\n def open(): Unit = ()\n}");
let methods: Vec<_> = result.nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
assert_eq!(methods.len(), 1);
assert_eq!(methods[0].visibility, codegraph::types::Visibility::Pub);
}
#[test]
fn test_scala_qualified_names() {
let result = extract("object Main {\n def hello(): Unit = ()\n}");
let method = result
.nodes
.iter()
.find(|n| n.kind == NodeKind::Method)
.unwrap();
assert!(method.qualified_name.contains("Main"));
assert!(method.qualified_name.contains("hello"));
}
#[test]
fn test_scala_scaladoc() {
let result = extract(
"/** A greeting object. */\nobject Greeter {\n /** Says hi. */\n def hi(): String = \"hi\"\n}",
);
let obj = result
.nodes
.iter()
.find(|n| n.kind == NodeKind::ScalaObject)
.unwrap();
assert!(obj.docstring.as_ref().unwrap().contains("greeting"));
}