echidna 0.9.0

A high-performance automatic differentiation library for Rust
Documentation
------------------------------ MODULE Revolve ------------------------------
(*
 * Formal specification of the base Revolve (binomial) gradient checkpointing
 * algorithm used in echidna.
 *
 * Models the full lifecycle:
 *   Phase 1 (Schedule):  Compute optimal checkpoint positions via recursive
 *                         interval splitting (mirrors schedule_recursive).
 *   Phase 2 (Forward):   Walk forward through all steps, storing state at
 *                         checkpoint positions (mirrors grad_checkpointed).
 *   Phase 3 (Backward):  Process segments in reverse, marking covered steps
 *                         (mirrors backward_from_checkpoints).
 *
 * State vectors are abstracted to step indices. We verify the bookkeeping
 * — which steps are checkpointed, which segments are covered by the
 * backward pass — not numerical values.
 *
 * The backward pass is deliberately coarse: each segment is modelled as an
 * atomic action that marks steps as covered. The real implementation
 * recomputes forward within each segment, but uses only a local buffer
 * (not additional checkpoint slots), so this abstraction is safe.
 *
 * Rust correspondence: grad_checkpointed() in src/checkpoint.rs
 *)

EXTENDS BinomialBeta, Naturals, Sequences, FiniteSets

CONSTANTS
    NumSteps,        \* Total forward steps (>= 2)
    NumCheckpoints   \* Available checkpoint slots (>= 1, <= NumSteps)

ASSUME NumSteps >= 2
ASSUME NumCheckpoints >= 1
ASSUME NumCheckpoints <= NumSteps

---------------------------------------------------------------------------
(* Variables *)
---------------------------------------------------------------------------

VARIABLES
    phase,              \* "schedule" | "forward" | "backward" | "done"
    positions,          \* Set of checkpoint positions computed by schedule
    workStack,          \* Stack of <<start, end, slots>> for schedule phase
    currentStep,        \* Forward pass loop counter
    storedCheckpoints,  \* Sequence of step indices stored during forward pass
    segIndex,           \* Backward pass segment counter (counts down)
    coveredSteps        \* Set of steps whose VJP has been computed

vars == <<phase, positions, workStack, currentStep,
          storedCheckpoints, segIndex, coveredSteps>>

---------------------------------------------------------------------------
(* Helper operators *)
---------------------------------------------------------------------------

(*
 * For the backward pass: determine the end of a segment given its index
 * in storedCheckpoints. The last segment ends at NumSteps.
 *)
SegEnd(seg) ==
    IF seg < Len(storedCheckpoints)
    THEN storedCheckpoints[seg + 1]
    ELSE NumSteps

---------------------------------------------------------------------------
(* Initial state *)
---------------------------------------------------------------------------

Init ==
    /\ phase = "schedule"
    /\ positions = {}
    /\ workStack = << <<0, NumSteps, NumCheckpoints>> >>
    /\ currentStep = 0
    /\ storedCheckpoints = << 0 >>   \* Step 0 (initial state) always stored
    /\ segIndex = 0
    /\ coveredSteps = {}

---------------------------------------------------------------------------
(* Phase 1: Schedule — recursive interval splitting *)
---------------------------------------------------------------------------
(*
 * Each step pops one work item from the stack, computes the optimal split,
 * adds the split to positions, and pushes two sub-intervals.
 *
 * Mirrors schedule_recursive() in src/checkpoint.rs.
 *)

ScheduleStep ==
    /\ phase = "schedule"
    /\ workStack # << >>
    /\ LET item  == Head(workStack)
           rest  == Tail(workStack)
           start == item[1]
           end   == item[2]
           slots == item[3]
           steps == end - start
       IN
       IF steps <= 1 \/ slots = 0
       THEN
           \* Base case: no split needed, just pop
           /\ workStack' = rest
           /\ positions' = positions
       ELSE
           LET advance == OptimalAdvance(steps, slots)
               split   == start + advance
           IN
           IF split > start /\ split < end
           THEN
               /\ positions' = positions \union {split}
               \* Push left sub-interval (slots - 1) then right (slots)
               /\ workStack' = Append(
                      Append(rest, <<start, split, slots - 1>>),
                      <<split, end, slots>>)
           ELSE
               /\ workStack' = rest
               /\ positions' = positions
    /\ UNCHANGED <<phase, currentStep, storedCheckpoints, segIndex, coveredSteps>>

(*
 * Transition from schedule phase to forward phase when the work stack
 * is empty.
 *)
ScheduleDone ==
    /\ phase = "schedule"
    /\ workStack = << >>
    /\ phase' = "forward"
    /\ UNCHANGED <<positions, workStack, currentStep, storedCheckpoints,
                   segIndex, coveredSteps>>

---------------------------------------------------------------------------
(* Phase 2: Forward pass *)
---------------------------------------------------------------------------
(*
 * Walk steps 0..NumSteps-1. After each step, if the next step index is a
 * checkpoint position and we haven't exceeded the budget, store it.
 *
 * Mirrors the forward loop in grad_checkpointed().
 *)

ForwardStep ==
    /\ phase = "forward"
    /\ currentStep < NumSteps
    /\ currentStep' = currentStep + 1
    /\ LET nextStep == currentStep + 1
       IN
       IF nextStep < NumSteps
          /\ nextStep \in positions
          /\ Len(storedCheckpoints) < NumCheckpoints + 1  \* +1 for pinned step 0
       THEN
           storedCheckpoints' = Append(storedCheckpoints, nextStep)
       ELSE
           storedCheckpoints' = storedCheckpoints
    /\ UNCHANGED <<phase, positions, workStack, segIndex, coveredSteps>>

(*
 * Transition from forward to backward when all steps are done.
 *)
ForwardDone ==
    /\ phase = "forward"
    /\ currentStep = NumSteps
    /\ phase' = "backward"
    /\ segIndex' = Len(storedCheckpoints)
    /\ UNCHANGED <<positions, workStack, currentStep, storedCheckpoints,
                   coveredSteps>>

---------------------------------------------------------------------------
(* Phase 3: Backward pass *)
---------------------------------------------------------------------------
(*
 * Process segments in reverse order. Each segment covers steps from
 * its checkpoint to the start of the next segment (or NumSteps for the
 * last segment).
 *
 * Mirrors backward_from_checkpoints() in src/checkpoint.rs.
 *)

BackwardSegment ==
    /\ phase = "backward"
    /\ segIndex >= 1
    /\ LET ckptStep == storedCheckpoints[segIndex]
           segEnd   == SegEnd(segIndex)
       IN
       /\ coveredSteps' = coveredSteps \union (ckptStep .. (segEnd - 1))
       /\ segIndex' = segIndex - 1
    /\ UNCHANGED <<phase, positions, workStack, currentStep, storedCheckpoints>>

(*
 * Transition to done when all segments have been processed.
 *)
BackwardDone ==
    /\ phase = "backward"
    /\ segIndex = 0
    /\ phase' = "done"
    /\ UNCHANGED <<positions, workStack, currentStep, storedCheckpoints,
                   segIndex, coveredSteps>>

---------------------------------------------------------------------------
(* Next-state relation *)
---------------------------------------------------------------------------

Next ==
    \/ ScheduleStep
    \/ ScheduleDone
    \/ ForwardStep
    \/ ForwardDone
    \/ BackwardSegment
    \/ BackwardDone

Spec == Init /\ [][Next]_vars /\ WF_vars(Next)

---------------------------------------------------------------------------
(* Invariants *)
---------------------------------------------------------------------------

(*
 * SAFETY: Checkpoint budget is never exceeded.
 * storedCheckpoints includes pinned step 0, so the limit is
 * NumCheckpoints + 1.
 *
 * Rust: all_positions.truncate(num_checkpoints) in grad_checkpointed
 *)
BudgetInvariant ==
    Len(storedCheckpoints) <= NumCheckpoints + 1

(*
 * SAFETY: All computed positions are within valid range [1, NumSteps-1].
 *
 * Rust: next_step < num_steps guard in forward loop
 *)
PositionRangeInvariant ==
    \A p \in positions : p >= 1 /\ p <= NumSteps - 1

(*
 * SAFETY: Stored checkpoints have strictly increasing step indices.
 * This must hold after the forward phase completes.
 *
 * Rust: positions are computed from a sorted, deduped Vec;
 *       forward pass inserts in order.
 *)
SortedCheckpoints ==
    phase \in {"backward", "done"} =>
        \A i, j \in 1..Len(storedCheckpoints) :
            i < j => storedCheckpoints[i] < storedCheckpoints[j]

(*
 * SAFETY: Step 0 (initial state) is always the first checkpoint.
 *
 * Rust: checkpoints.push((0, x0.to_vec())) is always first.
 *)
InitialStateStored ==
    Len(storedCheckpoints) >= 1 => storedCheckpoints[1] = 0

(*
 * SAFETY (schedule phase): Work stack entries are valid intervals.
 *)
WorkStackBoundsInvariant ==
    phase = "schedule" =>
        \A i \in 1..Len(workStack) :
            /\ workStack[i][1] >= 0
            /\ workStack[i][2] <= NumSteps
            /\ workStack[i][1] < workStack[i][2]
            /\ workStack[i][3] >= 0

(*
 * COMPLETENESS: When done, the backward pass has covered every step
 * in [0, NumSteps-1]. This is the key safety property: no step is missed.
 *)
CompletenessProperty ==
    phase = "done" => coveredSteps = 0 .. (NumSteps - 1)

---------------------------------------------------------------------------
(* Temporal properties *)
---------------------------------------------------------------------------

(*
 * LIVENESS: The algorithm always terminates.
 *)
Termination == <>(phase = "done")

==========================================================================