pipeline_script/llvm/
types.rs

1use crate::ast::r#type::Type;
2use crate::llvm::global::Global;
3use crate::llvm::value::bool::BoolValue;
4use crate::llvm::value::double::DoubleValue;
5use crate::llvm::value::float::FloatValue;
6use crate::llvm::value::fucntion::FunctionValue;
7use crate::llvm::value::int::{Int16Value, Int32Value, Int64Value, Int8Value};
8use crate::llvm::value::pointer::PointerValue;
9use crate::llvm::value::pstruct::StructValue;
10use crate::llvm::value::reference::ReferenceValue;
11use crate::llvm::value::LLVMValue;
12use llvm_sys::core::{
13    LLVMGetElementType, LLVMGetIntTypeWidth, LLVMGetTypeKind, LLVMGetUndef, LLVMInt32Type,
14    LLVMIsFunctionVarArg, LLVMPrintTypeToString, LLVMTypeOf,
15};
16use llvm_sys::prelude::{LLVMTypeRef, LLVMValueRef};
17use llvm_sys::LLVMTypeKind;
18use std::collections::HashMap;
19use std::ffi::CStr;
20use std::fmt::{Display, Formatter};
21
22#[derive(Clone, Debug, PartialEq)]
23pub enum LLVMType {
24    Int1(LLVMTypeRef),
25    Int8(LLVMTypeRef),
26    Int16(LLVMTypeRef),
27    Int32(LLVMTypeRef),
28    Int64(LLVMTypeRef),
29    Float(LLVMTypeRef),
30    Double(LLVMTypeRef),
31    Struct(String, Vec<(String, LLVMType)>, LLVMTypeRef),
32    Array(Box<LLVMType>, LLVMTypeRef),
33    Function(Box<LLVMType>, Vec<(String, LLVMType)>, LLVMTypeRef),
34    Pointer(Box<LLVMType>, LLVMTypeRef),
35    Ref(Box<LLVMType>, LLVMTypeRef),
36    String(LLVMTypeRef),
37    Unit(LLVMTypeRef),
38}
39
40impl From<LLVMValueRef> for LLVMType {
41    fn from(value: LLVMValueRef) -> Self {
42        let ty = unsafe { LLVMTypeOf(value) };
43        let type_kind = unsafe { LLVMGetTypeKind(ty) };
44        match type_kind {
45            LLVMTypeKind::LLVMIntegerTypeKind => {
46                let width = unsafe { LLVMGetIntTypeWidth(ty) };
47                let width = width as i8;
48                match width {
49                    8 => LLVMType::Int8(ty),
50                    32 => LLVMType::Int32(ty),
51                    _ => {
52                        todo!()
53                    }
54                }
55            }
56            _ => {
57                todo!()
58            }
59        }
60    }
61}
62
63impl From<LLVMTypeRef> for LLVMType {
64    fn from(ty: LLVMTypeRef) -> Self {
65        let type_kind = unsafe { LLVMGetTypeKind(ty) };
66        match type_kind {
67            LLVMTypeKind::LLVMIntegerTypeKind => {
68                let width = unsafe { LLVMGetIntTypeWidth(ty) };
69                let width = width as i8;
70                match width {
71                    8 => LLVMType::Int8(ty),
72                    32 => LLVMType::Int32(ty),
73                    64 => LLVMType::Int64(ty),
74                    _ => {
75                        todo!()
76                    }
77                }
78            }
79            LLVMTypeKind::LLVMArrayTypeKind => {
80                let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
81                LLVMType::Array(Box::new(element_ty), ty)
82            }
83            LLVMTypeKind::LLVMDoubleTypeKind => LLVMType::Double(ty),
84            t => {
85                println!("{t:?}");
86                todo!()
87            }
88        }
89    }
90}
91
92impl LLVMType {
93    pub fn as_llvm_type_ref(&self) -> LLVMTypeRef {
94        match self {
95            LLVMType::Int1(i) => *i,
96            LLVMType::Int8(i) => *i,
97            LLVMType::Int16(i) => *i,
98            LLVMType::Int32(i) => *i,
99            LLVMType::Int64(i) => *i,
100            LLVMType::Float(i) => *i,
101            LLVMType::Double(i) => *i,
102            LLVMType::Array(_, i) => *i,
103            LLVMType::Function(_, _, i) => *i,
104            LLVMType::Pointer(_, i) => *i,
105            LLVMType::Ref(_, i) => *i,
106            LLVMType::Unit(i) => *i,
107            LLVMType::Struct(_, _, i) => *i,
108            LLVMType::String(i) => *i,
109        }
110    }
111    pub fn get_element_type(&self) -> LLVMType {
112        match self {
113            LLVMType::Array(e_t, _) => *e_t.clone(),
114            LLVMType::Pointer(e, _) => *e.clone(),
115            t => {
116                println!("{t:?}");
117                todo!()
118            }
119        }
120    }
121    pub fn get_type(&self) -> Type {
122        match self {
123            LLVMType::Int1(_) => Type::Bool,
124            LLVMType::Int8(_) => Type::Int8,
125            LLVMType::Int16(_) => Type::Int16,
126            LLVMType::Int32(_) => Type::Int32,
127            LLVMType::Int64(_) => Type::Int64,
128            LLVMType::Float(_) => Type::Float,
129            LLVMType::Double(_) => Type::Double,
130            LLVMType::String(_) => Type::String,
131            LLVMType::Unit(_) => Type::Unit,
132            LLVMType::Pointer(element_type, _) => Type::Pointer(Box::new(element_type.get_type())),
133            LLVMType::Ref(element_type, _) => Type::Ref(Box::new(element_type.get_type())),
134            LLVMType::Array(element_type, _) => Type::Array(Box::new(element_type.get_type())),
135            LLVMType::Struct(name, fields, _) => {
136                let mut field_types = vec![];
137                for (field_name, field_type) in fields {
138                    field_types.push((field_name.clone(), field_type.get_type()));
139                }
140                Type::Struct(Some(name.clone()), field_types)
141            }
142            LLVMType::Function(return_type, args, _) => {
143                let mut arg_types = vec![];
144                for (arg_name, arg_type) in args {
145                    arg_types.push((arg_name.clone(), arg_type.get_type()));
146                }
147                Type::Function(Box::new(return_type.get_type()), arg_types, false)
148            }
149        }
150    }
151    pub fn get_function_param_type(&self, index: usize) -> LLVMType {
152        match self {
153            LLVMType::Function(_, v, f) => {
154                let r = unsafe { LLVMIsFunctionVarArg(*f) };
155                if r == 1 {
156                    return LLVMType::Unit(Global::unit_type().as_llvm_type_ref());
157                }
158                println!("{}", index);
159                v.get(index).cloned().unwrap().1
160            }
161            _ => panic!("Not a function"),
162        }
163    }
164    pub fn get_struct_field_type(&self, index: usize) -> LLVMType {
165        match self {
166            LLVMType::Struct(_, v, _) => v.get(index).cloned().unwrap().1,
167            _ => panic!("Not a struct"),
168        }
169    }
170    pub fn get_function_return_type(&self) -> LLVMType {
171        match self {
172            LLVMType::Function(_, _, f) => LLVMType::from(*f),
173            _ => panic!("Not a function"),
174        }
175    }
176    pub fn get_undef(&self) -> LLVMValue {
177        let reference = unsafe { LLVMGetUndef(self.as_llvm_type_ref()) };
178        match self {
179            LLVMType::Int1(_) => LLVMValue::Bool(BoolValue::new(reference)),
180            LLVMType::Pointer(element, _) => match &**element {
181                LLVMType::Function(ret, args, _) => {
182                    let undef_args = args
183                        .iter()
184                        .map(|arg| (arg.0.clone(), arg.1.get_undef()))
185                        .collect();
186                    LLVMValue::Function(FunctionValue::new(
187                        reference,
188                        "".into(),
189                        Box::new(ret.get_undef()),
190                        undef_args,
191                    ))
192                }
193                _ => LLVMValue::Pointer(PointerValue::new(reference, element.get_undef())),
194            },
195            LLVMType::Int8(_) => LLVMValue::Int8(Int8Value::new(reference)),
196            LLVMType::Int16(_) => LLVMValue::Int16(Int16Value::new(reference)),
197            LLVMType::Int32(_) => LLVMValue::Int32(Int32Value::new(reference)),
198            LLVMType::Int64(_) => LLVMValue::Int64(Int64Value::new(reference)),
199            // LLVMType::Struct(_, _) => LLVMValue::Struct(reference),
200            LLVMType::Unit(_) => LLVMValue::Unit,
201            LLVMType::Struct(name, field, _) => {
202                let mut filed_index = HashMap::new();
203                for (i, f) in field.iter().enumerate() {
204                    filed_index.insert(f.0.clone(), i);
205                }
206                LLVMValue::Struct(StructValue::new(
207                    reference,
208                    name.clone(),
209                    filed_index.clone(),
210                    field.iter().map(|f| f.1.get_undef()).collect(),
211                ))
212            }
213            LLVMType::String(_) => LLVMValue::String(reference),
214            LLVMType::Function(ret, args, _) => {
215                let undef_args = args
216                    .iter()
217                    .map(|arg| (arg.0.clone(), arg.1.get_undef()))
218                    .collect();
219                LLVMValue::Function(FunctionValue::new(
220                    reference,
221                    "".into(),
222                    Box::new(ret.get_undef()),
223                    undef_args,
224                ))
225            }
226            LLVMType::Ref(element_type, type_reference) => {
227                let reference = unsafe { LLVMGetUndef(*type_reference) };
228                LLVMValue::Reference(ReferenceValue::new(reference, element_type.get_undef()))
229            }
230            LLVMType::Float(_) => LLVMValue::Float(FloatValue::new(reference)),
231            LLVMType::Double(_) => LLVMValue::Double(DoubleValue::new(reference)),
232            t => {
233                println!("{t:?}");
234                todo!()
235            }
236        }
237    }
238    pub fn is_array(&self) -> bool {
239        matches!(self, LLVMType::Array(_, _))
240    }
241    pub fn is_struct(&self) -> bool {
242        matches!(self, LLVMType::Struct(_, _, _))
243    }
244    pub fn i32() -> Self {
245        let t = unsafe { LLVMInt32Type() };
246        LLVMType::Int32(t)
247    }
248    pub fn is_function(&self) -> bool {
249        matches!(self, LLVMType::Function(_, _, _))
250    }
251    pub fn is_float(&self) -> bool {
252        matches!(self, LLVMType::Float(_))
253    }
254    pub fn is_double(&self) -> bool {
255        matches!(self, LLVMType::Double(_))
256    }
257    pub fn is_pointer(&self) -> bool {
258        matches!(self, LLVMType::Pointer(_, _))
259    }
260    pub fn is_i32(&self) -> bool {
261        matches!(self, LLVMType::Int32(_))
262    }
263    pub fn size(&self) -> usize {
264        match self {
265            LLVMType::Int1(_) => 1,
266            LLVMType::Int8(_) => 1,
267            LLVMType::Int16(_) => 2,
268            LLVMType::Int32(_) => 4,
269            LLVMType::Int64(_) => 8,
270            LLVMType::Float(_) => 4,
271            LLVMType::Double(_) => 8,
272            LLVMType::Pointer(_, _) => 8, // 指针大小为8字节(64位系统)
273            LLVMType::Ref(_, _) => 8,     // 引用大小为8字节(64位系统)
274            LLVMType::Struct(_, fields, _) => {
275                // 结构体大小为所有字段大小之和
276                fields.iter().map(|f| f.1.size()).sum()
277            }
278            LLVMType::Array(element_type, _) => {
279                // 数组大小为元素大小
280                element_type.size()
281            }
282            LLVMType::Function(_, _, _) => 8, // 函数指针大小为8字节
283            LLVMType::Unit(_) => 0,           // Unit类型大小为0
284            LLVMType::String(_) => 8,         // 字符串指针大小为8字节
285        }
286    }
287}
288
289impl Display for LLVMType {
290    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
291        let str = unsafe {
292            let c = LLVMPrintTypeToString(self.as_llvm_type_ref());
293            let c_str = CStr::from_ptr(c);
294            let str_slice = c_str
295                .to_str()
296                .expect("Failed to convert C string to Rust string slice");
297            str_slice
298        };
299        write!(f, "{str}")
300    }
301}