Skip to main content

vyre_driver/
graph_capture.rs

1//! Backend-neutral planning for replayable graph-capture dispatch paths.
2//!
3//! CUDA graphs, WGPU command replay, and future persistent-dispatch recorders
4//! all need the same first step: walk a [`BindingPlan`] once, classify which
5//! runtime buffers require stable input storage, which require output readback
6//! storage, and how many kernel pointer arguments are needed in lowered binding
7//! order. This module owns that logic so backend crates do not fork planner
8//! invariants while adding API-specific capture and replay code.
9
10use crate::binding::{BindingPlan, BindingRole};
11use crate::transfer_accounting::TransferAccountingPolicy;
12use crate::BackendError;
13
14const GRAPH_CAPTURE_BINDING_ACCOUNTING: TransferAccountingPolicy =
15    TransferAccountingPolicy::new("graph capture binding plan", "record a smaller graph shape");
16
17/// Schema version for scan graph-capture edit classification evidence.
18pub const SCAN_GRAPH_CAPTURE_EDIT_SCHEMA_VERSION: u32 = 1;
19
20/// Capacity and safety plan for recording one replayable dispatch graph.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct GraphCaptureBindingPlan {
23    /// Device/storage entries needed for runtime input buffers. Input-output
24    /// bindings are counted here because their input allocation is reused for
25    /// output readback.
26    pub input_device_capacity: usize,
27    /// Device/storage entries needed for non-input runtime buffers. This is
28    /// intentionally separate from [`Self::output_readback_capacity`] because
29    /// an input-output binding needs output readback metadata but does not need
30    /// a second device pointer.
31    pub output_device_capacity: usize,
32    /// Host/readback entries needed for bindings with an output view.
33    pub output_readback_capacity: usize,
34    /// Pointer arguments passed to the captured kernel in binding order.
35    pub kernel_pointer_capacity: usize,
36    /// Kernel pointer arguments plus the trailing launch-parameter pointer.
37    pub kernel_argument_capacity: usize,
38    /// True when a backend can replay a no-upload steady-state graph after the
39    /// device inputs have been initialized once.
40    pub resident_input_replay_safe: bool,
41}
42
43/// Scan-specific edit class that can affect graph replay safety.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum ScanGraphCaptureEditKind {
46    /// Resident pattern database upload or replacement.
47    PatternDatabaseUpload,
48    /// Haystack bytes changed between graph dispatches.
49    HaystackBufferChange,
50    /// Output slab size or layout changed.
51    OutputSlabResize,
52    /// Verifier program or verifier metadata changed.
53    VerifierChange,
54}
55
56impl ScanGraphCaptureEditKind {
57    /// Stable evidence label for this scan edit kind.
58    #[must_use]
59    pub const fn as_str(self) -> &'static str {
60        match self {
61            Self::PatternDatabaseUpload => "pattern_database_upload",
62            Self::HaystackBufferChange => "haystack_buffer_change",
63            Self::OutputSlabResize => "output_slab_resize",
64            Self::VerifierChange => "verifier_change",
65        }
66    }
67}
68
69/// Backend-neutral action selected for a scan graph-capture edit.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum GraphCaptureEditAction {
72    /// Reuse the captured graph without parameter update.
73    Replay,
74    /// Update graph parameters or copied input bytes without recapturing topology.
75    Update,
76    /// Re-record the graph because topology, pointer shape, or code changed.
77    Recapture,
78}
79
80impl GraphCaptureEditAction {
81    /// Stable evidence label for this capture action.
82    #[must_use]
83    pub const fn as_str(self) -> &'static str {
84        match self {
85            Self::Replay => "replay",
86            Self::Update => "update",
87            Self::Recapture => "recapture",
88        }
89    }
90}
91
92/// Graph-topology stability after applying one scan edit.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
94pub enum GraphCaptureEditStability {
95    /// Captured graph topology and pointer table shape remain valid.
96    GraphStable,
97    /// Captured graph topology or pointer table shape is invalidated.
98    GraphBreaking,
99}
100
101impl GraphCaptureEditStability {
102    /// Stable evidence label for this stability class.
103    #[must_use]
104    pub const fn as_str(self) -> &'static str {
105        match self {
106            Self::GraphStable => "graph_stable",
107            Self::GraphBreaking => "graph_breaking",
108        }
109    }
110}
111
112/// Input facts for classifying one scan graph-capture edit.
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
114pub struct ScanGraphCaptureEdit {
115    /// Edit kind to classify.
116    pub kind: ScanGraphCaptureEditKind,
117    /// Previous resident artifact or buffer byte length.
118    pub previous_byte_len: u64,
119    /// New resident artifact or buffer byte length.
120    pub next_byte_len: u64,
121    /// Previous content, table, or verifier digest.
122    pub previous_digest: u64,
123    /// New content, table, or verifier digest.
124    pub next_digest: u64,
125}
126
127impl ScanGraphCaptureEdit {
128    /// Construct scan graph-capture edit facts.
129    #[must_use]
130    pub const fn new(
131        kind: ScanGraphCaptureEditKind,
132        previous_byte_len: u64,
133        next_byte_len: u64,
134        previous_digest: u64,
135        next_digest: u64,
136    ) -> Self {
137        Self {
138            kind,
139            previous_byte_len,
140            next_byte_len,
141            previous_digest,
142            next_digest,
143        }
144    }
145
146    const fn shape_unchanged(self) -> bool {
147        self.previous_byte_len == self.next_byte_len
148    }
149
150    const fn digest_unchanged(self) -> bool {
151        self.previous_digest == self.next_digest
152    }
153}
154
155/// Evidence emitted by scan graph-capture edit classification.
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub struct ScanGraphCaptureEditClassification {
158    /// Evidence schema version.
159    pub schema_version: u32,
160    /// Edit kind that was classified.
161    pub edit_kind: ScanGraphCaptureEditKind,
162    /// Replay, update, or recapture action.
163    pub action: GraphCaptureEditAction,
164    /// Whether graph topology and pointer shape remain stable.
165    pub stability: GraphCaptureEditStability,
166    /// Exact reason code for tests, logs, and release evidence.
167    pub reason: &'static str,
168    /// True when the graph can be replayed without re-recording.
169    pub graph_stable: bool,
170    /// True when the edit invalidates the captured graph.
171    pub graph_breaking: bool,
172    /// True when content changed but graph shape did not.
173    pub parameter_update_required: bool,
174}
175
176impl ScanGraphCaptureEditClassification {
177    /// Return true when this evidence has a valid schema, exact reason, and a
178    /// self-consistent action/stability pair.
179    #[must_use]
180    pub const fn is_complete(self) -> bool {
181        self.schema_version == SCAN_GRAPH_CAPTURE_EDIT_SCHEMA_VERSION
182            && !self.reason.is_empty()
183            && self.graph_stable == matches!(self.stability, GraphCaptureEditStability::GraphStable)
184            && self.graph_breaking
185                == matches!(self.stability, GraphCaptureEditStability::GraphBreaking)
186            && self.parameter_update_required
187                == matches!(self.action, GraphCaptureEditAction::Update)
188    }
189}
190
191/// Build a backend-neutral capture plan from a lowered binding plan.
192///
193/// # Errors
194///
195/// Returns [`BackendError::InvalidProgram`] if capacity arithmetic would
196/// overflow on the host.
197pub fn plan_graph_capture_bindings(
198    bindings: &BindingPlan,
199) -> Result<GraphCaptureBindingPlan, BackendError> {
200    let mut input_device_capacity = 0usize;
201    let mut output_device_capacity = 0usize;
202    let mut output_readback_capacity = 0usize;
203    let mut kernel_pointer_capacity = 0usize;
204    let mut resident_input_replay_safe = true;
205
206    for binding in &bindings.bindings {
207        if binding.role == BindingRole::Shared {
208            continue;
209        }
210
211        kernel_pointer_capacity =
212            graph_capture_capacity_add(kernel_pointer_capacity, 1, "kernel pointer table")?;
213
214        if binding.input_index.is_some() {
215            input_device_capacity =
216                graph_capture_capacity_add(input_device_capacity, 1, "input device table")?;
217        } else {
218            output_device_capacity =
219                graph_capture_capacity_add(output_device_capacity, 1, "output device table")?;
220        }
221
222        if binding.output_index.is_some() {
223            output_readback_capacity =
224                graph_capture_capacity_add(output_readback_capacity, 1, "output readback table")?;
225        }
226
227        if binding.input_index.is_some() && binding.output_index.is_some() {
228            resident_input_replay_safe = false;
229        }
230    }
231
232    let kernel_argument_capacity =
233        graph_capture_capacity_add(kernel_pointer_capacity, 1, "kernel argument table")?;
234
235    Ok(GraphCaptureBindingPlan {
236        input_device_capacity,
237        output_device_capacity,
238        output_readback_capacity,
239        kernel_pointer_capacity,
240        kernel_argument_capacity,
241        resident_input_replay_safe,
242    })
243}
244
245/// Classify one scan workload edit for replayable graph capture.
246///
247/// Pattern database and verifier changes are graph-breaking because they alter
248/// resident scan code/data semantics. Haystack content changes with identical
249/// byte length are graph-stable parameter updates. Output slab resizing is
250/// graph-breaking because readback and pointer-shape assumptions change.
251#[must_use]
252pub const fn classify_scan_graph_capture_edit(
253    edit: ScanGraphCaptureEdit,
254) -> ScanGraphCaptureEditClassification {
255    match edit.kind {
256        ScanGraphCaptureEditKind::PatternDatabaseUpload => {
257            if edit.shape_unchanged() && edit.digest_unchanged() {
258                scan_graph_capture_classification(
259                    edit.kind,
260                    GraphCaptureEditAction::Replay,
261                    GraphCaptureEditStability::GraphStable,
262                    "pattern_database_unchanged",
263                )
264            } else {
265                scan_graph_capture_classification(
266                    edit.kind,
267                    GraphCaptureEditAction::Recapture,
268                    GraphCaptureEditStability::GraphBreaking,
269                    "pattern_database_changed",
270                )
271            }
272        }
273        ScanGraphCaptureEditKind::HaystackBufferChange => {
274            if edit.shape_unchanged() {
275                if edit.digest_unchanged() {
276                    scan_graph_capture_classification(
277                        edit.kind,
278                        GraphCaptureEditAction::Replay,
279                        GraphCaptureEditStability::GraphStable,
280                        "haystack_unchanged",
281                    )
282                } else {
283                    scan_graph_capture_classification(
284                        edit.kind,
285                        GraphCaptureEditAction::Update,
286                        GraphCaptureEditStability::GraphStable,
287                        "haystack_contents_changed_same_shape",
288                    )
289                }
290            } else {
291                scan_graph_capture_classification(
292                    edit.kind,
293                    GraphCaptureEditAction::Recapture,
294                    GraphCaptureEditStability::GraphBreaking,
295                    "haystack_shape_changed",
296                )
297            }
298        }
299        ScanGraphCaptureEditKind::OutputSlabResize => {
300            if edit.shape_unchanged() {
301                scan_graph_capture_classification(
302                    edit.kind,
303                    GraphCaptureEditAction::Replay,
304                    GraphCaptureEditStability::GraphStable,
305                    "output_slab_unchanged",
306                )
307            } else {
308                scan_graph_capture_classification(
309                    edit.kind,
310                    GraphCaptureEditAction::Recapture,
311                    GraphCaptureEditStability::GraphBreaking,
312                    "output_slab_size_changed",
313                )
314            }
315        }
316        ScanGraphCaptureEditKind::VerifierChange => {
317            if edit.shape_unchanged() && edit.digest_unchanged() {
318                scan_graph_capture_classification(
319                    edit.kind,
320                    GraphCaptureEditAction::Replay,
321                    GraphCaptureEditStability::GraphStable,
322                    "verifier_unchanged",
323                )
324            } else {
325                scan_graph_capture_classification(
326                    edit.kind,
327                    GraphCaptureEditAction::Recapture,
328                    GraphCaptureEditStability::GraphBreaking,
329                    "verifier_changed",
330                )
331            }
332        }
333    }
334}
335
336const fn scan_graph_capture_classification(
337    edit_kind: ScanGraphCaptureEditKind,
338    action: GraphCaptureEditAction,
339    stability: GraphCaptureEditStability,
340    reason: &'static str,
341) -> ScanGraphCaptureEditClassification {
342    ScanGraphCaptureEditClassification {
343        schema_version: SCAN_GRAPH_CAPTURE_EDIT_SCHEMA_VERSION,
344        edit_kind,
345        action,
346        stability,
347        reason,
348        graph_stable: matches!(stability, GraphCaptureEditStability::GraphStable),
349        graph_breaking: matches!(stability, GraphCaptureEditStability::GraphBreaking),
350        parameter_update_required: matches!(action, GraphCaptureEditAction::Update),
351    }
352}
353
354fn graph_capture_capacity_add(lhs: usize, rhs: usize, label: &str) -> Result<usize, BackendError> {
355    GRAPH_CAPTURE_BINDING_ACCOUNTING.add_usize_capacity(lhs, rhs, label)
356}
357
358#[cfg(test)]
359mod tests {
360    use super::{
361        classify_scan_graph_capture_edit, graph_capture_capacity_add, plan_graph_capture_bindings,
362        GraphCaptureBindingPlan, GraphCaptureEditAction, GraphCaptureEditStability,
363        ScanGraphCaptureEdit, ScanGraphCaptureEditKind,
364    };
365    use crate::binding::{Binding, BindingPlan, BindingRole};
366    use std::sync::Arc;
367
368    fn binding(
369        name: &'static str,
370        slot: u32,
371        role: BindingRole,
372        input_index: Option<usize>,
373        output_index: Option<usize>,
374    ) -> Binding {
375        Binding {
376            name: Arc::from(name),
377            binding: slot,
378            buffer_index: slot as usize,
379            role,
380            element_size: 4,
381            preferred_alignment: 4,
382            element_count: 16,
383            static_byte_len: Some(64),
384            input_index,
385            output_index,
386        }
387    }
388
389    fn plan(bindings: Vec<Binding>) -> BindingPlan {
390        BindingPlan {
391            bindings,
392            input_indices: vec![],
393            output_indices: vec![],
394            shared_indices: vec![],
395        }
396    }
397
398    #[test]
399    fn graph_capture_binding_plan_counts_distinct_device_and_readback_tables() {
400        let bindings = plan(vec![
401            binding("input", 0, BindingRole::Input, Some(0), None),
402            binding("shared", 1, BindingRole::Shared, None, None),
403            binding("output", 2, BindingRole::Output, None, Some(0)),
404            binding("state", 3, BindingRole::InputOutput, Some(1), Some(1)),
405        ]);
406
407        assert_eq!(
408            plan_graph_capture_bindings(&bindings)
409                .expect("Fix: graph capture planning should accept normal bindings"),
410            GraphCaptureBindingPlan {
411                input_device_capacity: 2,
412                output_device_capacity: 1,
413                output_readback_capacity: 2,
414                kernel_pointer_capacity: 3,
415                kernel_argument_capacity: 4,
416                resident_input_replay_safe: false,
417            }
418        );
419    }
420
421    #[test]
422    fn generated_graph_capture_binding_plan_preserves_order_independent_counts() {
423        let mut state = 0x9e37_79b9_7f4a_7c15_u64;
424        for case_index in 0..768usize {
425            let binding_count = 1 + (next_u64(&mut state) as usize % 96);
426            let mut bindings = Vec::with_capacity(binding_count);
427            let mut expected_input_device_capacity = 0usize;
428            let mut expected_output_device_capacity = 0usize;
429            let mut expected_output_readback_capacity = 0usize;
430            let mut expected_kernel_pointer_capacity = 0usize;
431            let mut expected_safe = true;
432            let mut next_input = 0usize;
433            let mut next_output = 0usize;
434
435            for slot in 0..binding_count {
436                let role_selector = (next_u64(&mut state) % 4) as u8;
437                let (role, input_index, output_index) = match role_selector {
438                    0 => {
439                        let index = next_input;
440                        next_input += 1;
441                        (BindingRole::Input, Some(index), None)
442                    }
443                    1 => {
444                        let index = next_output;
445                        next_output += 1;
446                        (BindingRole::Output, None, Some(index))
447                    }
448                    2 => {
449                        let input = next_input;
450                        let output = next_output;
451                        next_input += 1;
452                        next_output += 1;
453                        expected_safe = false;
454                        (BindingRole::InputOutput, Some(input), Some(output))
455                    }
456                    _ => (BindingRole::Shared, None, None),
457                };
458
459                if role != BindingRole::Shared {
460                    expected_kernel_pointer_capacity += 1;
461                    if input_index.is_some() {
462                        expected_input_device_capacity += 1;
463                    } else {
464                        expected_output_device_capacity += 1;
465                    }
466                    if output_index.is_some() {
467                        expected_output_readback_capacity += 1;
468                    }
469                }
470
471                bindings.push(binding(
472                    "generated",
473                    slot as u32,
474                    role,
475                    input_index,
476                    output_index,
477                ));
478            }
479
480            let planned = plan_graph_capture_bindings(&plan(bindings))
481                .expect("Fix: generated graph capture plan should fit host capacities");
482            assert_eq!(
483                planned,
484                GraphCaptureBindingPlan {
485                    input_device_capacity: expected_input_device_capacity,
486                    output_device_capacity: expected_output_device_capacity,
487                    output_readback_capacity: expected_output_readback_capacity,
488                    kernel_pointer_capacity: expected_kernel_pointer_capacity,
489                    kernel_argument_capacity: expected_kernel_pointer_capacity + 1,
490                    resident_input_replay_safe: expected_safe,
491                },
492                "case {case_index}"
493            );
494        }
495    }
496
497    #[test]
498    fn graph_capture_capacity_overflow_fails_loudly() {
499        let error = graph_capture_capacity_add(usize::MAX, 1, "kernel argument table")
500            .expect_err("Fix: graph capture capacity overflow must not wrap");
501        let message = error.to_string();
502        assert!(message.contains("graph capture binding plan"));
503        assert!(message.contains("kernel argument table"));
504        assert!(message.contains("record a smaller graph shape"));
505    }
506
507    #[test]
508    fn scan_graph_capture_classifies_replay_update_and_recapture_reasons() {
509        let cases = [
510            (
511                ScanGraphCaptureEdit::new(
512                    ScanGraphCaptureEditKind::PatternDatabaseUpload,
513                    4096,
514                    4096,
515                    11,
516                    11,
517                ),
518                GraphCaptureEditAction::Replay,
519                GraphCaptureEditStability::GraphStable,
520                "pattern_database_unchanged",
521            ),
522            (
523                ScanGraphCaptureEdit::new(
524                    ScanGraphCaptureEditKind::PatternDatabaseUpload,
525                    4096,
526                    4096,
527                    11,
528                    12,
529                ),
530                GraphCaptureEditAction::Recapture,
531                GraphCaptureEditStability::GraphBreaking,
532                "pattern_database_changed",
533            ),
534            (
535                ScanGraphCaptureEdit::new(
536                    ScanGraphCaptureEditKind::HaystackBufferChange,
537                    8192,
538                    8192,
539                    21,
540                    22,
541                ),
542                GraphCaptureEditAction::Update,
543                GraphCaptureEditStability::GraphStable,
544                "haystack_contents_changed_same_shape",
545            ),
546            (
547                ScanGraphCaptureEdit::new(
548                    ScanGraphCaptureEditKind::HaystackBufferChange,
549                    8192,
550                    16_384,
551                    21,
552                    22,
553                ),
554                GraphCaptureEditAction::Recapture,
555                GraphCaptureEditStability::GraphBreaking,
556                "haystack_shape_changed",
557            ),
558            (
559                ScanGraphCaptureEdit::new(
560                    ScanGraphCaptureEditKind::OutputSlabResize,
561                    1024,
562                    2048,
563                    31,
564                    31,
565                ),
566                GraphCaptureEditAction::Recapture,
567                GraphCaptureEditStability::GraphBreaking,
568                "output_slab_size_changed",
569            ),
570            (
571                ScanGraphCaptureEdit::new(
572                    ScanGraphCaptureEditKind::VerifierChange,
573                    512,
574                    512,
575                    41,
576                    42,
577                ),
578                GraphCaptureEditAction::Recapture,
579                GraphCaptureEditStability::GraphBreaking,
580                "verifier_changed",
581            ),
582        ];
583
584        for (edit, action, stability, reason) in cases {
585            let classified = classify_scan_graph_capture_edit(edit);
586            assert!(classified.is_complete());
587            assert_eq!(classified.edit_kind, edit.kind);
588            assert_eq!(classified.action, action);
589            assert_eq!(classified.stability, stability);
590            assert_eq!(classified.reason, reason);
591        }
592    }
593
594    #[test]
595    fn scan_graph_capture_same_shape_haystack_update_is_not_a_hidden_recapture() {
596        let classified = classify_scan_graph_capture_edit(ScanGraphCaptureEdit::new(
597            ScanGraphCaptureEditKind::HaystackBufferChange,
598            65_536,
599            65_536,
600            100,
601            101,
602        ));
603
604        assert_eq!(classified.action, GraphCaptureEditAction::Update);
605        assert!(classified.graph_stable);
606        assert!(!classified.graph_breaking);
607        assert!(classified.parameter_update_required);
608        assert_eq!(classified.reason, "haystack_contents_changed_same_shape");
609    }
610
611    fn next_u64(state: &mut u64) -> u64 {
612        *state = state
613            .wrapping_mul(6_364_136_223_846_793_005)
614            .wrapping_add(1_442_695_040_888_963_407);
615        *state
616    }
617}