use super::{
ast::{Argument, Grouping, RandomEffect, RandomTerm, Response},
data_structures::{
FormulaMetadataInfo, Interaction, RandomEffectInfo, Transformation, VariableInfo,
VariableRole,
},
};
use std::collections::HashMap;
#[derive(Default)]
pub struct MetaBuilder {
name_to_id: HashMap<String, u32>,
columns: HashMap<String, VariableInfo>,
has_uncorrelated_slopes_and_intercepts: bool,
is_random_effects_model: bool,
next_id: u32,
}
impl MetaBuilder {
pub fn new() -> Self {
Self {
name_to_id: HashMap::new(),
columns: HashMap::new(),
has_uncorrelated_slopes_and_intercepts: false,
is_random_effects_model: false,
next_id: 1,
}
}
pub fn ensure_variable(&mut self, name: &str) -> u32 {
if let Some(&id) = self.name_to_id.get(name) {
id
} else {
let id = self.next_id;
self.next_id += 1;
self.name_to_id.insert(name.to_string(), id);
self.columns.insert(
name.to_string(),
VariableInfo {
id,
roles: Vec::new(),
transformations: Vec::new(),
interactions: Vec::new(),
random_effects: Vec::new(),
generated_columns: vec![name.to_string()], },
);
id
}
}
pub fn add_role(&mut self, name: &str, role: VariableRole) {
if let Some(var_info) = self.columns.get_mut(name) {
if !var_info.roles.contains(&role) {
var_info.roles.push(role);
}
}
}
pub fn add_transformation(&mut self, name: &str, transformation: Transformation) {
if let Some(var_info) = self.columns.get_mut(name) {
var_info.transformations.push(transformation.clone());
if var_info.roles.contains(&VariableRole::Identity) {
let mut new_columns = vec![name.to_string()]; new_columns.extend(transformation.generates_columns);
var_info.generated_columns = new_columns;
} else {
var_info.generated_columns = transformation.generates_columns;
}
}
}
pub fn add_interaction(&mut self, name: &str, interaction: Interaction) {
if let Some(var_info) = self.columns.get_mut(name) {
var_info.interactions.push(interaction);
}
}
pub fn add_random_effect(&mut self, name: &str, random_effect: RandomEffectInfo) {
if let Some(var_info) = self.columns.get_mut(name) {
var_info.random_effects.push(random_effect);
}
}
pub fn push_response(&mut self, response: &Response) {
match response {
Response::Single(name) => {
if !self.name_to_id.contains_key(name) {
self.name_to_id.insert(name.to_string(), 1);
self.columns.insert(
name.to_string(),
VariableInfo {
id: 1,
roles: vec![VariableRole::Response],
transformations: Vec::new(),
interactions: Vec::new(),
random_effects: Vec::new(),
generated_columns: vec![name.to_string()],
},
);
self.next_id = 2; } else {
self.add_role(name, VariableRole::Response);
}
}
Response::Multivariate(variables) => {
for name in variables {
if !self.name_to_id.contains_key(name) {
self.name_to_id.insert(name.to_string(), 1);
self.columns.insert(
name.to_string(),
VariableInfo {
id: 1,
roles: vec![VariableRole::Response],
transformations: Vec::new(),
interactions: Vec::new(),
random_effects: Vec::new(),
generated_columns: vec![name.to_string()],
},
);
} else {
self.add_role(name, VariableRole::Response);
}
}
self.next_id = 2; }
}
}
pub fn push_plain_term(&mut self, name: &str) {
self.ensure_variable(name);
self.add_role(name, VariableRole::Identity);
}
fn extract_all_variables(term: &crate::internal::ast::Term) -> Vec<String> {
match term {
crate::internal::ast::Term::Column(name) => vec![name.clone()],
crate::internal::ast::Term::Function { args, .. } => {
args.iter()
.find_map(|arg| match arg {
Argument::Ident(s) => Some(s.clone()),
_ => None,
})
.into_iter()
.collect()
}
crate::internal::ast::Term::Interaction { left, right } => {
let mut vars = Self::extract_all_variables(left);
vars.extend(Self::extract_all_variables(right));
vars
}
_ => vec![],
}
}
fn generate_interaction_combinations(variables: &[String]) -> Vec<Vec<String>> {
let mut combinations = Vec::new();
let n = variables.len();
for order in 2..=n {
combinations.extend(Self::combinations(variables, order));
}
combinations
}
fn combinations(variables: &[String], k: usize) -> Vec<Vec<String>> {
if k == 0 {
return vec![vec![]];
}
if variables.is_empty() {
return vec![];
}
let mut result = Vec::new();
let first = &variables[0];
let rest = &variables[1..];
for mut combo in Self::combinations(rest, k - 1) {
combo.insert(0, first.clone());
result.push(combo);
}
result.extend(Self::combinations(rest, k));
result
}
fn create_interaction_name(variables: &[String]) -> String {
variables.join("_")
}
pub fn push_interaction(
&mut self,
left: &crate::internal::ast::Term,
right: &crate::internal::ast::Term,
) {
let mut all_variables = Self::extract_all_variables(left);
all_variables.extend(Self::extract_all_variables(right));
let mut unique_variables = Vec::new();
for var in all_variables {
if !unique_variables.contains(&var) {
unique_variables.push(var);
}
}
if unique_variables.is_empty() {
return;
}
for var in &unique_variables {
self.ensure_variable(var);
self.add_role(var, VariableRole::FixedEffect);
}
let interaction_combinations = Self::generate_interaction_combinations(&unique_variables);
for combo in interaction_combinations {
let interaction_name = Self::create_interaction_name(&combo);
let order = combo.len() as u32;
self.ensure_variable(&interaction_name);
self.add_role(&interaction_name, VariableRole::InteractionTerm);
self.add_role(&interaction_name, VariableRole::FixedEffect);
for (i, var) in combo.iter().enumerate() {
let other_vars: Vec<String> = combo
.iter()
.enumerate()
.filter(|(j, _)| *j != i)
.map(|(_, v)| v.clone())
.collect();
let interaction = Interaction {
with: other_vars,
order,
context: "fixed_effects".to_string(),
grouping_variable: None,
};
self.add_interaction(var, interaction);
}
}
}
pub fn push_function_term(&mut self, fname: &str, args: &[Argument]) {
if fname == "c" || fname == "factor" {
self.push_categorical_term_with_name(fname, args);
return;
}
let base_ident = args.iter().find_map(|a| match a {
Argument::Ident(s) => Some(s.as_str()),
_ => None,
});
if let Some(base_col) = base_ident {
self.ensure_variable(base_col);
self.add_role(base_col, VariableRole::FixedEffect);
let parameters = self.extract_function_parameters(fname, args);
let generates_columns = self.generate_transformation_columns(fname, args);
let transformation = Transformation {
function: fname.to_string(),
parameters,
generates_columns,
};
self.add_transformation(base_col, transformation);
}
}
fn push_categorical_term_with_name(&mut self, fname: &str, args: &[Argument]) {
let var_name = args.iter().find_map(|a| match a {
Argument::Ident(s) => Some(s.as_str()),
_ => None,
});
if let Some(var_name) = var_name {
self.ensure_variable(var_name);
self.add_role(var_name, VariableRole::Categorical);
self.add_role(var_name, VariableRole::FixedEffect);
let ref_level = args.iter().find_map(|a| match a {
Argument::Named(key, value) if key == "ref" => Some(value.clone()),
_ => None,
});
let mut parameters = self.extract_function_parameters(fname, args);
if let Some(ref_level) = ref_level {
if let serde_json::Value::Object(ref mut params_map) = parameters {
params_map.insert("ref".to_string(), serde_json::Value::String(ref_level));
}
}
let generates_columns = self.generate_transformation_columns(fname, args);
let transformation = Transformation {
function: fname.to_string(),
parameters,
generates_columns,
};
self.add_transformation(var_name, transformation);
}
}
pub fn push_random_effect(&mut self, random_effect: &RandomEffect) {
self.is_random_effects_model = true;
if matches!(
random_effect.correlation,
crate::internal::ast::CorrelationType::Uncorrelated
) {
self.has_uncorrelated_slopes_and_intercepts = true;
}
let grouping_var = match &random_effect.grouping {
Grouping::Simple(group) => group.clone(),
Grouping::Gr { group, .. } => group.clone(),
Grouping::Mm { groups } => groups.join("_"),
Grouping::Interaction { left, right } => format!("{}:{}", left, right),
Grouping::Nested { outer, inner } => format!("{}/{}", outer, inner),
};
self.ensure_variable(&grouping_var);
self.add_role(&grouping_var, VariableRole::GroupingVariable);
let has_intercept = random_effect
.terms
.iter()
.any(|term| matches!(term, RandomTerm::Column(name) if name == "1"));
let correlated = !matches!(
random_effect.correlation,
crate::internal::ast::CorrelationType::Uncorrelated
);
let mut variables_in_random_effect = Vec::new();
let mut interactions_in_random_effect = Vec::new();
for term in &random_effect.terms {
match term {
RandomTerm::Column(name) => {
if name != "1" {
self.ensure_variable(name);
self.add_role(name, VariableRole::RandomEffect);
variables_in_random_effect.push(name.clone());
let random_effect_info = RandomEffectInfo {
kind: "slope".to_string(),
grouping_variable: grouping_var.clone(),
has_intercept,
correlated,
includes_interactions: Vec::new(),
variables: None,
};
self.add_random_effect(name, random_effect_info);
}
}
RandomTerm::Function {
name: func_name,
args,
} => {
let base_ident = args.iter().find_map(|a| match a {
Argument::Ident(s) => Some(s.as_str()),
_ => None,
});
if let Some(base_col) = base_ident {
self.ensure_variable(base_col);
self.add_role(base_col, VariableRole::RandomEffect);
variables_in_random_effect.push(base_col.to_string());
let parameters = self.extract_function_parameters(func_name, args);
let generates_columns =
self.generate_transformation_columns(func_name, args);
let transformation = Transformation {
function: func_name.clone(),
parameters,
generates_columns,
};
self.add_transformation(base_col, transformation);
let random_effect_info = RandomEffectInfo {
kind: "slope".to_string(),
grouping_variable: grouping_var.clone(),
has_intercept,
correlated,
includes_interactions: Vec::new(),
variables: None,
};
self.add_random_effect(base_col, random_effect_info);
}
}
RandomTerm::Interaction { left, right } => {
let left_name = match left.as_ref() {
RandomTerm::Column(name) => name.clone(),
_ => "interaction".to_string(),
};
let right_name = match right.as_ref() {
RandomTerm::Column(name) => name.clone(),
_ => "interaction".to_string(),
};
let interaction_name = format!("{}:{}", left_name, right_name);
interactions_in_random_effect.push(interaction_name.clone());
let interaction = Interaction {
with: vec![right_name.clone()],
order: 2,
context: "random_effects".to_string(),
grouping_variable: Some(grouping_var.clone()),
};
self.add_interaction(&left_name, interaction);
let interaction = Interaction {
with: vec![left_name.clone()],
order: 2,
context: "random_effects".to_string(),
grouping_variable: Some(grouping_var.clone()),
};
self.add_interaction(&right_name, interaction);
}
RandomTerm::SuppressIntercept => {
}
}
}
let grouping_random_effect = RandomEffectInfo {
kind: "grouping".to_string(),
grouping_variable: grouping_var.clone(),
has_intercept,
correlated,
includes_interactions: interactions_in_random_effect,
variables: Some(variables_in_random_effect),
};
self.add_random_effect(&grouping_var, grouping_random_effect);
}
fn extract_function_parameters(&self, fname: &str, args: &[Argument]) -> serde_json::Value {
let mut params = serde_json::Map::new();
match fname {
"poly" => {
if let Some(Argument::Integer(degree)) = args.get(1) {
params.insert(
"degree".to_string(),
serde_json::Value::Number((*degree).into()),
);
params.insert("orthogonal".to_string(), serde_json::Value::Bool(true));
}
}
"log" => {
}
"factor" => {
}
_ => {
for (i, arg) in args.iter().enumerate() {
let key = format!("arg_{}", i);
let value = match arg {
Argument::Integer(n) => serde_json::Value::Number((*n).into()),
Argument::String(s) => serde_json::Value::String(s.clone()),
Argument::Boolean(b) => serde_json::Value::Bool(*b),
Argument::Ident(s) => serde_json::Value::String(s.clone()),
Argument::Named(key, value) => {
params.insert(key.clone(), serde_json::Value::String(value.clone()));
continue; }
};
params.insert(key, value);
}
}
}
serde_json::Value::Object(params)
}
fn generate_transformation_columns(&self, fname: &str, args: &[Argument]) -> Vec<String> {
let base_name = args
.iter()
.find_map(|a| match a {
Argument::Ident(s) => Some(s.as_str()),
_ => None,
})
.unwrap_or("unknown");
match fname {
"poly" => {
if let Some(Argument::Integer(degree)) = args.get(1) {
(1..=*degree as usize)
.map(|i| format!("{}_poly_{}", base_name, i))
.collect()
} else {
vec![format!("{}_poly", base_name)]
}
}
"log" => vec![format!("{}_log", base_name)],
"c" | "factor" => {
vec![format!("{}_categorical", base_name)]
}
_ => vec![format!("{}_{}", base_name, fname)],
}
}
pub fn build(
self,
input: &str,
has_intercept: bool,
family: Option<String>,
) -> crate::internal::data_structures::FormulaMetaData {
let mut all_generated_columns = Vec::new();
let mut sorted_vars: Vec<_> = self.columns.values().collect();
sorted_vars.sort_by_key(|v| v.id);
for var in &sorted_vars {
all_generated_columns.extend(var.generated_columns.clone());
}
if has_intercept {
all_generated_columns.insert(1, "intercept".to_string()); }
let mut all_generated_columns_formula_order = std::collections::HashMap::new();
let mut order_index = 1;
for response_var in sorted_vars.iter().filter(|v| v.id == 1) {
for response_col in &response_var.generated_columns {
all_generated_columns_formula_order
.insert(order_index.to_string(), response_col.clone());
order_index += 1;
}
}
if has_intercept {
all_generated_columns_formula_order
.insert(order_index.to_string(), "intercept".to_string());
order_index += 1;
}
for var in &sorted_vars {
if var.id != 1 {
for col in &var.generated_columns {
all_generated_columns_formula_order
.insert(order_index.to_string(), col.clone());
order_index += 1;
}
}
}
let response_variable_count = self.columns.values().filter(|v| v.id == 1).count() as u32;
crate::internal::data_structures::FormulaMetaData {
formula: input.to_string(),
metadata: FormulaMetadataInfo {
has_intercept,
is_random_effects_model: self.is_random_effects_model,
has_uncorrelated_slopes_and_intercepts: self.has_uncorrelated_slopes_and_intercepts,
family,
response_variable_count,
},
columns: self.columns,
all_generated_columns,
all_generated_columns_formula_order,
}
}
}