esvc_core/
workcache.rs

1use crate::{Event, Graph, GraphError, Hash, IncludeSpec};
2use core::fmt;
3use esvc_traits::Engine;
4use rayon::prelude::*;
5use std::collections::{BTreeMap, BTreeSet};
6
7#[cfg(feature = "tracing")]
8use tracing::{event, Level};
9
10// NOTE: the elements of this *must* be public, because the user needs to be
11// able to deconstruct it if they want to modify the engine
12// (e.g. to register a new command at runtime)
13pub struct WorkCache<'a, En: Engine> {
14    pub engine: &'a En,
15    pub sts: BTreeMap<BTreeSet<Hash>, <En as Engine>::Dat>,
16}
17
18impl<'a, En: Engine> core::clone::Clone for WorkCache<'a, En> {
19    fn clone(&self) -> Self {
20        Self {
21            engine: self.engine,
22            sts: self.sts.clone(),
23        }
24    }
25
26    fn clone_from(&mut self, other: &Self) {
27        self.engine = other.engine;
28        self.sts.clone_from(&other.sts);
29    }
30}
31
32impl<En: Engine> fmt::Debug for WorkCache<'_, En> {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        f.debug_struct("WorkCache")
35            .field("sts", &self.sts)
36            .finish_non_exhaustive()
37    }
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum WorkCacheError<EE> {
42    #[error("engine couldn't find command with ID {0}")]
43    CommandNotFound(u32),
44
45    #[error(transparent)]
46    Graph(#[from] GraphError),
47
48    #[error(transparent)]
49    Engine(EE),
50}
51
52pub type RunResult<'a, En> =
53    Result<(&'a <En as Engine>::Dat, BTreeSet<Hash>), WorkCacheError<<En as Engine>::Error>>;
54
55impl<'a, En: Engine> WorkCache<'a, En> {
56    pub fn new(engine: &'a En, init_data: En::Dat) -> Self {
57        let mut sts = BTreeMap::new();
58        sts.insert(BTreeSet::new(), init_data);
59        Self { engine, sts }
60    }
61
62    /// this returns an error if `tt` is not present in `sts`.
63    pub fn run_recursively(
64        &mut self,
65        graph: &Graph<En::Arg>,
66        mut tt: BTreeSet<Hash>,
67        main_evid: Hash,
68        incl: IncludeSpec,
69    ) -> RunResult<'_, En> {
70        // heap of necessary dependencies
71        let mut deps = vec![main_evid];
72
73        let mut data: En::Dat = (*self.sts.get(&tt).ok_or(GraphError::DatasetNotFound)?).clone();
74
75        while let Some(evid) = deps.pop() {
76            if tt.contains(&evid) {
77                // nothing to do
78                continue;
79            } else if evid == main_evid && !deps.is_empty() {
80                return Err(GraphError::DependencyCircuit(main_evid).into());
81            }
82
83            let evwd = graph
84                .events
85                .get(&evid)
86                .ok_or(GraphError::DependencyNotFound(evid))?;
87            let mut necessary_deps = evwd.deps.difference(&tt);
88            if let Some(&x) = necessary_deps.next() {
89                deps.push(evid);
90                // TODO: check for dependency cycles
91                deps.push(x);
92                deps.extend(necessary_deps.copied());
93            } else {
94                if evid == main_evid && incl != IncludeSpec::IncludeAll {
95                    // we want to omit the final dep
96                    break;
97                }
98
99                // run the item, all dependencies are satisfied
100                use std::collections::btree_map::Entry;
101                // TODO: check if `data...clone()` is a bottleneck.
102                match self.sts.entry({
103                    let mut tmp = tt.clone();
104                    tmp.insert(evid);
105                    tmp
106                }) {
107                    Entry::Occupied(o) => {
108                        // reuse cached entry
109                        data = o.get().clone();
110                    }
111                    Entry::Vacant(v) => {
112                        // create cache entry
113                        data = self
114                            .engine
115                            .run_event_bare(evwd.cmd, &evwd.arg, &data)
116                            .map_err(WorkCacheError::Engine)?;
117                        v.insert(data.clone());
118                    }
119                }
120                tt.insert(evid);
121            }
122        }
123
124        let res = self.sts.get(&tt).unwrap();
125        Ok((res, tt))
126    }
127
128    pub fn run_foreach_recursively(
129        &mut self,
130        graph: &Graph<En::Arg>,
131        evids: BTreeMap<Hash, IncludeSpec>,
132    ) -> RunResult<'_, En> {
133        let tt = evids
134            .into_iter()
135            .try_fold(BTreeSet::new(), |tt, (i, idspec)| {
136                self.run_recursively(graph, tt, i, idspec)
137                    .map(|(_, new_tt)| new_tt)
138            })?;
139        let res = self.sts.get(&tt).unwrap();
140        Ok((res, tt))
141    }
142
143    /// NOTE: this ignores the contents of `ev.deps`
144    #[cfg_attr(feature = "tracing", tracing::instrument)]
145    pub fn shelve_event(
146        &mut self,
147        graph: &mut Graph<En::Arg>,
148        mut seed_deps: BTreeSet<Hash>,
149        ev: Event<En::Arg>,
150    ) -> Result<Option<Hash>, WorkCacheError<En::Error>> {
151        // check `ev` for independence
152        #[derive(Clone, Copy, PartialEq)]
153        enum DepSt {
154            Use,
155            Deny,
156        }
157        let mut cur_deps = BTreeMap::new();
158        let engine = self.engine;
159
160        while !seed_deps.is_empty() {
161            let mut new_seed_deps = BTreeSet::new();
162            // calculate cur state
163            let (base_st, _) = self.run_foreach_recursively(
164                graph,
165                seed_deps
166                    .iter()
167                    .chain(
168                        cur_deps
169                            .iter()
170                            .filter(|&(_, &s)| s == DepSt::Use)
171                            .map(|(h, _)| h),
172                    )
173                    .filter(|i| cur_deps.get(i) != Some(&DepSt::Deny))
174                    .map(|&i| (i, IncludeSpec::IncludeAll))
175                    .collect(),
176            )?;
177            let cur_st = engine
178                .run_event_bare(ev.cmd, &ev.arg, base_st)
179                .map_err(WorkCacheError::Engine)?;
180
181            #[cfg(feature = "tracing")]
182            event!(
183                Level::TRACE,
184                "constructed state {:?} +cur> {:?}",
185                base_st,
186                cur_st
187            );
188
189            if cur_deps.is_empty() && base_st == &cur_st {
190                // this is a no-op event, we can't handle it anyways.
191                return Ok(None);
192            }
193
194            for &conc_evid in &seed_deps {
195                if cur_deps.contains_key(&conc_evid) {
196                    continue;
197                }
198                // calculate base state = cur - conc
199                let (base_st, _) = self.run_foreach_recursively(
200                    graph,
201                    seed_deps
202                        .iter()
203                        .chain(
204                            cur_deps
205                                .iter()
206                                .filter(|&(_, s)| s == &DepSt::Use)
207                                .map(|(h, _)| h),
208                        )
209                        .map(|&i| {
210                            (
211                                i,
212                                if i == conc_evid {
213                                    IncludeSpec::IncludeOnlyDeps
214                                } else {
215                                    IncludeSpec::IncludeAll
216                                },
217                            )
218                        })
219                        .collect(),
220                )?;
221                let conc_ev = graph.events.get(&conc_evid).unwrap();
222                #[allow(clippy::if_same_then_else)]
223                let is_indep = if &cur_st == base_st {
224                    // this is a revert
225                    #[cfg(feature = "tracing")]
226                    event!(Level::TRACE, "{} is revert", conc_evid);
227                    false
228                } else if ev.cmd == conc_ev.cmd && ev.arg == conc_ev.arg {
229                    // necessary for non-idempotent events (e.g. s/0/0000/g)
230                    // base_st + conc = cur_st, so we detect if conc has an effect
231                    // even if it was already applied (case above)
232                    #[cfg(feature = "tracing")]
233                    event!(Level::TRACE, "{} is non-idempotent", conc_evid);
234                    false
235                } else {
236                    engine
237                        .run_event_bare(ev.cmd, &ev.arg, base_st)
238                        .and_then(|next_st| {
239                            self.engine
240                                .run_event_bare(conc_ev.cmd, &conc_ev.arg, &next_st)
241                        })
242                        .map_err(WorkCacheError::Engine)?
243                        == cur_st
244                };
245                #[cfg(feature = "tracing")]
246                event!(
247                    Level::TRACE,
248                    "{} is {}dependent",
249                    conc_evid,
250                    if is_indep { "in" } else { "" }
251                );
252                if is_indep {
253                    // independent -> move backward
254                    new_seed_deps.extend(conc_ev.deps.iter().copied());
255                } else {
256                    // not independent -> move forward
257                    // make sure that we don't overwrite `deny` entries
258                    cur_deps.entry(conc_evid).or_insert(DepSt::Use);
259                    cur_deps.extend(conc_ev.deps.iter().map(|&dep| (dep, DepSt::Deny)));
260                }
261            }
262            seed_deps = new_seed_deps;
263        }
264
265        // mangle deps
266        let ev = Event {
267            cmd: ev.cmd,
268            arg: ev.arg,
269            deps: cur_deps
270                .into_iter()
271                .flat_map(|(dep, st)| if st == DepSt::Use { Some(dep) } else { None })
272                .collect(),
273        };
274
275        // register event
276        let (collinfo, evhash) = graph.ensure_event(ev);
277        if let Some(ev) = collinfo {
278            return Err(GraphError::HashCollision(evhash, format!("{:?}", ev)).into());
279        }
280
281        Ok(Some(evhash))
282    }
283
284    pub fn check_if_mergable(
285        &mut self,
286        graph: &Graph<En::Arg>,
287        sts: BTreeSet<Hash>,
288    ) -> Result<Option<Self>, WorkCacheError<En::Error>> {
289        // we run this recursively (and non-parallel), which is a bit unfortunate,
290        // but we get the benefit that we can share the cache...
291        let bases = sts
292            .iter()
293            .map(|&h| {
294                self.run_recursively(graph, BTreeSet::new(), h, IncludeSpec::IncludeAll)
295                    .map(|r| (h, r.1))
296            })
297            .collect::<Result<BTreeMap<_, _>, _>>()?;
298
299        // calculate 2d matrix
300        let ret = bases
301            .iter()
302            .enumerate()
303            .flat_map(|(ni, (_, i))| {
304                sts.iter()
305                    .enumerate()
306                    .filter(move |(nj, _)| ni != *nj)
307                    .map(|(_, &j)| (i.clone(), j))
308            })
309            .collect::<Vec<_>>()
310            .into_par_iter()
311            // source: https://sts10.github.io/2019/06/06/is-all-equal-function.html
312            .try_fold(|| (true, None), {
313                |acc: (bool, Option<_>), (i, j)| {
314                    if !acc.0 {
315                        return Ok((false, None));
316                    }
317                    let mut this = self.clone();
318                    this.run_recursively(graph, i, j, IncludeSpec::IncludeAll)?;
319                    let elem = this.sts;
320                    Ok(if acc.1.map(|prev| prev == elem).unwrap_or(true) {
321                        (true, Some(elem))
322                    } else {
323                        (false, None)
324                    })
325                }
326            })
327            .collect::<Result<Vec<_>, WorkCacheError<_>>>()?
328            .into_iter()
329            .flat_map(|(uacc, x)| x.map(|y| (uacc, y)))
330            .fold((true, None), {
331                |acc, (uacc, elem)| {
332                    let is_mrgb = uacc && acc.0 && acc.1.map(|prev| prev == elem).unwrap_or(true);
333                    (is_mrgb, if is_mrgb { Some(elem) } else { None })
334                }
335            });
336        Ok(ret.1.map(|sts| Self {
337            engine: self.engine,
338            sts,
339        }))
340    }
341}
342
343// this is somewhat equivalent to the fuzzer code,
344// and is used to test known edge cases
345#[cfg(test)]
346mod tests {
347    use super::*;
348    #[derive(Clone, Debug, PartialEq, serde::Serialize)]
349    struct SearEvent<'a>(&'a str, &'a str);
350
351    impl<'a> From<SearEvent<'a>> for Event<SearEvent<'a>> {
352        fn from(ev: SearEvent<'a>) -> Self {
353            Event {
354                cmd: 0,
355                arg: ev,
356                deps: Default::default(),
357            }
358        }
359    }
360
361    struct SearEngine;
362
363    impl Engine for SearEngine {
364        type Error = ();
365        type Arg = SearEvent<'static>;
366        type Dat = String;
367
368        fn run_event_bare(&self, cmd: u32, arg: &SearEvent, dat: &String) -> Result<String, ()> {
369            assert_eq!(cmd, 0);
370            Ok(dat.replace(&arg.0, &arg.1))
371        }
372    }
373
374    fn assert_no_reorder_inner(start: &str, sears: Vec<SearEvent<'static>>) {
375        let expected = sears
376            .iter()
377            .fold(start.to_string(), |acc, item| acc.replace(&item.0, &item.1));
378        let e = SearEngine;
379        let mut g = Graph::default();
380        let mut w = WorkCache::new(&e, start.to_string());
381        let mut xs = BTreeSet::new();
382        for i in sears {
383            if let Some(h) = w
384                .shelve_event(&mut g, xs.clone(), i.into())
385                .expect("unable to shelve event")
386            {
387                xs.insert(h);
388            }
389        }
390
391        let minx: BTreeSet<_> = g
392            .fold_state(xs.iter().map(|&y| (y, false)).collect(), false)
393            .unwrap()
394            .into_iter()
395            .map(|x| x.0)
396            .collect();
397
398        let evs: BTreeMap<_, _> = minx
399            .iter()
400            .map(|&i| (i, crate::IncludeSpec::IncludeAll))
401            .collect();
402
403        let (got, tt) = w.run_foreach_recursively(&g, evs.clone()).unwrap();
404        assert_eq!(xs, tt);
405        assert_eq!(*got, expected);
406    }
407
408    fn assert_no_reorder(start: &str, sears: Vec<SearEvent<'static>>) {
409        #[cfg(feature = "tracing")]
410        tracing::subscriber::with_default(
411            tracing_subscriber::fmt()
412                .with_max_level(tracing::Level::TRACE)
413                .with_writer(std::io::stderr)
414                .finish(),
415            || {
416                assert_no_reorder_inner(start, sears);
417            },
418        );
419        #[cfg(not(feature = "tracing"))]
420        assert_no_reorder_inner(start, sears);
421    }
422
423    #[test]
424    fn equal_but_non_idempotent() {
425        assert_no_reorder(
426            "x",
427            vec![
428                SearEvent("x", "xx"),
429                SearEvent("x", "xx"),
430                SearEvent("x", "y"),
431            ],
432        );
433    }
434
435    #[test]
436    fn indirect_dep() {
437        assert_no_reorder(
438            "Hi, what's up??",
439            vec![
440                SearEvent("Hi", "Hello UwU"),
441                SearEvent("UwU", "World"),
442                SearEvent("what", "wow"),
443                SearEvent("s up", "sup"),
444                SearEvent("??", "!"),
445                SearEvent("sup!", "soap?"),
446                SearEvent("p", "np"),
447            ],
448        );
449    }
450
451    #[test]
452    fn revert_then() {
453        assert_no_reorder(
454            "a",
455            vec![
456                SearEvent("a", "xaa"),
457                SearEvent("xa", ""),
458                SearEvent("a", "bbbbb"),
459            ],
460        );
461    }
462}