1use super::block::Block;
6use super::register::Register;
7use super::types::MirType;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct Parameter {
13 pub reg: Register,
15
16 pub ty: MirType,
18}
19
20impl Parameter {
21 pub fn new(reg: Register, ty: MirType) -> Self {
22 Self { reg, ty }
23 }
24}
25
26#[derive(Debug, Clone, PartialEq)]
28pub struct Signature {
29 pub name: String,
31
32 pub params: Vec<Parameter>,
34
35 pub ret_ty: Option<MirType>,
37}
38
39impl Signature {
40 pub fn new(name: impl Into<String>) -> Self {
41 Self {
42 name: name.into(),
43 params: Vec::new(),
44 ret_ty: None,
45 }
46 }
47
48 pub fn with_params(mut self, params: Vec<Parameter>) -> Self {
49 self.params = params;
50 self
51 }
52
53 pub fn with_return(mut self, ty: MirType) -> Self {
54 self.ret_ty = Some(ty);
55 self
56 }
57
58 pub fn param_count(&self) -> usize {
59 self.params.len()
60 }
61
62 pub fn is_void(&self) -> bool {
63 self.ret_ty.is_none()
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
69pub struct Function {
70 pub sig: Signature,
72
73 pub blocks: Vec<Block>,
75
76 pub entry: String,
78}
79
80impl Function {
81 pub fn new(sig: Signature) -> Self {
83 Self {
84 sig,
85 blocks: Vec::new(),
86 entry: "entry".to_string(),
87 }
88 }
89
90 pub fn with_entry(mut self, entry: impl Into<String>) -> Self {
92 self.entry = entry.into();
93 self
94 }
95
96 pub fn add_block(&mut self, block: Block) {
98 self.blocks.push(block);
99 }
100
101 pub fn get_block(&self, label: &str) -> Option<&Block> {
103 self.blocks.iter().find(|b| b.label == label)
104 }
105
106 pub fn get_block_mut(&mut self, label: &str) -> Option<&mut Block> {
108 self.blocks.iter_mut().find(|b| b.label == label)
109 }
110
111 pub fn entry_block(&self) -> Option<&Block> {
113 self.get_block(&self.entry)
114 }
115
116 pub fn entry_block_mut(&mut self) -> Option<&mut Block> {
118 let entry = self.entry.clone();
119 self.get_block_mut(&entry)
120 }
121
122 pub fn block_labels(&self) -> Vec<&str> {
124 self.blocks.iter().map(|b| b.label.as_str()).collect()
125 }
126
127 pub fn instruction_count(&self) -> usize {
129 self.blocks.iter().map(|b| b.len()).sum()
130 }
131
132 pub fn validate(&self) -> Result<(), String> {
134 if self.entry_block().is_none() {
136 return Err(format!("Entry block '{}' not found", self.entry));
137 }
138
139 let mut seen_labels = std::collections::HashSet::new();
141 for block in &self.blocks {
142 if !seen_labels.insert(&block.label) {
143 return Err(format!("Duplicate block label: {}", block.label));
144 }
145 }
146
147 for block in &self.blocks {
149 if !block.has_terminator() {
150 return Err(format!("Block '{}' has no terminator", block.label));
151 }
152 }
153
154 for block in &self.blocks {
156 for label in block.successors() {
157 if !seen_labels.contains(&label) {
158 return Err(format!(
159 "Block '{}' references undefined target '{}'",
160 block.label, label
161 ));
162 }
163 }
164 }
165
166 Ok(())
167 }
168}
169
170impl fmt::Display for Function {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 write!(f, "fn {}(", self.sig.name)?;
174 for (i, p) in self.sig.params.iter().enumerate() {
175 if i > 0 {
176 write!(f, ", ")?;
177 }
178 write!(f, "{} {}", p.reg, p.ty)?;
179 }
180 write!(f, ")")?;
181 if let Some(ret) = &self.sig.ret_ty {
182 write!(f, " -> {}", ret)?;
183 }
184 writeln!(f, " {{")?;
185
186 for block in &self.blocks {
188 writeln!(f, "{}", block)?;
189 }
190
191 write!(f, "}}")
192 }
193}
194
195pub struct FunctionBuilder {
197 function: Function,
198 current_block: Option<String>,
199}
200
201impl FunctionBuilder {
202 pub fn new(name: impl Into<String>) -> Self {
204 Self {
205 function: Function::new(Signature::new(name)),
206 current_block: None,
207 }
208 }
209
210 pub fn param(mut self, reg: Register, ty: MirType) -> Self {
212 self.function.sig.params.push(Parameter::new(reg, ty));
213 self
214 }
215
216 pub fn returns(mut self, ty: MirType) -> Self {
218 self.function.sig.ret_ty = Some(ty);
219 self
220 }
221
222 pub fn block(mut self, label: impl Into<String>) -> Self {
224 let label = label.into();
225 self.function.add_block(Block::new(label.clone()));
226 self.current_block = Some(label);
227 self
228 }
229
230 pub fn instr(mut self, instr: super::instruction::Instruction) -> Self {
232 if let Some(ref label) = self.current_block
233 && let Some(block) = self.function.get_block_mut(label)
234 {
235 block.push(instr);
236 }
237 self
238 }
239
240 pub fn build(self) -> Function {
242 self.function
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use crate::instruction::{Instruction, IntBinOp, Operand};
250 use crate::register::VirtualReg;
251 use crate::types::ScalarType;
252
253 #[test]
254 fn test_signature_creation() {
255 let sig = Signature::new("test_func").with_return(MirType::Scalar(ScalarType::I64));
256
257 assert_eq!(sig.name, "test_func");
258 assert!(sig.ret_ty.is_some());
259 assert_eq!(sig.param_count(), 0);
260 }
261
262 #[test]
263 fn test_function_builder() {
264 let func = FunctionBuilder::new("add")
265 .param(
266 Register::Virtual(VirtualReg::gpr(0)),
267 MirType::Scalar(ScalarType::I64),
268 )
269 .param(
270 Register::Virtual(VirtualReg::gpr(1)),
271 MirType::Scalar(ScalarType::I64),
272 )
273 .returns(MirType::Scalar(ScalarType::I64))
274 .block("entry")
275 .instr(Instruction::IntBinary {
276 op: IntBinOp::Add,
277 ty: MirType::Scalar(ScalarType::I64),
278 dst: Register::Virtual(VirtualReg::gpr(2)),
279 lhs: Operand::Register(Register::Virtual(VirtualReg::gpr(0))),
280 rhs: Operand::Register(Register::Virtual(VirtualReg::gpr(1))),
281 })
282 .instr(Instruction::Ret {
283 value: Some(Operand::Register(Register::Virtual(VirtualReg::gpr(2)))),
284 })
285 .build();
286
287 assert_eq!(func.sig.name, "add");
288 assert_eq!(func.sig.param_count(), 2);
289 assert_eq!(func.blocks.len(), 1);
290 assert_eq!(func.instruction_count(), 2);
291 }
292
293 #[test]
294 fn test_function_validation() {
295 let mut func = Function::new(Signature::new("test"));
296
297 assert!(func.validate().is_err());
299
300 let mut entry = Block::new("entry");
302 entry.push(Instruction::Ret { value: None });
303 func.add_block(entry);
304
305 assert!(func.validate().is_ok());
306 }
307}