xad-rs 0.2.0

Automatic differentiation library for Rust — forward/reverse mode AD, a Rust port of the C++ XAD library (https://github.com/auto-differentiation/xad)
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
//! `LabeledForwardTape` — setup-and-freeze scope for labeled forward-mode values.
//!
//! Parallel to `src/labeled/areal.rs`'s `LabeledTape` (reverse mode). This
//! type owns the `VarRegistry` and manages a thread-local active-generation
//! counter used by the debug-only cross-registry guard on forward-mode
//! labeled wrappers (`LabeledFReal`, `LabeledDual`, `LabeledDual2`).
//!
//! # Two APIs, two shapes
//!
//! `LabeledForwardTape` hands out labeled wrappers via two distinct APIs,
//! one per wrapper family. The asymmetry is forced by the positional
//! `Dual::variable(value, idx, n)` constructor contract — `Dual` stores a
//! fixed-length `grad: Vec<f64>` whose length `n` must be known at
//! construction time, i.e. BEFORE all inputs have been declared. `FReal<T>`
//! is single-direction and has no such requirement.
//!
//! ## Shape 1 — single-step: `LabeledFReal<T>` (existing API)
//!
//! 1. `let mut ft = LabeledForwardTape::new();`
//! 2. `let x: LabeledFReal<f64> = ft.input_freal("x", 2.0);` — returns a
//!    fully-seeded wrapper immediately (no total-variable count needed).
//! 3. `let _registry = ft.freeze();` — finalises the registry and stamps
//!    the TLS `ACTIVE_REGISTRY` / `ACTIVE_GEN` slots for read-back.
//! 4. `ft` is dropped at end of scope, restoring the previous TLS values.
//!
//! ## Shape 2 — declare + freeze + handle: `LabeledDual`, `LabeledDual2<f64>`
//!
//! Because `Dual::variable(value, idx, n)` needs `n` at construction time,
//! we defer construction until `freeze_dual(self)` is called:
//!
//! 1. `let mut ft = LabeledForwardTape::new();`
//! 2. `let x_h: DualHandle = ft.declare_dual("x", 2.0);` — registers the
//!    name, stores `(name, value)` in a pending buffer, returns a
//!    lightweight `Copy` handle.
//! 3. `let scope: LabeledForwardScope = ft.freeze_dual();` — consumes the
//!    tape, allocates `Arc<VarRegistry>`, constructs each pending `Dual`
//!    with the final `n = pending.len()`, stamps `ACTIVE_REGISTRY` /
//!    `ACTIVE_GEN`, and returns a scope object that holds the wrappers.
//! 4. `let x: &LabeledDual = scope.dual(x_h);` — zero-copy, zero-alloc
//!    immutable lookup indexed by the handle's position.
//! 5. `scope` is dropped at end of scope, restoring the previous TLS
//!    values. The `Arc<VarRegistry>` lives as long as `scope` does.
//!
//! The same pattern applies to `LabeledDual2<f64>` via `declare_dual2_f64`,
//! `Dual2Handle`, and `scope.dual2(handle)`. Both wrapper families share
//! ONE scope object; `freeze_dual` builds both vectors at once. (Mixing
//! `input_freal` — Shape 1 — with `declare_dual` / `declare_dual2_f64` —
//! Shape 2 — on the same tape is not supported: users must pick one shape
//! per tape.)
//!
//! # Why two shapes?
//!
//! `FReal<T>::new(value, T::one())` requires only a value and the tangent
//! seed; no "total count" parameter. A single-step `input_freal` API
//! therefore works cleanly. `Dual::variable(value, idx, n)` requires `n`
//! at construction time, and `n` is unknown until the user finishes
//! declaring inputs. The handle pattern is the smallest API shape that
//! defers construction without breaking Shape A's zero-cost promise on the
//! per-value wrappers.
//!
//! # Debug-only field layout quirk
//!
//! `LabeledFReal<T>`, `LabeledDual`, `LabeledDual2<T>` carry
//! `#[cfg(debug_assertions)] gen_id: u64` — release builds have zero extra
//! bytes, debug builds have 8. This is valid Rust and does not affect the
//! labeled layer (no FFI here, no `mem::size_of` asserts).
//!
//! # `!Send` compile-fail assertion
//!
//! ```compile_fail,E0277
//! use xad_rs::labeled::LabeledForwardTape;
//! fn assert_send<T: Send>(_: T) {}
//! assert_send(LabeledForwardTape::new());
//! ```
//!
//! # `!Send` compile-fail assertion on `LabeledForwardScope`
//!
//! ```compile_fail,E0277
//! use xad_rs::labeled::{LabeledForwardScope, LabeledForwardTape};
//! fn assert_send<T: Send>(_: T) {}
//! let mut ft = LabeledForwardTape::new();
//! let _h = ft.declare_dual("x", 1.0);
//! let scope: LabeledForwardScope = ft.freeze_dual();
//! assert_send(scope);
//! ```

use std::cell::Cell;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
#[cfg(debug_assertions)]
use std::sync::atomic::{AtomicU64, Ordering};

use indexmap::IndexSet;

use crate::labeled::VarRegistry;

// === TLS generation counter (debug-only) ===

#[cfg(debug_assertions)]
static NEXT_GEN: AtomicU64 = AtomicU64::new(1);

thread_local! {
    #[cfg(debug_assertions)]
    static ACTIVE_GEN: Cell<u64> = const { Cell::new(0) };

    /// Raw-pointer slot for the currently-active frozen `VarRegistry` on
    /// this thread. Stamped by `LabeledForwardTape::freeze()` using
    /// `Arc::as_ptr` on the tape's own `Arc<VarRegistry>`; restored by
    /// `Drop` via save-restore discipline. Non-null only when a frozen
    /// tape is live on the current thread.
    static ACTIVE_REGISTRY: Cell<*const VarRegistry> = const { Cell::new(std::ptr::null()) };
}

/// Run a closure with access to the active registry on the current thread.
///
/// Returns `None` to the closure if no `LabeledForwardTape` is currently
/// frozen on this thread.
///
/// Used by `LabeledDual::partial`, `LabeledDual2::first_derivative`, and
/// `LabeledDual2::second_derivative` to look up positional indices from
/// string names without holding a per-value `Arc<VarRegistry>` handle.
pub(crate) fn with_active_registry<R>(f: impl FnOnce(Option<&VarRegistry>) -> R) -> R {
    ACTIVE_REGISTRY.with(|c| {
        let ptr = c.get();
        // SAFETY: The pointer, if non-null, points into an `Arc<VarRegistry>`
        // held by a live `LabeledForwardTape` on this thread. The tape
        // retains the `Arc` inside its `registry: Option<Arc<VarRegistry>>`
        // field from `freeze()` time through `Drop`, and `Drop` restores
        // the TLS pointer to the previous value before returning (save-
        // restore discipline). `LabeledForwardTape` is `!Send`, so the
        // pointer cannot be read from another thread. The only escape is
        // `std::mem::forget` on the tape, in which case `deactivate_all()`
        // is the documented recovery.
        let reg_ref: Option<&VarRegistry> = if ptr.is_null() {
            None
        } else {
            Some(unsafe { &*ptr })
        };
        f(reg_ref)
    })
}

/// Read the active generation for the current thread.
///
/// Called by labeled forward wrapper constructors (`input_freal` etc.)
/// and by operator impls to stamp / check the `gen` field.
#[cfg(debug_assertions)]
#[inline(always)]
pub(crate) fn current_gen() -> u64 {
    ACTIVE_GEN.with(|c| c.get())
}

/// Cross-generation check — debug only, empty in release.
#[cfg(debug_assertions)]
#[inline(always)]
pub(crate) fn check_gen(lhs: u64, rhs: u64) {
    assert_eq!(
        lhs, rhs,
        "xad_rs::labeled: cross-registry forward-mode op detected (lhs tape generation = {lhs}, rhs tape generation = {rhs}). \
         Both operands must come from the same LabeledForwardTape scope."
    );
}

#[cfg(not(debug_assertions))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) fn check_gen(_lhs: (), _rhs: ()) {}

// === LabeledForwardTape ===

/// Labeled forward-mode scope. Owns the pending `VarRegistry` and (in
/// debug builds) manages the TLS generation counter.
///
/// See the module-level docs for the two-phase contract and the
/// save-restore TLS discipline across nested scopes.
pub struct LabeledForwardTape {
    builder: IndexSet<String>,
    registry: Option<Arc<VarRegistry>>,
    /// Pending `(name, value)` entries destined for `LabeledDual`
    /// construction in `freeze_dual`. Each entry's position in this
    /// `Vec` equals the `DualHandle` idx returned by `declare_dual`.
    pending_dual: Vec<(String, f64)>,
    /// Pending `(name, value)` entries destined for `LabeledDual2<f64>`
    /// construction in `freeze_dual`. Each entry's position in this
    /// `Vec` equals the `Dual2Handle` idx returned by `declare_dual2_f64`.
    pending_dual2_f64: Vec<(String, f64)>,
    /// Save-restore slot for `ACTIVE_REGISTRY`. Holds the TLS pointer that
    /// was live when this tape called `freeze()` / `freeze_dual()`,
    /// restored in `Drop` so nested scopes unwind cleanly.
    /// `std::ptr::null()` when the tape has not been frozen yet.
    prev_registry: *const VarRegistry,
    /// Generation this tape activated itself under. Debug-only, retained
    /// for diagnostics (not currently read — `prev_gen` carries the
    /// save-restore state and `ACTIVE_GEN` carries the live value).
    #[cfg(debug_assertions)]
    #[allow(dead_code)]
    gen_id: u64,
    #[cfg(debug_assertions)]
    prev_gen: u64,
    frozen: bool,
    _not_send: PhantomData<*const ()>,
}

impl LabeledForwardTape {
    /// Construct a new forward scope. Allocates a fresh generation ID
    /// (debug builds) and stamps `ACTIVE_GEN` so that any subsequent
    /// `input_*` / `constant_*` calls land under this generation.
    /// The save-restore discipline in `Drop` restores the previous value.
    pub fn new() -> Self {
        #[cfg(debug_assertions)]
        let new_gen = NEXT_GEN.fetch_add(1, Ordering::Relaxed);
        #[cfg(debug_assertions)]
        let prev_gen = ACTIVE_GEN.with(|c| {
            let p = c.get();
            c.set(new_gen);
            p
        });
        Self {
            builder: IndexSet::new(),
            registry: None,
            pending_dual: Vec::new(),
            pending_dual2_f64: Vec::new(),
            prev_registry: std::ptr::null(),
            #[cfg(debug_assertions)]
            gen_id: new_gen,
            #[cfg(debug_assertions)]
            prev_gen,
            frozen: false,
            _not_send: PhantomData,
        }
    }

    /// Register a named input and return a seeded `LabeledFReal<T>`.
    /// Eager-registration (matches `LabeledTape::input`). Panics if
    /// called after `freeze()`.
    pub fn input_freal<T: crate::traits::Scalar>(
        &mut self,
        name: &str,
        value: T,
    ) -> crate::labeled::LabeledFReal<T> {
        assert!(
            !self.frozen,
            "LabeledForwardTape::input_freal({:?}) called after freeze(); add all inputs before forward pass",
            name
        );
        if !self.builder.contains(name) {
            self.builder.insert(name.to_string());
        }
        crate::labeled::LabeledFReal::<T>::__from_inner(crate::freal::FReal::<T>::new(
            value,
            T::one(),
        ))
    }

    /// Create a derivative-free constant. Returns a `LabeledFReal<T>`
    /// stamped with the active generation.
    pub fn constant_freal<T: crate::traits::Scalar>(
        &self,
        value: T,
    ) -> crate::labeled::LabeledFReal<T> {
        crate::labeled::LabeledFReal::<T>::__from_inner(crate::freal::FReal::<T>::constant(value))
    }

    /// Lock the registry and return the shared `Arc<VarRegistry>`.
    /// Stamps the `ACTIVE_REGISTRY` TLS pointer to this tape's registry
    /// with save-restore discipline so nested scopes unwind cleanly in
    /// `Drop`. Panics if already frozen.
    pub fn freeze(&mut self) -> Arc<VarRegistry> {
        assert!(!self.frozen, "LabeledForwardTape::freeze() called twice");
        let reg = Arc::new(VarRegistry::from_names(self.builder.iter().cloned()));
        self.registry = Some(Arc::clone(&reg));
        // Stamp ACTIVE_REGISTRY save-restore style. The `Arc::as_ptr` call
        // yields a pointer into the Arc's inner allocation; that allocation
        // lives at least as long as `self.registry`, which in turn lives
        // as long as `self`. `Drop` restores the previous pointer.
        let new_ptr: *const VarRegistry = Arc::as_ptr(self.registry.as_ref().unwrap());
        ACTIVE_REGISTRY.with(|c| {
            self.prev_registry = c.get();
            c.set(new_ptr);
        });
        self.frozen = true;
        reg
    }

    /// True if [`freeze`](Self::freeze) has been called.
    #[inline]
    pub fn is_frozen(&self) -> bool {
        self.frozen
    }

    /// Access the frozen registry, if any. Returns `None` until
    /// [`freeze`](Self::freeze) has been called.
    #[inline]
    pub fn registry(&self) -> Option<&Arc<VarRegistry>> {
        self.registry.as_ref()
    }

    /// Static escape hatch for `std::mem::forget` recovery. Clears the
    /// forward-mode TLS generation slot AND the active-registry pointer
    /// on the current thread.
    pub fn deactivate_all() {
        ACTIVE_REGISTRY.with(|c| c.set(std::ptr::null()));
        #[cfg(debug_assertions)]
        ACTIVE_GEN.with(|c| c.set(0));
    }

    // ============ Shape 2 API: declare + freeze_dual + handle lookup ============
    //
    // `LabeledDual` and `LabeledDual2<f64>` use a two-step construction
    // pattern because `Dual::variable(value, idx, n)` needs the final
    // variable count `n` at construction time. See the module docs for
    // the rationale and a worked example.

    /// Declare a named `LabeledDual` input, deferring construction to
    /// [`freeze_dual`](Self::freeze_dual).
    ///
    /// Registers `name` in the pending registry (insertion order) and
    /// returns an opaque `DualHandle` the user can later pass to
    /// [`LabeledForwardScope::dual`] to retrieve the seeded `LabeledDual`.
    /// Panics if called after `freeze_dual()` / `freeze()`.
    ///
    /// The handle's internal index is the position of the `(name, value)`
    /// entry in the tape's pending buffer, NOT the final registry index
    /// (the two coincide in the current implementation because every
    /// declared name is unique, but callers MUST treat the handle as
    /// opaque).
    pub fn declare_dual(&mut self, name: &str, value: f64) -> DualHandle {
        assert!(
            !self.frozen,
            "LabeledForwardTape::declare_dual({:?}) called after freeze",
            name
        );
        if !self.builder.contains(name) {
            self.builder.insert(name.to_string());
        }
        let idx = self.pending_dual.len();
        self.pending_dual.push((name.to_string(), value));
        DualHandle { idx }
    }

    /// Declare a named `LabeledDual2<f64>` input, deferring construction
    /// to [`freeze_dual`](Self::freeze_dual).
    ///
    /// Matches [`declare_dual`](Self::declare_dual) but for the seeded
    /// second-order `LabeledDual2<f64>` wrapper. Each declared input
    /// becomes its own seeded `LabeledDual2<f64>` in the returned scope;
    /// `merge_seeded` at operator time enforces the "one active direction
    /// per expression" rule. Panics if called after `freeze_dual()` /
    /// `freeze()`.
    pub fn declare_dual2_f64(&mut self, name: &str, value: f64) -> Dual2Handle {
        assert!(
            !self.frozen,
            "LabeledForwardTape::declare_dual2_f64({:?}) called after freeze",
            name
        );
        if !self.builder.contains(name) {
            self.builder.insert(name.to_string());
        }
        let idx = self.pending_dual2_f64.len();
        self.pending_dual2_f64.push((name.to_string(), value));
        Dual2Handle { idx }
    }

    /// Consume the tape, finalise the registry, construct every pending
    /// `LabeledDual` and `LabeledDual2<f64>` wrapper, stamp the TLS
    /// `ACTIVE_REGISTRY` / `ACTIVE_GEN` slots, and return a
    /// [`LabeledForwardScope`] holding the registry and the fully-seeded
    /// wrappers.
    ///
    /// Panics if called twice (the tape was already frozen via
    /// [`freeze`](Self::freeze)).
    pub fn freeze_dual(mut self) -> LabeledForwardScope {
        assert!(
            !self.frozen,
            "LabeledForwardTape::freeze_dual called after freeze"
        );
        // 1. Finalise the registry.
        let reg = Arc::new(VarRegistry::from_names(self.builder.iter().cloned()));
        self.registry = Some(Arc::clone(&reg));

        // 2. Stamp ACTIVE_REGISTRY save-restore style BEFORE constructing
        //    wrappers so that `LabeledDual::__from_inner` and
        //    `LabeledDual2::__from_parts` read the live generation off
        //    the TLS slot that was set up in `new()` (gen_id stamping via
        //    `current_gen()`). The ACTIVE_REGISTRY pointer itself is only
        //    needed by later `.partial()` / `.first_derivative()` /
        //    `.second_derivative()` lookups.
        let new_ptr: *const VarRegistry = Arc::as_ptr(self.registry.as_ref().unwrap());
        ACTIVE_REGISTRY.with(|c| {
            self.prev_registry = c.get();
            c.set(new_ptr);
        });
        self.frozen = true;

        // 3. Construct LabeledDual wrappers with final n = pending.len().
        //    Each pending entry becomes a seeded Dual::variable whose
        //    `idx` is the name's position in the final registry.
        let n_dual = self.pending_dual.len();
        let mut duals: Vec<crate::labeled::LabeledDual> = Vec::with_capacity(n_dual);
        for (i, (name, value)) in self.pending_dual.iter().enumerate() {
            // Cross-check: the pending idx SHOULD equal the registry idx
            // when every declared name is unique (the common path). If
            // the user declared the same name twice, the registry
            // contains it only once; we still want the wrapper to seed
            // the position the name actually occupies in the registry.
            let reg_idx = reg
                .index_of(name)
                .expect("declared name missing from frozen registry");
            let _ = i; // pending idx retained only for the handle mapping
            let inner = crate::dual::Dual::variable(*value, reg_idx, n_dual);
            duals.push(crate::labeled::LabeledDual::__from_inner(inner));
        }

        // 4. Construct LabeledDual2<f64> wrappers. `Dual2::variable`
        //    takes only a value (single-direction seeded). The labeled
        //    form additionally needs the `seeded: Option<usize>` index
        //    so that `.first_derivative("x")` can distinguish the active
        //    direction from constants.
        let n_dual2 = self.pending_dual2_f64.len();
        let mut dual2s_f64: Vec<crate::labeled::LabeledDual2<f64>> = Vec::with_capacity(n_dual2);
        for (name, value) in self.pending_dual2_f64.iter() {
            let reg_idx = reg
                .index_of(name)
                .expect("declared name missing from frozen registry");
            let inner = crate::dual2::Dual2::<f64>::variable(*value);
            dual2s_f64.push(crate::labeled::LabeledDual2::<f64>::__from_parts(
                inner,
                Some(reg_idx),
            ));
        }

        // 5. Transfer save-restore state into the scope. The tape's own
        //    `Drop` will still fire when this function returns (because
        //    `self` is moved into the function and goes out of scope at
        //    the end); we DISARM that Drop by:
        //      - setting `self.frozen = false` so the ACTIVE_REGISTRY
        //        restore branch is skipped, AND
        //      - setting `self.prev_gen` to the CURRENT ACTIVE_GEN so
        //        the debug-mode ACTIVE_GEN restore is a no-op (it writes
        //        back the same value that's already there).
        //    The scope takes over the real save-restore responsibility
        //    via its own `Drop` impl.
        let prev_registry = self.prev_registry;
        #[cfg(debug_assertions)]
        let prev_gen = self.prev_gen;

        self.frozen = false;
        self.prev_registry = std::ptr::null();
        #[cfg(debug_assertions)]
        {
            self.prev_gen = ACTIVE_GEN.with(|c| c.get());
        }

        LabeledForwardScope {
            registry: reg,
            duals,
            dual2s_f64,
            prev_registry,
            #[cfg(debug_assertions)]
            prev_gen,
            _not_send: PhantomData,
        }
    }
}

impl Default for LabeledForwardTape {
    fn default() -> Self {
        Self::new()
    }
}

// ============ Shape 2 API: handle + scope types ============

/// Opaque `Copy` handle returned by
/// [`LabeledForwardTape::declare_dual`], resolved to a
/// `&LabeledDual` via [`LabeledForwardScope::dual`] after
/// [`LabeledForwardTape::freeze_dual`].
///
/// The handle is a lightweight index — not a pointer, no lifetime — so
/// it can be captured by closures, stored in structs, and moved around
/// freely until the owning `LabeledForwardScope` is dropped.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DualHandle {
    idx: usize,
}

/// Opaque `Copy` handle returned by
/// [`LabeledForwardTape::declare_dual2_f64`], resolved to a
/// `&LabeledDual2<f64>` via [`LabeledForwardScope::dual2`] after
/// [`LabeledForwardTape::freeze_dual`].
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Dual2Handle {
    idx: usize,
}

/// Frozen forward-mode scope. Holds the finalised `Arc<VarRegistry>`
/// plus the per-handle `LabeledDual` / `LabeledDual2<f64>` wrapper
/// vectors produced by [`LabeledForwardTape::freeze_dual`].
///
/// The scope object:
///
/// - Owns the `Arc<VarRegistry>` that keeps the registry allocation
///   alive for the duration of the forward pass.
/// - Owns the TLS save-restore state inherited from the parent tape;
///   its `Drop` impl restores `ACTIVE_REGISTRY` and `ACTIVE_GEN` to
///   the values that were live BEFORE the parent tape was created.
/// - Is `!Send` via `PhantomData<*const ()>` for the same reason
///   `LabeledForwardTape` is `!Send`: moving the TLS state across
///   threads would corrupt the active-registry / active-gen contract.
///
/// Users do NOT construct `LabeledForwardScope` directly — it is the
/// return value of `ft.freeze_dual()`.
pub struct LabeledForwardScope {
    registry: Arc<VarRegistry>,
    duals: Vec<crate::labeled::LabeledDual>,
    dual2s_f64: Vec<crate::labeled::LabeledDual2<f64>>,
    /// Inherited from the parent `LabeledForwardTape` — the
    /// `ACTIVE_REGISTRY` TLS value that was live BEFORE the tape was
    /// frozen. `Drop` restores this so nested scopes unwind cleanly.
    prev_registry: *const VarRegistry,
    /// Inherited from the parent `LabeledForwardTape` — the
    /// `ACTIVE_GEN` TLS value that was live BEFORE the tape was
    /// constructed. `Drop` restores this in debug builds.
    #[cfg(debug_assertions)]
    prev_gen: u64,
    _not_send: PhantomData<*const ()>,
}

impl LabeledForwardScope {
    /// Retrieve the `LabeledDual` wrapper for a handle returned by
    /// [`LabeledForwardTape::declare_dual`].
    ///
    /// The returned reference borrows from the scope and is valid as
    /// long as the scope lives. Panics if the handle is out of range
    /// (this can only happen if the user constructed a `DualHandle`
    /// from a different tape and passed it in — a programmer error).
    #[inline]
    pub fn dual(&self, handle: DualHandle) -> &crate::labeled::LabeledDual {
        &self.duals[handle.idx]
    }

    /// Retrieve the `LabeledDual2<f64>` wrapper for a handle returned
    /// by [`LabeledForwardTape::declare_dual2_f64`].
    ///
    /// The returned reference borrows from the scope and is valid as
    /// long as the scope lives.
    #[inline]
    pub fn dual2(&self, handle: Dual2Handle) -> &crate::labeled::LabeledDual2<f64> {
        &self.dual2s_f64[handle.idx]
    }

    /// Shared access to the frozen registry. Useful when user code
    /// wants to iterate names for output labeling or similar.
    #[inline]
    pub fn registry(&self) -> &Arc<VarRegistry> {
        &self.registry
    }

    /// Construct a derivative-free `LabeledDual` constant scoped to this
    /// forward scope's registry. The resulting wrapper has a zero
    /// gradient of length `registry.len()` and is stamped with the
    /// scope's active TLS generation.
    ///
    /// Unlike [`dual`](Self::dual), the returned value is owned, not
    /// borrowed from the scope — constants are typically used inside
    /// loops where a fresh value is produced per iteration.
    #[inline]
    pub fn constant_dual(&self, value: f64) -> crate::labeled::LabeledDual {
        let inner = crate::dual::Dual::constant(value, self.registry.len());
        crate::labeled::LabeledDual::__from_inner(inner)
    }

    /// Construct a derivative-free `LabeledDual2<f64>` constant scoped
    /// to this forward scope's registry. The resulting wrapper has
    /// `seeded = None` (no active direction) and is stamped with the
    /// scope's active TLS generation.
    #[inline]
    pub fn constant_dual2_f64(&self, value: f64) -> crate::labeled::LabeledDual2<f64> {
        let inner = crate::dual2::Dual2::<f64>::constant(value);
        crate::labeled::LabeledDual2::<f64>::__from_parts(inner, None)
    }
}

impl fmt::Debug for LabeledForwardScope {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LabeledForwardScope")
            .field("registry_len", &self.registry.len())
            .field("duals", &self.duals.len())
            .field("dual2s_f64", &self.dual2s_f64.len())
            .finish()
    }
}

impl Drop for LabeledForwardScope {
    fn drop(&mut self) {
        // Save-restore: mirror `LabeledForwardTape::Drop` but unconditional
        // (the scope is always "frozen" from construction).
        ACTIVE_REGISTRY.with(|c| c.set(self.prev_registry));
        #[cfg(debug_assertions)]
        ACTIVE_GEN.with(|c| c.set(self.prev_gen));
    }
}

impl fmt::Debug for LabeledForwardTape {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LabeledForwardTape")
            .field("frozen", &self.frozen)
            .field("inputs", &self.builder.len())
            .field(
                "registry_len",
                &self.registry.as_ref().map(|r: &Arc<VarRegistry>| r.len()),
            )
            .finish()
    }
}

impl Drop for LabeledForwardTape {
    fn drop(&mut self) {
        // Save-restore discipline: restore the previous ACTIVE_REGISTRY
        // and ACTIVE_GEN values so nested `LabeledForwardTape` scopes
        // unwind correctly.
        if self.frozen {
            ACTIVE_REGISTRY.with(|c| c.set(self.prev_registry));
        }
        #[cfg(debug_assertions)]
        ACTIVE_GEN.with(|c| c.set(self.prev_gen));
    }
}