use super::block::Block;
use super::register::Register;
use super::types::MirType;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub struct Parameter {
pub reg: Register,
pub ty: MirType,
}
impl Parameter {
pub fn new(reg: Register, ty: MirType) -> Self {
Self { reg, ty }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Signature {
pub name: String,
pub params: Vec<Parameter>,
pub ret_ty: Option<MirType>,
}
impl Signature {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
params: Vec::new(),
ret_ty: None,
}
}
pub fn with_params(mut self, params: Vec<Parameter>) -> Self {
self.params = params;
self
}
pub fn with_return(mut self, ty: MirType) -> Self {
self.ret_ty = Some(ty);
self
}
pub fn param_count(&self) -> usize {
self.params.len()
}
pub fn is_void(&self) -> bool {
self.ret_ty.is_none()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Function {
pub sig: Signature,
pub blocks: Vec<Block>,
pub entry: String,
}
impl Function {
pub fn new(sig: Signature) -> Self {
Self {
sig,
blocks: Vec::new(),
entry: "entry".to_string(),
}
}
pub fn with_entry(mut self, entry: impl Into<String>) -> Self {
self.entry = entry.into();
self
}
pub fn add_block(&mut self, block: Block) {
self.blocks.push(block);
}
pub fn get_block(&self, label: &str) -> Option<&Block> {
self.blocks.iter().find(|b| b.label == label)
}
pub fn get_block_mut(&mut self, label: &str) -> Option<&mut Block> {
self.blocks.iter_mut().find(|b| b.label == label)
}
pub fn entry_block(&self) -> Option<&Block> {
self.get_block(&self.entry)
}
pub fn entry_block_mut(&mut self) -> Option<&mut Block> {
let entry = self.entry.clone();
self.get_block_mut(&entry)
}
pub fn block_labels(&self) -> Vec<&str> {
self.blocks.iter().map(|b| b.label.as_str()).collect()
}
pub fn instruction_count(&self) -> usize {
self.blocks.iter().map(|b| b.len()).sum()
}
pub fn validate(&self) -> Result<(), String> {
if self.entry_block().is_none() {
return Err(format!("Entry block '{}' not found", self.entry));
}
let mut seen_labels = std::collections::HashSet::new();
for block in &self.blocks {
if !seen_labels.insert(&block.label) {
return Err(format!("Duplicate block label: {}", block.label));
}
}
for block in &self.blocks {
if !block.has_terminator() {
return Err(format!("Block '{}' has no terminator", block.label));
}
}
for block in &self.blocks {
for label in block.successors() {
if !seen_labels.contains(&label) {
return Err(format!(
"Block '{}' references undefined target '{}'",
block.label, label
));
}
}
}
Ok(())
}
}
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "fn {}(", self.sig.name)?;
for (i, p) in self.sig.params.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{} {}", p.reg, p.ty)?;
}
write!(f, ")")?;
if let Some(ret) = &self.sig.ret_ty {
write!(f, " -> {}", ret)?;
}
writeln!(f, " {{")?;
for block in &self.blocks {
writeln!(f, "{}", block)?;
}
write!(f, "}}")
}
}
pub struct FunctionBuilder {
function: Function,
current_block: Option<String>,
}
impl FunctionBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
function: Function::new(Signature::new(name)),
current_block: None,
}
}
pub fn param(mut self, reg: Register, ty: MirType) -> Self {
self.function.sig.params.push(Parameter::new(reg, ty));
self
}
pub fn returns(mut self, ty: MirType) -> Self {
self.function.sig.ret_ty = Some(ty);
self
}
pub fn block(mut self, label: impl Into<String>) -> Self {
let label = label.into();
self.function.add_block(Block::new(label.clone()));
self.current_block = Some(label);
self
}
pub fn instr(mut self, instr: super::instruction::Instruction) -> Self {
if let Some(ref label) = self.current_block
&& let Some(block) = self.function.get_block_mut(label)
{
block.push(instr);
}
self
}
pub fn build(self) -> Function {
self.function
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::instruction::{Instruction, IntBinOp, Operand};
use crate::register::VirtualReg;
use crate::types::ScalarType;
#[test]
fn test_signature_creation() {
let sig = Signature::new("test_func").with_return(MirType::Scalar(ScalarType::I64));
assert_eq!(sig.name, "test_func");
assert!(sig.ret_ty.is_some());
assert_eq!(sig.param_count(), 0);
}
#[test]
fn test_function_builder() {
let func = FunctionBuilder::new("add")
.param(
Register::Virtual(VirtualReg::gpr(0)),
MirType::Scalar(ScalarType::I64),
)
.param(
Register::Virtual(VirtualReg::gpr(1)),
MirType::Scalar(ScalarType::I64),
)
.returns(MirType::Scalar(ScalarType::I64))
.block("entry")
.instr(Instruction::IntBinary {
op: IntBinOp::Add,
ty: MirType::Scalar(ScalarType::I64),
dst: Register::Virtual(VirtualReg::gpr(2)),
lhs: Operand::Register(Register::Virtual(VirtualReg::gpr(0))),
rhs: Operand::Register(Register::Virtual(VirtualReg::gpr(1))),
})
.instr(Instruction::Ret {
value: Some(Operand::Register(Register::Virtual(VirtualReg::gpr(2)))),
})
.build();
assert_eq!(func.sig.name, "add");
assert_eq!(func.sig.param_count(), 2);
assert_eq!(func.blocks.len(), 1);
assert_eq!(func.instruction_count(), 2);
}
#[test]
fn test_function_validation() {
let mut func = Function::new(Signature::new("test"));
assert!(func.validate().is_err());
let mut entry = Block::new("entry");
entry.push(Instruction::Ret { value: None });
func.add_block(entry);
assert!(func.validate().is_ok());
}
}