rtlola-streamir 0.1.0

A framework for the compilation of stream-based languages through an intermediate representation
Documentation
use std::collections::HashMap;

use rtlola_frontend::mir::{self, StreamAccessKind};

use crate::ir::{
    expressions::{Constant, Expr, ExprKind, Function, Operator},
    StreamReference, Type,
};

use super::LoweringError;

pub(super) fn convert_stream_expression(
    expr: mir::Expression,
    default: Option<Expr>,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<Expr, LoweringError> {
    let mir::Expression { ty, kind } = expr;

    let lir_ty = match ty {
        mir::Type::Option(inner) => *inner,
        other => other,
    }
    .into();

    let lir_kind = match kind {
        mir::ExpressionKind::LoadConstant(c) => {
            Ok(ExprKind::Constant(translate_constant(c, &lir_ty)))
        }

        mir::ExpressionKind::ArithLog(op, exprs) => translate_operator(op, exprs, sr2sr),

        mir::ExpressionKind::Ite {
            condition,
            consequence,
            alternative,
        } => translate_ite(*condition, *consequence, *alternative, sr2sr),

        mir::ExpressionKind::StreamAccess {
            target,
            parameters,
            access_kind: StreamAccessKind::Sync,
        } => translate_sync_access(target, parameters, sr2sr),

        mir::ExpressionKind::StreamAccess {
            target,
            parameters,
            access_kind: StreamAccessKind::Offset(offset),
        } => translate_offset_access(
            target,
            parameters,
            offset,
            default.ok_or(LoweringError::DefaultRequired)?,
            sr2sr,
        ),
        mir::ExpressionKind::StreamAccess {
            target,
            parameters,
            access_kind: StreamAccessKind::Hold,
        } => translate_hold_access(
            target,
            parameters,
            default.ok_or(LoweringError::DefaultRequired)?,
            sr2sr,
        ),

        mir::ExpressionKind::Default { expr, default } => {
            let default_expr = convert_stream_expression(*default, None, sr2sr)?;
            return convert_stream_expression(*expr, Some(default_expr), sr2sr);
        }
        mir::ExpressionKind::StreamAccess {
            target,
            parameters,
            access_kind:
                StreamAccessKind::SlidingWindow(window)
                | StreamAccessKind::DiscreteWindow(window)
                | StreamAccessKind::InstanceAggregation(window),
        } => translate_window(target, parameters, window, default, sr2sr),

        mir::ExpressionKind::StreamAccess {
            target,
            parameters,
            access_kind: StreamAccessKind::Fresh,
        } => translate_is_fresh(target, parameters, sr2sr),

        mir::ExpressionKind::StreamAccess {
            target,
            parameters,
            access_kind: StreamAccessKind::Get,
        } => translate_get(
            target,
            parameters,
            default.ok_or(LoweringError::DefaultRequired)?,
            sr2sr,
        ),

        mir::ExpressionKind::Convert { expr: inner_exp } => {
            translate_convert(*inner_exp, &lir_ty, sr2sr)
        }

        mir::ExpressionKind::Function(func, inner_exp) if func == "cast" => {
            let inner_exp = inner_exp.into_iter().next().unwrap();
            translate_convert(inner_exp, &lir_ty, sr2sr)
        }

        mir::ExpressionKind::ParameterAccess(target, parameter) => {
            translate_parameter_access(target, parameter, sr2sr)
        }

        mir::ExpressionKind::Function(func, inner_exps) => {
            let f = match func.as_str() {
                "sqrt" => Function::Sqrt,
                "sin" => Function::Sin,
                "arcsin" => Function::Arcsin,
                "cos" => Function::Cos,
                "arccos" => Function::Arccos,
                "tan" => Function::Tan,
                "arctan" => Function::Arctan,
                "abs" => Function::Abs,
                "min" => Function::Min,
                "max" => Function::Max,
                name => return Err(LoweringError::UnsupportedFunction(name.into())),
            };
            let args = inner_exps
                .into_iter()
                .map(|expr| convert_stream_expression(expr, None, sr2sr))
                .collect::<Result<Vec<_>, LoweringError>>()?;
            Ok(ExprKind::FunctionCall(f, args))
        }

        mir::ExpressionKind::Tuple(inner_exps) => {
            let inner_exps = inner_exps
                .into_iter()
                .map(|expr| convert_stream_expression(expr, None, sr2sr))
                .collect::<Result<Vec<_>, LoweringError>>()?;
            Ok(ExprKind::Tuple(inner_exps))
        }
        mir::ExpressionKind::TupleAccess(tuple_expr, i) => {
            let tuple_expr = convert_stream_expression(*tuple_expr, None, sr2sr)?;
            Ok(ExprKind::TupleAccess(Box::new(tuple_expr), i))
        }
        mir::ExpressionKind::LambdaParameterAccess { wref, pref } => {
            Ok(ExprKind::LambdaParameterAccess(wref.into(), pref))
        }
    }?;

    Ok(Expr {
        ty: lir_ty,
        kind: lir_kind,
    })
}

fn translate_constant(c: mir::Constant, ty: &Type) -> Constant {
    match (c, ty) {
        (mir::Constant::Str(s), _) => Constant::Str(s),
        (mir::Constant::Bool(b), _) => Constant::Bool(b),
        (mir::Constant::UInt(u), Type::UInt(b)) => Constant::UInt(u, *b),
        (mir::Constant::Int(i), Type::Int(b)) => Constant::Int(i, *b),
        (mir::Constant::Float(i), Type::Float32) => Constant::Float32(i),
        (mir::Constant::Float(i), Type::Float64) => Constant::Float64(i),
        (c, t) => unreachable!("ensured by type checker: {c}: {t:?}"),
    }
}

fn translate_operator(
    op: mir::ArithLogOp,
    operands: Vec<mir::Expression>,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    let operator = match op {
        mir::ArithLogOp::Not => Operator::Not,
        mir::ArithLogOp::Neg => Operator::Neg,
        mir::ArithLogOp::Add => Operator::Add,
        mir::ArithLogOp::Sub => Operator::Sub,
        mir::ArithLogOp::Mul => Operator::Mul,
        mir::ArithLogOp::Div => Operator::Div,
        mir::ArithLogOp::Rem => Operator::Rem,
        mir::ArithLogOp::Pow => Operator::Pow,
        mir::ArithLogOp::And => Operator::And,
        mir::ArithLogOp::Or => Operator::Or,
        mir::ArithLogOp::BitXor => Operator::BitXor,
        mir::ArithLogOp::BitAnd => Operator::BitAnd,
        mir::ArithLogOp::BitOr => Operator::BitOr,
        mir::ArithLogOp::BitNot => Operator::BitNot,
        mir::ArithLogOp::Shl => Operator::Shl,
        mir::ArithLogOp::Shr => Operator::Shr,
        mir::ArithLogOp::Eq => Operator::Eq,
        mir::ArithLogOp::Lt => Operator::Lt,
        mir::ArithLogOp::Le => Operator::Le,
        mir::ArithLogOp::Ne => Operator::Ne,
        mir::ArithLogOp::Ge => Operator::Ge,
        mir::ArithLogOp::Gt => Operator::Gt,
    };

    Ok(match operands.len() {
        1 => {
            let operand = operands.into_iter().next().unwrap();
            ExprKind::UnaryOperation(
                operator,
                Box::new(convert_stream_expression(operand, None, sr2sr)?),
            )
        }
        2 => {
            let (lhs, rhs) = {
                let mut op_iter = operands.into_iter();
                (op_iter.next().unwrap(), op_iter.next().unwrap())
            };
            ExprKind::BinaryOperation(
                operator,
                Box::new(convert_stream_expression(lhs, None, sr2sr)?),
                Box::new(convert_stream_expression(rhs, None, sr2sr)?),
            )
        }
        _ => unreachable!(),
    })
}

fn translate_ite(
    condition: mir::Expression,
    consequence: mir::Expression,
    alternative: mir::Expression,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::Ite(
        Box::new(convert_stream_expression(condition, None, sr2sr)?),
        Box::new(convert_stream_expression(consequence, None, sr2sr)?),
        Box::new(convert_stream_expression(alternative, None, sr2sr)?),
    ))
}

fn translate_sync_access(
    target: mir::StreamReference,
    parameters: Vec<mir::Expression>,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::SyncStreamAccess {
        target: sr2sr[&target],
        parameters: parameters
            .into_iter()
            .map(|expr| convert_stream_expression(expr, None, sr2sr))
            .collect::<Result<Vec<_>, LoweringError>>()?,
    })
}

fn translate_offset_access(
    target: mir::StreamReference,
    parameters: Vec<mir::Expression>,
    offset: mir::Offset,
    default: Expr,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    let offset = match offset {
        mir::Offset::Future(_) => return Err(LoweringError::FutureAccess),
        mir::Offset::Past(i) => i,
    };

    Ok(ExprKind::OffsetStreamAccess {
        target: sr2sr[&target],
        offset,
        default: Box::new(default),
        parameters: parameters
            .into_iter()
            .map(|parameter| convert_stream_expression(parameter, None, sr2sr))
            .collect::<Result<Vec<_>, LoweringError>>()?,
    })
}

fn translate_hold_access(
    target: mir::StreamReference,
    parameters: Vec<mir::Expression>,
    default: Expr,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::HoldStreamAccess {
        target: sr2sr[&target],
        default: Box::new(default),
        parameters: parameters
            .into_iter()
            .map(|parameter| convert_stream_expression(parameter, None, sr2sr))
            .collect::<Result<Vec<_>, LoweringError>>()?,
    })
}

fn translate_window(
    target: mir::StreamReference,
    parameters: Vec<mir::Expression>,
    window: mir::WindowReference,
    default: Option<Expr>,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    let default = default.map(Box::new);

    Ok(ExprKind::WindowAccess {
        target: sr2sr[&target],
        window: window.into(),
        parameters: parameters
            .into_iter()
            .map(|x| convert_stream_expression(x, None, sr2sr))
            .collect::<Result<Vec<_>, LoweringError>>()?,
        default,
    })
}

fn translate_convert(
    expr: mir::Expression,
    convert_to: &Type,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::Cast(
        convert_to.clone(),
        Box::new(convert_stream_expression(expr, None, sr2sr)?),
    ))
}

fn translate_parameter_access(
    target: mir::StreamReference,
    parameter: usize,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::ParameterAccess(sr2sr[&target], parameter))
}

fn translate_is_fresh(
    target: mir::StreamReference,
    parameters: Vec<mir::Expression>,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::IsFresh {
        target: sr2sr[&target],
        parameters: parameters
            .into_iter()
            .map(|x| convert_stream_expression(x, None, sr2sr))
            .collect::<Result<Vec<_>, LoweringError>>()?,
    })
}

fn translate_get(
    target: mir::StreamReference,
    parameters: Vec<mir::Expression>,
    default: Expr,
    sr2sr: &HashMap<mir::StreamReference, StreamReference>,
) -> Result<ExprKind, LoweringError> {
    Ok(ExprKind::GetAccess {
        target: sr2sr[&target],
        default: Box::new(default),
        parameters: parameters
            .into_iter()
            .map(|x| convert_stream_expression(x, None, sr2sr))
            .collect::<Result<Vec<_>, LoweringError>>()?,
    })
}