pipeline_script/context/
mod.rs

1use crate::ast::module::Module;
2use crate::ast::r#type::Type;
3use crate::ast::type_alias::TypeAlias;
4use crate::ast::NodeTrait;
5use crate::context::key::ContextKey;
6use crate::context::scope::Scope;
7use crate::context::value::ContextValue;
8use crate::llvm::builder::ir::IRBuilder;
9use crate::llvm::context::LLVMContext;
10use crate::llvm::module::LLVMModule;
11use crate::llvm::types::LLVMType;
12use crate::llvm::value::fucntion::FunctionValue;
13use crate::llvm::value::LLVMValue;
14use llvm_sys::prelude::LLVMBasicBlockRef;
15use slotmap::DefaultKey;
16use std::collections::HashMap;
17use std::rc::Rc;
18use std::sync::{Arc, Mutex, RwLock};
19
20pub mod key;
21pub mod scope;
22pub mod value;
23
24#[derive(Clone, Debug)]
25pub struct Context {
26    parent: Option<Box<Context>>,
27    key: ContextKey,
28    value: ContextValue,
29}
30
31impl Context {
32    pub fn background() -> Self {
33        Self {
34            parent: None,
35            key: ContextKey::Background,
36            value: ContextValue::Background,
37        }
38    }
39    pub fn with_builder(parent: &Context, builder: IRBuilder) -> Self {
40        Self {
41            parent: Some(Box::new(parent.clone())),
42            key: ContextKey::Builder,
43            value: ContextValue::Builder(Arc::new(builder)),
44        }
45    }
46    pub fn with_module_slot_map(parent: &Context, t: slotmap::SlotMap<DefaultKey, Module>) -> Self {
47        Self::with_value(
48            parent,
49            ContextKey::ModuleSlotMap,
50            ContextValue::ModuleSlotMap(Arc::new(RwLock::new(t))),
51        )
52    }
53    pub fn apply_mut_module(&self, key: DefaultKey, apply: impl Fn(&mut Module)) {
54        let slot_map = self.get(ContextKey::ModuleSlotMap).unwrap();
55        match slot_map {
56            ContextValue::ModuleSlotMap(slot_map) => {
57                let mut slot_map = slot_map.write().unwrap();
58                let module = slot_map.get_mut(key).unwrap();
59                apply(module)
60            }
61            _ => panic!("not a module slot map"),
62        }
63    }
64    pub fn get_module_slot_map(&self) -> Arc<RwLock<slotmap::SlotMap<DefaultKey, Module>>> {
65        match self.get(ContextKey::ModuleSlotMap) {
66            Some(ContextValue::ModuleSlotMap(slot_map)) => slot_map.clone(),
67            _ => Arc::new(RwLock::new(slotmap::SlotMap::new())),
68        }
69    }
70
71    pub fn apply_module(&self, key: DefaultKey, mut apply: impl FnMut(&Module)) {
72        let slot_map = self.get(ContextKey::ModuleSlotMap).unwrap();
73        match slot_map {
74            ContextValue::ModuleSlotMap(slot_map) => {
75                let slot_map = slot_map.read().unwrap();
76                let module = slot_map.get(key).unwrap();
77                apply(module)
78            }
79            _ => panic!("not a module slot map"),
80        }
81    }
82    pub fn register_module(&self, module: Module) -> DefaultKey {
83        let slot_map = self.get(ContextKey::ModuleSlotMap).unwrap();
84        match slot_map {
85            ContextValue::ModuleSlotMap(slot_map) => {
86                let mut slot_map = slot_map.write().unwrap();
87                slot_map.insert(module)
88            }
89            _ => panic!("not a module slot map"),
90        }
91    }
92    pub fn with_type(parent: &Context, name: String, ty: Type) -> Self {
93        Self::with_value(parent, ContextKey::Type(name), ContextValue::Type(ty))
94    }
95    pub fn get_builder(&self) -> Arc<IRBuilder> {
96        match self.get(ContextKey::Builder) {
97            Some(ContextValue::Builder(b)) => b.clone(),
98            _ => panic!("not a builder"),
99        }
100    }
101    pub fn get_current_function(&self) -> FunctionValue {
102        match self.get(ContextKey::Function) {
103            Some(ContextValue::Function(b)) => b.clone(),
104            _ => panic!("not a function"),
105        }
106    }
107    pub fn get_current_function_type(&self) -> &Type {
108        match self.get(ContextKey::Type("current_function".into())) {
109            Some(ContextValue::Type(ty)) => ty,
110            _ => panic!("not a function"),
111        }
112    }
113    pub fn with_scope(parent: &Context) -> Self {
114        Self::with_value(parent, ContextKey::Scope, ContextValue::Scope(Scope::new()))
115    }
116    pub fn with_type_table(parent: &Context, t: HashMap<Type, LLVMType>) -> Self {
117        Self::with_value(
118            parent,
119            ContextKey::TypeTable,
120            ContextValue::TypeTable(Rc::new(RwLock::new(t))),
121        )
122    }
123    pub fn create_llvm_context() -> Self {
124        let llvm_ctx = LLVMContext::new();
125        let module = llvm_ctx.create_module("main");
126        let ctx = Context {
127            parent: None,
128            key: ContextKey::LLVMContext,
129            value: ContextValue::LLVMContext(Rc::new(Mutex::new(LLVMContext::new()))),
130        };
131        Self {
132            parent: Some(Box::new(ctx)),
133            key: ContextKey::LLVMModule,
134            value: ContextValue::LLVMModule(Rc::new(RwLock::new(module))),
135        }
136    }
137    pub fn with_function(parent: &Context, f: FunctionValue) -> Self {
138        Self::with_value(parent, ContextKey::Function, ContextValue::Function(f))
139    }
140    pub fn with_flag(parent: &Context, key: impl Into<String>, flag: bool) -> Self {
141        Self::with_value(
142            parent,
143            ContextKey::Flag(key.into()),
144            ContextValue::Flag(Arc::new(RwLock::new(flag))),
145        )
146    }
147    pub fn get_llvm_module(&self) -> Rc<RwLock<LLVMModule>> {
148        match self.get(ContextKey::LLVMModule) {
149            Some(ContextValue::LLVMModule(m)) => m.clone(),
150            _ => panic!("not a llvm module"),
151        }
152    }
153    pub fn with_llvm_module(parent: &Context, m: LLVMModule) -> Self {
154        Self::with_value(
155            parent,
156            ContextKey::LLVMModule,
157            ContextValue::LLVMModule(Rc::new(RwLock::new(m))),
158        )
159    }
160    pub fn get_llvm_context(&self) -> Rc<Mutex<LLVMContext>> {
161        match self.get(ContextKey::LLVMContext) {
162            Some(ContextValue::LLVMContext(c)) => c.clone(),
163            _ => panic!("not a llvm context"),
164        }
165    }
166
167    pub fn with_local(parent: &Context, local: Vec<String>) -> Self {
168        Self::with_value(
169            parent,
170            ContextKey::LocalVariable,
171            ContextValue::LocalVariable(Arc::new(RwLock::new(local))),
172        )
173    }
174    pub fn get_type_binding_functions(&self, name: &str) -> Vec<crate::ast::function::Function> {
175        let slot_map = self.get_module_slot_map();
176        let slot_map = slot_map.read().unwrap();
177        let mut result = Vec::new();
178
179        for module in slot_map.values() {
180            for (_, function) in module.get_functions() {
181                if function.has_binding() && function.get_binding() == name {
182                    result.push(function);
183                }
184            }
185        }
186
187        result
188    }
189
190    pub fn with_capture(parent: &Context) -> Self {
191        Self::with_value(
192            parent,
193            ContextKey::CaptureVariable,
194            ContextValue::CaptureVariable(Arc::new(RwLock::new(vec![]))),
195        )
196    }
197    pub fn set_flag(&self, key: impl Into<String>, flag: bool) {
198        match self.get(ContextKey::Flag(key.into())) {
199            Some(ContextValue::Flag(f)) => {
200                let mut f = f.write().unwrap();
201                *f = flag;
202            }
203            _ => panic!("not a flag"),
204        }
205    }
206    pub fn get_flag(&self, key: impl Into<String>) -> Option<bool> {
207        match self.get(ContextKey::Flag(key.into())) {
208            Some(ContextValue::Flag(f)) => {
209                let f = f.read().unwrap();
210                Some(*f)
211            }
212            _ => None,
213        }
214    }
215    pub fn with_value(parent: &Context, key: ContextKey, value: ContextValue) -> Self {
216        Self {
217            parent: Some(Box::new(parent.clone())),
218            key,
219            value,
220        }
221    }
222
223    pub fn get(&self, key: ContextKey) -> Option<&ContextValue> {
224        if self.key == key {
225            return Some(&self.value);
226        }
227        match &self.parent {
228            None => None,
229            Some(parent) => parent.get(key),
230        }
231    }
232    pub fn get_scope(&self) -> Scope {
233        match self.get(ContextKey::Scope) {
234            Some(ContextValue::Scope(s)) => s.clone(),
235            _ => panic!("not a scope"),
236        }
237    }
238    pub fn set_symbol(&self, name: String, v: LLVMValue) {
239        let scope = self.get(ContextKey::Scope).unwrap().as_scope();
240        scope.set(name, v);
241    }
242    pub fn get_symbol(&self, name: impl AsRef<str>) -> Option<LLVMValue> {
243        if let Some(ContextValue::Scope(scope)) = self.get(ContextKey::Scope) {
244            match scope.has(name.as_ref()) {
245                true => return scope.get(name),
246                false => {
247                    // 当前Scope中没有,继续在父上下文中查找
248                }
249            }
250        }
251
252        match &self.parent {
253            None => None,
254            Some(parent) => parent.get_symbol(name),
255        }
256    }
257    pub fn get_type(&self, t: &Type) -> Option<LLVMType> {
258        match self.get(ContextKey::TypeTable) {
259            Some(ContextValue::TypeTable(tt0)) => {
260                let tt = tt0.read().unwrap();
261                let r = tt.get(t);
262                r.cloned()
263            }
264            _ => panic!("not a type table"),
265        }
266    }
267    pub fn register_type(&self, ty: &Type, llvm_ty: &LLVMType) {
268        match self.get(ContextKey::TypeTable) {
269            Some(ContextValue::TypeTable(tt0)) => {
270                let mut tt = tt0.write().unwrap();
271                tt.insert(ty.clone(), llvm_ty.clone());
272            }
273            _ => panic!("not a type table"),
274        }
275    }
276    pub fn get_alias_type(&self, name: impl AsRef<str>) -> Option<Type> {
277        match self.get(ContextKey::AliasType) {
278            Some(ContextValue::AliasType(t)) => {
279                let t = t.read().unwrap();
280                t.get(name.as_ref()).cloned()
281            }
282            _ => panic!("not a symbol type"),
283        }
284    }
285    pub fn set_alias_type(&self, name: String, t0: Type) {
286        match self.get(ContextKey::AliasType) {
287            Some(ContextValue::AliasType(t)) => {
288                let mut t = t.write().unwrap();
289                t.insert(name, t0);
290            }
291            _ => panic!("not a symbol type"),
292        }
293    }
294    pub fn get_symbol_type(&self, name: impl AsRef<str>) -> Option<Type> {
295        match self.get(ContextKey::SymbolType) {
296            Some(ContextValue::SymbolType(t)) => {
297                let t = t.read().unwrap();
298                let r = t.get(name.as_ref());
299                match r {
300                    None => self.parent.clone()?.parent?.get_symbol_type(name),
301                    Some(s) => Some(s.clone()),
302                }
303            }
304            _ => None,
305        }
306    }
307    pub fn try_add_local(&self, name: String) {
308        if let Some(ContextValue::LocalVariable(local)) = self.get(ContextKey::LocalVariable) {
309            let mut local = local.write().unwrap();
310            local.push(name);
311        }
312    }
313    pub fn add_capture(&self, name: String, ty: Type) {
314        match self.get(ContextKey::CaptureVariable) {
315            Some(ContextValue::CaptureVariable(local)) => {
316                let mut local = local.write().unwrap();
317                local.push((name, ty));
318            }
319            _ => panic!("not a local variable"),
320        }
321    }
322    pub fn is_local_variable(&self, name: impl AsRef<str>) -> bool {
323        match self.get(ContextKey::LocalVariable) {
324            Some(ContextValue::LocalVariable(local)) => {
325                let local = local.read().unwrap();
326                local.contains(&name.as_ref().to_string())
327            }
328            _ => panic!("not a local variable"),
329        }
330    }
331    pub fn get_captures(&self) -> Option<Vec<(String, Type)>> {
332        match self.get(ContextKey::CaptureVariable) {
333            Some(ContextValue::CaptureVariable(local)) => {
334                let local = local.read().unwrap();
335                Some(local.clone())
336            }
337            _ => panic!("not a local variable"),
338        }
339    }
340    pub fn get_function(&self, name: impl AsRef<str>) -> Option<crate::ast::function::Function> {
341        let slot_map = self.get_module_slot_map();
342        let slot_map = slot_map.read().unwrap();
343        for module in slot_map.values() {
344            if let Some(fun) = module.get_function(name.as_ref()) {
345                return Some(fun.clone());
346            }
347        }
348        None
349    }
350    pub fn set_symbol_type(&self, name: String, t0: Type) {
351        match self.get(ContextKey::SymbolType) {
352            Some(ContextValue::SymbolType(t)) => {
353                let mut t = t.write().unwrap();
354                t.insert(name, t0);
355            }
356            _ => panic!("not a symbol type"),
357        }
358    }
359    pub fn get_type_alias(&self, name: impl AsRef<str>) -> Option<TypeAlias> {
360        let slot_map = self.get_module_slot_map();
361        let slot_map = slot_map.read().unwrap();
362
363        for module in slot_map.values() {
364            if let Some(ty) = module.get_type_alias(name.as_ref()) {
365                return Some(ty.clone());
366            }
367        }
368
369        None
370    }
371    pub fn with_loop_block(parent: &Context, name: String, block: LLVMBasicBlockRef) -> Self {
372        Self::with_value(
373            parent,
374            ContextKey::LoopBlock(name),
375            ContextValue::LoopBlock(block),
376        )
377    }
378
379    pub fn get_loop_block(&self, name: &str) -> Option<LLVMBasicBlockRef> {
380        match self.get(ContextKey::LoopBlock(name.to_string())) {
381            Some(ContextValue::LoopBlock(block)) => Some(*block),
382            _ => None,
383        }
384    }
385
386    pub fn with_default_expr(parent: &Context) -> Self {
387        Self::with_value(
388            parent,
389            ContextKey::DefaultExpr,
390            ContextValue::DefaultExpr(Arc::new(RwLock::new(HashMap::new()))),
391        )
392    }
393    pub fn set_default_expr(&self, name: String, expr: Box<crate::ast::expr::ExprNode>) {
394        match self.get(ContextKey::DefaultExpr) {
395            Some(ContextValue::DefaultExpr(map)) => {
396                let mut map = map.write().unwrap();
397                map.insert(name, expr);
398            }
399            _ => panic!("not a default expr"),
400        }
401    }
402
403    pub fn get_default_expr(&self, name: &str) -> Option<Box<crate::ast::expr::ExprNode>> {
404        if let Some(ContextValue::DefaultExpr(map)) = self.get(ContextKey::DefaultExpr) {
405            map.read().unwrap().get(name).cloned()
406        } else {
407            None
408        }
409    }
410
411    pub fn is_global_variable(&self, name: impl AsRef<str>) -> bool {
412        let slot_map = self.get_module_slot_map();
413        let slot_map = slot_map.read().unwrap();
414        for module in slot_map.values() {
415            for stmt in module.get_global_block() {
416                if stmt.is_static_decl() {
417                    if let Some(var_name) = stmt.get_data("name") {
418                        if var_name.as_str().unwrap() == name.as_ref() {
419                            return true;
420                        }
421                    }
422                }
423            }
424        }
425        false
426    }
427}