use crate::parser::language::{
Export, Import, LanguageSupport, ParseResult, Symbol, SymbolKind, Visibility,
};
use tree_sitter::Language as TsLanguage;
pub struct OcamlLanguage;
pub struct OcamlInterfaceLanguage;
fn extract_common(source: &str, tree: &tree_sitter::Tree) -> ParseResult {
let source_bytes = source.as_bytes();
let root = tree.root_node();
let mut symbols: Vec<Symbol> = Vec::new();
let mut imports: Vec<Import> = Vec::new();
let mut exports: Vec<Export> = Vec::new();
let mut stack: Vec<tree_sitter::Node> = root.children(&mut root.walk()).collect();
while let Some(node) = stack.pop() {
match node.kind() {
"value_definition" | "let_binding" => {
let name = extract_name(&node, source_bytes);
if !name.is_empty() {
let signature = first_line(&node, source_bytes);
let body = node_text(&node, source_bytes).to_string();
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
exports.push(Export {
name: name.clone(),
kind: SymbolKind::Function,
});
symbols.push(Symbol {
name,
kind: SymbolKind::Function,
visibility: Visibility::Public,
signature,
body,
start_line,
end_line,
});
}
}
"type_definition" => {
let name = extract_type_name(&node, source_bytes);
if !name.is_empty() {
let signature = first_line(&node, source_bytes);
let body = node_text(&node, source_bytes).to_string();
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
exports.push(Export {
name: name.clone(),
kind: SymbolKind::TypeAlias,
});
symbols.push(Symbol {
name,
kind: SymbolKind::TypeAlias,
visibility: Visibility::Public,
signature,
body,
start_line,
end_line,
});
}
}
"module_definition" => {
let name = extract_module_name(&node, source_bytes);
if !name.is_empty() {
let signature = first_line(&node, source_bytes);
let body = node_text(&node, source_bytes).to_string();
let start_line = node.start_position().row + 1;
let end_line = node.end_position().row + 1;
exports.push(Export {
name: name.clone(),
kind: SymbolKind::Class,
});
symbols.push(Symbol {
name,
kind: SymbolKind::Class,
visibility: Visibility::Public,
signature,
body,
start_line,
end_line,
});
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
stack.push(child);
}
}
"open_module" => {
let text = node_text(&node, source_bytes);
let module_name = text.trim_start_matches("open").trim().to_string();
if !module_name.is_empty() {
imports.push(Import {
source: module_name.clone(),
names: vec![module_name],
});
}
}
"let_expression" | "structure_item" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
stack.push(child);
}
}
_ => {}
}
}
ParseResult {
symbols,
imports,
exports,
}
}
fn node_text<'a>(node: &tree_sitter::Node, source: &'a [u8]) -> &'a str {
node.utf8_text(source).unwrap_or("")
}
fn first_line(node: &tree_sitter::Node, source: &[u8]) -> String {
let text = node_text(node, source);
text.lines().next().unwrap_or("").trim().to_string()
}
fn extract_name(node: &tree_sitter::Node, source: &[u8]) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "value_name" || child.kind() == "identifier" {
return node_text(&child, source).to_string();
}
if child.kind() == "let_binding" {
let mut inner = child.walk();
for inner_child in child.children(&mut inner) {
if inner_child.kind() == "value_name" || inner_child.kind() == "identifier" {
return node_text(&inner_child, source).to_string();
}
}
}
}
String::new()
}
fn extract_type_name(node: &tree_sitter::Node, source: &[u8]) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "type_binding" {
let mut inner = child.walk();
for inner_child in child.children(&mut inner) {
if inner_child.kind() == "type_constructor" || inner_child.kind() == "identifier" {
return node_text(&inner_child, source).to_string();
}
}
return node_text(&child, source).to_string();
}
if child.kind() == "type_constructor" || child.kind() == "identifier" {
return node_text(&child, source).to_string();
}
}
String::new()
}
fn extract_module_name(node: &tree_sitter::Node, source: &[u8]) -> String {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "module_name" {
return node_text(&child, source).to_string();
}
if child.kind() == "module_binding" {
let mut inner = child.walk();
for inner_child in child.children(&mut inner) {
if inner_child.kind() == "module_name" {
return node_text(&inner_child, source).to_string();
}
}
return node_text(&child, source).to_string();
}
}
String::new()
}
impl LanguageSupport for OcamlLanguage {
fn ts_language(&self) -> TsLanguage {
tree_sitter_ocaml::LANGUAGE_OCAML.into()
}
fn name(&self) -> &str {
"ocaml"
}
fn extract(&self, source: &str, tree: &tree_sitter::Tree) -> ParseResult {
extract_common(source, tree)
}
}
impl LanguageSupport for OcamlInterfaceLanguage {
fn ts_language(&self) -> TsLanguage {
tree_sitter_ocaml::LANGUAGE_OCAML_INTERFACE.into()
}
fn name(&self) -> &str {
"ocaml_interface"
}
fn extract(&self, source: &str, tree: &tree_sitter::Tree) -> ParseResult {
extract_common(source, tree)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::language::{SymbolKind, Visibility};
fn make_parser() -> tree_sitter::Parser {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_ocaml::LANGUAGE_OCAML.into())
.expect("failed to set language");
parser
}
fn make_interface_parser() -> tree_sitter::Parser {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_ocaml::LANGUAGE_OCAML_INTERFACE.into())
.expect("failed to set language");
parser
}
#[test]
fn test_extract_function() {
let source = r#"let greet name =
Printf.printf "Hello, %s!\n" name
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(
!result.symbols.is_empty(),
"expected at least one symbol from let binding"
);
}
#[test]
fn test_extract_imports() {
let source = r#"open Printf
open List
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(
!result.imports.is_empty(),
"expected imports from open statements"
);
}
#[test]
fn test_empty_source() {
let source = "";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(result.symbols.is_empty());
assert!(result.imports.is_empty());
assert!(result.exports.is_empty());
}
#[test]
fn test_complex_snippet() {
let source = r#"open Printf
type point = { x : float; y : float }
let distance p1 p2 =
let dx = p1.x -. p2.x in
let dy = p1.y -. p2.y in
sqrt (dx *. dx +. dy *. dy)
let origin = { x = 0.0; y = 0.0 }
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(
!result.symbols.is_empty(),
"expected symbols in complex snippet"
);
assert!(!result.imports.is_empty(), "expected open import");
}
#[test]
fn test_interface_language() {
let source = r#"val greet : string -> unit
"#;
let mut parser = make_interface_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlInterfaceLanguage;
let result = lang.extract(source, &tree);
let _ = result;
}
#[test]
fn test_ocaml_type_definition() {
let source = r#"type color = Red | Green | Blue
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let types: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::TypeAlias)
.collect();
assert!(
!types.is_empty(),
"expected type alias from type definition"
);
}
#[test]
fn test_ocaml_all_public() {
let source = r#"let add a b = a + b
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
for sym in &result.symbols {
assert_eq!(
sym.visibility,
Visibility::Public,
"OCaml symbols should be public"
);
}
}
#[test]
fn test_coverage_type_record() {
let source = "type person = { name : string; age : int }\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let types: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::TypeAlias)
.collect();
assert!(
!types.is_empty(),
"expected type definition, got: {:?}",
result
.symbols
.iter()
.map(|s| (&s.name, &s.kind))
.collect::<Vec<_>>()
);
let exported = result
.exports
.iter()
.any(|e| e.kind == SymbolKind::TypeAlias);
assert!(exported, "type should be exported");
}
#[test]
fn test_coverage_module_definition() {
let source = r#"module MyModule = struct
let helper x = x + 1
end
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let modules: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Class)
.collect();
assert!(
!modules.is_empty(),
"expected module as Class, got: {:?}",
result
.symbols
.iter()
.map(|s| (&s.name, &s.kind))
.collect::<Vec<_>>()
);
let exported = result.exports.iter().any(|e| e.kind == SymbolKind::Class);
assert!(exported, "module should be exported");
}
#[test]
fn test_coverage_multiple_let_bindings() {
let source = r#"let x = 42
let y = "hello"
let add a b = a + b
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(
result.symbols.len() >= 2,
"expected at least 2 let bindings, got: {:?}",
result
.symbols
.iter()
.map(|s| (&s.name, &s.kind))
.collect::<Vec<_>>()
);
assert!(
result.exports.len() >= 2,
"expected exports for let bindings"
);
}
#[test]
fn test_coverage_open_imports() {
let source = r#"open Printf
open List
open String
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(
result.imports.len() >= 2,
"expected at least 2 open imports, got: {:?}",
result.imports
);
}
#[test]
fn test_coverage_interface_language_separately() {
let lang = OcamlInterfaceLanguage;
assert_eq!(lang.name(), "ocaml_interface");
let source = r#"val add : int -> int -> int
type t = int
"#;
let mut parser = make_interface_parser();
let tree = parser.parse(source, None).expect("parse failed");
let result = lang.extract(source, &tree);
let _ = result;
}
#[test]
fn test_coverage_variant_type() {
let source = "type shape = Circle of float | Rectangle of float * float\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let types: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::TypeAlias)
.collect();
assert!(
!types.is_empty(),
"expected variant type definition, got: {:?}",
result
.symbols
.iter()
.map(|s| (&s.name, &s.kind))
.collect::<Vec<_>>()
);
}
#[test]
fn test_coverage_module_binding_drilldown() {
let source = r#"module type S = sig
val x : int
end
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let _ = result;
}
#[test]
fn test_coverage_empty_name_branches() {
let source = "(* just a comment *)\n";
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let _ = result;
}
#[test]
fn test_coverage_nested_let_in_module() {
let source = r#"module Helpers = struct
let id x = x
let double x = x + x
end
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
let modules: Vec<_> = result
.symbols
.iter()
.filter(|s| s.kind == SymbolKind::Class)
.collect();
assert!(
!modules.is_empty(),
"expected module symbol, got: {:?}",
result
.symbols
.iter()
.map(|s| (&s.name, &s.kind))
.collect::<Vec<_>>()
);
}
#[test]
fn test_coverage_combined_module() {
let source = r#"open Printf
type color = Red | Green | Blue
module Utils = struct
let format s = sprintf "<%s>" s
end
let main () =
printf "Hello\n"
"#;
let mut parser = make_parser();
let tree = parser.parse(source, None).expect("parse failed");
let lang = OcamlLanguage;
let result = lang.extract(source, &tree);
assert!(!result.imports.is_empty(), "expected open import");
assert!(
result.symbols.len() >= 2,
"expected multiple symbols (type, module, let), got: {:?}",
result
.symbols
.iter()
.map(|s| (&s.name, &s.kind))
.collect::<Vec<_>>()
);
}
}