Skip to main content

compiler/
symbol.rs

1use dynamic::{ConstIntOp, Dynamic, Type};
2use parser::Stmt;
3use smol_str::SmolStr;
4use std::{rc::Rc, sync::Arc};
5
6use super::Capture;
7
8#[derive(Debug, Clone, Default)]
9pub enum Symbol {
10    #[default]
11    Null,
12    Const {
13        value: Dynamic,
14        ty: Type,
15        is_pub: bool,
16    },
17    Static {
18        value: Option<Dynamic>,
19        ty: Type,
20        is_pub: bool,
21    },
22    Struct(Type, bool),
23    Fn {
24        ty: Type,
25        args: Vec<SmolStr>,
26        generic_params: Vec<Type>,
27        cap: Capture,
28        body: Arc<Stmt>,
29        is_pub: bool,
30    },
31    Native(Type),
32}
33
34impl Symbol {
35    pub fn native(tys: Vec<Type>, ret: Type) -> Self {
36        Self::Native(Type::Fn { tys, ret: Rc::new(ret) })
37    }
38
39    pub fn is_pub(&self) -> bool {
40        match self {
41            Self::Const { value: _, ty: _, is_pub } => *is_pub,
42            Self::Static { value: _, ty: _, is_pub } => *is_pub,
43            Self::Struct(_, is_pub) => *is_pub,
44            Self::Fn { ty: _, args: _, generic_params: _, cap: _, body: _, is_pub } => *is_pub,
45            _ => true,
46        }
47    }
48
49    pub fn is_fn(&self) -> bool {
50        match self {
51            Self::Fn { ty: _, args: _, generic_params: _, cap: _, body: _, is_pub: _ } => true,
52            Self::Native(_) => true,
53            _ => false,
54        }
55    }
56}
57
58use anyhow::{Result, anyhow};
59use indexmap::IndexMap;
60use std::{cell::RefCell, collections::HashMap};
61
62pub fn eval_const_int_type(ty: &Type) -> Option<i64> {
63    match ty {
64        Type::ConstInt(value) => Some(*value),
65        Type::ConstBinary { op, left, right } => {
66            let left = eval_const_int_type(left)?;
67            let right = eval_const_int_type(right)?;
68            match op {
69                ConstIntOp::Add => Some(left + right),
70                ConstIntOp::Sub => Some(left - right),
71                ConstIntOp::Mul => Some(left * right),
72                ConstIntOp::Div => (right != 0).then_some(left / right),
73                ConstIntOp::Mod => (right != 0).then_some(left % right),
74            }
75        }
76        _ => None,
77    }
78}
79
80pub fn substitute_type(ty: &Type, params: &[Type], args: &[Type]) -> Type {
81    match ty {
82        Type::Ident { name, params: nested } if nested.is_empty() => {
83            params.iter().position(|param| matches!(param, Type::Ident { name: param_name, params } if params.is_empty() && param_name == name)).map(|idx| args[idx].clone()).unwrap_or_else(|| ty.clone())
84        }
85        Type::Ident { name, params: nested } => Type::Ident { name: name.clone(), params: nested.iter().map(|param| substitute_type(param, params, args)).collect() },
86        Type::Struct { params: struct_params, fields } => Type::Struct {
87            params: struct_params.iter().map(|param| substitute_type(param, params, args)).collect(),
88            fields: fields.iter().map(|(name, field_ty)| (name.clone(), substitute_type(field_ty, params, args))).collect(),
89        },
90        Type::List(elem) => Type::List(Rc::new(substitute_type(elem, params, args))),
91        Type::Vec(elem, len) => Type::Vec(Rc::new(substitute_type(elem, params, args)), *len),
92        Type::Array(elem, len) => Type::Array(Rc::new(substitute_type(elem, params, args)), *len),
93        Type::ArrayParam(elem, len) => Type::ArrayParam(Rc::new(substitute_type(elem, params, args)), Rc::new(substitute_type(len, params, args))),
94        Type::ConstBinary { op, left, right } => {
95            let left = substitute_type(left, params, args);
96            let right = substitute_type(right, params, args);
97            let ty = Type::ConstBinary { op: *op, left: Rc::new(left), right: Rc::new(right) };
98            eval_const_int_type(&ty).map(Type::ConstInt).unwrap_or(ty)
99        }
100        Type::Fn { tys, ret } => Type::Fn { tys: tys.iter().map(|ty| substitute_type(ty, params, args)).collect(), ret: Rc::new(substitute_type(ret, params, args)) },
101        Type::Symbol { id, params: nested } => Type::Symbol { id: *id, params: nested.iter().map(|param| substitute_type(param, params, args)).collect() },
102        Type::Tuple(items) => Type::Tuple(items.iter().map(|item| substitute_type(item, params, args)).collect()),
103        _ => ty.clone(),
104    }
105}
106
107#[derive(Clone, Default)]
108pub struct SymbolTable {
109    pub symbols: IndexMap<SmolStr, Symbol>,
110    // 双 IndexMap:按模块名 + 模块内符号名 O(1) 查找;按插入顺序遍历。
111    modules: IndexMap<SmolStr, IndexMap<SmolStr, u32>>,
112    pub roots: Vec<SmolStr>,
113    lookup_cache: RefCell<HashMap<SmolStr, (u64, u32)>>,
114    lookup_epoch: u64,
115}
116
117impl SymbolTable {
118    fn invalidate_lookup_cache(&mut self) {
119        self.lookup_epoch = self.lookup_epoch.wrapping_add(1);
120        self.lookup_cache.get_mut().clear();
121    }
122
123    fn cached_id(&self, name: &str) -> Option<u32> {
124        let key = SmolStr::new(name);
125        self.lookup_cache.borrow().get(&key).and_then(|(epoch, id)| (*epoch == self.lookup_epoch).then_some(*id))
126    }
127
128    fn cache_id(&self, name: &str, id: u32) {
129        self.lookup_cache.borrow_mut().insert(SmolStr::new(name), (self.lookup_epoch, id));
130    }
131
132    pub fn add_to_module(&mut self, module: &str, name: SmolStr, s: Symbol) -> Result<u32> {
133        self.invalidate_lookup_cache();
134        let full_name: SmolStr = format!("{}::{}", module, name).into();
135        let id = self.symbols.insert_full(full_name, s).0 as u32;
136        let module_symbols = self.modules.get_mut(module).ok_or_else(|| anyhow!("模块 {} 不存在", module))?;
137        module_symbols.insert(name, id);
138        Ok(id)
139    }
140    pub fn get_symbol(&self, idx: u32) -> Result<(&SmolStr, &Symbol)> {
141        self.symbols.get_index(idx as usize).ok_or(anyhow!("未发现符号 {}", idx))
142    }
143
144    pub fn get_symbol_mut(&mut self, idx: u32) -> Option<(&SmolStr, &mut Symbol)> {
145        self.symbols.get_index_mut(idx as usize)
146    }
147
148    pub fn symbol(&self, name: &str) -> Vec<(SmolStr, u32)> {
149        self.modules.get(name).map(|m| m.iter().map(|(name, id)| (name.clone(), *id)).collect()).unwrap_or(Vec::new())
150    }
151
152    pub fn disassemble(&self, name: &str) -> Result<String> {
153        let id = self.get_id(name)?;
154        let (name, s) = self.get_symbol(id)?;
155        if let Symbol::Fn { ty, args, generic_params: _, cap, body, is_pub } = s {
156            if *is_pub { Ok(format!("pub {} {:?} {:?} {:?}\n{}", name, ty, args, cap, body)) } else { Ok(format!("{} {:?} {:?} {:?}\n{}", name, ty, args, cap, body)) }
157        } else {
158            Err(anyhow!("未发现符号 {}", name))
159        }
160    }
161
162    pub fn get_field(&self, ty: &Type, name: &str) -> Result<(usize, Type)> {
163        //原生类型的函数 is_map is_list 或者 sqrt
164        let id = match ty {
165            Type::Any => {
166                if let Ok(id) = self.get_id("Any")
167                    && let Ok((_, Symbol::Struct(any_ty, _))) = self.get_symbol(id)
168                    && let Ok((idx, field_ty)) = any_ty.get_field(name)
169                {
170                    return Ok((idx, field_ty.clone()));
171                }
172                match name {
173                    "is_map" | "is_list" | "is_string" | "is_null" | "contains" | "starts_with" => return Ok((usize::MAX, Type::Bool)),
174                    "len" => return Ok((usize::MAX, Type::I32)),
175                    _ => return Ok((usize::MAX, Type::Any)),
176                }
177            }
178            Type::Struct { params: _, fields: _ } => {
179                return ty.get_field(name).map(|(idx, ty)| (idx, ty.clone()));
180            }
181            Type::Str => {
182                let any_method = match name {
183                    "len" | "contains" | "split" | "starts_with" | "is_string" | "is_null" => format!("Any::{}", name),
184                    _ => return Err(anyhow!("未发现 symbol {:?} {}", ty, name)),
185                };
186                return Ok((usize::MAX, Type::Symbol { id: self.get_id(&any_method)?, params: Vec::new() }));
187            }
188            Type::List(_) | Type::Array(_, _) => {
189                let any_method = match name {
190                    "len" | "push" | "pop" | "get_idx" | "set_idx" | "slice" | "is_list" | "is_null" => format!("Any::{}", name),
191                    _ => return Err(anyhow!("未发现 symbol {:?} {}", ty, name)),
192                };
193                return Ok((usize::MAX, Type::Symbol { id: self.get_id(&any_method)?, params: Vec::new() }));
194            }
195            Type::Symbol { id, params: _ } => *id,
196            Type::Vec(_, _) => self.get_id("Vec")?,
197            Type::Fn { tys: _, ret } => {
198                return self.get_field(ret, name);
199            }
200            _ => {
201                //增加一个外部函数定义
202                if matches!(name, "is_map" | "is_list" | "is_string" | "is_null") {
203                    return Ok((usize::MAX, Type::Symbol { id: self.get_id(&format!("Any::{}", name))?, params: Vec::new() }));
204                }
205                return Err(anyhow!("未发现 symbol {:?} {}", ty, name));
206            }
207        };
208        let (_, s) = self.get_symbol(id)?;
209        if let Symbol::Struct(s, _) = s {
210            return s.get_field(name).and_then(|(idx, ty)| Ok((idx, ty.clone())));
211        };
212        Err(anyhow!("未发现 field {:?} {}", ty, name))
213    }
214
215    pub fn get_type(&self, ty: &Type) -> Result<Type> {
216        match ty {
217            Type::Ident { name, params } => {
218                let params = params.iter().map(|param| self.get_type(param)).collect::<Result<Vec<_>>>()?;
219                if name.as_str() == "Vec" && params.len() == 1 {
220                    return Ok(Type::Vec(Rc::new(params[0].clone()), 0));
221                }
222                if name.as_str() == "List" {
223                    return Ok(if params.is_empty() { Type::list_any() } else { Type::List(Rc::new(params[0].clone())) });
224                }
225                let id = self.get_id(&name)?;
226                if let (_, Symbol::Struct(ty, _)) = self.get_symbol(id)? {
227                    if let Type::Struct { params: generic_params, .. } = ty
228                        && !generic_params.is_empty()
229                        && generic_params.len() == params.len()
230                    {
231                        return self.get_type(&substitute_type(ty, generic_params, &params));
232                    }
233                    return self.get_type(ty);
234                }
235                return Ok(Type::Symbol { id, params });
236            }
237            Type::Symbol { id, params } => {
238                return match self.get_symbol(*id)? {
239                    (_, Symbol::Fn { ty, args: _, generic_params: _, cap: _, body: _, is_pub: _ }) => self.get_type(ty),
240                    (_, Symbol::Native(ty)) => self.get_type(ty),
241                    (_, Symbol::Struct(ty, _)) => {
242                        let params = params.iter().map(|param| self.get_type(param)).collect::<Result<Vec<_>>>()?;
243                        if let Type::Struct { params: generic_params, .. } = ty
244                            && !generic_params.is_empty()
245                            && generic_params.len() == params.len()
246                        {
247                            self.get_type(&substitute_type(ty, generic_params, &params))
248                        } else {
249                            self.get_type(ty)
250                        }
251                    }
252                    (_, s) => {
253                        log::debug!("s-> {:?}", s);
254                        Ok(Type::Symbol { id: *id, params: params.clone() })
255                    }
256                };
257            }
258            Type::Vec(elem, len) => {
259                return Ok(Type::Vec(Rc::new(self.get_type(elem)?), *len));
260            }
261            Type::List(elem) => {
262                return Ok(Type::List(Rc::new(self.get_type(elem)?)));
263            }
264            Type::Array(elem, len) => {
265                return Ok(Type::Array(Rc::new(self.get_type(elem)?), *len));
266            }
267            Type::ArrayParam(elem, len) => {
268                let elem = self.get_type(elem)?;
269                let len = self.get_type(len)?;
270                if let Some(len) = eval_const_int_type(&len) {
271                    let len = u32::try_from(len).map_err(|_| anyhow!("数组长度超出 u32 范围"))?;
272                    return Ok(Type::Array(Rc::new(elem), len));
273                }
274                return Ok(Type::ArrayParam(Rc::new(elem), Rc::new(len)));
275            }
276            Type::ConstBinary { op, left, right } => {
277                let left = self.get_type(left)?;
278                let right = self.get_type(right)?;
279                let ty = Type::ConstBinary { op: *op, left: Rc::new(left), right: Rc::new(right) };
280                return Ok(eval_const_int_type(&ty).map(Type::ConstInt).unwrap_or(ty));
281            }
282            Type::Fn { tys, ret } => {
283                return Ok(Type::Fn { tys: tys.iter().map(|ty| self.get_type(ty)).collect::<Result<Vec<_>>>()?, ret: Rc::new(self.get_type(ret)?) });
284            }
285            Type::Struct { params, fields } => {
286                return Ok(Type::Struct {
287                    params: params.iter().map(|param| self.get_type(param)).collect::<Result<Vec<_>>>()?,
288                    fields: fields.iter().map(|(name, ty)| if matches!(ty, Type::Symbol { .. }) { Ok((name.clone(), ty.clone())) } else { self.get_type(ty).map(|ty| (name.clone(), ty)) }).collect::<Result<Vec<_>>>()?,
289                });
290            }
291            _ => {}
292        }
293        Ok(ty.clone())
294    }
295
296    pub fn add_module(&mut self, name: SmolStr) {
297        self.invalidate_lookup_cache();
298        let len = self.roots.len();
299        if let Some(pos) = self.roots.iter().position(|r| r.as_str() == name.as_str()) {
300            if pos != len - 1 {
301                self.roots.swap(pos, len - 1);
302            }
303        } else {
304            self.roots.push(name.clone());
305        }
306        self.modules.insert(name, IndexMap::new());
307    }
308
309    pub fn push_module_scope(&mut self, name: SmolStr) {
310        self.invalidate_lookup_cache();
311        self.roots.push(name);
312    }
313
314    pub fn pop_module_scope(&mut self) {
315        self.invalidate_lookup_cache();
316        self.roots.pop();
317    }
318
319    pub fn pop_module(&mut self) {
320        self.invalidate_lookup_cache();
321        //如果不想模块成为全局的 add_module 之后调用 pop_module
322        if let Some(last) = self.roots.pop() {
323            if let Some(names) = self.modules.get(&last).map(|m| {
324                let kvs: Vec<(SmolStr, u32)> = m.iter().map(|kv| (kv.0.clone(), *kv.1)).collect();
325                kvs.iter().filter_map(|kv| if !self.get_symbol(kv.1).map(|s| s.1.is_pub()).unwrap_or(false) { Some(kv.0.clone()) } else { None }).collect::<Vec<_>>()
326            }) {
327                if let Some(m) = self.modules.get_mut(&last) {
328                    for name in names {
329                        m.shift_remove(&name); //删除非 pub 的符号;保持插入顺序
330                    }
331                }
332            }
333        }
334    }
335
336    pub fn get_id(&self, name: &str) -> Result<u32> {
337        if let Some(id) = self.cached_id(name) {
338            return Ok(id);
339        }
340        // 1) 全名命中 (含 :: 的 id 或 add_global 注册过的全局符号)。
341        if let Some(idx) = self.symbols.get_index_of(name) {
342            let id = idx as u32;
343            self.cache_id(name, id);
344            return Ok(id);
345        }
346        // 2) 含 :: 的名字拆分按模块查找,O(1)。
347        if let Some((mod_name, sym_name)) = name.split_once("::") {
348            if let Some(&id) = self.modules.get(mod_name).and_then(|m| m.get(sym_name)) {
349                self.cache_id(name, id);
350                return Ok(id);
351            }
352            // 即使 modules 里被 pop_module 移除,完整名仍可能在 symbols IndexMap 里。
353            let full = format!("{mod_name}::{sym_name}");
354            if let Some(idx) = self.symbols.get_index_of(full.as_str()) {
355                let id = idx as u32;
356                self.cache_id(name, id);
357                return Ok(id);
358            }
359        }
360        // 3) 不含 :: 的短名,按 roots 倒序查找第一个匹配的模块。O(M)。
361        //    pop_module 会清掉非 pub 项,所以 modules 可能没有;再退回按 root 拼全名查 symbols。
362        let short_name = name;
363        for root in self.roots.iter().rev() {
364            if let Some(&id) = self.modules.get(root).and_then(|m| m.get(short_name)) {
365                self.cache_id(name, id);
366                return Ok(id);
367            }
368            let full = format!("{root}::{short_name}");
369            if let Some(idx) = self.symbols.get_index_of(full.as_str()) {
370                let id = idx as u32;
371                self.cache_id(name, id);
372                return Ok(id);
373            }
374        }
375        // 4) 兜底:任意模块下的同名短符号。O(M*K),K 为模块内符号数。
376        for (_, m) in &self.modules {
377            if let Some(&id) = m.get(short_name) {
378                self.cache_id(name, id);
379                return Ok(id);
380            }
381        }
382        Err(anyhow!("{} 未发现", name))
383    }
384
385    pub fn add(&mut self, name: SmolStr, s: Symbol) -> u32 {
386        self.invalidate_lookup_cache();
387        let root = self.roots.last().cloned().unwrap();
388        let id = self.symbols.insert_full(format!("{}::{}", root, name).into(), s).0 as u32;
389        self.modules.get_mut(&root).map(|m| m.insert(name, id));
390        id
391    }
392
393    pub fn add_global(&mut self, name: SmolStr, s: Symbol) -> u32 {
394        if let Some(idx) = self.symbols.get_index_of(name.as_str()) {
395            return idx as u32;
396        }
397        if let Some((mod_name, symbol_name)) = name.as_str().split_once("::") {
398            if let Some(m) = self.modules.get_mut(mod_name) {
399                if let Some(&id) = m.get(symbol_name) {
400                    return id;
401                }
402            }
403        }
404        self.invalidate_lookup_cache();
405        let id = self.symbols.insert_full(name.clone(), s).0 as u32;
406        if let Some((mod_name, symbol_name)) = name.as_str().split_once("::") {
407            if let Some(m) = self.modules.get_mut(mod_name) {
408                m.insert(symbol_name.into(), id);
409            }
410        }
411        id
412    }
413
414    pub fn take(&mut self, id: u32) -> Option<Symbol> {
415        self.invalidate_lookup_cache();
416        self.symbols.get_index_mut(id as usize).map(|(_, s)| std::mem::take(s))
417    }
418}