1use crate::ast::r#type::Type;
2use crate::llvm::global::Global;
3use crate::llvm::value::bool::BoolValue;
4use crate::llvm::value::double::DoubleValue;
5use crate::llvm::value::float::FloatValue;
6use crate::llvm::value::fucntion::FunctionValue;
7use crate::llvm::value::int::{Int16Value, Int32Value, Int64Value, Int8Value};
8use crate::llvm::value::pointer::PointerValue;
9use crate::llvm::value::pstruct::StructValue;
10use crate::llvm::value::reference::ReferenceValue;
11use crate::llvm::value::LLVMValue;
12use llvm_sys::core::{
13 LLVMGetElementType, LLVMGetIntTypeWidth, LLVMGetTypeKind, LLVMGetUndef, LLVMInt32Type,
14 LLVMIsFunctionVarArg, LLVMPrintTypeToString, LLVMTypeOf,
15};
16use llvm_sys::prelude::{LLVMTypeRef, LLVMValueRef};
17use llvm_sys::LLVMTypeKind;
18use std::collections::HashMap;
19use std::ffi::CStr;
20use std::fmt::{Display, Formatter};
21
22#[derive(Clone, Debug, PartialEq)]
23pub enum LLVMType {
24 Int1(LLVMTypeRef),
25 Int8(LLVMTypeRef),
26 Int16(LLVMTypeRef),
27 Int32(LLVMTypeRef),
28 Int64(LLVMTypeRef),
29 Float(LLVMTypeRef),
30 Double(LLVMTypeRef),
31 Struct(String, Vec<(String, LLVMType)>, LLVMTypeRef),
32 Array(Box<LLVMType>, LLVMTypeRef),
33 Function(Box<LLVMType>, Vec<(String, LLVMType)>, LLVMTypeRef),
34 Pointer(Box<LLVMType>, LLVMTypeRef),
35 Ref(Box<LLVMType>, LLVMTypeRef),
36 Inject(Box<LLVMType>),
37 String(LLVMTypeRef),
38 Unit(LLVMTypeRef),
39}
40
41impl From<LLVMValueRef> for LLVMType {
42 #[allow(clippy::not_unsafe_ptr_arg_deref)]
43 fn from(value: LLVMValueRef) -> Self {
44 let ty = unsafe { LLVMTypeOf(value) };
45 let type_kind = unsafe { LLVMGetTypeKind(ty) };
46 match type_kind {
47 LLVMTypeKind::LLVMIntegerTypeKind => {
48 let width = unsafe { LLVMGetIntTypeWidth(ty) };
49 let width = width as i8;
50 match width {
51 8 => LLVMType::Int8(ty),
52 32 => LLVMType::Int32(ty),
53 _ => {
54 todo!()
55 }
56 }
57 }
58 _ => {
59 todo!()
60 }
61 }
62 }
63}
64
65impl From<LLVMTypeRef> for LLVMType {
66 #[allow(clippy::not_unsafe_ptr_arg_deref)]
67 fn from(ty: LLVMTypeRef) -> Self {
68 let type_kind = unsafe { LLVMGetTypeKind(ty) };
69 match type_kind {
70 LLVMTypeKind::LLVMIntegerTypeKind => {
71 let width = unsafe { LLVMGetIntTypeWidth(ty) };
72 let width = width as i8;
73 match width {
74 1 => LLVMType::Int1(ty),
75 8 => LLVMType::Int8(ty),
76 16 => LLVMType::Int16(ty),
77 32 => LLVMType::Int32(ty),
78 64 => LLVMType::Int64(ty),
79 _ => {
80 todo!()
81 }
82 }
83 }
84 LLVMTypeKind::LLVMArrayTypeKind => {
85 let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
86 LLVMType::Array(Box::new(element_ty), ty)
87 }
88 LLVMTypeKind::LLVMPointerTypeKind => {
89 let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
90 LLVMType::Pointer(Box::new(element_ty), ty)
91 }
92 LLVMTypeKind::LLVMVoidTypeKind => LLVMType::Unit(ty),
93 LLVMTypeKind::LLVMFloatTypeKind => LLVMType::Float(ty),
94 LLVMTypeKind::LLVMDoubleTypeKind => LLVMType::Double(ty),
95 t => {
96 println!("{t:?}");
97 todo!()
98 }
99 }
100 }
101}
102
103impl LLVMType {
104 pub fn as_llvm_type_ref(&self) -> LLVMTypeRef {
105 match self {
106 LLVMType::Int1(i) => *i,
107 LLVMType::Int8(i) => *i,
108 LLVMType::Int16(i) => *i,
109 LLVMType::Int32(i) => *i,
110 LLVMType::Int64(i) => *i,
111 LLVMType::Float(i) => *i,
112 LLVMType::Double(i) => *i,
113 LLVMType::Array(_, i) => *i,
114 LLVMType::Function(_, _, i) => *i,
115 LLVMType::Pointer(_, i) => *i,
116 LLVMType::Ref(_, i) => *i,
117 LLVMType::Unit(i) => *i,
118 LLVMType::Struct(_, _, i) => *i,
119 LLVMType::String(i) => *i,
120 LLVMType::Inject(i) => i.as_llvm_type_ref(),
121 }
122 }
123 pub fn get_element_type(&self) -> LLVMType {
124 match self {
125 LLVMType::Array(e_t, _) => *e_t.clone(),
126 LLVMType::Pointer(e, _) => *e.clone(),
127 t => {
128 println!("{t:?}");
129 todo!()
130 }
131 }
132 }
133 pub fn get_type(&self) -> Type {
134 match self {
135 LLVMType::Int1(_) => Type::Bool,
136 LLVMType::Int8(_) => Type::Int8,
137 LLVMType::Int16(_) => Type::Int16,
138 LLVMType::Int32(_) => Type::Int32,
139 LLVMType::Int64(_) => Type::Int64,
140 LLVMType::Float(_) => Type::Float,
141 LLVMType::Double(_) => Type::Double,
142 LLVMType::String(_) => Type::String,
143 LLVMType::Unit(_) => Type::Unit,
144 LLVMType::Pointer(element_type, _) => Type::Pointer(Box::new(element_type.get_type())),
145 LLVMType::Ref(element_type, _) => Type::Ref(Box::new(element_type.get_type())),
146 LLVMType::Array(element_type, _) => Type::Array(Box::new(element_type.get_type())),
147 LLVMType::Struct(name, fields, _) => {
148 let mut field_types = vec![];
149 for (field_name, field_type) in fields {
150 field_types.push((field_name.clone(), field_type.get_type()));
151 }
152 Type::Struct(Some(name.clone()), field_types)
153 }
154 LLVMType::Function(return_type, args, _) => {
155 let mut arg_types = vec![];
156 for (arg_name, arg_type) in args {
157 arg_types.push((arg_name.clone(), arg_type.get_type()));
158 }
159 Type::Function(Box::new(return_type.get_type()), arg_types, false)
160 }
161 LLVMType::Inject(i) => Type::Inject(Box::new(i.get_type())),
162 }
163 }
164 pub fn get_function_param_type(&self, index: usize) -> LLVMType {
165 match self {
166 LLVMType::Function(_, v, f) => {
167 let r = unsafe { LLVMIsFunctionVarArg(*f) };
168 if r == 1 {
169 return LLVMType::Unit(Global::unit_type().as_llvm_type_ref());
170 }
171 println!("{}", index);
172 v.get(index).cloned().unwrap().1
173 }
174 _ => panic!("Not a function"),
175 }
176 }
177 pub fn get_struct_field_type(&self, index: usize) -> LLVMType {
178 match self {
179 LLVMType::Struct(_, v, _) => v.get(index).cloned().unwrap().1,
180 _ => panic!("Not a struct"),
181 }
182 }
183 pub fn get_function_return_type(&self) -> LLVMType {
184 match self {
185 LLVMType::Function(_, _, f) => LLVMType::from(*f),
186 _ => panic!("Not a function"),
187 }
188 }
189 pub fn get_undef(&self) -> LLVMValue {
190 let reference = unsafe { LLVMGetUndef(self.as_llvm_type_ref()) };
191 match self {
192 LLVMType::Int1(_) => LLVMValue::Bool(BoolValue::new(reference)),
193 LLVMType::Pointer(element, _) => match &**element {
194 LLVMType::Function(ret, args, _) => {
195 let undef_args = args
196 .iter()
197 .map(|arg| (arg.0.clone(), arg.1.get_undef()))
198 .collect();
199 LLVMValue::Function(FunctionValue::new(
200 reference,
201 "".into(),
202 Box::new(ret.get_undef()),
203 undef_args,
204 ))
205 }
206 _ => LLVMValue::Pointer(PointerValue::new(reference, element.get_undef())),
207 },
208 LLVMType::Int8(_) => LLVMValue::Int8(Int8Value::new(reference)),
209 LLVMType::Int16(_) => LLVMValue::Int16(Int16Value::new(reference)),
210 LLVMType::Int32(_) => LLVMValue::Int32(Int32Value::new(reference)),
211 LLVMType::Int64(_) => LLVMValue::Int64(Int64Value::new(reference)),
212 LLVMType::Unit(_) => LLVMValue::Unit,
214 LLVMType::Struct(name, field, _) => {
215 let mut filed_index = HashMap::new();
216 for (i, f) in field.iter().enumerate() {
217 filed_index.insert(f.0.clone(), i);
218 }
219 LLVMValue::Struct(StructValue::new(
220 reference,
221 name.clone(),
222 filed_index.clone(),
223 field.iter().map(|f| f.1.get_undef()).collect(),
224 ))
225 }
226 LLVMType::String(_) => LLVMValue::String(reference),
227 LLVMType::Function(ret, args, _) => {
228 let undef_args = args
229 .iter()
230 .map(|arg| (arg.0.clone(), arg.1.get_undef()))
231 .collect();
232 LLVMValue::Function(FunctionValue::new(
233 reference,
234 "".into(),
235 Box::new(ret.get_undef()),
236 undef_args,
237 ))
238 }
239 LLVMType::Ref(element_type, type_reference) => {
240 let reference = unsafe { LLVMGetUndef(*type_reference) };
241 LLVMValue::Reference(ReferenceValue::new(reference, element_type.get_undef()))
242 }
243 LLVMType::Float(_) => LLVMValue::Float(FloatValue::new(reference)),
244 LLVMType::Double(_) => LLVMValue::Double(DoubleValue::new(reference)),
245 LLVMType::Inject(i) => LLVMValue::Inject(Box::new(i.get_undef())),
246 t => {
247 println!("{t:?}");
248 todo!()
249 }
250 }
251 }
252 pub fn is_array(&self) -> bool {
253 matches!(self, LLVMType::Array(_, _))
254 }
255 pub fn is_struct(&self) -> bool {
256 matches!(self, LLVMType::Struct(_, _, _))
257 }
258 pub fn is_special_struct(&self, name: impl AsRef<str>) -> bool {
259 let name = name.as_ref();
260 if let LLVMType::Struct(name0, ..) = self {
261 if name0.eq(name) {
262 return true;
263 }
264 }
265 false
266 }
267 pub fn i32() -> Self {
268 let t = unsafe { LLVMInt32Type() };
269 LLVMType::Int32(t)
270 }
271 pub fn is_function(&self) -> bool {
272 matches!(self, LLVMType::Function(_, _, _))
273 }
274 pub fn is_float(&self) -> bool {
275 matches!(self, LLVMType::Float(_))
276 }
277 pub fn is_double(&self) -> bool {
278 matches!(self, LLVMType::Double(_))
279 }
280 pub fn is_pointer(&self) -> bool {
281 matches!(self, LLVMType::Pointer(_, _))
282 }
283 pub fn is_i32(&self) -> bool {
284 matches!(self, LLVMType::Int32(_))
285 }
286 pub fn size(&self) -> usize {
287 match self {
288 LLVMType::Int1(_) => 1,
289 LLVMType::Int8(_) => 1,
290 LLVMType::Int16(_) => 2,
291 LLVMType::Int32(_) => 4,
292 LLVMType::Int64(_) => 8,
293 LLVMType::Float(_) => 4,
294 LLVMType::Double(_) => 8,
295 LLVMType::Pointer(_, _) => 8, LLVMType::Inject(i) => i.size(), LLVMType::Ref(_, _) => 8, LLVMType::Struct(_, fields, _) => {
299 fields.iter().map(|f| f.1.size()).sum()
301 }
302 LLVMType::Array(element_type, _) => {
303 element_type.size()
305 }
306 LLVMType::Function(_, _, _) => 8, LLVMType::Unit(_) => 0, LLVMType::String(_) => 8, }
310 }
311}
312
313impl Display for LLVMType {
314 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
315 let str = unsafe {
316 let c = LLVMPrintTypeToString(self.as_llvm_type_ref());
317 let c_str = CStr::from_ptr(c);
318 let str_slice = c_str
319 .to_str()
320 .expect("Failed to convert C string to Rust string slice");
321 str_slice
322 };
323 write!(f, "{str}")
324 }
325}