Skip to main content

lamina_mir/
function.rs

1//! Function representation in LUMIR.
2//!
3//! This module defines the structures for representing functions in LUMIR,
4//! including function signatures, parameters, and basic blocks.
5use super::block::Block;
6use super::register::Register;
7use super::types::MirType;
8use std::fmt;
9
10/// Function parameter
11#[derive(Debug, Clone, PartialEq)]
12pub struct Parameter {
13    /// Virtual register assigned to this parameter
14    pub reg: Register,
15
16    /// Type of the parameter
17    pub ty: MirType,
18}
19
20impl Parameter {
21    pub fn new(reg: Register, ty: MirType) -> Self {
22        Self { reg, ty }
23    }
24}
25
26/// Function signature
27#[derive(Debug, Clone, PartialEq)]
28pub struct Signature {
29    /// Function name
30    pub name: String,
31
32    /// Parameters (in order)
33    pub params: Vec<Parameter>,
34
35    /// Return type (None for void)
36    pub ret_ty: Option<MirType>,
37}
38
39impl Signature {
40    pub fn new(name: impl Into<String>) -> Self {
41        Self {
42            name: name.into(),
43            params: Vec::new(),
44            ret_ty: None,
45        }
46    }
47
48    pub fn with_params(mut self, params: Vec<Parameter>) -> Self {
49        self.params = params;
50        self
51    }
52
53    pub fn with_return(mut self, ty: MirType) -> Self {
54        self.ret_ty = Some(ty);
55        self
56    }
57
58    pub fn param_count(&self) -> usize {
59        self.params.len()
60    }
61
62    pub fn is_void(&self) -> bool {
63        self.ret_ty.is_none()
64    }
65}
66
67/// LUMIR function
68#[derive(Debug, Clone, PartialEq)]
69pub struct Function {
70    /// Function signature
71    pub sig: Signature,
72
73    /// Basic blocks (ordered map for deterministic iteration)
74    pub blocks: Vec<Block>,
75
76    /// Entry block label
77    pub entry: String,
78}
79
80impl Function {
81    /// Create a new function with the given signature
82    pub fn new(sig: Signature) -> Self {
83        Self {
84            sig,
85            blocks: Vec::new(),
86            entry: "entry".to_string(),
87        }
88    }
89
90    /// Set the entry block label
91    pub fn with_entry(mut self, entry: impl Into<String>) -> Self {
92        self.entry = entry.into();
93        self
94    }
95
96    /// Add a basic block to this function
97    pub fn add_block(&mut self, block: Block) {
98        self.blocks.push(block);
99    }
100
101    /// Get a basic block by label
102    pub fn get_block(&self, label: &str) -> Option<&Block> {
103        self.blocks.iter().find(|b| b.label == label)
104    }
105
106    /// Get a mutable reference to a basic block by label
107    pub fn get_block_mut(&mut self, label: &str) -> Option<&mut Block> {
108        self.blocks.iter_mut().find(|b| b.label == label)
109    }
110
111    /// Get the entry block
112    pub fn entry_block(&self) -> Option<&Block> {
113        self.get_block(&self.entry)
114    }
115
116    /// Get a mutable reference to the entry block
117    pub fn entry_block_mut(&mut self) -> Option<&mut Block> {
118        let entry = self.entry.clone();
119        self.get_block_mut(&entry)
120    }
121
122    /// Get all block labels
123    pub fn block_labels(&self) -> Vec<&str> {
124        self.blocks.iter().map(|b| b.label.as_str()).collect()
125    }
126
127    /// Total number of instructions across all blocks
128    pub fn instruction_count(&self) -> usize {
129        self.blocks.iter().map(|b| b.len()).sum()
130    }
131
132    /// Check if this function is well-formed
133    pub fn validate(&self) -> Result<(), String> {
134        // Check that entry block exists
135        if self.entry_block().is_none() {
136            return Err(format!("Entry block '{}' not found", self.entry));
137        }
138
139        // Check that all blocks have unique labels
140        let mut seen_labels = std::collections::HashSet::new();
141        for block in &self.blocks {
142            if !seen_labels.insert(&block.label) {
143                return Err(format!("Duplicate block label: {}", block.label));
144            }
145        }
146
147        // Check that all blocks have terminators
148        for block in &self.blocks {
149            if !block.has_terminator() {
150                return Err(format!("Block '{}' has no terminator", block.label));
151            }
152        }
153
154        // Check that all branch/jump targets reference existing blocks
155        for block in &self.blocks {
156            for label in block.successors() {
157                if !seen_labels.contains(&label) {
158                    return Err(format!(
159                        "Block '{}' references undefined target '{}'",
160                        block.label, label
161                    ));
162                }
163            }
164        }
165
166        Ok(())
167    }
168}
169
170impl fmt::Display for Function {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        // Signature header
173        write!(f, "fn {}(", self.sig.name)?;
174        for (i, p) in self.sig.params.iter().enumerate() {
175            if i > 0 {
176                write!(f, ", ")?;
177            }
178            write!(f, "{} {}", p.reg, p.ty)?;
179        }
180        write!(f, ")")?;
181        if let Some(ret) = &self.sig.ret_ty {
182            write!(f, " -> {}", ret)?;
183        }
184        writeln!(f, " {{")?;
185
186        // Emit blocks in order
187        for block in &self.blocks {
188            writeln!(f, "{}", block)?;
189        }
190
191        write!(f, "}}")
192    }
193}
194
195/// Builder for constructing LUMIR functions
196pub struct FunctionBuilder {
197    function: Function,
198    current_block: Option<String>,
199}
200
201impl FunctionBuilder {
202    /// Create a new function builder
203    pub fn new(name: impl Into<String>) -> Self {
204        Self {
205            function: Function::new(Signature::new(name)),
206            current_block: None,
207        }
208    }
209
210    /// Add a parameter to the function
211    pub fn param(mut self, reg: Register, ty: MirType) -> Self {
212        self.function.sig.params.push(Parameter::new(reg, ty));
213        self
214    }
215
216    /// Set the return type
217    pub fn returns(mut self, ty: MirType) -> Self {
218        self.function.sig.ret_ty = Some(ty);
219        self
220    }
221
222    /// Create a new basic block and make it the current block
223    pub fn block(mut self, label: impl Into<String>) -> Self {
224        let label = label.into();
225        self.function.add_block(Block::new(label.clone()));
226        self.current_block = Some(label);
227        self
228    }
229
230    /// Add an instruction to the current block
231    pub fn instr(mut self, instr: super::instruction::Instruction) -> Self {
232        if let Some(ref label) = self.current_block
233            && let Some(block) = self.function.get_block_mut(label)
234        {
235            block.push(instr);
236        }
237        self
238    }
239
240    /// Build the function
241    pub fn build(self) -> Function {
242        self.function
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::instruction::{Instruction, IntBinOp, Operand};
250    use crate::register::VirtualReg;
251    use crate::types::ScalarType;
252
253    #[test]
254    fn test_signature_creation() {
255        let sig = Signature::new("test_func").with_return(MirType::Scalar(ScalarType::I64));
256
257        assert_eq!(sig.name, "test_func");
258        assert!(sig.ret_ty.is_some());
259        assert_eq!(sig.param_count(), 0);
260    }
261
262    #[test]
263    fn test_function_builder() {
264        let func = FunctionBuilder::new("add")
265            .param(
266                Register::Virtual(VirtualReg::gpr(0)),
267                MirType::Scalar(ScalarType::I64),
268            )
269            .param(
270                Register::Virtual(VirtualReg::gpr(1)),
271                MirType::Scalar(ScalarType::I64),
272            )
273            .returns(MirType::Scalar(ScalarType::I64))
274            .block("entry")
275            .instr(Instruction::IntBinary {
276                op: IntBinOp::Add,
277                ty: MirType::Scalar(ScalarType::I64),
278                dst: Register::Virtual(VirtualReg::gpr(2)),
279                lhs: Operand::Register(Register::Virtual(VirtualReg::gpr(0))),
280                rhs: Operand::Register(Register::Virtual(VirtualReg::gpr(1))),
281            })
282            .instr(Instruction::Ret {
283                value: Some(Operand::Register(Register::Virtual(VirtualReg::gpr(2)))),
284            })
285            .build();
286
287        assert_eq!(func.sig.name, "add");
288        assert_eq!(func.sig.param_count(), 2);
289        assert_eq!(func.blocks.len(), 1);
290        assert_eq!(func.instruction_count(), 2);
291    }
292
293    #[test]
294    fn test_function_validation() {
295        let mut func = Function::new(Signature::new("test"));
296
297        // Missing entry block
298        assert!(func.validate().is_err());
299
300        // Add entry block with terminator
301        let mut entry = Block::new("entry");
302        entry.push(Instruction::Ret { value: None });
303        func.add_block(entry);
304
305        assert!(func.validate().is_ok());
306    }
307}