1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
use crate::ir::{macros::cpa, Scope, Variable, Vectorization};
use serde::{Deserialize, Serialize};

/// Assign value to a variable based on a given condition.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct ConditionalAssign {
    pub cond: Variable,
    pub lhs: Variable,
    pub rhs: Variable,
    pub out: Variable,
}

impl ConditionalAssign {
    #[allow(missing_docs)]
    pub fn expand(self, scope: &mut Scope) {
        let cond = self.cond;
        let lhs = self.lhs;
        let rhs = self.rhs;
        let out = self.out;

        let index_var =
            |scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 {
                true => var,
                false => {
                    let out = scope.create_local(var.item().elem());
                    cpa!(scope, out = var[index]);
                    out
                }
            };

        let mut assign_index = |index: usize| {
            let cond = index_var(scope, cond, index);

            cpa!(scope, if (cond).then(|scope| {
                let lhs = index_var(scope, lhs, index);
                let index: Variable = index.into();
                cpa!(scope, out[index] = lhs);
            }).else(|scope| {
                let rhs = index_var(scope, rhs, index);
                let index: Variable = index.into();
                cpa!(scope, out[index] = rhs);
            }));
        };

        let vectorization = out.item().vectorization;
        match vectorization == 1 {
            true => {
                cpa!(scope, if (cond).then(|scope| {
                    cpa!(scope, out = lhs);
                }).else(|scope| {
                    cpa!(scope, out = rhs);
                }));
            }
            false => {
                for i in 0..vectorization {
                    assign_index(i as usize);
                }
            }
        };
    }

    pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
        Self {
            cond: self.cond.vectorize(vectorization),
            lhs: self.lhs.vectorize(vectorization),
            rhs: self.rhs.vectorize(vectorization),
            out: self.out.vectorize(vectorization),
        }
    }
}