use compact_str::{CompactString, ToCompactString};
use crate::{
ast::{DefineVariable, Expression, Statement, StringPart},
decorator::{self, Decorator},
name_resolution::NameResolutionError,
prefix_parser::{AliasSpanInfo, PrefixParser, PrefixParserResult},
span::Span,
};
type Result<T> = std::result::Result<T, NameResolutionError>;
fn temperature_conversion_function(ident: &str) -> Option<&'static str> {
match ident {
"°C" | "celsius" | "degree_celsius" => Some("from_celsius"),
"°F" | "fahrenheit" | "degree_fahrenheit" => Some("from_fahrenheit"),
_ => None,
}
}
#[derive(Debug, Clone)]
pub(crate) struct Transformer {
pub prefix_parser: PrefixParser,
pub variable_names: Vec<CompactString>,
pub function_names: Vec<CompactString>,
pub unit_names: Vec<Vec<CompactString>>,
pub dimension_names: Vec<CompactString>,
}
impl Transformer {
pub fn new() -> Self {
Self {
prefix_parser: PrefixParser::new(),
variable_names: vec![],
function_names: vec![],
unit_names: vec![],
dimension_names: vec![],
}
}
pub fn transform_expression(&self, expression: &mut Expression) {
match expression {
Expression::Scalar(..) | Expression::Boolean(_, _) | Expression::TypedHole(_) => {}
Expression::Identifier(span, identifier) => {
if let PrefixParserResult::UnitIdentifier(
_definition_span,
prefix,
unit_name,
full_name,
) = self.prefix_parser.parse(identifier)
{
*expression = Expression::UnitIdentifier {
span: *span,
prefix,
name: unit_name,
full_name,
};
} else {
*expression = Expression::Identifier(*span, identifier);
}
}
Expression::UnitIdentifier { .. } => {
unreachable!("Prefixed identifiers should not exist prior to this stage")
}
Expression::UnaryOperator { op, expr, span_op } => {
if *op == crate::ast::UnaryOperator::Negate
&& let Expression::BinaryOperator {
op: bin_op,
lhs: inner_lhs,
rhs: inner_rhs,
span_op: _,
} = expr.as_mut()
&& *bin_op == crate::ast::BinaryOperator::Mul
&& let Expression::Identifier(rhs_span, ident) = inner_rhs.as_ref()
&& let Some(fn_name) = temperature_conversion_function(ident)
{
self.transform_expression(inner_lhs);
let negated_arg = Expression::UnaryOperator {
op: crate::ast::UnaryOperator::Negate,
expr: Box::new((**inner_lhs).clone()),
span_op: *span_op,
};
let full_span = span_op.extend(rhs_span);
*expression = Expression::FunctionCall {
ident_span: *rhs_span,
full_span,
callable: Box::new(Expression::Identifier(*rhs_span, fn_name)),
args: vec![negated_arg],
};
return;
}
self.transform_expression(expr);
}
Expression::BinaryOperator {
op,
lhs,
rhs,
span_op: _,
} => {
self.transform_expression(lhs);
self.transform_expression(rhs);
if *op == crate::ast::BinaryOperator::Mul
&& let Expression::Identifier(rhs_span, ident) = rhs.as_ref()
&& let Some(fn_name) = temperature_conversion_function(ident)
{
let full_span = lhs.full_span().extend(rhs_span);
*expression = Expression::FunctionCall {
ident_span: *rhs_span,
full_span,
callable: Box::new(Expression::Identifier(*rhs_span, fn_name)),
args: vec![(**lhs).clone()],
};
}
}
Expression::FunctionCall { args, .. } => {
for arg in args {
self.transform_expression(arg);
}
}
Expression::Condition {
condition,
then_expr,
else_expr,
..
} => {
self.transform_expression(condition);
self.transform_expression(then_expr);
self.transform_expression(else_expr);
}
Expression::String(_, parts) => {
for p in parts {
match p {
StringPart::Fixed(_) => {}
StringPart::Interpolation { expr, .. } => self.transform_expression(expr),
}
}
}
Expression::InstantiateStruct { fields, .. } => {
for (_, _, arg) in fields {
self.transform_expression(arg);
}
}
Expression::AccessField { expr, .. } => {
self.transform_expression(expr);
}
Expression::List(_, elements) => {
for e in elements {
self.transform_expression(e);
}
}
}
}
fn has_decorator(decorators: &[Decorator], decorator: Decorator) -> bool {
decorators.iter().any(|d| d == &decorator)
}
pub(crate) fn register_name_and_aliases(
&mut self,
name: &str,
name_span: Span,
decorators: &[Decorator],
) -> Result<()> {
let mut unit_names = vec![];
let metric_prefixes = Self::has_decorator(decorators, Decorator::MetricPrefixes);
let binary_prefixes = Self::has_decorator(decorators, Decorator::BinaryPrefixes);
for (alias, accepts_prefix, alias_span) in
decorator::name_and_aliases_spans(name, name_span, decorators)
{
self.prefix_parser.add_unit(
alias,
accepts_prefix,
metric_prefixes,
binary_prefixes,
name,
AliasSpanInfo {
name_span,
alias_span,
},
)?;
unit_names.push(alias.to_compact_string());
}
unit_names.sort();
self.unit_names.push(unit_names);
Ok(())
}
fn transform_define_variable(
&mut self,
define_variable: &mut DefineVariable,
allow_shadowing: bool,
) -> Result<()> {
let DefineVariable {
identifier_span,
identifier,
expr,
type_annotation: _,
decorators,
} = define_variable;
for (name, _) in decorator::name_and_aliases(identifier, decorators) {
self.variable_names.push(name.to_compact_string());
}
if allow_shadowing {
self.prefix_parser
.add_shadowing_identifier(identifier, *identifier_span)?;
} else {
self.prefix_parser
.add_other_identifier(identifier, *identifier_span)?;
}
self.transform_expression(expr);
Ok(())
}
fn transform_statement(&mut self, statement: &mut Statement) -> Result<()> {
match statement {
Statement::DefineStruct { .. } | Statement::ModuleImport(_, _) => {}
Statement::Expression(expr) => {
self.transform_expression(expr);
}
Statement::DefineBaseUnit(span, name, _, decorators) => {
self.register_name_and_aliases(name, *span, decorators)?;
}
Statement::DefineDerivedUnit {
identifier_span,
identifier,
expr,
decorators,
..
} => {
self.register_name_and_aliases(identifier, *identifier_span, decorators)?;
self.transform_expression(expr);
}
Statement::DefineVariable(define_variable) => {
self.transform_define_variable(define_variable, false)?
}
Statement::DefineFunction {
function_name_span,
function_name,
parameters,
body,
local_variables,
..
} => {
self.function_names.push(function_name.to_compact_string());
self.prefix_parser
.add_other_identifier(function_name, *function_name_span)?;
let mut fn_body_transformer = self.clone();
for (param_span, param, _) in &*parameters {
fn_body_transformer
.prefix_parser
.add_shadowing_identifier(param, *param_span)?;
}
for def in &mut *local_variables {
fn_body_transformer
.variable_names
.push(def.identifier.to_compact_string());
fn_body_transformer
.prefix_parser
.add_shadowing_identifier(def.identifier, def.identifier_span)?;
}
if let Some(expr) = body {
fn_body_transformer.transform_expression(expr);
}
for def in local_variables {
fn_body_transformer.transform_expression(&mut def.expr);
}
}
Statement::DefineDimension(_, name, _) => {
self.dimension_names.push(name.to_compact_string());
}
Statement::ProcedureCall(_, _, args) => {
for arg in args {
self.transform_expression(arg);
}
}
}
Ok(())
}
pub fn transform<'a>(
&mut self,
statements: impl IntoIterator<Item = Statement<'a>>,
) -> Result<Vec<Statement<'a>>> {
statements
.into_iter()
.map(|mut statement| {
self.transform_statement(&mut statement)?;
Ok(statement)
})
.collect()
}
}