1use crate::ast::r#type::Type;
2use crate::llvm::global::Global;
3use crate::llvm::value::bool::BoolValue;
4use crate::llvm::value::double::DoubleValue;
5use crate::llvm::value::float::FloatValue;
6use crate::llvm::value::fucntion::FunctionValue;
7use crate::llvm::value::int::{Int16Value, Int32Value, Int64Value, Int8Value};
8use crate::llvm::value::pointer::PointerValue;
9use crate::llvm::value::pstruct::StructValue;
10use crate::llvm::value::reference::ReferenceValue;
11use crate::llvm::value::LLVMValue;
12use llvm_sys::core::{
13 LLVMGetElementType, LLVMGetIntTypeWidth, LLVMGetTypeKind, LLVMGetUndef, LLVMInt32Type,
14 LLVMIsFunctionVarArg, LLVMPrintTypeToString, LLVMTypeOf,
15};
16use llvm_sys::prelude::{LLVMTypeRef, LLVMValueRef};
17use llvm_sys::LLVMTypeKind;
18use std::collections::HashMap;
19use std::ffi::CStr;
20use std::fmt::{Display, Formatter};
21
22#[derive(Clone, Debug, PartialEq)]
23pub enum LLVMType {
24 Int1(LLVMTypeRef),
25 Int8(LLVMTypeRef),
26 Int16(LLVMTypeRef),
27 Int32(LLVMTypeRef),
28 Int64(LLVMTypeRef),
29 Float(LLVMTypeRef),
30 Double(LLVMTypeRef),
31 Struct(String, Vec<(String, LLVMType)>, LLVMTypeRef),
32 Array(Box<LLVMType>, LLVMTypeRef),
33 Function(Box<LLVMType>, Vec<(String, LLVMType)>, LLVMTypeRef),
34 Pointer(Box<LLVMType>, LLVMTypeRef),
35 Ref(Box<LLVMType>, LLVMTypeRef),
36 Inject(Box<LLVMType>),
37 String(LLVMTypeRef),
38 Unit(LLVMTypeRef),
39}
40
41impl From<LLVMValueRef> for LLVMType {
42 #[allow(clippy::not_unsafe_ptr_arg_deref)]
43 fn from(value: LLVMValueRef) -> Self {
44 let ty = unsafe { LLVMTypeOf(value) };
45 let type_kind = unsafe { LLVMGetTypeKind(ty) };
46 match type_kind {
47 LLVMTypeKind::LLVMIntegerTypeKind => {
48 let width = unsafe { LLVMGetIntTypeWidth(ty) };
49 let width = width as i8;
50 match width {
51 8 => LLVMType::Int8(ty),
52 32 => LLVMType::Int32(ty),
53 _ => {
54 todo!()
55 }
56 }
57 }
58 _ => {
59 todo!()
60 }
61 }
62 }
63}
64
65impl From<LLVMTypeRef> for LLVMType {
66 #[allow(clippy::not_unsafe_ptr_arg_deref)]
67 fn from(ty: LLVMTypeRef) -> Self {
68 let type_kind = unsafe { LLVMGetTypeKind(ty) };
69 match type_kind {
70 LLVMTypeKind::LLVMIntegerTypeKind => {
71 let width = unsafe { LLVMGetIntTypeWidth(ty) };
72 let width = width as i8;
73 match width {
74 8 => LLVMType::Int8(ty),
75 32 => LLVMType::Int32(ty),
76 64 => LLVMType::Int64(ty),
77 _ => {
78 todo!()
79 }
80 }
81 }
82 LLVMTypeKind::LLVMArrayTypeKind => {
83 let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
84 LLVMType::Array(Box::new(element_ty), ty)
85 }
86 LLVMTypeKind::LLVMDoubleTypeKind => LLVMType::Double(ty),
87 t => {
88 println!("{t:?}");
89 todo!()
90 }
91 }
92 }
93}
94
95impl LLVMType {
96 pub fn as_llvm_type_ref(&self) -> LLVMTypeRef {
97 match self {
98 LLVMType::Int1(i) => *i,
99 LLVMType::Int8(i) => *i,
100 LLVMType::Int16(i) => *i,
101 LLVMType::Int32(i) => *i,
102 LLVMType::Int64(i) => *i,
103 LLVMType::Float(i) => *i,
104 LLVMType::Double(i) => *i,
105 LLVMType::Array(_, i) => *i,
106 LLVMType::Function(_, _, i) => *i,
107 LLVMType::Pointer(_, i) => *i,
108 LLVMType::Ref(_, i) => *i,
109 LLVMType::Unit(i) => *i,
110 LLVMType::Struct(_, _, i) => *i,
111 LLVMType::String(i) => *i,
112 LLVMType::Inject(i) => i.as_llvm_type_ref(),
113 }
114 }
115 pub fn get_element_type(&self) -> LLVMType {
116 match self {
117 LLVMType::Array(e_t, _) => *e_t.clone(),
118 LLVMType::Pointer(e, _) => *e.clone(),
119 t => {
120 println!("{t:?}");
121 todo!()
122 }
123 }
124 }
125 pub fn get_type(&self) -> Type {
126 match self {
127 LLVMType::Int1(_) => Type::Bool,
128 LLVMType::Int8(_) => Type::Int8,
129 LLVMType::Int16(_) => Type::Int16,
130 LLVMType::Int32(_) => Type::Int32,
131 LLVMType::Int64(_) => Type::Int64,
132 LLVMType::Float(_) => Type::Float,
133 LLVMType::Double(_) => Type::Double,
134 LLVMType::String(_) => Type::String,
135 LLVMType::Unit(_) => Type::Unit,
136 LLVMType::Pointer(element_type, _) => Type::Pointer(Box::new(element_type.get_type())),
137 LLVMType::Ref(element_type, _) => Type::Ref(Box::new(element_type.get_type())),
138 LLVMType::Array(element_type, _) => Type::Array(Box::new(element_type.get_type())),
139 LLVMType::Struct(name, fields, _) => {
140 let mut field_types = vec![];
141 for (field_name, field_type) in fields {
142 field_types.push((field_name.clone(), field_type.get_type()));
143 }
144 Type::Struct(Some(name.clone()), field_types)
145 }
146 LLVMType::Function(return_type, args, _) => {
147 let mut arg_types = vec![];
148 for (arg_name, arg_type) in args {
149 arg_types.push((arg_name.clone(), arg_type.get_type()));
150 }
151 Type::Function(Box::new(return_type.get_type()), arg_types, false)
152 }
153 LLVMType::Inject(i) => Type::Inject(Box::new(i.get_type())),
154 }
155 }
156 pub fn get_function_param_type(&self, index: usize) -> LLVMType {
157 match self {
158 LLVMType::Function(_, v, f) => {
159 let r = unsafe { LLVMIsFunctionVarArg(*f) };
160 if r == 1 {
161 return LLVMType::Unit(Global::unit_type().as_llvm_type_ref());
162 }
163 println!("{}", index);
164 v.get(index).cloned().unwrap().1
165 }
166 _ => panic!("Not a function"),
167 }
168 }
169 pub fn get_struct_field_type(&self, index: usize) -> LLVMType {
170 match self {
171 LLVMType::Struct(_, v, _) => v.get(index).cloned().unwrap().1,
172 _ => panic!("Not a struct"),
173 }
174 }
175 pub fn get_function_return_type(&self) -> LLVMType {
176 match self {
177 LLVMType::Function(_, _, f) => LLVMType::from(*f),
178 _ => panic!("Not a function"),
179 }
180 }
181 pub fn get_undef(&self) -> LLVMValue {
182 let reference = unsafe { LLVMGetUndef(self.as_llvm_type_ref()) };
183 match self {
184 LLVMType::Int1(_) => LLVMValue::Bool(BoolValue::new(reference)),
185 LLVMType::Pointer(element, _) => match &**element {
186 LLVMType::Function(ret, args, _) => {
187 let undef_args = args
188 .iter()
189 .map(|arg| (arg.0.clone(), arg.1.get_undef()))
190 .collect();
191 LLVMValue::Function(FunctionValue::new(
192 reference,
193 "".into(),
194 Box::new(ret.get_undef()),
195 undef_args,
196 ))
197 }
198 _ => LLVMValue::Pointer(PointerValue::new(reference, element.get_undef())),
199 },
200 LLVMType::Int8(_) => LLVMValue::Int8(Int8Value::new(reference)),
201 LLVMType::Int16(_) => LLVMValue::Int16(Int16Value::new(reference)),
202 LLVMType::Int32(_) => LLVMValue::Int32(Int32Value::new(reference)),
203 LLVMType::Int64(_) => LLVMValue::Int64(Int64Value::new(reference)),
204 LLVMType::Unit(_) => LLVMValue::Unit,
206 LLVMType::Struct(name, field, _) => {
207 let mut filed_index = HashMap::new();
208 for (i, f) in field.iter().enumerate() {
209 filed_index.insert(f.0.clone(), i);
210 }
211 LLVMValue::Struct(StructValue::new(
212 reference,
213 name.clone(),
214 filed_index.clone(),
215 field.iter().map(|f| f.1.get_undef()).collect(),
216 ))
217 }
218 LLVMType::String(_) => LLVMValue::String(reference),
219 LLVMType::Function(ret, args, _) => {
220 let undef_args = args
221 .iter()
222 .map(|arg| (arg.0.clone(), arg.1.get_undef()))
223 .collect();
224 LLVMValue::Function(FunctionValue::new(
225 reference,
226 "".into(),
227 Box::new(ret.get_undef()),
228 undef_args,
229 ))
230 }
231 LLVMType::Ref(element_type, type_reference) => {
232 let reference = unsafe { LLVMGetUndef(*type_reference) };
233 LLVMValue::Reference(ReferenceValue::new(reference, element_type.get_undef()))
234 }
235 LLVMType::Float(_) => LLVMValue::Float(FloatValue::new(reference)),
236 LLVMType::Double(_) => LLVMValue::Double(DoubleValue::new(reference)),
237 LLVMType::Inject(i) => LLVMValue::Inject(Box::new(i.get_undef())),
238 t => {
239 println!("{t:?}");
240 todo!()
241 }
242 }
243 }
244 pub fn is_array(&self) -> bool {
245 matches!(self, LLVMType::Array(_, _))
246 }
247 pub fn is_struct(&self) -> bool {
248 matches!(self, LLVMType::Struct(_, _, _))
249 }
250 pub fn is_special_struct(&self, name: impl AsRef<str>) -> bool {
251 let name = name.as_ref();
252 if let LLVMType::Struct(name0, ..) = self {
253 if name0.eq(name) {
254 return true;
255 }
256 }
257 false
258 }
259 pub fn i32() -> Self {
260 let t = unsafe { LLVMInt32Type() };
261 LLVMType::Int32(t)
262 }
263 pub fn is_function(&self) -> bool {
264 matches!(self, LLVMType::Function(_, _, _))
265 }
266 pub fn is_float(&self) -> bool {
267 matches!(self, LLVMType::Float(_))
268 }
269 pub fn is_double(&self) -> bool {
270 matches!(self, LLVMType::Double(_))
271 }
272 pub fn is_pointer(&self) -> bool {
273 matches!(self, LLVMType::Pointer(_, _))
274 }
275 pub fn is_i32(&self) -> bool {
276 matches!(self, LLVMType::Int32(_))
277 }
278 pub fn size(&self) -> usize {
279 match self {
280 LLVMType::Int1(_) => 1,
281 LLVMType::Int8(_) => 1,
282 LLVMType::Int16(_) => 2,
283 LLVMType::Int32(_) => 4,
284 LLVMType::Int64(_) => 8,
285 LLVMType::Float(_) => 4,
286 LLVMType::Double(_) => 8,
287 LLVMType::Pointer(_, _) => 8, LLVMType::Inject(i) => i.size(), LLVMType::Ref(_, _) => 8, LLVMType::Struct(_, fields, _) => {
291 fields.iter().map(|f| f.1.size()).sum()
293 }
294 LLVMType::Array(element_type, _) => {
295 element_type.size()
297 }
298 LLVMType::Function(_, _, _) => 8, LLVMType::Unit(_) => 0, LLVMType::String(_) => 8, }
302 }
303}
304
305impl Display for LLVMType {
306 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
307 let str = unsafe {
308 let c = LLVMPrintTypeToString(self.as_llvm_type_ref());
309 let c_str = CStr::from_ptr(c);
310 let str_slice = c_str
311 .to_str()
312 .expect("Failed to convert C string to Rust string slice");
313 str_slice
314 };
315 write!(f, "{str}")
316 }
317}