starlight/lower/
lower_state.rs

1use std::num::NonZeroUsize;
2
3use awint::{
4    awint_dag::{smallvec::smallvec, ConcatFieldsType, ConcatType, Op::*, PState},
5    bw,
6};
7
8use crate::{
9    ensemble::Ensemble,
10    epoch::EpochShared,
11    lower::{lower_op, LowerManagement},
12    Error,
13};
14
15impl Ensemble {
16    /// Used for forbidden meta psuedo-DSL techniques in which a single state is
17    /// replaced by more basic states.
18    pub fn graft(&mut self, p_state: PState, operands: &[PState]) -> Result<(), Error> {
19        #[cfg(debug_assertions)]
20        {
21            if (self.stator.states[p_state].op.operands_len() + 1) != operands.len() {
22                return Err(Error::OtherStr(
23                    "wrong number of operands for the `graft` function",
24                ))
25            }
26            for (i, op) in self.stator.states[p_state].op.operands().iter().enumerate() {
27                let current_nzbw = self.stator.states[operands[i + 1]].nzbw;
28                let current_is_opaque = self.stator.states[operands[i + 1]].op.is_opaque();
29                if self.stator.states[op].nzbw != current_nzbw {
30                    return Err(Error::OtherString(format!(
31                        "operand {}: a bitwidth of {:?} is trying to be grafted to a bitwidth of \
32                         {:?}",
33                        i, current_nzbw, self.stator.states[op].nzbw
34                    )))
35                }
36                if !current_is_opaque {
37                    return Err(Error::OtherStr(
38                        "expected an `Opaque` for the `graft` function",
39                    ))
40                }
41            }
42            let lhs_w = self.stator.states[p_state].nzbw.get();
43            let rhs_w = self.stator.states[operands[0]].nzbw.get();
44            if lhs_w != rhs_w {
45                return Err(Error::BitwidthMismatch(lhs_w, rhs_w))
46            }
47        }
48
49        // graft input
50        for i in 1..operands.len() {
51            let grafted = operands[i];
52            let graftee = self.stator.states.get(p_state).unwrap().op.operands()[i - 1];
53            if let Some(grafted) = self.stator.states.get_mut(grafted) {
54                // change the grafted `Opaque` into a `Copy` that routes to the graftee instead
55                // of needing to change all the operands of potentially many nodes
56                grafted.op = Copy([graftee]);
57            } else {
58                // else the operand is not used because it was optimized away, this is removing
59                // a tree outside of the grafted part
60                self.state_dec_rc(graftee).unwrap();
61            }
62        }
63
64        // graft output
65        let grafted = operands[0];
66        self.stator.states.get_mut(p_state).unwrap().op = Copy([grafted]);
67        self.stator.states[grafted].inc_rc();
68
69        Ok(())
70    }
71
72    pub fn lower_op(epoch_shared: &EpochShared, p_state: PState) -> Result<bool, Error> {
73        struct Tmp<'a> {
74            ptr: PState,
75            epoch_shared: &'a EpochShared,
76        }
77        impl<'a> LowerManagement<PState> for Tmp<'a> {
78            fn graft(&mut self, operands: &[PState]) {
79                self.epoch_shared
80                    .epoch_data
81                    .borrow_mut()
82                    .ensemble
83                    .graft(self.ptr, operands)
84                    .unwrap();
85            }
86
87            fn get_nzbw(&self, p: PState) -> NonZeroUsize {
88                self.epoch_shared
89                    .epoch_data
90                    .borrow()
91                    .ensemble
92                    .stator
93                    .states
94                    .get(p)
95                    .unwrap()
96                    .nzbw
97            }
98
99            fn is_literal(&self, p: PState) -> bool {
100                self.epoch_shared
101                    .epoch_data
102                    .borrow()
103                    .ensemble
104                    .stator
105                    .states
106                    .get(p)
107                    .unwrap()
108                    .op
109                    .is_literal()
110            }
111
112            fn usize(&self, p: PState) -> usize {
113                if let Literal(ref lit) = self
114                    .epoch_shared
115                    .epoch_data
116                    .borrow()
117                    .ensemble
118                    .stator
119                    .states
120                    .get(p)
121                    .unwrap()
122                    .op
123                {
124                    if lit.bw() != 64 {
125                        panic!()
126                    }
127                    lit.to_usize()
128                } else {
129                    panic!()
130                }
131            }
132
133            fn bool(&self, p: PState) -> bool {
134                if let Literal(ref lit) = self
135                    .epoch_shared
136                    .epoch_data
137                    .borrow()
138                    .ensemble
139                    .stator
140                    .states
141                    .get(p)
142                    .unwrap()
143                    .op
144                {
145                    if lit.bw() != 1 {
146                        panic!()
147                    }
148                    lit.to_bool()
149                } else {
150                    panic!()
151                }
152            }
153
154            fn dec_rc(&mut self, p: PState) {
155                self.epoch_shared
156                    .epoch_data
157                    .borrow_mut()
158                    .ensemble
159                    .state_dec_rc(p)
160                    .unwrap()
161            }
162        }
163        let lock = epoch_shared.epoch_data.borrow();
164        let state = lock.ensemble.stator.states.get(p_state).unwrap();
165        let start_op = state.op.clone();
166        let out_w = state.nzbw;
167        drop(lock);
168        lower_op(start_op, out_w, Tmp {
169            ptr: p_state,
170            epoch_shared,
171        })
172    }
173
174    /// Lowers the rootward tree from `p_state` down to the elementary `Op`s
175    pub fn dfs_lower_states_to_elementary(
176        epoch_shared: &EpochShared,
177        p_state: PState,
178    ) -> Result<(), Error> {
179        let mut lock = epoch_shared.epoch_data.borrow_mut();
180        if let Some(state) = lock.ensemble.stator.states.get(p_state) {
181            if state.lowered_to_elementary {
182                return Ok(())
183            }
184        } else {
185            return Err(Error::InvalidPtr)
186        }
187        lock.ensemble.stator.states[p_state].lowered_to_elementary = true;
188
189        drop(lock);
190        let mut path: Vec<(usize, PState)> = vec![(0, p_state)];
191        loop {
192            let (i, p_state) = path[path.len() - 1];
193            let mut lock = epoch_shared.epoch_data.borrow_mut();
194            let state = &lock.ensemble.stator.states[p_state];
195            let ops = state.op.operands();
196            if ops.is_empty() {
197                // reached a root
198                path.pop().unwrap();
199                if path.is_empty() {
200                    break
201                }
202                path.last_mut().unwrap().0 += 1;
203            } else if i >= ops.len() {
204                // checked all sources, attempt evaluation first, this is crucial in preventing
205                // wasted work in multiple layer lowerings
206                match lock.ensemble.eval_state(p_state) {
207                    Ok(()) => {
208                        path.pop().unwrap();
209                        if path.is_empty() {
210                            break
211                        } else {
212                            continue
213                        }
214                    }
215                    // Continue on to lowering
216                    Err(Error::Unevaluatable) => (),
217                    Err(e) => {
218                        lock.ensemble.stator.states[p_state].err = Some(e.clone());
219                        return Err(e)
220                    }
221                }
222                let needs_lower = match lock.ensemble.stator.states[p_state].op {
223                    Opaque(..) | Argument(_) | Literal(_) | Assert(_) | Copy(_) | StaticGet(..)
224                    | Repeat(_) | StaticLut(..) => false,
225                    // for dynamic LUTs
226                    Mux(_) => false,
227                    Lut([lut, inx]) => {
228                        if let Literal(ref lit) = lock.ensemble.stator.states[lut].op {
229                            let lit = lit.clone();
230                            let out_w = lock.ensemble.stator.states[p_state].nzbw.get();
231                            let inx_w = lock.ensemble.stator.states[inx].nzbw.get();
232                            let no_op = if let Ok(inx_w) = u32::try_from(inx_w) {
233                                if let Some(num_entries) = 1usize.checked_shl(inx_w) {
234                                    (out_w * num_entries) != lit.bw()
235                                } else {
236                                    true
237                                }
238                            } else {
239                                true
240                            };
241                            if no_op {
242                                // TODO should I add the extra arg to `Lut` to fix this edge case?
243                                // or `Unknown` it?
244                                lock.ensemble.stator.states[p_state].op = Opaque(smallvec![], None);
245                                lock.ensemble.state_dec_rc(inx).unwrap();
246                            } else {
247                                lock.ensemble.stator.states[p_state].op =
248                                    StaticLut(ConcatType::from_iter([inx]), lit);
249                            }
250                            lock.ensemble.state_dec_rc(lut).unwrap();
251                        }
252                        // else it is a dynamic LUT that could be lowered on the
253                        // `LNode` side if needed
254                        false
255                    }
256                    Get([bits, inx]) => {
257                        if let Literal(ref lit) = lock.ensemble.stator.states[inx].op {
258                            let lit = lit.clone();
259                            let lit_u = lit.to_usize();
260                            if lit_u >= lock.ensemble.stator.states[bits].nzbw.get() {
261                                // TODO I realize now that no-op `get` specifically is fundamentally
262                                // ill-defined to some extent because it returns `Option<bool>`, it
263                                // must be asserted against, this
264                                // provides the next best thing
265
266                                // or TODO does it just cause `Unknown`?
267                                lock.ensemble.stator.states[p_state].op = Opaque(smallvec![], None);
268                                lock.ensemble.state_dec_rc(bits).unwrap();
269                            } else {
270                                lock.ensemble.stator.states[p_state].op = ConcatFields(
271                                    ConcatFieldsType::from_iter([(bits, lit_u, bw(1))]),
272                                );
273                            }
274                            lock.ensemble.state_dec_rc(inx).unwrap();
275                            false
276                        } else {
277                            true
278                        }
279                    }
280                    Set([bits, inx, bit]) => {
281                        if let Literal(ref lit) = lock.ensemble.stator.states[inx].op {
282                            let lit = lit.clone();
283                            let lit_u = lit.to_usize();
284                            let bits_w = lock.ensemble.stator.states[bits].nzbw.get();
285                            if lit_u >= bits_w {
286                                // no-op
287                                lock.ensemble.stator.states[p_state].op = Copy([bits]);
288                                lock.ensemble.state_dec_rc(bit).unwrap();
289                            } else if let Some(lo_rem) = NonZeroUsize::new(lit_u) {
290                                if let Some(hi_rem) = NonZeroUsize::new(bits_w - 1 - lit_u) {
291                                    lock.ensemble.stator.states[p_state].op =
292                                        ConcatFields(ConcatFieldsType::from_iter([
293                                            (bits, 0, lo_rem),
294                                            (bit, 0, bw(1)),
295                                            (bits, lit_u + 1, hi_rem),
296                                        ]));
297                                } else {
298                                    // setting the last bit
299                                    lock.ensemble.stator.states[p_state].op =
300                                        ConcatFields(ConcatFieldsType::from_iter([
301                                            (bits, 0, lo_rem),
302                                            (bit, 0, bw(1)),
303                                        ]));
304                                }
305                            } else if let Some(rem) = NonZeroUsize::new(bits_w - 1) {
306                                // setting the first bit
307                                lock.ensemble.stator.states[p_state].op =
308                                    ConcatFields(ConcatFieldsType::from_iter([
309                                        (bit, 0, bw(1)),
310                                        (bits, 1, rem),
311                                    ]));
312                            } else {
313                                // setting a single bit
314                                lock.ensemble.stator.states[p_state].op = Copy([bit]);
315                                lock.ensemble.state_dec_rc(bits).unwrap();
316                            }
317                            lock.ensemble.state_dec_rc(inx).unwrap();
318                            false
319                        } else {
320                            true
321                        }
322                    }
323                    _ => true,
324                };
325                drop(lock);
326                let lowering_done = if needs_lower {
327                    // this is used to be able to remove ultimately unused temporaries
328                    let mut temporary = EpochShared::shared_with(epoch_shared);
329                    temporary.set_as_current();
330                    let lowering_done = match Ensemble::lower_op(&temporary, p_state) {
331                        Ok(lowering_done) => lowering_done,
332                        Err(e) => {
333                            temporary.remove_as_current().unwrap();
334                            let mut lock = epoch_shared.epoch_data.borrow_mut();
335                            lock.ensemble.stator.states[p_state].err = Some(e.clone());
336                            return Err(e)
337                        }
338                    };
339                    // shouldn't be adding additional assertions
340                    // TODO make sure there is no meta lowering using assertions assert!(temporary.
341                    // assertions_empty());
342                    let states = temporary.take_states_added();
343                    temporary.remove_as_current().unwrap();
344                    let mut lock = epoch_shared.epoch_data.borrow_mut();
345                    for p_state in states {
346                        lock.ensemble
347                            .remove_state_if_pruning_allowed(p_state)
348                            .unwrap();
349                    }
350                    lowering_done
351                } else {
352                    true
353                };
354                if lowering_done {
355                    path.pop().unwrap();
356                    if path.is_empty() {
357                        break
358                    }
359                } else {
360                    // else do not call `path.pop`, restart the DFS here
361                    path.last_mut().unwrap().0 = 0;
362                }
363            } else {
364                let mut p_next = ops[i];
365                if lock.ensemble.stator.states[p_next].lowered_to_elementary {
366                    // do not visit
367                    path.last_mut().unwrap().0 += 1;
368                } else {
369                    while let Copy([a]) = lock.ensemble.stator.states[p_next].op {
370                        // special optimization case: forward Copies
371                        lock.ensemble.stator.states[p_state].op.operands_mut()[i] = a;
372                        lock.ensemble.stator.states[a].inc_rc();
373                        lock.ensemble.state_dec_rc(p_next).unwrap();
374                        p_next = a;
375                    }
376                    lock.ensemble.stator.states[p_next].lowered_to_elementary = true;
377                    path.push((0, p_next));
378                }
379                drop(lock);
380            }
381        }
382
383        Ok(())
384    }
385}