Skip to main content

embeddenator_vsa/
coherency.rs

1//! Host-Device Coherency Protocol for GPU VRAM
2//!
3//! This module implements a coherency protocol for maintaining consistency
4//! between CPU (host) and GPU (device) memory for engrams.
5//!
6//! # Design
7//!
8//! The protocol uses a state machine with dirty bits to track modifications:
9//!
10//! ```text
11//! Host-Resident → (upload) → Device-Resident → (download) → Host-Resident
12//!       ↓                           ↓
13//!  (modify host)              (modify device)
14//!       ↓                           ↓
15//! Host-Dirty ←── (sync) ──→ Device-Dirty
16//! ```
17//!
18//! # States
19//!
20//! - `HostResident`: Data is current on host, may or may not be on device
21//! - `DeviceResident`: Data is current on device, host may be stale
22//! - `HostDirty`: Host has been modified, device is stale
23//! - `DeviceDirty`: Device has been modified, host is stale
24//! - `Synced`: Both host and device have identical data
25//!
26//! # Example
27//!
28//! ```rust,ignore
29//! use embeddenator_vsa::{CoherencyState, CoherentEngram};
30//!
31//! let mut engram = CoherentEngram::new(data);
32//!
33//! // Upload to device
34//! engram.upload_to_device(&pool)?;
35//! assert!(engram.state() == CoherencyState::Synced);
36//!
37//! // Modify on device
38//! engram.mark_device_dirty();
39//! assert!(engram.state() == CoherencyState::DeviceDirty);
40//!
41//! // Sync back to host
42//! engram.sync_to_host(&pool)?;
43//! assert!(engram.state() == CoherencyState::Synced);
44//! ```
45
46use std::sync::atomic::{AtomicU64, Ordering};
47
48#[cfg(feature = "cuda")]
49use crate::gpu::GpuError;
50#[cfg(feature = "cuda")]
51use crate::vram_pool::{VramHandle, VramPool};
52use crate::vsa::SparseVec;
53
54/// Error type for coherency operations (used when cuda feature is disabled)
55#[cfg(not(feature = "cuda"))]
56#[derive(Debug, Clone)]
57pub enum GpuError {
58    /// GPU not available
59    NotAvailable,
60}
61
62/// Coherency state for host-device memory
63#[derive(Clone, Copy, Debug, PartialEq, Eq)]
64pub enum CoherencyState {
65    /// Data only exists on host
66    HostOnly,
67    /// Data only exists on device
68    DeviceOnly,
69    /// Host and device are synchronized (identical)
70    Synced,
71    /// Host has been modified, device is stale
72    HostDirty,
73    /// Device has been modified, host is stale
74    DeviceDirty,
75}
76
77impl CoherencyState {
78    /// Check if host data is current
79    pub fn host_is_current(&self) -> bool {
80        matches!(
81            self,
82            CoherencyState::HostOnly | CoherencyState::Synced | CoherencyState::HostDirty
83        )
84    }
85
86    /// Check if device data is current
87    pub fn device_is_current(&self) -> bool {
88        matches!(
89            self,
90            CoherencyState::DeviceOnly | CoherencyState::Synced | CoherencyState::DeviceDirty
91        )
92    }
93
94    /// Check if synchronization is needed
95    pub fn needs_sync(&self) -> bool {
96        matches!(
97            self,
98            CoherencyState::HostDirty | CoherencyState::DeviceDirty
99        )
100    }
101}
102
103/// A coherent engram that can live on both host and device
104///
105/// This type manages the synchronization state between CPU and GPU memory
106/// for a SparseVec engram.
107#[derive(Debug)]
108pub struct CoherentEngram {
109    /// Host-side data (always present as fallback)
110    host_data: Vec<u8>,
111    /// VRAM handle (if uploaded to device)
112    #[cfg(feature = "cuda")]
113    device_handle: Option<VramHandle>,
114    #[cfg(not(feature = "cuda"))]
115    device_handle: Option<()>,
116    /// Current coherency state
117    state: CoherencyState,
118    /// Version counter for optimistic concurrency
119    version: AtomicU64,
120}
121
122impl CoherentEngram {
123    /// Create a new coherent engram from serialized data
124    pub fn new(data: Vec<u8>) -> Self {
125        Self {
126            host_data: data,
127            device_handle: None,
128            state: CoherencyState::HostOnly,
129            version: AtomicU64::new(0),
130        }
131    }
132
133    /// Create from a SparseVec by serializing it
134    pub fn from_sparse_vec(vec: &SparseVec) -> Self {
135        // Simple serialization: pos length, neg length, then pos dims, then neg dims
136        let mut data = Vec::new();
137
138        // Write lengths as u32
139        let pos_len = vec.pos.len() as u32;
140        let neg_len = vec.neg.len() as u32;
141        data.extend_from_slice(&pos_len.to_le_bytes());
142        data.extend_from_slice(&neg_len.to_le_bytes());
143
144        // Write pos dimensions as u32
145        for &dim in &vec.pos {
146            data.extend_from_slice(&(dim as u32).to_le_bytes());
147        }
148
149        // Write neg dimensions as u32
150        for &dim in &vec.neg {
151            data.extend_from_slice(&(dim as u32).to_le_bytes());
152        }
153
154        Self::new(data)
155    }
156
157    /// Deserialize back to SparseVec
158    pub fn to_sparse_vec(&self) -> Option<SparseVec> {
159        if self.host_data.len() < 8 {
160            return None;
161        }
162
163        let pos_len = u32::from_le_bytes([
164            self.host_data[0],
165            self.host_data[1],
166            self.host_data[2],
167            self.host_data[3],
168        ]) as usize;
169        let neg_len = u32::from_le_bytes([
170            self.host_data[4],
171            self.host_data[5],
172            self.host_data[6],
173            self.host_data[7],
174        ]) as usize;
175
176        let expected_size = 8 + (pos_len + neg_len) * 4;
177        if self.host_data.len() < expected_size {
178            return None;
179        }
180
181        let mut pos = Vec::with_capacity(pos_len);
182        let mut offset = 8;
183        for _ in 0..pos_len {
184            let dim = u32::from_le_bytes([
185                self.host_data[offset],
186                self.host_data[offset + 1],
187                self.host_data[offset + 2],
188                self.host_data[offset + 3],
189            ]) as usize;
190            pos.push(dim);
191            offset += 4;
192        }
193
194        let mut neg = Vec::with_capacity(neg_len);
195        for _ in 0..neg_len {
196            let dim = u32::from_le_bytes([
197                self.host_data[offset],
198                self.host_data[offset + 1],
199                self.host_data[offset + 2],
200                self.host_data[offset + 3],
201            ]) as usize;
202            neg.push(dim);
203            offset += 4;
204        }
205
206        Some(SparseVec { pos, neg })
207    }
208
209    /// Get current coherency state
210    pub fn state(&self) -> CoherencyState {
211        self.state
212    }
213
214    /// Get version number
215    pub fn version(&self) -> u64 {
216        self.version.load(Ordering::Acquire)
217    }
218
219    /// Increment version
220    fn bump_version(&self) {
221        self.version.fetch_add(1, Ordering::AcqRel);
222    }
223
224    /// Get host data (may be stale if DeviceDirty)
225    pub fn host_data(&self) -> &[u8] {
226        &self.host_data
227    }
228
229    /// Get mutable host data and mark as HostDirty
230    pub fn host_data_mut(&mut self) -> &mut Vec<u8> {
231        self.state = match self.state {
232            CoherencyState::HostOnly => CoherencyState::HostOnly,
233            CoherencyState::Synced => CoherencyState::HostDirty,
234            CoherencyState::HostDirty => CoherencyState::HostDirty,
235            _ => CoherencyState::HostDirty,
236        };
237        self.bump_version();
238        &mut self.host_data
239    }
240
241    /// Check if data is on device
242    pub fn is_on_device(&self) -> bool {
243        self.device_handle.is_some()
244    }
245
246    /// Upload host data to device
247    #[cfg(feature = "cuda")]
248    pub fn upload_to_device(&mut self, pool: &VramPool) -> Result<(), GpuError> {
249        // Allocate if not already on device
250        let handle = if let Some(h) = self.device_handle {
251            h
252        } else {
253            let h = pool.allocate(self.host_data.len())?;
254            self.device_handle = Some(h);
255            h
256        };
257
258        // Upload data
259        pool.upload(&handle, &self.host_data)?;
260
261        // Update state
262        self.state = CoherencyState::Synced;
263        Ok(())
264    }
265
266    /// Download device data to host
267    #[cfg(feature = "cuda")]
268    pub fn download_to_host(&mut self, pool: &VramPool) -> Result<(), GpuError> {
269        let handle = self
270            .device_handle
271            .ok_or_else(|| GpuError::InvalidValue("No device handle".to_string()))?;
272
273        self.host_data = pool.download(&handle)?;
274        self.state = CoherencyState::Synced;
275        self.bump_version();
276        Ok(())
277    }
278
279    /// Sync: ensure host and device are consistent
280    #[cfg(feature = "cuda")]
281    pub fn sync(&mut self, pool: &VramPool) -> Result<(), GpuError> {
282        match self.state {
283            CoherencyState::HostDirty => {
284                // Upload host changes to device
285                if self.device_handle.is_some() {
286                    self.upload_to_device(pool)?;
287                }
288                // If not on device, nothing to sync
289                self.state = if self.device_handle.is_some() {
290                    CoherencyState::Synced
291                } else {
292                    CoherencyState::HostOnly
293                };
294            }
295            CoherencyState::DeviceDirty => {
296                // Download device changes to host
297                self.download_to_host(pool)?;
298            }
299            _ => {
300                // Already synced or single-location
301            }
302        }
303        Ok(())
304    }
305
306    /// Mark device data as dirty (modified on GPU)
307    pub fn mark_device_dirty(&mut self) {
308        if self.device_handle.is_some() {
309            self.state = CoherencyState::DeviceDirty;
310            self.bump_version();
311        }
312    }
313
314    /// Release device memory
315    #[cfg(feature = "cuda")]
316    pub fn release_device(&mut self, pool: &VramPool) -> Result<(), GpuError> {
317        if let Some(handle) = self.device_handle.take() {
318            // Ensure we have current data on host first
319            if self.state == CoherencyState::DeviceDirty {
320                let data = pool.download(&handle)?;
321                self.host_data = data;
322            }
323            pool.free(handle)?;
324            self.state = CoherencyState::HostOnly;
325        }
326        Ok(())
327    }
328
329    /// Get device handle (if on device)
330    #[cfg(feature = "cuda")]
331    pub fn device_handle(&self) -> Option<VramHandle> {
332        self.device_handle
333    }
334}
335
336// Stubs for non-CUDA builds
337#[cfg(not(feature = "cuda"))]
338impl CoherentEngram {
339    pub fn upload_to_device(&mut self, _pool: &()) -> Result<(), GpuError> {
340        Err(GpuError::NotAvailable)
341    }
342
343    pub fn download_to_host(&mut self, _pool: &()) -> Result<(), GpuError> {
344        Err(GpuError::NotAvailable)
345    }
346
347    pub fn sync(&mut self, _pool: &()) -> Result<(), GpuError> {
348        Err(GpuError::NotAvailable)
349    }
350
351    pub fn release_device(&mut self, _pool: &()) -> Result<(), GpuError> {
352        Ok(())
353    }
354}
355
356/// Coherency manager for multiple engrams
357///
358/// Manages a collection of coherent engrams with batched sync operations.
359#[derive(Default)]
360pub struct CoherencyManager {
361    /// Tracked engrams by ID
362    engrams: std::collections::HashMap<u64, CoherentEngram>,
363    /// Next engram ID
364    next_id: AtomicU64,
365}
366
367impl CoherencyManager {
368    /// Create a new coherency manager
369    pub fn new() -> Self {
370        Self {
371            engrams: std::collections::HashMap::new(),
372            next_id: AtomicU64::new(1),
373        }
374    }
375
376    /// Register a new engram
377    pub fn register(&mut self, engram: CoherentEngram) -> u64 {
378        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
379        self.engrams.insert(id, engram);
380        id
381    }
382
383    /// Get an engram by ID
384    pub fn get(&self, id: u64) -> Option<&CoherentEngram> {
385        self.engrams.get(&id)
386    }
387
388    /// Get mutable engram by ID
389    pub fn get_mut(&mut self, id: u64) -> Option<&mut CoherentEngram> {
390        self.engrams.get_mut(&id)
391    }
392
393    /// Remove an engram
394    pub fn remove(&mut self, id: u64) -> Option<CoherentEngram> {
395        self.engrams.remove(&id)
396    }
397
398    /// Sync all dirty engrams
399    #[cfg(feature = "cuda")]
400    pub fn sync_all(&mut self, pool: &VramPool) -> Result<(), GpuError> {
401        for engram in self.engrams.values_mut() {
402            if engram.state().needs_sync() {
403                engram.sync(pool)?;
404            }
405        }
406        Ok(())
407    }
408
409    /// Get statistics
410    pub fn stats(&self) -> CoherencyStats {
411        let total = self.engrams.len();
412        let host_only = self
413            .engrams
414            .values()
415            .filter(|e| e.state() == CoherencyState::HostOnly)
416            .count();
417        let device_only = self
418            .engrams
419            .values()
420            .filter(|e| e.state() == CoherencyState::DeviceOnly)
421            .count();
422        let synced = self
423            .engrams
424            .values()
425            .filter(|e| e.state() == CoherencyState::Synced)
426            .count();
427        let dirty = self
428            .engrams
429            .values()
430            .filter(|e| e.state().needs_sync())
431            .count();
432
433        CoherencyStats {
434            total,
435            host_only,
436            device_only,
437            synced,
438            dirty,
439        }
440    }
441}
442
443/// Statistics about coherency state
444#[derive(Clone, Debug, Default)]
445pub struct CoherencyStats {
446    /// Total engrams tracked
447    pub total: usize,
448    /// Host-only engrams
449    pub host_only: usize,
450    /// Device-only engrams
451    pub device_only: usize,
452    /// Synced engrams
453    pub synced: usize,
454    /// Dirty engrams needing sync
455    pub dirty: usize,
456}
457
458// ============================================================================
459// Multi-Tier Coherency Protocol (#48)
460// ============================================================================
461
462/// Memory tier identifier for multi-tier coherency
463#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
464#[repr(u8)]
465pub enum Tier {
466    /// GPU VRAM
467    Vram = 0,
468    /// Host RAM
469    Host = 1,
470    /// Disk storage
471    Disk = 2,
472}
473
474impl Tier {
475    /// Get tier priority (lower = faster)
476    pub fn priority(&self) -> u8 {
477        match self {
478            Tier::Vram => 0,
479            Tier::Host => 1,
480            Tier::Disk => 2,
481        }
482    }
483}
484
485/// Bitmask tracking which tiers have valid copies of data
486#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
487pub struct TierMask(u8);
488
489impl TierMask {
490    /// Empty mask (no valid copies)
491    pub const NONE: TierMask = TierMask(0);
492    /// VRAM tier bit
493    pub const VRAM: TierMask = TierMask(1 << 0);
494    /// Host tier bit
495    pub const HOST: TierMask = TierMask(1 << 1);
496    /// Disk tier bit
497    pub const DISK: TierMask = TierMask(1 << 2);
498
499    /// Create a new mask from a tier
500    pub fn from_tier(tier: Tier) -> Self {
501        match tier {
502            Tier::Vram => Self::VRAM,
503            Tier::Host => Self::HOST,
504            Tier::Disk => Self::DISK,
505        }
506    }
507
508    /// Check if tier has valid copy
509    pub fn has(&self, tier: Tier) -> bool {
510        let bit = match tier {
511            Tier::Vram => Self::VRAM.0,
512            Tier::Host => Self::HOST.0,
513            Tier::Disk => Self::DISK.0,
514        };
515        (self.0 & bit) != 0
516    }
517
518    /// Add a tier to the mask
519    pub fn add(&mut self, tier: Tier) {
520        self.0 |= TierMask::from_tier(tier).0;
521    }
522
523    /// Remove a tier from the mask
524    pub fn remove(&mut self, tier: Tier) {
525        self.0 &= !TierMask::from_tier(tier).0;
526    }
527
528    /// Union of two masks
529    pub fn union(&self, other: TierMask) -> TierMask {
530        TierMask(self.0 | other.0)
531    }
532
533    /// Count number of valid tiers
534    pub fn count(&self) -> u32 {
535        self.0.count_ones()
536    }
537
538    /// Check if any tier has valid copy
539    pub fn any(&self) -> bool {
540        self.0 != 0
541    }
542
543    /// Get the fastest tier with valid copy
544    pub fn fastest(&self) -> Option<Tier> {
545        if self.has(Tier::Vram) {
546            Some(Tier::Vram)
547        } else if self.has(Tier::Host) {
548            Some(Tier::Host)
549        } else if self.has(Tier::Disk) {
550            Some(Tier::Disk)
551        } else {
552            None
553        }
554    }
555
556    /// Iterator over valid tiers
557    pub fn iter(&self) -> impl Iterator<Item = Tier> {
558        let mask = *self;
559        [Tier::Vram, Tier::Host, Tier::Disk]
560            .into_iter()
561            .filter(move |&t| mask.has(t))
562    }
563}
564
565/// Write policy for multi-tier coherency
566#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
567pub enum WritePolicy {
568    /// Write to fastest tier only, mark others stale
569    #[default]
570    WriteBack,
571    /// Write to fastest tier and immediately propagate to home tier
572    WriteThrough,
573    /// Write to all tiers with valid copies
574    WriteAll,
575}
576
577/// Extended coherency state for multi-tier systems
578#[derive(Clone, Debug)]
579pub struct TieredState {
580    /// Mask of tiers with valid copies
581    valid: TierMask,
582    /// Mask of tiers with dirty (modified) data
583    dirty: TierMask,
584    /// The tier that was most recently written
585    owner: Option<Tier>,
586    /// Home tier (where data should be persisted)
587    home: Tier,
588    /// Current epoch for sync tracking
589    epoch: u64,
590}
591
592impl TieredState {
593    /// Create a new state with data on one tier
594    pub fn new(tier: Tier, home: Tier) -> Self {
595        Self {
596            valid: TierMask::from_tier(tier),
597            dirty: TierMask::NONE,
598            owner: Some(tier),
599            home,
600            epoch: 0,
601        }
602    }
603
604    /// Create state for host-resident data
605    pub fn host_resident() -> Self {
606        Self::new(Tier::Host, Tier::Host)
607    }
608
609    /// Create state for disk-resident data
610    pub fn disk_resident() -> Self {
611        Self::new(Tier::Disk, Tier::Disk)
612    }
613
614    /// Check if tier has valid copy
615    pub fn is_valid(&self, tier: Tier) -> bool {
616        self.valid.has(tier)
617    }
618
619    /// Check if tier is dirty
620    pub fn is_dirty(&self, tier: Tier) -> bool {
621        self.dirty.has(tier)
622    }
623
624    /// Check if any tier needs sync
625    pub fn needs_sync(&self) -> bool {
626        self.dirty.any()
627    }
628
629    /// Get fastest tier with valid data
630    pub fn fastest_valid(&self) -> Option<Tier> {
631        self.valid.fastest()
632    }
633
634    /// Get the owner tier (most recent write)
635    pub fn owner(&self) -> Option<Tier> {
636        self.owner
637    }
638
639    /// Get current epoch
640    pub fn epoch(&self) -> u64 {
641        self.epoch
642    }
643
644    /// Check whether the given tier currently has a valid copy of the data.
645    pub fn has_valid_copy(&self, tier: Tier) -> bool {
646        self.valid.has(tier)
647    }
648
649    /// Record a write to a tier
650    pub fn record_write(&mut self, tier: Tier, policy: WritePolicy) {
651        match policy {
652            WritePolicy::WriteBack => {
653                // Invalidate all other copies
654                self.valid = TierMask::from_tier(tier);
655                self.dirty = TierMask::NONE;
656                self.dirty.add(tier);
657            }
658            WritePolicy::WriteThrough => {
659                // Keep valid on writer and home tier
660                self.valid = TierMask::from_tier(tier);
661                self.valid.add(self.home);
662                self.dirty = TierMask::NONE;
663            }
664            WritePolicy::WriteAll => {
665                // All valid copies remain valid; ensure writer tier is marked valid
666                self.valid.add(tier);
667                self.dirty = TierMask::NONE;
668            }
669        }
670        self.owner = Some(tier);
671        self.epoch += 1;
672    }
673
674    /// Mark a tier as synced (has current data)
675    pub fn mark_synced(&mut self, tier: Tier) {
676        self.valid.add(tier);
677        self.dirty.remove(tier);
678    }
679
680    /// Invalidate a tier (remove from valid set)
681    pub fn invalidate(&mut self, tier: Tier) {
682        self.valid.remove(tier);
683        self.dirty.remove(tier);
684        if self.owner == Some(tier) {
685            self.owner = self.valid.fastest();
686        }
687    }
688
689    /// Sync dirty data from owner to target tier
690    /// Returns true if sync was performed
691    pub fn sync_to(&mut self, target: Tier) -> bool {
692        if !self.needs_sync() || self.is_valid(target) {
693            return false;
694        }
695        self.valid.add(target);
696        if target == self.home {
697            self.dirty = TierMask::NONE;
698        }
699        true
700    }
701
702    /// Get tiers that need sync
703    pub fn tiers_needing_sync(&self) -> Vec<Tier> {
704        if !self.needs_sync() {
705            return vec![];
706        }
707        // If dirty, we need to sync to home tier
708        if !self.is_valid(self.home) {
709            vec![self.home]
710        } else {
711            vec![]
712        }
713    }
714}
715
716/// Multi-tier coherent data block
717#[derive(Debug)]
718pub struct TieredBlock {
719    /// Block ID
720    pub id: u64,
721    /// Size in bytes
722    pub size: usize,
723    /// Coherency state
724    state: TieredState,
725    /// Version for optimistic concurrency
726    version: AtomicU64,
727}
728
729impl TieredBlock {
730    /// Create a new block on host tier
731    pub fn new_host(id: u64, size: usize) -> Self {
732        Self {
733            id,
734            size,
735            state: TieredState::host_resident(),
736            version: AtomicU64::new(0),
737        }
738    }
739
740    /// Create a new block on disk tier
741    pub fn new_disk(id: u64, size: usize) -> Self {
742        Self {
743            id,
744            size,
745            state: TieredState::disk_resident(),
746            version: AtomicU64::new(0),
747        }
748    }
749
750    /// Get current state
751    pub fn state(&self) -> &TieredState {
752        &self.state
753    }
754
755    /// Get mutable state
756    pub fn state_mut(&mut self) -> &mut TieredState {
757        &mut self.state
758    }
759
760    /// Get version
761    pub fn version(&self) -> u64 {
762        self.version.load(Ordering::Acquire)
763    }
764
765    /// Bump version
766    pub fn bump_version(&self) {
767        self.version.fetch_add(1, Ordering::AcqRel);
768    }
769
770    /// Record a read from tier
771    /// Check if a tier has a valid copy for reading
772    pub fn read(&self, tier: Tier) -> bool {
773        self.state.has_valid_copy(tier)
774    }
775
776    /// Record a write to tier
777    pub fn write(&mut self, tier: Tier, policy: WritePolicy) {
778        self.state.record_write(tier, policy);
779        self.bump_version();
780    }
781
782    /// Check if sync is needed
783    pub fn needs_sync(&self) -> bool {
784        self.state.needs_sync()
785    }
786}
787
788/// Sync protocol for batched coherency operations
789#[derive(Debug, Default)]
790pub struct SyncProtocol {
791    /// Current global epoch
792    epoch: AtomicU64,
793    /// Write policy
794    policy: WritePolicy,
795    /// Pending syncs (block_id, source_tier, target_tier)
796    pending: std::sync::RwLock<Vec<(u64, Tier, Tier)>>,
797}
798
799impl SyncProtocol {
800    /// Create with default write policy
801    pub fn new() -> Self {
802        Self {
803            epoch: AtomicU64::new(0),
804            policy: WritePolicy::default(),
805            pending: std::sync::RwLock::new(Vec::new()),
806        }
807    }
808
809    /// Create with specific write policy
810    pub fn with_policy(policy: WritePolicy) -> Self {
811        Self {
812            epoch: AtomicU64::new(0),
813            policy,
814            pending: std::sync::RwLock::new(Vec::new()),
815        }
816    }
817
818    /// Get current epoch
819    pub fn epoch(&self) -> u64 {
820        self.epoch.load(Ordering::Acquire)
821    }
822
823    /// Advance epoch (called after barrier sync)
824    pub fn advance_epoch(&self) -> u64 {
825        self.epoch.fetch_add(1, Ordering::AcqRel)
826    }
827
828    /// Get write policy
829    pub fn policy(&self) -> WritePolicy {
830        self.policy
831    }
832
833    /// Queue a sync operation
834    pub fn queue_sync(&self, block_id: u64, from: Tier, to: Tier) {
835        let mut pending = self
836            .pending
837            .write()
838            .expect("SyncProtocol pending lock poisoned in queue_sync");
839        pending.push((block_id, from, to));
840    }
841
842    /// Get and clear pending syncs
843    pub fn drain_pending(&self) -> Vec<(u64, Tier, Tier)> {
844        let mut pending = self
845            .pending
846            .write()
847            .expect("SyncProtocol pending lock poisoned in drain_pending");
848        std::mem::take(&mut *pending)
849    }
850
851    /// Check if any syncs are pending
852    pub fn has_pending(&self) -> bool {
853        let pending = self
854            .pending
855            .read()
856            .expect("SyncProtocol pending lock poisoned in has_pending");
857        !pending.is_empty()
858    }
859
860    /// Count pending syncs
861    pub fn pending_count(&self) -> usize {
862        let pending = self
863            .pending
864            .read()
865            .expect("SyncProtocol pending lock poisoned in pending_count");
866        pending.len()
867    }
868}
869
870/// Statistics for multi-tier coherency
871#[derive(Clone, Debug, Default)]
872pub struct TieredCoherencyStats {
873    /// Total blocks tracked
874    pub total_blocks: usize,
875    /// Blocks with VRAM copy
876    pub vram_copies: usize,
877    /// Blocks with host copy
878    pub host_copies: usize,
879    /// Blocks with disk copy
880    pub disk_copies: usize,
881    /// Blocks needing sync
882    pub dirty_blocks: usize,
883    /// Total syncs performed
884    pub sync_count: u64,
885    /// Current epoch
886    pub epoch: u64,
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892
893    #[test]
894    fn test_coherency_state_checks() {
895        assert!(CoherencyState::HostOnly.host_is_current());
896        assert!(!CoherencyState::HostOnly.device_is_current());
897
898        assert!(CoherencyState::Synced.host_is_current());
899        assert!(CoherencyState::Synced.device_is_current());
900
901        assert!(CoherencyState::HostDirty.needs_sync());
902        assert!(CoherencyState::DeviceDirty.needs_sync());
903        assert!(!CoherencyState::Synced.needs_sync());
904    }
905
906    #[test]
907    fn test_coherent_engram_new() {
908        let data = vec![1, 2, 3, 4];
909        let engram = CoherentEngram::new(data.clone());
910
911        assert_eq!(engram.state(), CoherencyState::HostOnly);
912        assert_eq!(engram.host_data(), &data);
913        assert!(!engram.is_on_device());
914    }
915
916    #[test]
917    fn test_coherent_engram_sparse_vec_roundtrip() {
918        let vec = SparseVec {
919            pos: vec![1, 5, 10, 100],
920            neg: vec![2, 7, 50],
921        };
922
923        let engram = CoherentEngram::from_sparse_vec(&vec);
924        let recovered = engram.to_sparse_vec().unwrap();
925
926        assert_eq!(recovered.pos, vec.pos);
927        assert_eq!(recovered.neg, vec.neg);
928    }
929
930    #[test]
931    fn test_coherent_engram_modify_marks_dirty() {
932        let mut engram = CoherentEngram::new(vec![1, 2, 3]);
933        assert_eq!(engram.state(), CoherencyState::HostOnly);
934
935        // Modifying host data when Synced should mark as HostDirty
936        engram.state = CoherencyState::Synced;
937        let _ = engram.host_data_mut();
938        assert_eq!(engram.state(), CoherencyState::HostDirty);
939    }
940
941    #[test]
942    fn test_coherency_manager() {
943        let mut manager = CoherencyManager::new();
944
945        let e1 = CoherentEngram::new(vec![1, 2, 3]);
946        let e2 = CoherentEngram::new(vec![4, 5, 6]);
947
948        let id1 = manager.register(e1);
949        let id2 = manager.register(e2);
950
951        assert!(manager.get(id1).is_some());
952        assert!(manager.get(id2).is_some());
953        assert!(manager.get(999).is_none());
954
955        let stats = manager.stats();
956        assert_eq!(stats.total, 2);
957        assert_eq!(stats.host_only, 2);
958    }
959
960    // Multi-tier coherency tests (#48)
961
962    #[test]
963    fn test_tier_priority() {
964        assert_eq!(Tier::Vram.priority(), 0);
965        assert_eq!(Tier::Host.priority(), 1);
966        assert_eq!(Tier::Disk.priority(), 2);
967    }
968
969    #[test]
970    fn test_tier_mask_operations() {
971        let mut mask = TierMask::NONE;
972        assert!(!mask.any());
973        assert_eq!(mask.count(), 0);
974
975        mask.add(Tier::Host);
976        assert!(mask.has(Tier::Host));
977        assert!(!mask.has(Tier::Vram));
978        assert_eq!(mask.count(), 1);
979
980        mask.add(Tier::Vram);
981        assert!(mask.has(Tier::Host));
982        assert!(mask.has(Tier::Vram));
983        assert_eq!(mask.count(), 2);
984
985        mask.remove(Tier::Host);
986        assert!(!mask.has(Tier::Host));
987        assert!(mask.has(Tier::Vram));
988        assert_eq!(mask.count(), 1);
989    }
990
991    #[test]
992    fn test_tier_mask_fastest() {
993        let mut mask = TierMask::NONE;
994        assert_eq!(mask.fastest(), None);
995
996        mask.add(Tier::Disk);
997        assert_eq!(mask.fastest(), Some(Tier::Disk));
998
999        mask.add(Tier::Host);
1000        assert_eq!(mask.fastest(), Some(Tier::Host));
1001
1002        mask.add(Tier::Vram);
1003        assert_eq!(mask.fastest(), Some(Tier::Vram));
1004    }
1005
1006    #[test]
1007    fn test_tier_mask_union() {
1008        let a = TierMask::HOST;
1009        let b = TierMask::DISK;
1010        let union = a.union(b);
1011
1012        assert!(union.has(Tier::Host));
1013        assert!(union.has(Tier::Disk));
1014        assert!(!union.has(Tier::Vram));
1015    }
1016
1017    #[test]
1018    fn test_tiered_state_new() {
1019        let state = TieredState::host_resident();
1020        assert!(state.is_valid(Tier::Host));
1021        assert!(!state.is_valid(Tier::Vram));
1022        assert!(!state.is_valid(Tier::Disk));
1023        assert!(!state.needs_sync());
1024    }
1025
1026    #[test]
1027    fn test_tiered_state_write_back() {
1028        let mut state = TieredState::host_resident();
1029
1030        // Write to VRAM with writeback policy
1031        state.record_write(Tier::Vram, WritePolicy::WriteBack);
1032
1033        assert!(state.is_valid(Tier::Vram));
1034        assert!(!state.is_valid(Tier::Host)); // Invalidated
1035        assert!(state.is_dirty(Tier::Vram));
1036        assert!(state.needs_sync());
1037        assert_eq!(state.owner(), Some(Tier::Vram));
1038    }
1039
1040    #[test]
1041    fn test_tiered_state_write_through() {
1042        let mut state = TieredState::host_resident();
1043
1044        // Write to VRAM with writethrough policy
1045        state.record_write(Tier::Vram, WritePolicy::WriteThrough);
1046
1047        assert!(state.is_valid(Tier::Vram));
1048        assert!(state.is_valid(Tier::Host)); // Still valid (home tier)
1049        assert!(!state.needs_sync()); // No sync needed
1050    }
1051
1052    #[test]
1053    fn test_tiered_state_write_all() {
1054        let mut state = TieredState::host_resident();
1055
1056        // Mark multiple tiers as valid
1057        state.mark_synced(Tier::Vram);
1058        state.mark_synced(Tier::Disk);
1059        assert!(state.is_valid(Tier::Host));
1060        assert!(state.is_valid(Tier::Vram));
1061        assert!(state.is_valid(Tier::Disk));
1062
1063        // Write with WriteAll policy - all copies remain valid
1064        state.record_write(Tier::Host, WritePolicy::WriteAll);
1065
1066        // Writer tier should be valid
1067        assert!(state.is_valid(Tier::Host));
1068        // WriteAll keeps existing valid copies (unlike WriteBack which invalidates)
1069        // Note: WriteAll means the write is broadcast to all valid copies
1070        assert!(!state.needs_sync()); // No sync needed with WriteAll
1071        assert_eq!(state.owner(), Some(Tier::Host));
1072    }
1073
1074    #[test]
1075    fn test_has_valid_copy() {
1076        let state = TieredState::host_resident();
1077
1078        assert!(state.has_valid_copy(Tier::Host));
1079        assert!(!state.has_valid_copy(Tier::Vram));
1080        assert!(!state.has_valid_copy(Tier::Disk));
1081    }
1082
1083    #[test]
1084    fn test_tiered_state_sync() {
1085        let mut state = TieredState::host_resident();
1086        state.record_write(Tier::Vram, WritePolicy::WriteBack);
1087
1088        // Sync to host
1089        let synced = state.sync_to(Tier::Host);
1090        assert!(synced);
1091        assert!(state.is_valid(Tier::Host));
1092        assert!(state.is_valid(Tier::Vram));
1093        assert!(!state.needs_sync());
1094    }
1095
1096    #[test]
1097    fn test_tiered_state_invalidate() {
1098        let mut state = TieredState::host_resident();
1099        state.mark_synced(Tier::Vram);
1100        assert!(state.is_valid(Tier::Vram));
1101
1102        state.invalidate(Tier::Vram);
1103        assert!(!state.is_valid(Tier::Vram));
1104        assert!(state.is_valid(Tier::Host));
1105    }
1106
1107    #[test]
1108    fn test_tiered_block() {
1109        let mut block = TieredBlock::new_host(1, 1024);
1110        assert_eq!(block.id, 1);
1111        assert_eq!(block.size, 1024);
1112        assert_eq!(block.version(), 0);
1113
1114        block.write(Tier::Host, WritePolicy::WriteBack);
1115        assert_eq!(block.version(), 1);
1116        assert!(block.needs_sync());
1117    }
1118
1119    #[test]
1120    fn test_sync_protocol() {
1121        let protocol = SyncProtocol::new();
1122        assert_eq!(protocol.epoch(), 0);
1123        assert!(!protocol.has_pending());
1124
1125        protocol.queue_sync(1, Tier::Vram, Tier::Host);
1126        protocol.queue_sync(2, Tier::Host, Tier::Disk);
1127        assert!(protocol.has_pending());
1128        assert_eq!(protocol.pending_count(), 2);
1129
1130        let pending = protocol.drain_pending();
1131        assert_eq!(pending.len(), 2);
1132        assert!(!protocol.has_pending());
1133
1134        let new_epoch = protocol.advance_epoch();
1135        assert_eq!(new_epoch, 0);
1136        assert_eq!(protocol.epoch(), 1);
1137    }
1138
1139    #[test]
1140    fn test_sync_protocol_policy() {
1141        let default = SyncProtocol::new();
1142        assert_eq!(default.policy(), WritePolicy::WriteBack);
1143
1144        let writethrough = SyncProtocol::with_policy(WritePolicy::WriteThrough);
1145        assert_eq!(writethrough.policy(), WritePolicy::WriteThrough);
1146    }
1147}