use std::collections::HashMap;
use cha_core::{TypeOrigin, TypeRef};
pub type ImportsMap = HashMap<String, TypeOrigin>;
pub fn resolve(raw: impl Into<String>, imports: &ImportsMap) -> TypeRef {
let raw = raw.into();
let name = strip_decor(&raw);
let origin = imports
.get(&name)
.cloned()
.unwrap_or_else(|| fallback_origin(&name, &raw));
TypeRef { name, raw, origin }
}
fn fallback_origin(name: &str, raw: &str) -> TypeOrigin {
if is_universal_primitive(name) || is_container_expression(raw) {
TypeOrigin::Primitive
} else {
TypeOrigin::Unknown
}
}
fn is_container_expression(raw: &str) -> bool {
let trimmed = raw.trim_start_matches(|c: char| c == '&' || c == '*' || c.is_whitespace());
let trimmed = trimmed.trim_start_matches("mut ").trim();
for prefix in ["dict[", "list[", "set[", "tuple[", "frozenset["] {
if trimmed.starts_with(prefix) {
return true;
}
}
false
}
fn is_universal_primitive(name: &str) -> bool {
UNIVERSAL_PRIMITIVES.contains(&name)
}
const UNIVERSAL_PRIMITIVES: &[&str] = &[
"i8",
"i16",
"i32",
"i64",
"i128",
"isize",
"u8",
"u16",
"u32",
"u64",
"u128",
"usize",
"f32",
"f64",
"bool",
"char",
"str",
"String",
"Vec",
"Option",
"Result",
"Box",
"Arc",
"Rc",
"Cell",
"RefCell",
"HashMap",
"HashSet",
"BTreeMap",
"BTreeSet",
"Path",
"PathBuf",
"OsStr",
"OsString",
"int",
"float",
"bytes",
"bytearray",
"list",
"dict",
"set",
"tuple",
"None",
"Any",
"number",
"string",
"boolean",
"null",
"undefined",
"void",
"never",
"unknown",
"any",
"Array",
"Promise",
"Map",
"Set",
"Date",
"Error",
"RegExp",
"Function",
"Object",
];
pub fn strip_decor(raw: &str) -> String {
let s = strip_rust_refs_and_slice(raw.trim());
if let Some(recur) = rust_slice_inner(s) {
return strip_decor(recur);
}
let s = strip_basic_decor(s);
if let Some(inner) = peel_any_container(s) {
return strip_decor(inner);
}
take_last_segment(s)
}
fn strip_rust_refs_and_slice(s: &str) -> &str {
let s = s.trim_start_matches('&').trim();
if !s.starts_with('\'') {
return s;
}
s.split_once(char::is_whitespace)
.map(|(_, rest)| rest.trim())
.unwrap_or(s)
}
fn rust_slice_inner(s: &str) -> Option<&str> {
s.strip_prefix('[').and_then(|r| r.strip_suffix(']'))
}
fn strip_basic_decor(s: &str) -> &str {
s.trim_start_matches("mut ")
.trim_start_matches("dyn ")
.trim_start_matches("impl ")
.trim_start_matches('*')
.trim_end_matches('*')
.trim_end_matches('?')
.trim_end_matches("[]")
.trim_start_matches("[]")
.trim()
}
fn peel_any_container(s: &str) -> Option<&str> {
for wrapper in ["Vec", "Option", "Box", "Arc", "Rc", "Cell", "RefCell"] {
if let Some(inner) = peel_rust_generic(s, wrapper) {
return Some(inner);
}
}
for wrapper in [
"List", "Optional", "Set", "Dict", "Iterable", "Tuple", "Sequence", "list", "dict", "set",
"tuple",
] {
if let Some(inner) = peel_py_generic(s, wrapper) {
return Some(inner);
}
}
None
}
fn take_last_segment(s: &str) -> String {
if let Some(last) = s.rsplit("::").next().filter(|p| !p.is_empty() && *p != s) {
return last.to_string();
}
if let Some(last) = s.rsplit('.').next().filter(|p| !p.is_empty() && *p != s) {
return last.to_string();
}
s.to_string()
}
fn peel_rust_generic<'a>(s: &'a str, wrapper: &str) -> Option<&'a str> {
let prefix = format!("{wrapper}<");
let rest = s.strip_prefix(&prefix)?.strip_suffix('>')?;
Some(rest)
}
fn peel_py_generic<'a>(s: &'a str, wrapper: &str) -> Option<&'a str> {
let prefix = format!("{wrapper}[");
let rest = s.strip_prefix(&prefix)?.strip_suffix(']')?;
Some(rest)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strip_decor_primitives() {
assert_eq!(strip_decor("i32"), "i32");
assert_eq!(strip_decor("bool"), "bool");
}
#[test]
fn strip_decor_rust_refs() {
assert_eq!(strip_decor("&Node"), "Node");
assert_eq!(strip_decor("&mut Vec<Finding>"), "Finding");
assert_eq!(strip_decor("&'a String"), "String");
}
#[test]
fn strip_decor_rust_paths() {
assert_eq!(strip_decor("tree_sitter::Node"), "Node");
assert_eq!(strip_decor("std::collections::HashMap"), "HashMap");
}
#[test]
fn strip_decor_rust_generics() {
assert_eq!(strip_decor("Option<String>"), "String");
assert_eq!(strip_decor("Vec<Box<MyTrait>>"), "MyTrait");
assert_eq!(strip_decor("Arc<RefCell<State>>"), "State");
}
#[test]
fn strip_decor_c_pointers() {
assert_eq!(strip_decor("char *"), "char");
assert_eq!(strip_decor("cmark_node_t *"), "cmark_node_t");
}
#[test]
fn strip_decor_ts_and_py_containers() {
assert_eq!(strip_decor("Foo[]"), "Foo");
assert_eq!(strip_decor("Foo?"), "Foo");
assert_eq!(strip_decor("List[int]"), "int");
assert_eq!(strip_decor("Optional[MyType]"), "MyType");
}
#[test]
fn strip_decor_dotted_paths() {
assert_eq!(strip_decor("package.module.Type"), "Type");
}
#[test]
fn strip_decor_rust_slice() {
assert_eq!(strip_decor("[PathBuf]"), "PathBuf");
assert_eq!(strip_decor("&[Finding]"), "Finding");
assert_eq!(strip_decor("&[&DetailClass]"), "DetailClass");
}
#[test]
fn resolve_uses_imports_map() {
let mut imports = ImportsMap::new();
imports.insert("Node".into(), TypeOrigin::External("tree_sitter".into()));
let tr = resolve("&tree_sitter::Node", &imports);
assert_eq!(tr.name, "Node");
assert_eq!(tr.origin, TypeOrigin::External("tree_sitter".into()));
}
#[test]
fn resolve_falls_back_to_unknown() {
let imports = ImportsMap::new();
let tr = resolve("SomeUnknownType", &imports);
assert_eq!(tr.origin, TypeOrigin::Unknown);
}
}