use super::function::Function;
use super::types::MirType;
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub struct Global {
pub name: String,
pub ty: MirType,
pub mutable: bool,
pub initializer: Option<Vec<u8>>,
}
impl Global {
pub fn new(name: impl Into<String>, ty: MirType) -> Self {
Self {
name: name.into(),
ty,
mutable: true,
initializer: None,
}
}
pub fn immutable(mut self) -> Self {
self.mutable = false;
self
}
pub fn with_initializer(mut self, data: Vec<u8>) -> Self {
self.initializer = Some(data);
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Module {
pub name: String,
pub functions: HashMap<String, Function>,
pub globals: HashMap<String, Global>,
pub external_functions: HashSet<String>,
}
impl Module {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
functions: HashMap::new(),
globals: HashMap::new(),
external_functions: HashSet::new(),
}
}
pub fn add_function(&mut self, func: Function) {
let name = func.sig.name.clone();
self.functions.insert(name.clone(), func);
}
pub fn mark_external(&mut self, name: impl Into<String>) {
self.external_functions.insert(name.into());
}
pub fn is_external(&self, name: &str) -> bool {
self.external_functions.contains(name)
}
pub fn add_global(&mut self, global: Global) {
let name = global.name.clone();
self.globals.insert(name, global);
}
pub fn get_function(&self, name: &str) -> Option<&Function> {
self.functions.get(name)
}
pub fn get_function_mut(&mut self, name: &str) -> Option<&mut Function> {
self.functions.get_mut(name)
}
pub fn get_global(&self, name: &str) -> Option<&Global> {
self.globals.get(name)
}
pub fn function_names(&self) -> Vec<&str> {
self.functions.keys().map(|s| s.as_str()).collect()
}
pub fn global_names(&self) -> Vec<&str> {
self.globals.keys().map(|s| s.as_str()).collect()
}
pub fn instruction_count(&self) -> usize {
self.functions.values().map(|f| f.instruction_count()).sum()
}
pub fn validate(&self) -> Result<(), Vec<String>> {
let mut errors = Vec::new();
for (name, func) in &self.functions {
if let Err(e) = func.validate() {
errors.push(format!("Function '{}': {}", name, e));
}
}
for func_name in self.functions.keys() {
if self.globals.contains_key(func_name) {
errors.push(format!(
"Name '{}' used for both function and global",
func_name
));
}
}
if errors.is_empty() {
Ok(())
} else {
Err(errors)
}
}
}
impl fmt::Display for Module {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut names: Vec<_> = self.functions.keys().collect();
names.sort();
for (i, name) in names.iter().enumerate() {
if let Some(func) = self.functions.get(*name) {
writeln!(f, "{}", func)?;
if i < names.len() - 1 {
writeln!(f)?;
}
}
}
Ok(())
}
}
pub struct ModuleBuilder {
module: Module,
}
impl ModuleBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
module: Module::new(name),
}
}
pub fn function(mut self, func: Function) -> Self {
self.module.add_function(func);
self
}
pub fn global(mut self, global: Global) -> Self {
self.module.add_global(global);
self
}
pub fn build(self) -> Module {
self.module
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mir::function::{Function, Signature};
use crate::mir::types::{MirType, ScalarType};
#[test]
fn test_module_creation() {
let module = Module::new("test_module");
assert_eq!(module.name, "test_module");
assert!(module.functions.is_empty());
assert!(module.globals.is_empty());
}
#[test]
fn test_module_add_function() {
let mut module = Module::new("test");
let func = Function::new(Signature::new("test_func"));
module.add_function(func);
assert_eq!(module.functions.len(), 1);
assert!(module.get_function("test_func").is_some());
}
#[test]
fn test_module_add_global() {
let mut module = Module::new("test");
let global = Global::new("my_global", MirType::Scalar(ScalarType::I32));
module.add_global(global);
assert_eq!(module.globals.len(), 1);
assert!(module.get_global("my_global").is_some());
}
#[test]
fn test_module_builder() {
let func = Function::new(Signature::new("main"));
let global = Global::new("counter", MirType::Scalar(ScalarType::I64));
let module = ModuleBuilder::new("my_module")
.function(func)
.global(global)
.build();
assert_eq!(module.name, "my_module");
assert_eq!(module.functions.len(), 1);
assert_eq!(module.globals.len(), 1);
}
}