pipeline_script/llvm/
types.rs

1use crate::ast::r#type::Type;
2use crate::llvm::global::Global;
3use crate::llvm::value::bool::BoolValue;
4use crate::llvm::value::double::DoubleValue;
5use crate::llvm::value::float::FloatValue;
6use crate::llvm::value::fucntion::FunctionValue;
7use crate::llvm::value::int::{Int16Value, Int32Value, Int64Value, Int8Value};
8use crate::llvm::value::pointer::PointerValue;
9use crate::llvm::value::pstruct::StructValue;
10use crate::llvm::value::reference::ReferenceValue;
11use crate::llvm::value::LLVMValue;
12use llvm_sys::core::{
13    LLVMGetElementType, LLVMGetIntTypeWidth, LLVMGetTypeKind, LLVMGetUndef, LLVMInt32Type,
14    LLVMIsFunctionVarArg, LLVMPrintTypeToString, LLVMTypeOf,
15};
16use llvm_sys::prelude::{LLVMTypeRef, LLVMValueRef};
17use llvm_sys::LLVMTypeKind;
18use std::collections::HashMap;
19use std::ffi::CStr;
20use std::fmt::{Display, Formatter};
21
22#[derive(Clone, Debug, PartialEq)]
23pub enum LLVMType {
24    Int1(LLVMTypeRef),
25    Int8(LLVMTypeRef),
26    Int16(LLVMTypeRef),
27    Int32(LLVMTypeRef),
28    Int64(LLVMTypeRef),
29    Float(LLVMTypeRef),
30    Double(LLVMTypeRef),
31    Struct(String, Vec<(String, LLVMType)>, LLVMTypeRef),
32    Array(Box<LLVMType>, LLVMTypeRef),
33    Function(Box<LLVMType>, Vec<(String, LLVMType)>, LLVMTypeRef),
34    Pointer(Box<LLVMType>, LLVMTypeRef),
35    Ref(Box<LLVMType>, LLVMTypeRef),
36    Inject(Box<LLVMType>),
37    String(LLVMTypeRef),
38    Unit(LLVMTypeRef),
39}
40
41impl From<LLVMValueRef> for LLVMType {
42    #[allow(clippy::not_unsafe_ptr_arg_deref)]
43    fn from(value: LLVMValueRef) -> Self {
44        let ty = unsafe { LLVMTypeOf(value) };
45        let type_kind = unsafe { LLVMGetTypeKind(ty) };
46        match type_kind {
47            LLVMTypeKind::LLVMIntegerTypeKind => {
48                let width = unsafe { LLVMGetIntTypeWidth(ty) };
49                let width = width as i8;
50                match width {
51                    8 => LLVMType::Int8(ty),
52                    32 => LLVMType::Int32(ty),
53                    _ => {
54                        todo!()
55                    }
56                }
57            }
58            _ => {
59                todo!()
60            }
61        }
62    }
63}
64
65impl From<LLVMTypeRef> for LLVMType {
66    #[allow(clippy::not_unsafe_ptr_arg_deref)]
67    fn from(ty: LLVMTypeRef) -> Self {
68        let type_kind = unsafe { LLVMGetTypeKind(ty) };
69        match type_kind {
70            LLVMTypeKind::LLVMIntegerTypeKind => {
71                let width = unsafe { LLVMGetIntTypeWidth(ty) };
72                let width = width as i8;
73                match width {
74                    1 => LLVMType::Int1(ty),
75                    8 => LLVMType::Int8(ty),
76                    16 => LLVMType::Int16(ty),
77                    32 => LLVMType::Int32(ty),
78                    64 => LLVMType::Int64(ty),
79                    _ => {
80                        todo!()
81                    }
82                }
83            }
84            LLVMTypeKind::LLVMArrayTypeKind => {
85                let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
86                LLVMType::Array(Box::new(element_ty), ty)
87            }
88            LLVMTypeKind::LLVMPointerTypeKind => {
89                let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
90                LLVMType::Pointer(Box::new(element_ty), ty)
91            }
92            LLVMTypeKind::LLVMVoidTypeKind => LLVMType::Unit(ty),
93            LLVMTypeKind::LLVMFloatTypeKind => LLVMType::Float(ty),
94            LLVMTypeKind::LLVMDoubleTypeKind => LLVMType::Double(ty),
95            t => {
96                println!("{t:?}");
97                todo!()
98            }
99        }
100    }
101}
102
103impl LLVMType {
104    pub fn as_llvm_type_ref(&self) -> LLVMTypeRef {
105        match self {
106            LLVMType::Int1(i) => *i,
107            LLVMType::Int8(i) => *i,
108            LLVMType::Int16(i) => *i,
109            LLVMType::Int32(i) => *i,
110            LLVMType::Int64(i) => *i,
111            LLVMType::Float(i) => *i,
112            LLVMType::Double(i) => *i,
113            LLVMType::Array(_, i) => *i,
114            LLVMType::Function(_, _, i) => *i,
115            LLVMType::Pointer(_, i) => *i,
116            LLVMType::Ref(_, i) => *i,
117            LLVMType::Unit(i) => *i,
118            LLVMType::Struct(_, _, i) => *i,
119            LLVMType::String(i) => *i,
120            LLVMType::Inject(i) => i.as_llvm_type_ref(),
121        }
122    }
123    pub fn get_element_type(&self) -> LLVMType {
124        match self {
125            LLVMType::Array(e_t, _) => *e_t.clone(),
126            LLVMType::Pointer(e, _) => *e.clone(),
127            t => {
128                println!("{t:?}");
129                todo!()
130            }
131        }
132    }
133    pub fn get_type(&self) -> Type {
134        match self {
135            LLVMType::Int1(_) => Type::Bool,
136            LLVMType::Int8(_) => Type::Int8,
137            LLVMType::Int16(_) => Type::Int16,
138            LLVMType::Int32(_) => Type::Int32,
139            LLVMType::Int64(_) => Type::Int64,
140            LLVMType::Float(_) => Type::Float,
141            LLVMType::Double(_) => Type::Double,
142            LLVMType::String(_) => Type::String,
143            LLVMType::Unit(_) => Type::Unit,
144            LLVMType::Pointer(element_type, _) => Type::Pointer(Box::new(element_type.get_type())),
145            LLVMType::Ref(element_type, _) => Type::Ref(Box::new(element_type.get_type())),
146            LLVMType::Array(element_type, _) => Type::Array(Box::new(element_type.get_type())),
147            LLVMType::Struct(name, fields, _) => {
148                let mut field_types = vec![];
149                for (field_name, field_type) in fields {
150                    field_types.push((field_name.clone(), field_type.get_type()));
151                }
152                Type::Struct(Some(name.clone()), field_types)
153            }
154            LLVMType::Function(return_type, args, _) => {
155                let mut arg_types = vec![];
156                for (arg_name, arg_type) in args {
157                    arg_types.push((arg_name.clone(), arg_type.get_type()));
158                }
159                Type::Function(Box::new(return_type.get_type()), arg_types, false)
160            }
161            LLVMType::Inject(i) => Type::Inject(Box::new(i.get_type())),
162        }
163    }
164    pub fn get_function_param_type(&self, index: usize) -> LLVMType {
165        match self {
166            LLVMType::Function(_, v, f) => {
167                let r = unsafe { LLVMIsFunctionVarArg(*f) };
168                if r == 1 {
169                    return LLVMType::Unit(Global::unit_type().as_llvm_type_ref());
170                }
171                println!("{}", index);
172                v.get(index).cloned().unwrap().1
173            }
174            _ => panic!("Not a function"),
175        }
176    }
177    pub fn get_struct_field_type(&self, index: usize) -> LLVMType {
178        match self {
179            LLVMType::Struct(_, v, _) => v.get(index).cloned().unwrap().1,
180            _ => panic!("Not a struct"),
181        }
182    }
183    pub fn get_function_return_type(&self) -> LLVMType {
184        match self {
185            LLVMType::Function(_, _, f) => LLVMType::from(*f),
186            _ => panic!("Not a function"),
187        }
188    }
189    pub fn get_undef(&self) -> LLVMValue {
190        let reference = unsafe { LLVMGetUndef(self.as_llvm_type_ref()) };
191        match self {
192            LLVMType::Int1(_) => LLVMValue::Bool(BoolValue::new(reference)),
193            LLVMType::Pointer(element, _) => match &**element {
194                LLVMType::Function(ret, args, _) => {
195                    let undef_args = args
196                        .iter()
197                        .map(|arg| (arg.0.clone(), arg.1.get_undef()))
198                        .collect();
199                    LLVMValue::Function(FunctionValue::new(
200                        reference,
201                        "".into(),
202                        Box::new(ret.get_undef()),
203                        undef_args,
204                    ))
205                }
206                _ => LLVMValue::Pointer(PointerValue::new(reference, element.get_undef())),
207            },
208            LLVMType::Int8(_) => LLVMValue::Int8(Int8Value::new(reference)),
209            LLVMType::Int16(_) => LLVMValue::Int16(Int16Value::new(reference)),
210            LLVMType::Int32(_) => LLVMValue::Int32(Int32Value::new(reference)),
211            LLVMType::Int64(_) => LLVMValue::Int64(Int64Value::new(reference)),
212            // LLVMType::Struct(_, _) => LLVMValue::Struct(reference),
213            LLVMType::Unit(_) => LLVMValue::Unit,
214            LLVMType::Struct(name, field, _) => {
215                let mut filed_index = HashMap::new();
216                for (i, f) in field.iter().enumerate() {
217                    filed_index.insert(f.0.clone(), i);
218                }
219                LLVMValue::Struct(StructValue::new(
220                    reference,
221                    name.clone(),
222                    filed_index.clone(),
223                    field.iter().map(|f| f.1.get_undef()).collect(),
224                ))
225            }
226            LLVMType::String(_) => LLVMValue::String(reference),
227            LLVMType::Function(ret, args, _) => {
228                let undef_args = args
229                    .iter()
230                    .map(|arg| (arg.0.clone(), arg.1.get_undef()))
231                    .collect();
232                LLVMValue::Function(FunctionValue::new(
233                    reference,
234                    "".into(),
235                    Box::new(ret.get_undef()),
236                    undef_args,
237                ))
238            }
239            LLVMType::Ref(element_type, type_reference) => {
240                let reference = unsafe { LLVMGetUndef(*type_reference) };
241                LLVMValue::Reference(ReferenceValue::new(reference, element_type.get_undef()))
242            }
243            LLVMType::Float(_) => LLVMValue::Float(FloatValue::new(reference)),
244            LLVMType::Double(_) => LLVMValue::Double(DoubleValue::new(reference)),
245            LLVMType::Inject(i) => LLVMValue::Inject(Box::new(i.get_undef())),
246            t => {
247                println!("{t:?}");
248                todo!()
249            }
250        }
251    }
252    pub fn is_array(&self) -> bool {
253        matches!(self, LLVMType::Array(_, _))
254    }
255    pub fn is_struct(&self) -> bool {
256        matches!(self, LLVMType::Struct(_, _, _))
257    }
258    pub fn is_special_struct(&self, name: impl AsRef<str>) -> bool {
259        let name = name.as_ref();
260        if let LLVMType::Struct(name0, ..) = self {
261            if name0.eq(name) {
262                return true;
263            }
264        }
265        false
266    }
267    pub fn i32() -> Self {
268        let t = unsafe { LLVMInt32Type() };
269        LLVMType::Int32(t)
270    }
271    pub fn is_function(&self) -> bool {
272        matches!(self, LLVMType::Function(_, _, _))
273    }
274    pub fn is_float(&self) -> bool {
275        matches!(self, LLVMType::Float(_))
276    }
277    pub fn is_double(&self) -> bool {
278        matches!(self, LLVMType::Double(_))
279    }
280    pub fn is_pointer(&self) -> bool {
281        matches!(self, LLVMType::Pointer(_, _))
282    }
283    pub fn is_i32(&self) -> bool {
284        matches!(self, LLVMType::Int32(_))
285    }
286    pub fn size(&self) -> usize {
287        match self {
288            LLVMType::Int1(_) => 1,
289            LLVMType::Int8(_) => 1,
290            LLVMType::Int16(_) => 2,
291            LLVMType::Int32(_) => 4,
292            LLVMType::Int64(_) => 8,
293            LLVMType::Float(_) => 4,
294            LLVMType::Double(_) => 8,
295            LLVMType::Pointer(_, _) => 8, // 指针大小为8字节(64位系统)
296            LLVMType::Inject(i) => i.size(), // 指针大小为8字节(64位系统)
297            LLVMType::Ref(_, _) => 8,     // 引用大小为8字节(64位系统)
298            LLVMType::Struct(_, fields, _) => {
299                // 结构体大小为所有字段大小之和
300                fields.iter().map(|f| f.1.size()).sum()
301            }
302            LLVMType::Array(element_type, _) => {
303                // 数组大小为元素大小
304                element_type.size()
305            }
306            LLVMType::Function(_, _, _) => 8, // 函数指针大小为8字节
307            LLVMType::Unit(_) => 0,           // Unit类型大小为0
308            LLVMType::String(_) => 8,         // 字符串指针大小为8字节
309        }
310    }
311}
312
313impl Display for LLVMType {
314    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
315        let str = unsafe {
316            let c = LLVMPrintTypeToString(self.as_llvm_type_ref());
317            let c_str = CStr::from_ptr(c);
318            let str_slice = c_str
319                .to_str()
320                .expect("Failed to convert C string to Rust string slice");
321            str_slice
322        };
323        write!(f, "{str}")
324    }
325}