pipeline_script/llvm/
types.rs

1use crate::ast::r#type::Type;
2use crate::llvm::global::Global;
3use crate::llvm::value::bool::BoolValue;
4use crate::llvm::value::double::DoubleValue;
5use crate::llvm::value::float::FloatValue;
6use crate::llvm::value::fucntion::FunctionValue;
7use crate::llvm::value::int::{Int16Value, Int32Value, Int64Value, Int8Value};
8use crate::llvm::value::pointer::PointerValue;
9use crate::llvm::value::pstruct::StructValue;
10use crate::llvm::value::reference::ReferenceValue;
11use crate::llvm::value::LLVMValue;
12use llvm_sys::core::{
13    LLVMGetElementType, LLVMGetIntTypeWidth, LLVMGetTypeKind, LLVMGetUndef, LLVMInt32Type,
14    LLVMIsFunctionVarArg, LLVMPrintTypeToString, LLVMTypeOf,
15};
16use llvm_sys::prelude::{LLVMTypeRef, LLVMValueRef};
17use llvm_sys::LLVMTypeKind;
18use std::collections::HashMap;
19use std::ffi::CStr;
20use std::fmt::{Display, Formatter};
21
22#[derive(Clone, Debug, PartialEq)]
23pub enum LLVMType {
24    Int1(LLVMTypeRef),
25    Int8(LLVMTypeRef),
26    Int16(LLVMTypeRef),
27    Int32(LLVMTypeRef),
28    Int64(LLVMTypeRef),
29    Float(LLVMTypeRef),
30    Double(LLVMTypeRef),
31    Struct(String, Vec<(String, LLVMType)>, LLVMTypeRef),
32    Array(Box<LLVMType>, LLVMTypeRef),
33    Function(Box<LLVMType>, Vec<(String, LLVMType)>, LLVMTypeRef),
34    Pointer(Box<LLVMType>, LLVMTypeRef),
35    Ref(Box<LLVMType>, LLVMTypeRef),
36    Inject(Box<LLVMType>),
37    String(LLVMTypeRef),
38    Unit(LLVMTypeRef),
39}
40
41impl From<LLVMValueRef> for LLVMType {
42    #[allow(clippy::not_unsafe_ptr_arg_deref)]
43    fn from(value: LLVMValueRef) -> Self {
44        let ty = unsafe { LLVMTypeOf(value) };
45        let type_kind = unsafe { LLVMGetTypeKind(ty) };
46        match type_kind {
47            LLVMTypeKind::LLVMIntegerTypeKind => {
48                let width = unsafe { LLVMGetIntTypeWidth(ty) };
49                let width = width as i8;
50                match width {
51                    8 => LLVMType::Int8(ty),
52                    32 => LLVMType::Int32(ty),
53                    _ => {
54                        todo!()
55                    }
56                }
57            }
58            _ => {
59                todo!()
60            }
61        }
62    }
63}
64
65impl From<LLVMTypeRef> for LLVMType {
66    #[allow(clippy::not_unsafe_ptr_arg_deref)]
67    fn from(ty: LLVMTypeRef) -> Self {
68        let type_kind = unsafe { LLVMGetTypeKind(ty) };
69        match type_kind {
70            LLVMTypeKind::LLVMIntegerTypeKind => {
71                let width = unsafe { LLVMGetIntTypeWidth(ty) };
72                let width = width as i8;
73                match width {
74                    8 => LLVMType::Int8(ty),
75                    32 => LLVMType::Int32(ty),
76                    64 => LLVMType::Int64(ty),
77                    _ => {
78                        todo!()
79                    }
80                }
81            }
82            LLVMTypeKind::LLVMArrayTypeKind => {
83                let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
84                LLVMType::Array(Box::new(element_ty), ty)
85            }
86            LLVMTypeKind::LLVMDoubleTypeKind => LLVMType::Double(ty),
87            t => {
88                println!("{t:?}");
89                todo!()
90            }
91        }
92    }
93}
94
95impl LLVMType {
96    pub fn as_llvm_type_ref(&self) -> LLVMTypeRef {
97        match self {
98            LLVMType::Int1(i) => *i,
99            LLVMType::Int8(i) => *i,
100            LLVMType::Int16(i) => *i,
101            LLVMType::Int32(i) => *i,
102            LLVMType::Int64(i) => *i,
103            LLVMType::Float(i) => *i,
104            LLVMType::Double(i) => *i,
105            LLVMType::Array(_, i) => *i,
106            LLVMType::Function(_, _, i) => *i,
107            LLVMType::Pointer(_, i) => *i,
108            LLVMType::Ref(_, i) => *i,
109            LLVMType::Unit(i) => *i,
110            LLVMType::Struct(_, _, i) => *i,
111            LLVMType::String(i) => *i,
112            LLVMType::Inject(i) => i.as_llvm_type_ref(),
113        }
114    }
115    pub fn get_element_type(&self) -> LLVMType {
116        match self {
117            LLVMType::Array(e_t, _) => *e_t.clone(),
118            LLVMType::Pointer(e, _) => *e.clone(),
119            t => {
120                println!("{t:?}");
121                todo!()
122            }
123        }
124    }
125    pub fn get_type(&self) -> Type {
126        match self {
127            LLVMType::Int1(_) => Type::Bool,
128            LLVMType::Int8(_) => Type::Int8,
129            LLVMType::Int16(_) => Type::Int16,
130            LLVMType::Int32(_) => Type::Int32,
131            LLVMType::Int64(_) => Type::Int64,
132            LLVMType::Float(_) => Type::Float,
133            LLVMType::Double(_) => Type::Double,
134            LLVMType::String(_) => Type::String,
135            LLVMType::Unit(_) => Type::Unit,
136            LLVMType::Pointer(element_type, _) => Type::Pointer(Box::new(element_type.get_type())),
137            LLVMType::Ref(element_type, _) => Type::Ref(Box::new(element_type.get_type())),
138            LLVMType::Array(element_type, _) => Type::Array(Box::new(element_type.get_type())),
139            LLVMType::Struct(name, fields, _) => {
140                let mut field_types = vec![];
141                for (field_name, field_type) in fields {
142                    field_types.push((field_name.clone(), field_type.get_type()));
143                }
144                Type::Struct(Some(name.clone()), field_types)
145            }
146            LLVMType::Function(return_type, args, _) => {
147                let mut arg_types = vec![];
148                for (arg_name, arg_type) in args {
149                    arg_types.push((arg_name.clone(), arg_type.get_type()));
150                }
151                Type::Function(Box::new(return_type.get_type()), arg_types, false)
152            }
153            LLVMType::Inject(i) => Type::Inject(Box::new(i.get_type())),
154        }
155    }
156    pub fn get_function_param_type(&self, index: usize) -> LLVMType {
157        match self {
158            LLVMType::Function(_, v, f) => {
159                let r = unsafe { LLVMIsFunctionVarArg(*f) };
160                if r == 1 {
161                    return LLVMType::Unit(Global::unit_type().as_llvm_type_ref());
162                }
163                println!("{}", index);
164                v.get(index).cloned().unwrap().1
165            }
166            _ => panic!("Not a function"),
167        }
168    }
169    pub fn get_struct_field_type(&self, index: usize) -> LLVMType {
170        match self {
171            LLVMType::Struct(_, v, _) => v.get(index).cloned().unwrap().1,
172            _ => panic!("Not a struct"),
173        }
174    }
175    pub fn get_function_return_type(&self) -> LLVMType {
176        match self {
177            LLVMType::Function(_, _, f) => LLVMType::from(*f),
178            _ => panic!("Not a function"),
179        }
180    }
181    pub fn get_undef(&self) -> LLVMValue {
182        let reference = unsafe { LLVMGetUndef(self.as_llvm_type_ref()) };
183        match self {
184            LLVMType::Int1(_) => LLVMValue::Bool(BoolValue::new(reference)),
185            LLVMType::Pointer(element, _) => match &**element {
186                LLVMType::Function(ret, args, _) => {
187                    let undef_args = args
188                        .iter()
189                        .map(|arg| (arg.0.clone(), arg.1.get_undef()))
190                        .collect();
191                    LLVMValue::Function(FunctionValue::new(
192                        reference,
193                        "".into(),
194                        Box::new(ret.get_undef()),
195                        undef_args,
196                    ))
197                }
198                _ => LLVMValue::Pointer(PointerValue::new(reference, element.get_undef())),
199            },
200            LLVMType::Int8(_) => LLVMValue::Int8(Int8Value::new(reference)),
201            LLVMType::Int16(_) => LLVMValue::Int16(Int16Value::new(reference)),
202            LLVMType::Int32(_) => LLVMValue::Int32(Int32Value::new(reference)),
203            LLVMType::Int64(_) => LLVMValue::Int64(Int64Value::new(reference)),
204            // LLVMType::Struct(_, _) => LLVMValue::Struct(reference),
205            LLVMType::Unit(_) => LLVMValue::Unit,
206            LLVMType::Struct(name, field, _) => {
207                let mut filed_index = HashMap::new();
208                for (i, f) in field.iter().enumerate() {
209                    filed_index.insert(f.0.clone(), i);
210                }
211                LLVMValue::Struct(StructValue::new(
212                    reference,
213                    name.clone(),
214                    filed_index.clone(),
215                    field.iter().map(|f| f.1.get_undef()).collect(),
216                ))
217            }
218            LLVMType::String(_) => LLVMValue::String(reference),
219            LLVMType::Function(ret, args, _) => {
220                let undef_args = args
221                    .iter()
222                    .map(|arg| (arg.0.clone(), arg.1.get_undef()))
223                    .collect();
224                LLVMValue::Function(FunctionValue::new(
225                    reference,
226                    "".into(),
227                    Box::new(ret.get_undef()),
228                    undef_args,
229                ))
230            }
231            LLVMType::Ref(element_type, type_reference) => {
232                let reference = unsafe { LLVMGetUndef(*type_reference) };
233                LLVMValue::Reference(ReferenceValue::new(reference, element_type.get_undef()))
234            }
235            LLVMType::Float(_) => LLVMValue::Float(FloatValue::new(reference)),
236            LLVMType::Double(_) => LLVMValue::Double(DoubleValue::new(reference)),
237            LLVMType::Inject(i) => LLVMValue::Inject(Box::new(i.get_undef())),
238            t => {
239                println!("{t:?}");
240                todo!()
241            }
242        }
243    }
244    pub fn is_array(&self) -> bool {
245        matches!(self, LLVMType::Array(_, _))
246    }
247    pub fn is_struct(&self) -> bool {
248        matches!(self, LLVMType::Struct(_, _, _))
249    }
250    pub fn is_special_struct(&self, name: impl AsRef<str>) -> bool {
251        let name = name.as_ref();
252        if let LLVMType::Struct(name0, ..) = self {
253            if name0.eq(name) {
254                return true;
255            }
256        }
257        false
258    }
259    pub fn i32() -> Self {
260        let t = unsafe { LLVMInt32Type() };
261        LLVMType::Int32(t)
262    }
263    pub fn is_function(&self) -> bool {
264        matches!(self, LLVMType::Function(_, _, _))
265    }
266    pub fn is_float(&self) -> bool {
267        matches!(self, LLVMType::Float(_))
268    }
269    pub fn is_double(&self) -> bool {
270        matches!(self, LLVMType::Double(_))
271    }
272    pub fn is_pointer(&self) -> bool {
273        matches!(self, LLVMType::Pointer(_, _))
274    }
275    pub fn is_i32(&self) -> bool {
276        matches!(self, LLVMType::Int32(_))
277    }
278    pub fn size(&self) -> usize {
279        match self {
280            LLVMType::Int1(_) => 1,
281            LLVMType::Int8(_) => 1,
282            LLVMType::Int16(_) => 2,
283            LLVMType::Int32(_) => 4,
284            LLVMType::Int64(_) => 8,
285            LLVMType::Float(_) => 4,
286            LLVMType::Double(_) => 8,
287            LLVMType::Pointer(_, _) => 8, // 指针大小为8字节(64位系统)
288            LLVMType::Inject(i) => i.size(), // 指针大小为8字节(64位系统)
289            LLVMType::Ref(_, _) => 8,     // 引用大小为8字节(64位系统)
290            LLVMType::Struct(_, fields, _) => {
291                // 结构体大小为所有字段大小之和
292                fields.iter().map(|f| f.1.size()).sum()
293            }
294            LLVMType::Array(element_type, _) => {
295                // 数组大小为元素大小
296                element_type.size()
297            }
298            LLVMType::Function(_, _, _) => 8, // 函数指针大小为8字节
299            LLVMType::Unit(_) => 0,           // Unit类型大小为0
300            LLVMType::String(_) => 8,         // 字符串指针大小为8字节
301        }
302    }
303}
304
305impl Display for LLVMType {
306    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
307        let str = unsafe {
308            let c = LLVMPrintTypeToString(self.as_llvm_type_ref());
309            let c_str = CStr::from_ptr(c);
310            let str_slice = c_str
311                .to_str()
312                .expect("Failed to convert C string to Rust string slice");
313            str_slice
314        };
315        write!(f, "{str}")
316    }
317}