1use crate::ast::r#type::Type;
2use crate::llvm::global::Global;
3use crate::llvm::value::bool::BoolValue;
4use crate::llvm::value::double::DoubleValue;
5use crate::llvm::value::float::FloatValue;
6use crate::llvm::value::fucntion::FunctionValue;
7use crate::llvm::value::int::{Int16Value, Int32Value, Int64Value, Int8Value};
8use crate::llvm::value::pointer::PointerValue;
9use crate::llvm::value::pstruct::StructValue;
10use crate::llvm::value::reference::ReferenceValue;
11use crate::llvm::value::LLVMValue;
12use llvm_sys::core::{
13 LLVMGetElementType, LLVMGetIntTypeWidth, LLVMGetTypeKind, LLVMGetUndef, LLVMInt32Type,
14 LLVMIsFunctionVarArg, LLVMPrintTypeToString, LLVMTypeOf,
15};
16use llvm_sys::prelude::{LLVMTypeRef, LLVMValueRef};
17use llvm_sys::LLVMTypeKind;
18use std::collections::HashMap;
19use std::ffi::CStr;
20use std::fmt::{Display, Formatter};
21
22#[derive(Clone, Debug, PartialEq)]
23pub enum LLVMType {
24 Int1(LLVMTypeRef),
25 Int8(LLVMTypeRef),
26 Int16(LLVMTypeRef),
27 Int32(LLVMTypeRef),
28 Int64(LLVMTypeRef),
29 Float(LLVMTypeRef),
30 Double(LLVMTypeRef),
31 Struct(String, Vec<(String, LLVMType)>, LLVMTypeRef),
32 Array(Box<LLVMType>, LLVMTypeRef),
33 Function(Box<LLVMType>, Vec<(String, LLVMType)>, LLVMTypeRef),
34 Pointer(Box<LLVMType>, LLVMTypeRef),
35 Ref(Box<LLVMType>, LLVMTypeRef),
36 String(LLVMTypeRef),
37 Unit(LLVMTypeRef),
38}
39
40impl From<LLVMValueRef> for LLVMType {
41 fn from(value: LLVMValueRef) -> Self {
42 let ty = unsafe { LLVMTypeOf(value) };
43 let type_kind = unsafe { LLVMGetTypeKind(ty) };
44 match type_kind {
45 LLVMTypeKind::LLVMIntegerTypeKind => {
46 let width = unsafe { LLVMGetIntTypeWidth(ty) };
47 let width = width as i8;
48 match width {
49 8 => LLVMType::Int8(ty),
50 32 => LLVMType::Int32(ty),
51 _ => {
52 todo!()
53 }
54 }
55 }
56 _ => {
57 todo!()
58 }
59 }
60 }
61}
62
63impl From<LLVMTypeRef> for LLVMType {
64 fn from(ty: LLVMTypeRef) -> Self {
65 let type_kind = unsafe { LLVMGetTypeKind(ty) };
66 match type_kind {
67 LLVMTypeKind::LLVMIntegerTypeKind => {
68 let width = unsafe { LLVMGetIntTypeWidth(ty) };
69 let width = width as i8;
70 match width {
71 8 => LLVMType::Int8(ty),
72 32 => LLVMType::Int32(ty),
73 64 => LLVMType::Int64(ty),
74 _ => {
75 todo!()
76 }
77 }
78 }
79 LLVMTypeKind::LLVMArrayTypeKind => {
80 let element_ty = LLVMType::from(unsafe { LLVMGetElementType(ty) });
81 LLVMType::Array(Box::new(element_ty), ty)
82 }
83 LLVMTypeKind::LLVMDoubleTypeKind => LLVMType::Double(ty),
84 t => {
85 println!("{t:?}");
86 todo!()
87 }
88 }
89 }
90}
91
92impl LLVMType {
93 pub fn as_llvm_type_ref(&self) -> LLVMTypeRef {
94 match self {
95 LLVMType::Int1(i) => *i,
96 LLVMType::Int8(i) => *i,
97 LLVMType::Int16(i) => *i,
98 LLVMType::Int32(i) => *i,
99 LLVMType::Int64(i) => *i,
100 LLVMType::Float(i) => *i,
101 LLVMType::Double(i) => *i,
102 LLVMType::Array(_, i) => *i,
103 LLVMType::Function(_, _, i) => *i,
104 LLVMType::Pointer(_, i) => *i,
105 LLVMType::Ref(_, i) => *i,
106 LLVMType::Unit(i) => *i,
107 LLVMType::Struct(_, _, i) => *i,
108 LLVMType::String(i) => *i,
109 }
110 }
111 pub fn get_element_type(&self) -> LLVMType {
112 match self {
113 LLVMType::Array(e_t, _) => *e_t.clone(),
114 LLVMType::Pointer(e, _) => *e.clone(),
115 t => {
116 println!("{t:?}");
117 todo!()
118 }
119 }
120 }
121 pub fn get_type(&self) -> Type {
122 match self {
123 LLVMType::Int1(_) => Type::Bool,
124 LLVMType::Int8(_) => Type::Int8,
125 LLVMType::Int16(_) => Type::Int16,
126 LLVMType::Int32(_) => Type::Int32,
127 LLVMType::Int64(_) => Type::Int64,
128 LLVMType::Float(_) => Type::Float,
129 LLVMType::Double(_) => Type::Double,
130 LLVMType::String(_) => Type::String,
131 LLVMType::Unit(_) => Type::Unit,
132 LLVMType::Pointer(element_type, _) => Type::Pointer(Box::new(element_type.get_type())),
133 LLVMType::Ref(element_type, _) => Type::Ref(Box::new(element_type.get_type())),
134 LLVMType::Array(element_type, _) => Type::Array(Box::new(element_type.get_type())),
135 LLVMType::Struct(name, fields, _) => {
136 let mut field_types = vec![];
137 for (field_name, field_type) in fields {
138 field_types.push((field_name.clone(), field_type.get_type()));
139 }
140 Type::Struct(Some(name.clone()), field_types)
141 }
142 LLVMType::Function(return_type, args, _) => {
143 let mut arg_types = vec![];
144 for (arg_name, arg_type) in args {
145 arg_types.push((arg_name.clone(), arg_type.get_type()));
146 }
147 Type::Function(Box::new(return_type.get_type()), arg_types, false)
148 }
149 }
150 }
151 pub fn get_function_param_type(&self, index: usize) -> LLVMType {
152 match self {
153 LLVMType::Function(_, v, f) => {
154 let r = unsafe { LLVMIsFunctionVarArg(*f) };
155 if r == 1 {
156 return LLVMType::Unit(Global::unit_type().as_llvm_type_ref());
157 }
158 println!("{}", index);
159 v.get(index).cloned().unwrap().1
160 }
161 _ => panic!("Not a function"),
162 }
163 }
164 pub fn get_struct_field_type(&self, index: usize) -> LLVMType {
165 match self {
166 LLVMType::Struct(_, v, _) => v.get(index).cloned().unwrap().1,
167 _ => panic!("Not a struct"),
168 }
169 }
170 pub fn get_function_return_type(&self) -> LLVMType {
171 match self {
172 LLVMType::Function(_, _, f) => LLVMType::from(*f),
173 _ => panic!("Not a function"),
174 }
175 }
176 pub fn get_undef(&self) -> LLVMValue {
177 let reference = unsafe { LLVMGetUndef(self.as_llvm_type_ref()) };
178 match self {
179 LLVMType::Int1(_) => LLVMValue::Bool(BoolValue::new(reference)),
180 LLVMType::Pointer(element, _) => match &**element {
181 LLVMType::Function(ret, args, _) => {
182 let undef_args = args
183 .iter()
184 .map(|arg| (arg.0.clone(), arg.1.get_undef()))
185 .collect();
186 LLVMValue::Function(FunctionValue::new(
187 reference,
188 "".into(),
189 Box::new(ret.get_undef()),
190 undef_args,
191 ))
192 }
193 _ => LLVMValue::Pointer(PointerValue::new(reference, element.get_undef())),
194 },
195 LLVMType::Int8(_) => LLVMValue::Int8(Int8Value::new(reference)),
196 LLVMType::Int16(_) => LLVMValue::Int16(Int16Value::new(reference)),
197 LLVMType::Int32(_) => LLVMValue::Int32(Int32Value::new(reference)),
198 LLVMType::Int64(_) => LLVMValue::Int64(Int64Value::new(reference)),
199 LLVMType::Unit(_) => LLVMValue::Unit,
201 LLVMType::Struct(name, field, _) => {
202 let mut filed_index = HashMap::new();
203 for (i, f) in field.iter().enumerate() {
204 filed_index.insert(f.0.clone(), i);
205 }
206 LLVMValue::Struct(StructValue::new(
207 reference,
208 name.clone(),
209 filed_index.clone(),
210 field.iter().map(|f| f.1.get_undef()).collect(),
211 ))
212 }
213 LLVMType::String(_) => LLVMValue::String(reference),
214 LLVMType::Function(ret, args, _) => {
215 let undef_args = args
216 .iter()
217 .map(|arg| (arg.0.clone(), arg.1.get_undef()))
218 .collect();
219 LLVMValue::Function(FunctionValue::new(
220 reference,
221 "".into(),
222 Box::new(ret.get_undef()),
223 undef_args,
224 ))
225 }
226 LLVMType::Ref(element_type, type_reference) => {
227 let reference = unsafe { LLVMGetUndef(*type_reference) };
228 LLVMValue::Reference(ReferenceValue::new(reference, element_type.get_undef()))
229 }
230 LLVMType::Float(_) => LLVMValue::Float(FloatValue::new(reference)),
231 LLVMType::Double(_) => LLVMValue::Double(DoubleValue::new(reference)),
232 t => {
233 println!("{t:?}");
234 todo!()
235 }
236 }
237 }
238 pub fn is_array(&self) -> bool {
239 matches!(self, LLVMType::Array(_, _))
240 }
241 pub fn is_struct(&self) -> bool {
242 matches!(self, LLVMType::Struct(_, _, _))
243 }
244 pub fn i32() -> Self {
245 let t = unsafe { LLVMInt32Type() };
246 LLVMType::Int32(t)
247 }
248 pub fn is_function(&self) -> bool {
249 matches!(self, LLVMType::Function(_, _, _))
250 }
251 pub fn is_float(&self) -> bool {
252 matches!(self, LLVMType::Float(_))
253 }
254 pub fn is_double(&self) -> bool {
255 matches!(self, LLVMType::Double(_))
256 }
257 pub fn is_pointer(&self) -> bool {
258 matches!(self, LLVMType::Pointer(_, _))
259 }
260 pub fn is_i32(&self) -> bool {
261 matches!(self, LLVMType::Int32(_))
262 }
263 pub fn size(&self) -> usize {
264 match self {
265 LLVMType::Int1(_) => 1,
266 LLVMType::Int8(_) => 1,
267 LLVMType::Int16(_) => 2,
268 LLVMType::Int32(_) => 4,
269 LLVMType::Int64(_) => 8,
270 LLVMType::Float(_) => 4,
271 LLVMType::Double(_) => 8,
272 LLVMType::Pointer(_, _) => 8, LLVMType::Ref(_, _) => 8, LLVMType::Struct(_, fields, _) => {
275 fields.iter().map(|f| f.1.size()).sum()
277 }
278 LLVMType::Array(element_type, _) => {
279 element_type.size()
281 }
282 LLVMType::Function(_, _, _) => 8, LLVMType::Unit(_) => 0, LLVMType::String(_) => 8, }
286 }
287}
288
289impl Display for LLVMType {
290 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
291 let str = unsafe {
292 let c = LLVMPrintTypeToString(self.as_llvm_type_ref());
293 let c_str = CStr::from_ptr(c);
294 let str_slice = c_str
295 .to_str()
296 .expect("Failed to convert C string to Rust string slice");
297 str_slice
298 };
299 write!(f, "{str}")
300 }
301}