Skip to main content

omena_incremental/
lib.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use salsa::Setter;
4use serde::Serialize;
5
6#[cfg(test)]
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9#[cfg(test)]
10static SALSA_DIGEST_QUERY_RUNS: AtomicUsize = AtomicUsize::new(0);
11#[cfg(test)]
12static SALSA_DEPENDENCY_QUERY_RUNS: AtomicUsize = AtomicUsize::new(0);
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
15#[serde(rename_all = "camelCase")]
16pub struct OmenaIncrementalBoundarySummaryV0 {
17    pub schema_version: &'static str,
18    pub product: &'static str,
19    pub engine_name: &'static str,
20    pub invalidation_model: &'static str,
21    pub query_model: &'static str,
22    pub node_identity: Vec<&'static str>,
23    pub dirty_reasons: Vec<&'static str>,
24    pub ready_surfaces: Vec<&'static str>,
25}
26
27pub const DEFAULT_INCREMENTAL_CANCELLATION_LIMIT: usize = 128;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize)]
30#[serde(rename_all = "camelCase")]
31pub struct IncrementalRevisionV0 {
32    pub value: u64,
33}
34
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct IncrementalGraphInputV0 {
37    pub revision: IncrementalRevisionV0,
38    pub nodes: Vec<IncrementalNodeInputV0>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct IncrementalNodeInputV0 {
43    pub id: String,
44    pub digest: String,
45    pub dependency_ids: Vec<String>,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
49#[serde(rename_all = "camelCase")]
50pub struct IncrementalSnapshotV0 {
51    pub schema_version: &'static str,
52    pub product: &'static str,
53    pub revision: IncrementalRevisionV0,
54    pub nodes: Vec<IncrementalSnapshotNodeV0>,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
58#[serde(rename_all = "camelCase")]
59pub struct IncrementalSnapshotNodeV0 {
60    pub id: String,
61    pub digest: String,
62    pub dependency_ids: Vec<String>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
66#[serde(rename_all = "camelCase")]
67pub struct IncrementalComputationPlanV0 {
68    pub schema_version: &'static str,
69    pub product: &'static str,
70    pub revision: IncrementalRevisionV0,
71    pub node_count: usize,
72    pub dirty_node_count: usize,
73    pub changed_input_count: usize,
74    pub new_node_count: usize,
75    pub removed_node_count: usize,
76    pub dependency_dirty_count: usize,
77    pub nodes: Vec<IncrementalComputationNodeV0>,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
81#[serde(rename_all = "camelCase")]
82pub struct IncrementalComputationNodeV0 {
83    pub id: String,
84    pub digest: String,
85    pub dependency_ids: Vec<String>,
86    pub dirty: bool,
87    pub reasons: Vec<&'static str>,
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
91#[serde(rename_all = "camelCase")]
92pub struct IncrementalCancellationSnapshotV0 {
93    pub schema_version: &'static str,
94    pub product: &'static str,
95    pub cancelled_request_count: usize,
96    pub cancelled_request_ids: Vec<String>,
97}
98
99#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
100#[serde(rename_all = "camelCase")]
101pub struct IncrementalDatabaseUpdateV0 {
102    pub schema_version: &'static str,
103    pub product: &'static str,
104    pub incremental_plan: IncrementalComputationPlanV0,
105    pub next_snapshot: IncrementalSnapshotV0,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub struct IncrementalCancellationRegistryV0 {
110    limit: usize,
111    cancelled_request_ids: BTreeSet<String>,
112}
113
114#[salsa::input(debug)]
115pub struct SalsaIncrementalNodeInputV0 {
116    #[returns(ref)]
117    id: String,
118    #[returns(ref)]
119    digest: String,
120    #[returns(ref)]
121    dependency_ids: Vec<String>,
122}
123
124#[derive(Default)]
125pub struct OmenaIncrementalDatabaseV0 {
126    db: salsa::DatabaseImpl,
127    node_inputs_by_id: BTreeMap<String, SalsaIncrementalNodeInputV0>,
128    current_snapshot: Option<IncrementalSnapshotV0>,
129}
130
131pub fn summarize_omena_incremental_boundary() -> OmenaIncrementalBoundarySummaryV0 {
132    OmenaIncrementalBoundarySummaryV0 {
133        schema_version: "0",
134        product: "omena-incremental.boundary",
135        engine_name: "omena-incremental",
136        invalidation_model: "stableNodeId+inputDigest+dependencyPropagation",
137        query_model: "salsaInput+trackedQueryFieldGranularReuse",
138        node_identity: vec!["id", "digest", "dependencyIds"],
139        dirty_reasons: vec![
140            "newNode",
141            "inputDigestChanged",
142            "dependencySetChanged",
143            "dependencyDirty",
144        ],
145        ready_surfaces: vec![
146            "incrementalGraphInput",
147            "incrementalSnapshot",
148            "incrementalComputationPlan",
149            "incrementalCancellationRegistry",
150            "salsaPersistentDatabase",
151            "salsaTrackedNodeSnapshotQuery",
152            "salsaFieldGranularReuse",
153            "salsaPlanAndSnapshotUpdate",
154        ],
155    }
156}
157
158pub fn snapshot_from_graph_input(input: &IncrementalGraphInputV0) -> IncrementalSnapshotV0 {
159    IncrementalSnapshotV0 {
160        schema_version: "0",
161        product: "omena-incremental.snapshot",
162        revision: input.revision,
163        nodes: normalized_snapshot_nodes(input),
164    }
165}
166
167pub fn plan_incremental_computation(
168    input: &IncrementalGraphInputV0,
169    previous: Option<&IncrementalSnapshotV0>,
170) -> IncrementalComputationPlanV0 {
171    let normalized_nodes = normalized_snapshot_nodes(input);
172    let previous_by_id = previous
173        .map(|snapshot| {
174            snapshot
175                .nodes
176                .iter()
177                .map(|node| (node.id.as_str(), node))
178                .collect::<BTreeMap<_, _>>()
179        })
180        .unwrap_or_default();
181    let current_ids = normalized_nodes
182        .iter()
183        .map(|node| node.id.as_str())
184        .collect::<BTreeSet<_>>();
185    let removed_node_count = previous_by_id
186        .keys()
187        .filter(|id| !current_ids.contains(**id))
188        .count();
189    let mut dirty_ids = BTreeSet::<String>::new();
190    let mut nodes = normalized_nodes
191        .into_iter()
192        .map(|node| {
193            let mut reasons = Vec::new();
194            match previous_by_id.get(node.id.as_str()) {
195                None => reasons.push("newNode"),
196                Some(previous_node) => {
197                    if previous_node.digest != node.digest {
198                        reasons.push("inputDigestChanged");
199                    }
200                    if previous_node.dependency_ids != node.dependency_ids {
201                        reasons.push("dependencySetChanged");
202                    }
203                }
204            }
205            if !reasons.is_empty() {
206                dirty_ids.insert(node.id.clone());
207            }
208
209            IncrementalComputationNodeV0 {
210                id: node.id,
211                digest: node.digest,
212                dependency_ids: node.dependency_ids,
213                dirty: !reasons.is_empty(),
214                reasons,
215            }
216        })
217        .collect::<Vec<_>>();
218
219    propagate_dependency_dirty(&mut nodes, &mut dirty_ids);
220
221    IncrementalComputationPlanV0 {
222        schema_version: "0",
223        product: "omena-incremental.computation-plan",
224        revision: input.revision,
225        node_count: nodes.len(),
226        dirty_node_count: nodes.iter().filter(|node| node.dirty).count(),
227        changed_input_count: nodes
228            .iter()
229            .filter(|node| node.reasons.contains(&"inputDigestChanged"))
230            .count(),
231        new_node_count: nodes
232            .iter()
233            .filter(|node| node.reasons.contains(&"newNode"))
234            .count(),
235        removed_node_count,
236        dependency_dirty_count: nodes
237            .iter()
238            .filter(|node| node.reasons.contains(&"dependencyDirty"))
239            .count(),
240        nodes,
241    }
242}
243
244#[salsa::tracked(returns(clone))]
245pub fn summarize_salsa_incremental_node_snapshot(
246    db: &dyn salsa::Database,
247    node: SalsaIncrementalNodeInputV0,
248) -> IncrementalSnapshotNodeV0 {
249    IncrementalSnapshotNodeV0 {
250        id: node.id(db).clone(),
251        digest: node.digest(db).clone(),
252        dependency_ids: normalized_ids(node.dependency_ids(db)),
253    }
254}
255
256#[salsa::tracked(returns(clone))]
257pub fn read_salsa_incremental_node_digest(
258    db: &dyn salsa::Database,
259    node: SalsaIncrementalNodeInputV0,
260) -> String {
261    #[cfg(test)]
262    SALSA_DIGEST_QUERY_RUNS.fetch_add(1, Ordering::Relaxed);
263
264    node.digest(db).clone()
265}
266
267#[salsa::tracked(returns(clone))]
268pub fn read_salsa_incremental_node_dependency_ids(
269    db: &dyn salsa::Database,
270    node: SalsaIncrementalNodeInputV0,
271) -> Vec<String> {
272    #[cfg(test)]
273    SALSA_DEPENDENCY_QUERY_RUNS.fetch_add(1, Ordering::Relaxed);
274
275    normalized_ids(node.dependency_ids(db))
276}
277
278fn normalized_snapshot_nodes(input: &IncrementalGraphInputV0) -> Vec<IncrementalSnapshotNodeV0> {
279    let mut nodes = input
280        .nodes
281        .iter()
282        .map(|node| IncrementalSnapshotNodeV0 {
283            id: node.id.clone(),
284            digest: node.digest.clone(),
285            dependency_ids: normalized_ids(&node.dependency_ids),
286        })
287        .collect::<Vec<_>>();
288    nodes.sort_by(|left, right| left.id.cmp(&right.id));
289    nodes
290}
291
292fn normalized_ids(ids: &[String]) -> Vec<String> {
293    ids.iter()
294        .cloned()
295        .collect::<BTreeSet<_>>()
296        .into_iter()
297        .collect()
298}
299
300fn propagate_dependency_dirty(
301    nodes: &mut [IncrementalComputationNodeV0],
302    dirty_ids: &mut BTreeSet<String>,
303) {
304    loop {
305        let mut changed = false;
306        for node in nodes.iter_mut() {
307            if node.dirty {
308                continue;
309            }
310            if node
311                .dependency_ids
312                .iter()
313                .any(|dependency_id| dirty_ids.contains(dependency_id))
314            {
315                node.dirty = true;
316                node.reasons.push("dependencyDirty");
317                dirty_ids.insert(node.id.clone());
318                changed = true;
319            }
320        }
321
322        if !changed {
323            break;
324        }
325    }
326}
327
328impl OmenaIncrementalDatabaseV0 {
329    pub fn salsa_database(&self) -> &salsa::DatabaseImpl {
330        &self.db
331    }
332
333    pub fn node_input(&self, id: &str) -> Option<SalsaIncrementalNodeInputV0> {
334        self.node_inputs_by_id.get(id).copied()
335    }
336
337    pub fn current_snapshot(&self) -> Option<&IncrementalSnapshotV0> {
338        self.current_snapshot.as_ref()
339    }
340
341    pub fn plan_and_upsert_graph_input(
342        &mut self,
343        input: &IncrementalGraphInputV0,
344    ) -> IncrementalDatabaseUpdateV0 {
345        let incremental_plan = plan_incremental_computation(input, self.current_snapshot.as_ref());
346        let next_snapshot = self.upsert_graph_input(input);
347        self.current_snapshot = Some(next_snapshot.clone());
348
349        IncrementalDatabaseUpdateV0 {
350            schema_version: "0",
351            product: "omena-incremental.salsa-database-update",
352            incremental_plan,
353            next_snapshot,
354        }
355    }
356
357    pub fn upsert_graph_input(&mut self, input: &IncrementalGraphInputV0) -> IncrementalSnapshotV0 {
358        let normalized_nodes = normalized_snapshot_nodes(input);
359        let current_ids = normalized_nodes
360            .iter()
361            .map(|node| node.id.as_str())
362            .collect::<BTreeSet<_>>();
363        self.node_inputs_by_id
364            .retain(|id, _node| current_ids.contains(id.as_str()));
365
366        for node in &normalized_nodes {
367            self.upsert_node_input(node);
368        }
369
370        let nodes = self
371            .node_inputs_by_id
372            .values()
373            .copied()
374            .map(|node| summarize_salsa_incremental_node_snapshot(&self.db, node))
375            .collect::<Vec<_>>();
376
377        IncrementalSnapshotV0 {
378            schema_version: "0",
379            product: "omena-incremental.salsa-snapshot",
380            revision: input.revision,
381            nodes,
382        }
383    }
384
385    fn upsert_node_input(&mut self, node: &IncrementalSnapshotNodeV0) {
386        let Some(node_input) = self.node_inputs_by_id.get(node.id.as_str()).copied() else {
387            let node_input = SalsaIncrementalNodeInputV0::new(
388                &self.db,
389                node.id.clone(),
390                node.digest.clone(),
391                node.dependency_ids.clone(),
392            );
393            self.node_inputs_by_id.insert(node.id.clone(), node_input);
394            return;
395        };
396
397        if node_input.digest(&self.db).as_str() != node.digest.as_str() {
398            node_input.set_digest(&mut self.db).to(node.digest.clone());
399        }
400        if node_input.dependency_ids(&self.db).as_slice() != node.dependency_ids.as_slice() {
401            node_input
402                .set_dependency_ids(&mut self.db)
403                .to(node.dependency_ids.clone());
404        }
405    }
406}
407
408impl Default for IncrementalCancellationRegistryV0 {
409    fn default() -> Self {
410        Self::with_limit(DEFAULT_INCREMENTAL_CANCELLATION_LIMIT)
411    }
412}
413
414impl IncrementalCancellationRegistryV0 {
415    pub fn with_limit(limit: usize) -> Self {
416        Self {
417            limit: limit.max(1),
418            cancelled_request_ids: BTreeSet::new(),
419        }
420    }
421
422    pub fn cancel(&mut self, request_id: impl Into<String>) {
423        if self.cancelled_request_ids.len() >= self.limit {
424            self.cancelled_request_ids.clear();
425        }
426        self.cancelled_request_ids.insert(request_id.into());
427    }
428
429    pub fn take_cancelled(&mut self, request_id: &str) -> bool {
430        self.cancelled_request_ids.remove(request_id)
431    }
432
433    pub fn len(&self) -> usize {
434        self.cancelled_request_ids.len()
435    }
436
437    pub fn is_empty(&self) -> bool {
438        self.cancelled_request_ids.is_empty()
439    }
440
441    pub fn snapshot(&self) -> IncrementalCancellationSnapshotV0 {
442        IncrementalCancellationSnapshotV0 {
443            schema_version: "0",
444            product: "omena-incremental.cancellation-registry",
445            cancelled_request_count: self.cancelled_request_ids.len(),
446            cancelled_request_ids: self.cancelled_request_ids.iter().cloned().collect(),
447        }
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::{
454        IncrementalCancellationRegistryV0, IncrementalGraphInputV0, IncrementalNodeInputV0,
455        IncrementalRevisionV0, OmenaIncrementalDatabaseV0, SALSA_DEPENDENCY_QUERY_RUNS,
456        SALSA_DIGEST_QUERY_RUNS, plan_incremental_computation,
457        read_salsa_incremental_node_dependency_ids, read_salsa_incremental_node_digest,
458        snapshot_from_graph_input, summarize_omena_incremental_boundary,
459    };
460    use std::sync::atomic::Ordering;
461
462    #[test]
463    fn summarizes_incremental_boundary() {
464        let summary = summarize_omena_incremental_boundary();
465
466        assert_eq!(summary.product, "omena-incremental.boundary");
467        assert_eq!(
468            summary.query_model,
469            "salsaInput+trackedQueryFieldGranularReuse"
470        );
471        assert!(summary.dirty_reasons.contains(&"dependencyDirty"));
472        assert!(
473            summary
474                .ready_surfaces
475                .contains(&"incrementalCancellationRegistry")
476        );
477        assert!(
478            summary
479                .ready_surfaces
480                .contains(&"salsaTrackedNodeSnapshotQuery")
481        );
482    }
483
484    #[test]
485    fn first_plan_marks_all_nodes_dirty() {
486        let input = sample_input("a:v1", "b:v1", 1);
487        let plan = plan_incremental_computation(&input, None);
488
489        assert_eq!(plan.product, "omena-incremental.computation-plan");
490        assert_eq!(plan.node_count, 2);
491        assert_eq!(plan.dirty_node_count, 2);
492        assert_eq!(plan.new_node_count, 2);
493    }
494
495    #[test]
496    fn unchanged_second_plan_marks_nodes_clean() {
497        let input = sample_input("a:v1", "b:v1", 1);
498        let snapshot = snapshot_from_graph_input(&input);
499        let next_input = sample_input("a:v1", "b:v1", 2);
500        let plan = plan_incremental_computation(&next_input, Some(&snapshot));
501
502        assert_eq!(plan.dirty_node_count, 0);
503        assert_eq!(plan.changed_input_count, 0);
504    }
505
506    #[test]
507    fn changed_dependency_marks_dependent_dirty() {
508        let input = sample_input("a:v1", "b:v1", 1);
509        let snapshot = snapshot_from_graph_input(&input);
510        let next_input = sample_input("a:v2", "b:v1", 2);
511        let plan = plan_incremental_computation(&next_input, Some(&snapshot));
512
513        assert_eq!(plan.changed_input_count, 1);
514        assert_eq!(plan.dependency_dirty_count, 1);
515        assert_eq!(node_reasons(&plan, "a"), vec!["inputDigestChanged"]);
516        assert_eq!(node_reasons(&plan, "b"), vec!["dependencyDirty"]);
517    }
518
519    #[test]
520    fn salsa_database_reuses_digest_query_when_only_dependencies_change() {
521        SALSA_DIGEST_QUERY_RUNS.store(0, Ordering::Relaxed);
522        SALSA_DEPENDENCY_QUERY_RUNS.store(0, Ordering::Relaxed);
523
524        let mut db = OmenaIncrementalDatabaseV0::default();
525        let input = IncrementalGraphInputV0 {
526            revision: IncrementalRevisionV0 { value: 1 },
527            nodes: vec![IncrementalNodeInputV0 {
528                id: "a".to_string(),
529                digest: "a:v1".to_string(),
530                dependency_ids: Vec::new(),
531            }],
532        };
533        let snapshot = db.upsert_graph_input(&input);
534        assert_eq!(snapshot.product, "omena-incremental.salsa-snapshot");
535
536        let Some(node) = db.node_input("a") else {
537            return;
538        };
539        assert_eq!(
540            read_salsa_incremental_node_digest(db.salsa_database(), node),
541            "a:v1"
542        );
543        assert_eq!(
544            read_salsa_incremental_node_dependency_ids(db.salsa_database(), node),
545            Vec::<String>::new()
546        );
547        assert_eq!(SALSA_DIGEST_QUERY_RUNS.load(Ordering::Relaxed), 1);
548        assert_eq!(SALSA_DEPENDENCY_QUERY_RUNS.load(Ordering::Relaxed), 1);
549
550        let next_input = IncrementalGraphInputV0 {
551            revision: IncrementalRevisionV0 { value: 2 },
552            nodes: vec![IncrementalNodeInputV0 {
553                id: "a".to_string(),
554                digest: "a:v1".to_string(),
555                dependency_ids: vec!["root".to_string()],
556            }],
557        };
558        db.upsert_graph_input(&next_input);
559
560        let Some(node) = db.node_input("a") else {
561            return;
562        };
563        assert_eq!(
564            read_salsa_incremental_node_digest(db.salsa_database(), node),
565            "a:v1"
566        );
567        assert_eq!(
568            read_salsa_incremental_node_dependency_ids(db.salsa_database(), node),
569            vec!["root".to_string()]
570        );
571        assert_eq!(SALSA_DIGEST_QUERY_RUNS.load(Ordering::Relaxed), 1);
572        assert_eq!(SALSA_DEPENDENCY_QUERY_RUNS.load(Ordering::Relaxed), 2);
573    }
574
575    #[test]
576    fn salsa_database_update_owns_plan_and_snapshot_progression() {
577        let mut db = OmenaIncrementalDatabaseV0::default();
578        let input = sample_input("a:v1", "b:v1", 1);
579        let first = db.plan_and_upsert_graph_input(&input);
580
581        assert_eq!(first.product, "omena-incremental.salsa-database-update");
582        assert_eq!(first.incremental_plan.dirty_node_count, 2);
583        assert_eq!(
584            first.next_snapshot.product,
585            "omena-incremental.salsa-snapshot"
586        );
587        assert!(db.current_snapshot().is_some());
588
589        let unchanged = db.plan_and_upsert_graph_input(&sample_input("a:v1", "b:v1", 2));
590        assert_eq!(unchanged.incremental_plan.dirty_node_count, 0);
591
592        let changed = db.plan_and_upsert_graph_input(&sample_input("a:v2", "b:v1", 3));
593        assert_eq!(changed.incremental_plan.changed_input_count, 1);
594        assert_eq!(changed.incremental_plan.dependency_dirty_count, 1);
595    }
596
597    #[test]
598    fn cancellation_registry_tracks_and_consumes_request_ids() {
599        let mut registry = IncrementalCancellationRegistryV0::with_limit(4);
600
601        registry.cancel("s:hover-1");
602
603        assert_eq!(registry.len(), 1);
604        assert!(registry.take_cancelled("s:hover-1"));
605        assert!(!registry.take_cancelled("s:hover-1"));
606        assert!(registry.is_empty());
607    }
608
609    #[test]
610    fn cancellation_registry_bounds_stale_cancelled_requests() {
611        let mut registry = IncrementalCancellationRegistryV0::with_limit(2);
612
613        registry.cancel("n:1");
614        registry.cancel("n:2");
615        registry.cancel("n:3");
616
617        let snapshot = registry.snapshot();
618        assert_eq!(snapshot.product, "omena-incremental.cancellation-registry");
619        assert_eq!(snapshot.cancelled_request_ids, vec!["n:3"]);
620    }
621
622    fn sample_input(a_digest: &str, b_digest: &str, revision: u64) -> IncrementalGraphInputV0 {
623        IncrementalGraphInputV0 {
624            revision: IncrementalRevisionV0 { value: revision },
625            nodes: vec![
626                IncrementalNodeInputV0 {
627                    id: "b".to_string(),
628                    digest: b_digest.to_string(),
629                    dependency_ids: vec!["a".to_string()],
630                },
631                IncrementalNodeInputV0 {
632                    id: "a".to_string(),
633                    digest: a_digest.to_string(),
634                    dependency_ids: Vec::new(),
635                },
636            ],
637        }
638    }
639
640    fn node_reasons(plan: &super::IncrementalComputationPlanV0, id: &str) -> Vec<&'static str> {
641        plan.nodes
642            .iter()
643            .find(|node| node.id == id)
644            .map(|node| node.reasons.clone())
645            .unwrap_or_default()
646    }
647}