xlog-prob 0.5.0

Probabilistic inference engines for XLOG
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
//! GPU-native knowledge compilation.
//!
//! This module is the home of GPU-native compilation + verification utilities.
//!
//! Production correctness requires the GPU CDCL equivalence verifier (see `validation`).

use std::sync::Arc;
use std::time::Instant;

use cudarc::driver::DeviceSlice;
use xlog_core::{Result, XlogError};
use xlog_cuda::memory::TrackedCudaSlice;
use xlog_cuda::CudaKernelProvider;
use xlog_solve::{GpuCdclConfig, GpuCnf};

use crate::compilation::gpu_cache::{GpuCircuitCache, GpuCircuitCacheHandle};
use crate::gpu::GpuXgcf;

pub mod disk_cache;
pub mod gpu_cache;
pub mod gpu_cnf;
pub mod gpu_d4;
pub mod gpu_pir;
pub mod gpu_pir_intern;
pub mod gpu_weights;
pub mod sparse_matrix;
pub mod validation;

pub use gpu_cnf::{encode_cnf_gpu, GpuCnfEncoding, GpuCnfVarTables};
pub use gpu_d4::GpuCompileConfig;
pub use gpu_pir::{GpuPirGraph, GpuPirRoots, PIR_AND, PIR_LIT, PIR_NEG_LIT, PIR_OR};
// PIR_CONST and PIR_DECISION are used within gpu_pir.rs and gpu_pir_intern.rs
// via direct module paths; no crate-level re-export needed.
pub use gpu_pir_intern::{GpuPirInterner, PirBatch};
pub use gpu_weights::GpuWeights;
pub use gpu_weights::{
    apply_query_vars_device, build_evidence_by_var_gpu, build_weights_gpu, map_nodes_to_vars_gpu,
    restore_query_vars_device,
};
// GpuCsrCnf is currently unused (dead code); remove re-export.
pub use validation::{
    build_equivalence_queries_gpu, validate_equivalence_gpu, validate_equivalence_gpu_gated,
    GpuEquivalenceConfig, GpuEquivalenceQueries,
};
// check_equivalence_gpu and check_equivalence_gpu_gated are called only
// within validation.rs itself; no crate-level re-export needed.

/// Per-stage compilation timing (populated only when XLOG_WARMUP_PROFILE=1).
#[derive(Debug, Clone, Default)]
pub struct CircuitCompileProfile {
    pub cnf_hash_sec: f64,
    pub d4_compile_sec: f64,
    pub verify_sec: f64,
    pub smooth_sec: f64,
    pub cache_store_sec: f64,
    pub free_var_mask_sec: f64,
    pub gpu_cache_hit: bool,
    pub disk_cache_hit: bool,
}

fn warmup_profiling_enabled() -> bool {
    std::env::var("XLOG_WARMUP_PROFILE")
        .map(|v| v == "1")
        .unwrap_or(false)
}

/// Device-resident random-variable list for GPU smoothing.
pub struct DeviceRandomVarList {
    list: TrackedCudaSlice<u32>,
    count: u32,
}

impl DeviceRandomVarList {
    pub fn from_device(list: TrackedCudaSlice<u32>, count: u32) -> Result<Self> {
        let len = u32::try_from(list.len()).map_err(|_| {
            XlogError::Compilation("DeviceRandomVarList: list length exceeds u32".to_string())
        })?;
        if count > len {
            return Err(XlogError::Compilation(format!(
                "DeviceRandomVarList: count {} exceeds list len {}",
                count, len
            )));
        }
        Ok(Self { list, count })
    }

    pub fn from_host(provider: &CudaKernelProvider, host: &[u32]) -> Result<Self> {
        let memory = provider.memory();
        let mut list = memory.alloc::<u32>(host.len())?;
        if !host.is_empty() {
            provider
                .device()
                .inner()
                .htod_sync_copy_into(host, &mut list)
                .map_err(|e| {
                    XlogError::Kernel(format!("DeviceRandomVarList upload failed: {}", e))
                })?;
        }
        let count = u32::try_from(host.len()).map_err(|_| {
            XlogError::Compilation("DeviceRandomVarList: host len exceeds u32".to_string())
        })?;
        Ok(Self { list, count })
    }

    pub fn is_empty(&self) -> bool {
        self.count == 0
    }

    pub fn count(&self) -> u32 {
        self.count
    }

    pub fn list(&self) -> &TrackedCudaSlice<u32> {
        &self.list
    }
}

/// Compile CNF on GPU, then verify equivalence with GPU CDCL.
pub fn compile_gpu_d4_and_verify(
    cnf: &GpuCnf,
    decision_var_limit: &TrackedCudaSlice<u32>,
    provider: &Arc<CudaKernelProvider>,
    config: &GpuCompileConfig,
) -> Result<GpuXgcf> {
    if config.cdcl_conflict_budget.is_some() {
        return Err(XlogError::Compilation(
            "cdcl_conflict_budget is not supported by the GPU CDCL verifier".to_string(),
        ));
    }
    let circuit = gpu_d4::compile_gpu_d4(cnf, provider, config)?;
    let cdcl = cdcl_config_from_compile(config)?;
    validate_equivalence_gpu(
        cnf,
        decision_var_limit,
        &circuit,
        provider,
        GpuEquivalenceConfig {
            cdcl,
            reuse_workspace: config.incremental_verify,
        },
    )?;
    Ok(circuit)
}

/// Compile CNF on GPU, cache the circuit, then verify equivalence with GPU CDCL.
///
/// `canonical_cnf_hash`: a process-independent hash of the PIR structure, used as
/// the `cnf_hash` in the disk cache key. Computed via [`crate::cnf::canonical_pir_hash`].
/// If `None`, disk caching is skipped.
pub fn compile_gpu_d4_and_verify_cached(
    cnf: &GpuCnf,
    decision_var_limit: &TrackedCudaSlice<u32>,
    provider: &Arc<CudaKernelProvider>,
    config: &GpuCompileConfig,
    cache: &mut GpuCircuitCache,
    random_vars: &DeviceRandomVarList,
    canonical_cnf_hash: Option<u64>,
) -> Result<(GpuCircuitCacheHandle, Option<CircuitCompileProfile>)> {
    if config.cdcl_conflict_budget.is_some() {
        return Err(XlogError::Compilation(
            "cdcl_conflict_budget is not supported by the GPU CDCL verifier".to_string(),
        ));
    }

    let profiling = warmup_profiling_enabled();
    let mut profile = CircuitCompileProfile::default();

    // --- CNF hash stage ---
    #[cfg(debug_assertions)]
    eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: hash_cnf_gpu");
    let t_hash = if profiling {
        Some(Instant::now())
    } else {
        None
    };
    let key = gpu_cache::hash_cnf_gpu(cnf, provider)?;
    if let Some(t0) = t_hash {
        provider
            .device()
            .synchronize()
            .map_err(|e| XlogError::Kernel(format!("sync after hash_cnf_gpu: {}", e)))?;
        profile.cnf_hash_sec = t0.elapsed().as_secs_f64();
    }
    #[cfg(debug_assertions)]
    {
        if !profiling {
            provider
                .device()
                .synchronize()
                .map_err(|e| XlogError::Kernel(format!("sync after hash_cnf_gpu failed: {}", e)))?;
        }
    }
    #[cfg(debug_assertions)]
    eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: lookup_or_insert_device");
    let lookup = cache.lookup_or_insert_device(&key)?;
    let mut handle = lookup.into_handle()?;

    // --- Disk cache check (only on GPU cache miss) ---
    //
    // D→H copy compile_needed to decide whether we need to compile at all.
    // If compile_needed == 0, the GPU cache already has the circuit (GPU cache hit).
    // If compile_needed == 1, we check the disk cache before falling through to D4.
    let compile_needed_host: Vec<u32> = provider
        .device()
        .inner()
        .dtoh_sync_copy(handle.compile_needed_device())
        .map_err(|e| XlogError::Kernel(format!("dtoh compile_needed: {}", e)))?;
    let compile_needed = compile_needed_host[0];

    // GPU cache hit — short-circuit the entire compile pipeline.
    if compile_needed == 0 {
        profile.gpu_cache_hit = true;
        let out_profile = if profiling { Some(profile) } else { None };
        return Ok((handle, out_profile));
    }

    // Build the disk cache key (we know compile_needed == 1 at this point).
    // Uses the caller-supplied canonical PIR hash (process-independent) instead of the
    // GPU CNF hash (which varies per process due to PirNodeId non-determinism).
    let cache_key = if compile_needed == 1 {
        if let Some(cnf_hash) = canonical_cnf_hash {
            let config_hash = hash_compile_config(config);
            let random_vars_hash = hash_random_vars(random_vars, provider)?;
            let sm = detect_compute_capability(provider)?;
            Some(disk_cache::CircuitCacheKey {
                cnf_hash,
                config_hash,
                random_vars_hash,
                sm,
            })
        } else {
            None
        }
    } else {
        None
    };

    // Check disk cache on GPU cache miss
    if let Some(ref disk_key) = cache_key {
        #[cfg(debug_assertions)]
        eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: checking disk cache");
        if let Ok(Some(artifact)) = disk_cache::read_artifact(disk_key) {
            #[cfg(debug_assertions)]
            eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: disk cache hit");
            cache.restore_from_host_arrays(&mut handle, &artifact)?;
            provider
                .device()
                .synchronize()
                .map_err(|e| XlogError::Kernel(format!("sync after disk cache restore: {}", e)))?;
            profile.disk_cache_hit = true;
            let out_profile = if profiling { Some(profile) } else { None };
            return Ok((handle, out_profile));
        }
        #[cfg(debug_assertions)]
        eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: disk cache miss");
    }

    let d4_config = d4_config_for_smoothing(config, random_vars.count())?;

    // --- D4 compile stage ---
    #[cfg(debug_assertions)]
    eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: compile_gpu_d4_gated");
    let t_d4 = if profiling {
        Some(Instant::now())
    } else {
        None
    };
    let circuit_base =
        gpu_d4::compile_gpu_d4_gated(cnf, provider, &d4_config, handle.compile_needed_device())?;
    if let Some(t0) = t_d4 {
        provider
            .device()
            .synchronize()
            .map_err(|e| XlogError::Kernel(format!("sync after d4 compile: {}", e)))?;
        profile.d4_compile_sec = t0.elapsed().as_secs_f64();
    }
    #[cfg(debug_assertions)]
    {
        if !profiling {
            provider.device().synchronize().map_err(|e| {
                XlogError::Kernel(format!("sync after compile_gpu_d4_gated failed: {}", e))
            })?;
        }
    }
    if circuit_base.num_nodes() == 0 || circuit_base.num_levels() == 0 {
        // Defensive: D4 returned an empty circuit (the primary GPU cache hit is handled
        // by the compile_needed == 0 early return above; this catches degenerate CNFs).
        let out_profile = if profiling { Some(profile) } else { None };
        return Ok((handle, out_profile));
    }

    // --- Verify equivalence stage ---
    //
    // Verify equivalence on the *base* circuit (pre-smoothing) to keep the verifier CNFs minimal.
    //
    // `encode_cnf_gpu` sets `decision_var_limit` to the end of the leaf+choice var range. For
    // deterministic programs with no probabilistic vars, this range is empty (limit=0). In that
    // case, the verifier must still be able to branch, so fall back to `cnf.num_vars` (all CNF
    // vars are semantically meaningful when there is no probabilistic decision set).
    let verifier_decision_var_limit = if random_vars.is_empty() {
        &cnf.num_vars
    } else {
        decision_var_limit
    };
    let cdcl = cdcl_config_from_compile(config)?;
    #[cfg(debug_assertions)]
    eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: validate_equivalence_gpu_gated");
    let t_verify = if profiling {
        Some(Instant::now())
    } else {
        None
    };
    validate_equivalence_gpu_gated(
        cnf,
        verifier_decision_var_limit,
        &circuit_base,
        provider,
        GpuEquivalenceConfig {
            cdcl,
            reuse_workspace: config.incremental_verify,
        },
        handle.compile_needed_device(),
    )?;
    if let Some(t0) = t_verify {
        provider
            .device()
            .synchronize()
            .map_err(|e| XlogError::Kernel(format!("sync after verify: {}", e)))?;
        profile.verify_sec = t0.elapsed().as_secs_f64();
    }
    #[cfg(debug_assertions)]
    {
        if !profiling {
            provider.device().synchronize().map_err(|e| {
                XlogError::Kernel(format!(
                    "sync after validate_equivalence_gpu_gated failed: {}",
                    e
                ))
            })?;
        }
    }

    // --- Smoothing stage ---
    //
    // Smoothing is evaluation-only (WMC/grad correctness); it is semantics-preserving and does not
    // need to participate in the equivalence check.
    let t_smooth = if profiling {
        Some(Instant::now())
    } else {
        None
    };
    let circuit_eval = if random_vars.is_empty() {
        circuit_base
    } else {
        #[cfg(debug_assertions)]
        eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: smooth_random_vars_device");
        let smoothed = circuit_base.smooth_random_vars_device(
            provider,
            random_vars.list(),
            random_vars.count(),
            config.smooth_node_cap,
            config.smooth_edge_cap,
        )?;
        #[cfg(debug_assertions)]
        {
            if !profiling {
                provider.device().synchronize().map_err(|e| {
                    XlogError::Kernel(format!(
                        "sync after smooth_random_vars_device failed: {}",
                        e
                    ))
                })?;
            }
        }
        smoothed
    };
    if let Some(t0) = t_smooth {
        provider
            .device()
            .synchronize()
            .map_err(|e| XlogError::Kernel(format!("sync after smooth: {}", e)))?;
        profile.smooth_sec = t0.elapsed().as_secs_f64();
    }

    // --- Cache store stage ---
    #[cfg(debug_assertions)]
    eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: store_from_xgcf");
    let t_store = if profiling {
        Some(Instant::now())
    } else {
        None
    };
    cache.store_from_xgcf(&mut handle, &circuit_eval)?;
    if let Some(t0) = t_store {
        provider
            .device()
            .synchronize()
            .map_err(|e| XlogError::Kernel(format!("sync after cache store: {}", e)))?;
        profile.cache_store_sec = t0.elapsed().as_secs_f64();
    }
    #[cfg(debug_assertions)]
    {
        if !profiling {
            provider.device().synchronize().map_err(|e| {
                XlogError::Kernel(format!("sync after store_from_xgcf failed: {}", e))
            })?;
        }
    }

    // --- Free-var mask stage ---
    #[cfg(debug_assertions)]
    eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: compute_free_var_mask_gpu_gated");
    let t_fvm = if profiling {
        Some(Instant::now())
    } else {
        None
    };
    let free_var_mask = gpu_d4::compute_free_var_mask_gpu_gated(
        cnf,
        &circuit_eval,
        provider,
        handle.compile_needed_device(),
    )?;
    #[cfg(debug_assertions)]
    {
        if !profiling {
            provider.device().synchronize().map_err(|e| {
                XlogError::Kernel(format!(
                    "sync after compute_free_var_mask_gpu_gated failed: {}",
                    e
                ))
            })?;
        }
    }
    // Only enable free-var correction if there are actual free variables.
    // When the mask is all-zero (common for smoothed d-DNNF circuits),
    // skipping this keeps has_free_var_mask[slot]=false, which avoids unnecessary
    // free-var correction kernel launches on every subsequent eval.
    let mask_host: Vec<u8> = provider
        .device()
        .inner()
        .dtoh_sync_copy(&free_var_mask)
        .map_err(|e| XlogError::Kernel(format!("Failed to read free_var_mask: {}", e)))?;
    let has_free_vars = mask_host.iter().any(|&b| b != 0);
    #[cfg(debug_assertions)]
    eprintln!(
        "[xlog-prob] free_var_mask: {} free vars, batched eval {}",
        if has_free_vars { "has" } else { "no" },
        if has_free_vars { "DISABLED" } else { "ENABLED" },
    );
    if has_free_vars {
        cache.store_free_var_mask(&mut handle, &free_var_mask)?;
        #[cfg(debug_assertions)]
        {
            if !profiling {
                provider.device().synchronize().map_err(|e| {
                    XlogError::Kernel(format!("sync after store_free_var_mask failed: {}", e))
                })?;
            }
        }
    }
    if let Some(t0) = t_fvm {
        provider
            .device()
            .synchronize()
            .map_err(|e| XlogError::Kernel(format!("sync after free_var_mask: {}", e)))?;
        profile.free_var_mask_sec = t0.elapsed().as_secs_f64();
    }

    // --- Disk cache write (opportunistic) ---
    //
    // After a successful compilation, write the artifact to disk for next warm start.
    // Errors are silently ignored — the disk cache is best-effort.
    if let Some(ref disk_key) = cache_key {
        if let Ok(artifact) = cache.build_artifact_from_device(&handle, provider) {
            let _ = disk_cache::write_artifact(disk_key, &artifact);
            #[cfg(debug_assertions)]
            eprintln!("[xlog-prob] compile_gpu_d4_and_verify_cached: wrote disk cache artifact");
        }
    }

    let out_profile = if profiling { Some(profile) } else { None };
    Ok((handle, out_profile))
}

fn d4_config_for_smoothing(
    config: &GpuCompileConfig,
    random_var_count: u32,
) -> Result<GpuCompileConfig> {
    if random_var_count == 0 {
        return Ok(*config);
    }
    let headroom = 2u32
        .checked_add(random_var_count)
        .ok_or_else(|| XlogError::Compilation("smooth headroom overflow".to_string()))?;
    if config.smooth_node_cap <= headroom {
        return Err(XlogError::Compilation(format!(
            "GpuCompileConfig smooth_node_cap {} too small for smoothing headroom {}",
            config.smooth_node_cap, headroom
        )));
    }
    let base_cap = config
        .smooth_node_cap
        .checked_sub(headroom)
        .ok_or_else(|| XlogError::Compilation("smooth node cap underflow".to_string()))?;
    if base_cap < 3 {
        return Err(XlogError::Compilation(
            "GpuCompileConfig smooth_node_cap leaves <3 base nodes".to_string(),
        ));
    }
    let mut out = *config;
    out.smooth_node_cap = base_cap;
    Ok(out)
}

fn cdcl_config_from_compile(config: &GpuCompileConfig) -> Result<GpuCdclConfig> {
    if config.cdcl_restart_interval == 0 {
        return Err(XlogError::Compilation(
            "cdcl_restart_interval must be > 0".to_string(),
        ));
    }
    if config.cdcl_learned_bytes == 0 {
        return Err(XlogError::Compilation(
            "cdcl_learned_bytes must be > 0".to_string(),
        ));
    }

    // Deterministic sizing: assume average learned clause length = 4.
    const AVG_LEN: u64 = 4;
    const META_BYTES_PER_CLAUSE: u64 = 24; // offsets + lbd + activity + flags + proof offsets (rounded up)
    const PROOF_BYTES_PER_CLAUSE: u64 = 8 + (8 * AVG_LEN); // (conflict, steps) + 2*u32 per lit
    const LIT_BYTES_PER_CLAUSE: u64 = 4 * AVG_LEN;

    let bytes_per_clause = META_BYTES_PER_CLAUSE
        .checked_add(PROOF_BYTES_PER_CLAUSE)
        .and_then(|v| v.checked_add(LIT_BYTES_PER_CLAUSE))
        .ok_or_else(|| XlogError::Compilation("cdcl bytes per clause overflow".to_string()))?;

    let max_clauses = config
        .cdcl_learned_bytes
        .checked_div(bytes_per_clause)
        .ok_or_else(|| XlogError::Compilation("cdcl_learned_bytes div overflow".to_string()))?;
    if max_clauses == 0 {
        return Err(XlogError::Compilation(
            "cdcl_learned_bytes too small for learned clause arena".to_string(),
        ));
    }

    let max_lits = max_clauses
        .checked_mul(AVG_LEN)
        .ok_or_else(|| XlogError::Compilation("max_learned_lits overflow".to_string()))?;
    let max_proof_u32 = max_clauses
        .checked_mul(2 + 2 * AVG_LEN)
        .ok_or_else(|| XlogError::Compilation("max_proof_u32 overflow".to_string()))?;

    let max_learned_clauses = u32::try_from(max_clauses)
        .map_err(|_| XlogError::Compilation("max_learned_clauses exceeds u32::MAX".to_string()))?;
    let max_learned_lits = u32::try_from(max_lits)
        .map_err(|_| XlogError::Compilation("max_learned_lits exceeds u32::MAX".to_string()))?;
    let max_proof_u32 = u32::try_from(max_proof_u32)
        .map_err(|_| XlogError::Compilation("max_proof_u32 exceeds u32::MAX".to_string()))?;

    let reduce_interval = config
        .cdcl_restart_interval
        .checked_mul(20)
        .ok_or_else(|| XlogError::Compilation("cdcl reduce_interval overflow".to_string()))?;

    Ok(GpuCdclConfig {
        max_learned_clauses,
        max_learned_lits,
        max_proof_u32,
        restart_base: config.cdcl_restart_interval,
        reduce_interval,
        ..Default::default()
    })
}

// ---------------------------------------------------------------------------
// Disk cache helpers
// ---------------------------------------------------------------------------

/// FNV-1a 64-bit hash — deterministic across processes and Rust versions.
/// Matches the FNV-1a algorithm used in the GPU hash kernel (kernels/cache.cu).
fn fnv1a_u64(bytes: &[u8]) -> u64 {
    const FNV_OFFSET: u64 = 0xcbf29ce484222325;
    const FNV_PRIME: u64 = 0x100000001b3;
    let mut h = FNV_OFFSET;
    for &b in bytes {
        h ^= b as u64;
        h = h.wrapping_mul(FNV_PRIME);
    }
    h
}

/// Hash the compile config fields that affect circuit topology output.
fn hash_compile_config(config: &GpuCompileConfig) -> u64 {
    let mut buf = Vec::new();
    buf.extend_from_slice(&config.frontier_depth.to_le_bytes());
    buf.extend_from_slice(&config.max_frontier_items.to_le_bytes());
    buf.extend_from_slice(&config.max_depth.to_le_bytes());
    buf.extend_from_slice(&config.smooth_node_cap.to_le_bytes());
    buf.extend_from_slice(&config.smooth_edge_cap.to_le_bytes());
    // CDCL verifier params do not affect the compiled circuit topology,
    // but we include them for safety so a verifier config change invalidates the cache.
    buf.extend_from_slice(&config.cdcl_restart_interval.to_le_bytes());
    buf.extend_from_slice(&config.cdcl_learned_bytes.to_le_bytes());
    fnv1a_u64(&buf)
}

/// Hash the random variable list (D→H copy + hash).
fn hash_random_vars(
    random_vars: &DeviceRandomVarList,
    provider: &Arc<CudaKernelProvider>,
) -> Result<u64> {
    let count = random_vars.count();
    let mut buf = Vec::new();
    buf.extend_from_slice(&count.to_le_bytes());
    if count > 0 {
        let host: Vec<u32> = provider
            .device()
            .inner()
            .dtoh_sync_copy(random_vars.list())
            .map_err(|e| {
                XlogError::Kernel(format!("dtoh random_vars for disk cache hash: {}", e))
            })?;
        // Hash only the valid elements (count may be less than the allocation).
        for &v in &host[..count as usize] {
            buf.extend_from_slice(&v.to_le_bytes());
        }
    }
    Ok(fnv1a_u64(&buf))
}

/// Query the device compute capability and encode as `major * 10 + minor` (e.g. 89 for sm_89).
fn detect_compute_capability(provider: &Arc<CudaKernelProvider>) -> Result<u32> {
    use cudarc::driver::sys::CUdevice_attribute;

    let device = provider.device().inner();
    let major = device
        .attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR)
        .map_err(|e| {
            XlogError::Kernel(format!("Failed to query compute capability major: {}", e))
        })?;
    let minor = device
        .attribute(CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR)
        .map_err(|e| {
            XlogError::Kernel(format!("Failed to query compute capability minor: {}", e))
        })?;
    let major_u32: u32 = major.try_into().map_err(|_| {
        XlogError::Kernel(format!(
            "compute capability major {} cannot be converted to u32",
            major
        ))
    })?;
    let minor_u32: u32 = minor.try_into().map_err(|_| {
        XlogError::Kernel(format!(
            "compute capability minor {} cannot be converted to u32",
            minor
        ))
    })?;
    Ok(major_u32 * 10 + minor_u32)
}