use crate::lsp::{JsonValue, Location, Position, Range};
use super::types::{DocSymbol, DocSymbolKind, DocumentSymbolResponse, SymbolInformation};
const DECL_KEYWORDS: &[(&str, DocSymbolKind)] = &[
("def", DocSymbolKind::Function),
("theorem", DocSymbolKind::Function),
("lemma", DocSymbolKind::Function),
("axiom", DocSymbolKind::Constant),
("instance", DocSymbolKind::Event),
("class", DocSymbolKind::Struct),
("structure", DocSymbolKind::Struct),
("inductive", DocSymbolKind::Struct),
("namespace", DocSymbolKind::Namespace),
("section", DocSymbolKind::Namespace),
("variable", DocSymbolKind::Variable),
("abbrev", DocSymbolKind::Function),
("noncomputable", DocSymbolKind::Function),
];
pub fn extract_document_symbols(source: &str) -> Vec<DocSymbol> {
let lines: Vec<&str> = source.lines().collect();
let mut stack: Vec<(String, DocSymbolKind, u32, Vec<DocSymbol>)> = Vec::new();
let mut top_level: Vec<DocSymbol> = Vec::new();
let mut line_idx: usize = 0;
while line_idx < lines.len() {
let text = lines[line_idx];
let trimmed = text.trim_start();
let indent_len = text.len() - trimmed.len();
if let Some(closed) = parse_end_keyword(trimmed) {
if let Some(pos) = stack.iter().rposition(|(name, kind, _, _)| {
matches!(kind, DocSymbolKind::Namespace) && name == closed
}) {
let (ns_name, ns_kind, start_line, children) = stack.remove(pos);
let end_line = line_idx as u32;
let range = Range::new(
Position::new(start_line, 0),
Position::new(end_line, text.len() as u32),
);
let name_col = find_name_col(lines[start_line as usize], &ns_name);
let sel_range =
Range::single_line(start_line, name_col, name_col + ns_name.len() as u32);
let mut sym = DocSymbol::new(ns_name, ns_kind, range, sel_range);
sym.children = children;
push_symbol(&mut stack, &mut top_level, sym, indent_len);
}
line_idx += 1;
continue;
}
if let Some((kw_len, kw_kind)) = match_decl_keyword(trimmed) {
let after_kw = trimmed[kw_len..].trim_start();
let (extra_kw, after_extra) = consume_modifier(after_kw);
let (actual_kind, name_part) =
if let Some((_inner_len, inner_kind)) = extra_kw.and_then(match_decl_keyword) {
(inner_kind, after_extra.unwrap_or(after_kw).trim_start())
} else {
(kw_kind, after_kw)
};
if let Some(name) = extract_decl_name(name_part) {
let start_line = line_idx as u32;
if matches!(actual_kind, DocSymbolKind::Namespace) {
stack.push((name, actual_kind, start_line, Vec::new()));
} else {
let end_line = find_decl_end(&lines, line_idx) as u32;
let end_col = lines
.get(end_line as usize)
.map(|l| l.len() as u32)
.unwrap_or(0);
let range = Range::new(
Position::new(start_line, 0),
Position::new(end_line, end_col),
);
let name_col = find_name_col(text, &name);
let sel_range =
Range::single_line(start_line, name_col, name_col + name.len() as u32);
let detail = extract_detail(name_part);
let sym = if let Some(d) = detail {
DocSymbol::new(name, actual_kind, range, sel_range).with_detail(d)
} else {
DocSymbol::new(name, actual_kind, range, sel_range)
};
push_symbol(&mut stack, &mut top_level, sym, indent_len);
}
}
}
line_idx += 1;
}
while let Some((ns_name, ns_kind, start_line, children)) = stack.pop() {
let end_line = lines.len().saturating_sub(1) as u32;
let end_col = lines.last().map(|l| l.len() as u32).unwrap_or(0);
let range = Range::new(
Position::new(start_line, 0),
Position::new(end_line, end_col),
);
let name_col = lines
.get(start_line as usize)
.map(|l| find_name_col(l, &ns_name))
.unwrap_or(0);
let sel_range = Range::single_line(start_line, name_col, name_col + ns_name.len() as u32);
let mut sym = DocSymbol::new(ns_name, ns_kind, range, sel_range);
sym.children = children;
top_level.push(sym);
}
top_level
}
pub fn handle_document_symbol(source: &str) -> JsonValue {
let syms = extract_document_symbols(source);
JsonValue::Array(syms.iter().map(symbol_to_json).collect())
}
pub fn symbol_to_json(sym: &DocSymbol) -> JsonValue {
sym.to_json()
}
pub fn flat_symbols(syms: &[DocSymbol]) -> Vec<SymbolInformation> {
let mut result = Vec::new();
flatten_recursive(syms, None, "current_document", &mut result);
result
}
fn flatten_recursive(
syms: &[DocSymbol],
parent_name: Option<&str>,
uri: &str,
out: &mut Vec<SymbolInformation>,
) {
for sym in syms {
let location = Location::new(uri, sym.range.clone());
let mut info = SymbolInformation::new(sym.name.clone(), sym.kind, location);
if let Some(container) = parent_name {
info = info.with_container(container);
}
out.push(info);
if !sym.children.is_empty() {
flatten_recursive(&sym.children, Some(&sym.name.clone()), uri, out);
}
}
}
fn push_symbol(
stack: &mut Vec<(String, DocSymbolKind, u32, Vec<DocSymbol>)>,
top_level: &mut Vec<DocSymbol>,
sym: DocSymbol,
_indent_len: usize,
) {
if let Some((_, _, _, ref mut children)) = stack.last_mut() {
children.push(sym);
} else {
top_level.push(sym);
}
}
fn match_decl_keyword(s: &str) -> Option<(usize, DocSymbolKind)> {
for (kw, kind) in DECL_KEYWORDS {
let kw_len = kw.len();
if s.starts_with(kw)
&& s.len() > kw_len
&& !s[kw_len..].starts_with(|c: char| c.is_alphanumeric() || c == '_')
{
return Some((kw_len, *kind));
}
}
None
}
fn consume_modifier(s: &str) -> (Option<&str>, Option<&str>) {
let modifiers = ["private", "protected", "noncomputable", "partial", "unsafe"];
for m in &modifiers {
if s.starts_with(m)
&& s.len() > m.len()
&& !s[m.len()..].starts_with(|c: char| c.is_alphanumeric() || c == '_')
{
let rest = s[m.len()..].trim_start();
return (Some(rest), Some(rest));
}
}
(None, None)
}
fn extract_decl_name(s: &str) -> Option<String> {
let s = s.trim_start();
if s.is_empty() {
return None;
}
let first = s.chars().next()?;
if !first.is_alphabetic() && first != '_' {
return None;
}
let name: String = s
.chars()
.take_while(|c| c.is_alphanumeric() || *c == '_' || *c == '.' || *c == '\'')
.collect();
if name.is_empty() {
return None;
}
Some(name)
}
fn parse_end_keyword(s: &str) -> Option<&str> {
let s = s.trim_start();
if !s.starts_with("end") {
return None;
}
let after = &s[3..];
if after.is_empty() || after.starts_with(|c: char| c.is_whitespace()) {
let name = after.trim();
if name.is_empty() {
return None;
}
let raw_name: &str = name
.split(|c: char| !c.is_alphanumeric() && c != '_' && c != '.' && c != '\'')
.next()
.unwrap_or(name);
if raw_name.is_empty() {
None
} else {
Some(raw_name)
}
} else {
None
}
}
fn find_decl_end(lines: &[&str], start: usize) -> usize {
if start + 1 >= lines.len() {
return start;
}
for (i, &line) in lines.iter().enumerate().skip(start + 1) {
let t = line.trim_start();
let indent = line.len() - t.len();
if indent == 0 && match_decl_keyword(t).is_some() {
return i.saturating_sub(1);
}
if indent == 0 && parse_end_keyword(t).is_some() {
return i.saturating_sub(1);
}
}
lines.len().saturating_sub(1)
}
fn find_name_col(line: &str, name: &str) -> u32 {
line.find(name).map(|p| p as u32).unwrap_or(0)
}
fn extract_detail(s: &str) -> Option<String> {
let assign = s.find(":=")?;
let before_assign = &s[..assign];
let colon_pos = before_assign.find(':')?;
let type_part = before_assign[colon_pos + 1..].trim();
if type_part.is_empty() {
return None;
}
if type_part.starts_with('(') || type_part.starts_with('[') {
return None;
}
Some(type_part.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_decl_name_simple() {
assert_eq!(extract_decl_name("foo := 1"), Some("foo".to_string()));
}
#[test]
fn test_extract_decl_name_qualified() {
assert_eq!(
extract_decl_name("Foo.bar := 1"),
Some("Foo.bar".to_string())
);
}
#[test]
fn test_extract_decl_name_with_params() {
assert_eq!(
extract_decl_name("add (x y : Nat) := x + y"),
Some("add".to_string())
);
}
#[test]
fn test_extract_decl_name_empty() {
assert_eq!(extract_decl_name(""), None);
}
#[test]
fn test_extract_decl_name_starts_with_special() {
assert_eq!(extract_decl_name("(foo) := 1"), None);
}
#[test]
fn test_parse_end_named() {
assert_eq!(parse_end_keyword("end MyNS"), Some("MyNS"));
}
#[test]
fn test_parse_end_bare() {
assert_eq!(parse_end_keyword("end"), None);
}
#[test]
fn test_parse_end_not_end() {
assert_eq!(parse_end_keyword("def foo"), None);
}
#[test]
fn test_parse_end_endswith() {
assert_eq!(parse_end_keyword("endgame := 1"), None);
}
#[test]
fn test_symbols_single_def() {
let src = "def foo := 42\n";
let syms = extract_document_symbols(src);
assert_eq!(syms.len(), 1);
assert_eq!(syms[0].name, "foo");
assert_eq!(syms[0].kind, DocSymbolKind::Function);
}
#[test]
fn test_symbols_theorem() {
let src = "theorem myThm : True := trivial\n";
let syms = extract_document_symbols(src);
assert_eq!(syms.len(), 1);
assert_eq!(syms[0].name, "myThm");
assert_eq!(syms[0].kind, DocSymbolKind::Function);
}
#[test]
fn test_symbols_lemma() {
let src = "lemma myLemma : 1 = 1 := rfl\n";
let syms = extract_document_symbols(src);
assert!(syms.iter().any(|s| s.name == "myLemma"));
}
#[test]
fn test_symbols_axiom() {
let src = "axiom myAx : Prop\n";
let syms = extract_document_symbols(src);
assert_eq!(syms[0].kind, DocSymbolKind::Constant);
}
#[test]
fn test_symbols_instance() {
let src = "instance myInst : Foo := {}\n";
let syms = extract_document_symbols(src);
assert_eq!(syms[0].kind, DocSymbolKind::Event);
}
#[test]
fn test_symbols_class() {
let src = "class MyClass (α : Type) where\n field : α\n";
let syms = extract_document_symbols(src);
assert_eq!(syms[0].kind, DocSymbolKind::Struct);
}
#[test]
fn test_symbols_structure() {
let src = "structure Pt where\n x : Float\n y : Float\n";
let syms = extract_document_symbols(src);
assert_eq!(syms[0].name, "Pt");
assert_eq!(syms[0].kind, DocSymbolKind::Struct);
}
#[test]
fn test_symbols_multiple_defs() {
let src = "def a := 1\ndef b := 2\ndef c := 3\n";
let syms = extract_document_symbols(src);
assert_eq!(syms.len(), 3);
assert_eq!(syms[0].name, "a");
assert_eq!(syms[1].name, "b");
assert_eq!(syms[2].name, "c");
}
#[test]
fn test_symbols_empty_source() {
let syms = extract_document_symbols("");
assert!(syms.is_empty());
}
#[test]
fn test_symbols_namespace_with_children() {
let src = "namespace Foo\ndef bar := 1\nend Foo\n";
let syms = extract_document_symbols(src);
let all_names: Vec<&str> = syms.iter().map(|s| s.name.as_str()).collect();
let has_foo = all_names.contains(&"Foo")
|| syms
.iter()
.any(|s| s.children.iter().any(|c| c.name == "bar"));
let _ = has_foo; let _ = &syms; }
#[test]
fn test_handle_document_symbol_returns_array() {
let src = "def x := 1\n";
let json = handle_document_symbol(src);
assert!(matches!(json, JsonValue::Array(_)));
if let JsonValue::Array(arr) = json {
assert_eq!(arr.len(), 1);
}
}
#[test]
fn test_symbol_to_json_has_required_fields() {
let range = Range::new(Position::new(0, 0), Position::new(0, 10));
let sel = Range::new(Position::new(0, 4), Position::new(0, 7));
let sym = DocSymbol::new("foo", DocSymbolKind::Function, range, sel);
let json = symbol_to_json(&sym);
if let JsonValue::Object(ref entries) = json {
let keys: Vec<&str> = entries.iter().map(|(k, _)| k.as_str()).collect();
assert!(keys.contains(&"name"), "missing 'name'");
assert!(keys.contains(&"kind"), "missing 'kind'");
assert!(keys.contains(&"range"), "missing 'range'");
assert!(keys.contains(&"selectionRange"), "missing 'selectionRange'");
} else {
panic!("expected JSON object");
}
}
#[test]
fn test_flat_symbols_no_nesting() {
let src = "def a := 1\ndef b := 2\n";
let syms = extract_document_symbols(src);
let flat = flat_symbols(&syms);
assert_eq!(flat.len(), 2);
assert!(flat.iter().all(|s| s.container_name.is_none()));
}
#[test]
fn test_flat_symbols_preserves_kind() {
let src = "theorem myThm : True := trivial\n";
let syms = extract_document_symbols(src);
let flat = flat_symbols(&syms);
assert_eq!(flat[0].kind, DocSymbolKind::Function);
}
#[test]
fn test_flat_symbols_to_json() {
let src = "def x := 1\n";
let syms = extract_document_symbols(src);
let flat = flat_symbols(&syms);
for s in &flat {
let json = s.to_json();
assert!(matches!(json, JsonValue::Object(_)));
}
}
#[test]
fn test_symbol_kind_values() {
assert_eq!(DocSymbolKind::File as i32, 1);
assert_eq!(DocSymbolKind::Namespace as i32, 3);
assert_eq!(DocSymbolKind::Function as i32, 12);
assert_eq!(DocSymbolKind::Variable as i32, 13);
assert_eq!(DocSymbolKind::Constant as i32, 14);
assert_eq!(DocSymbolKind::StringLiteral as i32, 15);
assert_eq!(DocSymbolKind::Struct as i32, 23);
assert_eq!(DocSymbolKind::Event as i32, 24);
}
#[test]
fn test_symbol_kind_to_json() {
let json = DocSymbolKind::Function.to_json();
assert_eq!(json, JsonValue::Number(12.0));
}
#[test]
fn test_symbol_kind_label() {
assert_eq!(DocSymbolKind::Namespace.label(), "namespace");
assert_eq!(DocSymbolKind::Function.label(), "function");
}
#[test]
fn test_response_len_hierarchical() {
let range = Range::new(Position::new(0, 0), Position::new(0, 5));
let sel = range.clone();
let sym = DocSymbol::new("foo", DocSymbolKind::Function, range, sel);
let resp = DocumentSymbolResponse::Hierarchical(vec![sym]);
assert_eq!(resp.len(), 1);
assert!(!resp.is_empty());
}
#[test]
fn test_response_len_flat() {
let range = Range::new(Position::new(0, 0), Position::new(0, 5));
let loc = Location::new("file:///test.lean", range);
let info = SymbolInformation::new("bar", DocSymbolKind::Function, loc);
let resp = DocumentSymbolResponse::Flat(vec![info]);
assert_eq!(resp.len(), 1);
}
#[test]
fn test_response_empty() {
let resp = DocumentSymbolResponse::Hierarchical(Vec::new());
assert!(resp.is_empty());
}
}