Skip to main content

oxicuda_driver/
debug.rs

1//! Kernel debugging utilities for OxiCUDA.
2//!
3//! This module provides tools for debugging GPU kernels without traditional
4//! debuggers. It includes memory checking, NaN/Inf detection, printf
5//! emulation, assertion support, and PTX instrumentation for automated
6//! bounds/NaN checking.
7//!
8//! # Architecture
9//!
10//! The debugging system is layered:
11//!
12//! 1. **`KernelDebugger`** — Top-level manager that creates debug sessions and
13//!    manages breakpoints / watchpoints.
14//! 2. **`DebugSession`** — Collects [`DebugEvent`]s for a single kernel launch.
15//! 3. **`MemoryChecker`** — Validates memory accesses against known allocations.
16//! 4. **`NanInfChecker`** — Scans host-side buffers for NaN / Inf values.
17//! 5. **`PrintfBuffer`** — Parses a raw byte buffer that emulates GPU `printf`.
18//! 6. **`KernelAssertions`** — Convenience assertion helpers that produce
19//!    [`DebugEvent`]s instead of panicking.
20//! 7. **`DebugPtxInstrumenter`** — Instruments PTX source for automated
21//!    bounds/NaN checking and printf support.
22//!
23//! # Example
24//!
25//! ```rust
26//! use oxicuda_driver::debug::*;
27//!
28//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
29//! let config = KernelDebugConfig::default();
30//! let mut debugger = KernelDebugger::new(config);
31//! let session = debugger.attach("my_kernel")?;
32//! assert_eq!(session.kernel_name(), "my_kernel");
33//! # Ok(())
34//! # }
35//! ```
36
37use std::fmt;
38
39use crate::error::{CudaError, CudaResult};
40
41// ---------------------------------------------------------------------------
42// DebugLevel
43// ---------------------------------------------------------------------------
44
45/// Verbosity level for kernel debugging output.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
47pub enum DebugLevel {
48    /// No debug output.
49    Off,
50    /// Only errors.
51    Error,
52    /// Errors and warnings.
53    Warn,
54    /// Errors, warnings, and informational messages.
55    #[default]
56    Info,
57    /// Verbose debugging output.
58    Debug,
59    /// Maximum verbosity — every detail is logged.
60    Trace,
61}
62
63impl fmt::Display for DebugLevel {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            Self::Off => write!(f, "OFF"),
67            Self::Error => write!(f, "ERROR"),
68            Self::Warn => write!(f, "WARN"),
69            Self::Info => write!(f, "INFO"),
70            Self::Debug => write!(f, "DEBUG"),
71            Self::Trace => write!(f, "TRACE"),
72        }
73    }
74}
75
76// ---------------------------------------------------------------------------
77// KernelDebugConfig
78// ---------------------------------------------------------------------------
79
80/// Configuration for a kernel debug session.
81#[derive(Debug, Clone)]
82pub struct KernelDebugConfig {
83    /// The verbosity level for debug output.
84    pub debug_level: DebugLevel,
85    /// Whether to instrument bounds checking on memory accesses.
86    pub enable_bounds_check: bool,
87    /// Whether to detect NaN values in floating-point registers.
88    pub enable_nan_check: bool,
89    /// Whether to detect Inf values in floating-point registers.
90    pub enable_inf_check: bool,
91    /// Whether to detect potential race conditions.
92    pub enable_race_detection: bool,
93    /// Size of the GPU-side printf buffer in bytes.
94    pub print_buffer_size: usize,
95    /// Maximum number of printf calls per thread before truncation.
96    pub max_print_per_thread: usize,
97}
98
99impl Default for KernelDebugConfig {
100    fn default() -> Self {
101        Self {
102            debug_level: DebugLevel::Info,
103            enable_bounds_check: true,
104            enable_nan_check: true,
105            enable_inf_check: true,
106            enable_race_detection: false,
107            print_buffer_size: 1024 * 1024, // 1 MiB
108            max_print_per_thread: 32,
109        }
110    }
111}
112
113// ---------------------------------------------------------------------------
114// DebugEventType
115// ---------------------------------------------------------------------------
116
117/// The kind of debug event captured during kernel execution.
118#[derive(Debug, Clone, PartialEq)]
119pub enum DebugEventType {
120    /// A memory access was out of the allocated bounds.
121    OutOfBounds {
122        /// The faulting address.
123        address: u64,
124        /// The size of the attempted access in bytes.
125        size: usize,
126    },
127    /// A NaN was detected in a floating-point register.
128    NanDetected {
129        /// Register or variable name.
130        register: String,
131        /// The NaN bit-pattern reinterpreted as f64.
132        value: f64,
133    },
134    /// An infinity was detected in a floating-point register.
135    InfDetected {
136        /// Register or variable name.
137        register: String,
138    },
139    /// A potential race condition on a shared memory address.
140    RaceCondition {
141        /// The conflicting address.
142        address: u64,
143    },
144    /// A kernel-side assertion.
145    Assertion {
146        /// The assertion condition expression.
147        condition: String,
148        /// Source file name.
149        file: String,
150        /// Source line number.
151        line: u32,
152    },
153    /// A kernel-side printf invocation.
154    Printf {
155        /// The format string.
156        format: String,
157    },
158    /// A breakpoint was hit.
159    Breakpoint {
160        /// The breakpoint identifier.
161        id: u32,
162    },
163}
164
165impl DebugEventType {
166    /// Returns a short category tag suitable for filtering.
167    fn tag(&self) -> &'static str {
168        match self {
169            Self::OutOfBounds { .. } => "OOB",
170            Self::NanDetected { .. } => "NaN",
171            Self::InfDetected { .. } => "Inf",
172            Self::RaceCondition { .. } => "RACE",
173            Self::Assertion { .. } => "ASSERT",
174            Self::Printf { .. } => "PRINTF",
175            Self::Breakpoint { .. } => "BP",
176        }
177    }
178
179    /// Returns `true` when this variant has the same discriminant as `other`,
180    /// ignoring inner field values. Used by [`DebugSession::filter_events`].
181    fn same_kind(&self, other: &Self) -> bool {
182        std::mem::discriminant(self) == std::mem::discriminant(other)
183    }
184}
185
186// ---------------------------------------------------------------------------
187// DebugEvent
188// ---------------------------------------------------------------------------
189
190/// A single debug event captured during kernel execution.
191#[derive(Debug, Clone)]
192pub struct DebugEvent {
193    /// What kind of event occurred.
194    pub event_type: DebugEventType,
195    /// The CUDA thread index `(x, y, z)` that triggered the event.
196    pub thread_id: (u32, u32, u32),
197    /// The CUDA block index `(x, y, z)` that triggered the event.
198    pub block_id: (u32, u32, u32),
199    /// Timestamp in nanoseconds (monotonic, relative to session start).
200    pub timestamp_ns: u64,
201    /// Free-form human-readable message.
202    pub message: String,
203}
204
205impl fmt::Display for DebugEvent {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        write!(
208            f,
209            "[{tag}] block({bx},{by},{bz}) thread({tx},{ty},{tz}) @{ts}ns: {msg}",
210            tag = self.event_type.tag(),
211            bx = self.block_id.0,
212            by = self.block_id.1,
213            bz = self.block_id.2,
214            tx = self.thread_id.0,
215            ty = self.thread_id.1,
216            tz = self.thread_id.2,
217            ts = self.timestamp_ns,
218            msg = self.message,
219        )
220    }
221}
222
223// ---------------------------------------------------------------------------
224// WatchType
225// ---------------------------------------------------------------------------
226
227/// The kind of memory access a watchpoint monitors.
228#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
229pub enum WatchType {
230    /// Trigger on reads.
231    Read,
232    /// Trigger on writes.
233    Write,
234    /// Trigger on both reads and writes.
235    ReadWrite,
236}
237
238// ---------------------------------------------------------------------------
239// Breakpoint / Watchpoint helpers
240// ---------------------------------------------------------------------------
241
242#[derive(Debug, Clone)]
243#[allow(dead_code)]
244struct Breakpoint {
245    id: u32,
246    line: u32,
247}
248
249#[derive(Debug, Clone)]
250#[allow(dead_code)]
251struct Watchpoint {
252    id: u32,
253    address: u64,
254    size: usize,
255    watch_type: WatchType,
256}
257
258// ---------------------------------------------------------------------------
259// KernelDebugger
260// ---------------------------------------------------------------------------
261
262/// Top-level kernel debugging manager.
263///
264/// Create one per debugging session group. It holds global breakpoints and
265/// watchpoints and spawns [`DebugSession`] instances per kernel launch.
266#[derive(Debug)]
267pub struct KernelDebugger {
268    config: KernelDebugConfig,
269    breakpoints: Vec<Breakpoint>,
270    watchpoints: Vec<Watchpoint>,
271    next_bp_id: u32,
272    next_wp_id: u32,
273}
274
275impl KernelDebugger {
276    /// Create a new kernel debugger with the given configuration.
277    pub fn new(config: KernelDebugConfig) -> Self {
278        Self {
279            config,
280            breakpoints: Vec::new(),
281            watchpoints: Vec::new(),
282            next_bp_id: 1,
283            next_wp_id: 1,
284        }
285    }
286
287    /// Attach the debugger to a kernel launch, returning a new debug session.
288    ///
289    /// On macOS this always succeeds with a synthetic (empty) session because
290    /// no actual GPU driver is available.
291    ///
292    /// # Errors
293    ///
294    /// Returns [`CudaError::InvalidValue`] if `kernel_name` is empty.
295    pub fn attach(&mut self, kernel_name: &str) -> CudaResult<DebugSession> {
296        if kernel_name.is_empty() {
297            return Err(CudaError::InvalidValue);
298        }
299        Ok(DebugSession {
300            kernel_name: kernel_name.to_owned(),
301            events: Vec::new(),
302            config: self.config.clone(),
303        })
304    }
305
306    /// Set a breakpoint at a PTX source line. Returns the breakpoint ID.
307    pub fn set_breakpoint(&mut self, line: u32) -> u32 {
308        let id = self.next_bp_id;
309        self.next_bp_id = self.next_bp_id.saturating_add(1);
310        self.breakpoints.push(Breakpoint { id, line });
311        id
312    }
313
314    /// Remove a breakpoint by ID. Returns `true` if it was found and removed.
315    pub fn remove_breakpoint(&mut self, bp_id: u32) -> bool {
316        let before = self.breakpoints.len();
317        self.breakpoints.retain(|bp| bp.id != bp_id);
318        self.breakpoints.len() < before
319    }
320
321    /// Set a memory watchpoint. Returns the watchpoint ID.
322    pub fn watchpoint(&mut self, address: u64, size: usize, watch_type: WatchType) -> u32 {
323        let id = self.next_wp_id;
324        self.next_wp_id = self.next_wp_id.saturating_add(1);
325        self.watchpoints.push(Watchpoint {
326            id,
327            address,
328            size,
329            watch_type,
330        });
331        id
332    }
333
334    /// Returns a reference to the current debug configuration.
335    pub fn config(&self) -> &KernelDebugConfig {
336        &self.config
337    }
338}
339
340// ---------------------------------------------------------------------------
341// DebugSummary
342// ---------------------------------------------------------------------------
343
344/// Aggregate statistics for a debug session.
345#[derive(Debug, Clone, Default, PartialEq, Eq)]
346pub struct DebugSummary {
347    /// Total number of debug events.
348    pub total_events: usize,
349    /// Number of error-level events (OOB, assertions, races).
350    pub errors: usize,
351    /// Number of warning-level events (NaN, Inf).
352    pub warnings: usize,
353    /// Number of NaN detection events.
354    pub nan_count: usize,
355    /// Number of Inf detection events.
356    pub inf_count: usize,
357    /// Number of out-of-bounds events.
358    pub oob_count: usize,
359    /// Number of race condition events.
360    pub race_count: usize,
361}
362
363// ---------------------------------------------------------------------------
364// DebugSession
365// ---------------------------------------------------------------------------
366
367/// An active debug session for a single kernel launch.
368///
369/// Collects [`DebugEvent`]s and provides analysis / reporting helpers.
370#[derive(Debug)]
371pub struct DebugSession {
372    kernel_name: String,
373    events: Vec<DebugEvent>,
374    #[allow(dead_code)]
375    config: KernelDebugConfig,
376}
377
378impl DebugSession {
379    /// The kernel name this session is attached to.
380    pub fn kernel_name(&self) -> &str {
381        &self.kernel_name
382    }
383
384    /// All events collected so far.
385    pub fn events(&self) -> &[DebugEvent] {
386        &self.events
387    }
388
389    /// Record a new debug event.
390    pub fn add_event(&mut self, event: DebugEvent) {
391        self.events.push(event);
392    }
393
394    /// Return references to events whose type matches the discriminant of
395    /// `event_type` (field values inside variants are ignored for matching).
396    pub fn filter_events(&self, event_type: &DebugEventType) -> Vec<&DebugEvent> {
397        self.events
398            .iter()
399            .filter(|e| e.event_type.same_kind(event_type))
400            .collect()
401    }
402
403    /// Compute aggregate statistics over all collected events.
404    pub fn summary(&self) -> DebugSummary {
405        let mut s = DebugSummary {
406            total_events: self.events.len(),
407            ..DebugSummary::default()
408        };
409        for ev in &self.events {
410            match &ev.event_type {
411                DebugEventType::OutOfBounds { .. } => {
412                    s.oob_count += 1;
413                    s.errors += 1;
414                }
415                DebugEventType::NanDetected { .. } => {
416                    s.nan_count += 1;
417                    s.warnings += 1;
418                }
419                DebugEventType::InfDetected { .. } => {
420                    s.inf_count += 1;
421                    s.warnings += 1;
422                }
423                DebugEventType::RaceCondition { .. } => {
424                    s.race_count += 1;
425                    s.errors += 1;
426                }
427                DebugEventType::Assertion { .. } => {
428                    s.errors += 1;
429                }
430                DebugEventType::Printf { .. } | DebugEventType::Breakpoint { .. } => {}
431            }
432        }
433        s
434    }
435
436    /// Produce a human-readable debug report.
437    pub fn format_report(&self) -> String {
438        let summary = self.summary();
439        let mut out = String::with_capacity(512);
440        out.push_str(&format!("=== Debug Report: {} ===\n", self.kernel_name));
441        out.push_str(&format!(
442            "Total events: {}  (errors: {}, warnings: {})\n",
443            summary.total_events, summary.errors, summary.warnings
444        ));
445        if summary.oob_count > 0 {
446            out.push_str(&format!("  Out-of-bounds: {}\n", summary.oob_count));
447        }
448        if summary.nan_count > 0 {
449            out.push_str(&format!("  NaN detected:  {}\n", summary.nan_count));
450        }
451        if summary.inf_count > 0 {
452            out.push_str(&format!("  Inf detected:  {}\n", summary.inf_count));
453        }
454        if summary.race_count > 0 {
455            out.push_str(&format!("  Race cond.:    {}\n", summary.race_count));
456        }
457        out.push_str("--- Events ---\n");
458        for ev in &self.events {
459            out.push_str(&format!("  {ev}\n"));
460        }
461        out.push_str("=== End Report ===\n");
462        out
463    }
464}
465
466// ---------------------------------------------------------------------------
467// MemoryRegion / MemoryChecker
468// ---------------------------------------------------------------------------
469
470/// A contiguous GPU memory allocation known to the memory checker.
471#[derive(Debug, Clone)]
472pub struct MemoryRegion {
473    /// Base device address of the allocation.
474    pub base_address: u64,
475    /// Size in bytes.
476    pub size: usize,
477    /// Human-readable name for diagnostics.
478    pub name: String,
479    /// Whether the allocation is read-only.
480    pub is_readonly: bool,
481}
482
483/// Validates memory accesses against a set of known [`MemoryRegion`]s.
484#[derive(Debug)]
485pub struct MemoryChecker {
486    allocations: Vec<MemoryRegion>,
487}
488
489impl MemoryChecker {
490    /// Create a memory checker from a list of known allocations.
491    pub fn new(allocations: Vec<MemoryRegion>) -> Self {
492        Self { allocations }
493    }
494
495    /// Check whether a memory access is valid.
496    ///
497    /// Returns `Some(DebugEvent)` if the access is out of bounds or violates
498    /// read-only protections; `None` if the access is valid.
499    pub fn check_access(&self, address: u64, size: usize, is_write: bool) -> Option<DebugEvent> {
500        // Find the allocation that contains this address.
501        let region = self.allocations.iter().find(|r| {
502            address >= r.base_address
503                && address
504                    .checked_add(size as u64)
505                    .is_some_and(|end| end <= r.base_address + r.size as u64)
506        });
507
508        match region {
509            Some(r) if is_write && r.is_readonly => Some(DebugEvent {
510                event_type: DebugEventType::OutOfBounds { address, size },
511                thread_id: (0, 0, 0),
512                block_id: (0, 0, 0),
513                timestamp_ns: 0,
514                message: format!("Write to read-only region '{}' at {:#x}", r.name, address),
515            }),
516            Some(_) => None,
517            None => Some(DebugEvent {
518                event_type: DebugEventType::OutOfBounds { address, size },
519                thread_id: (0, 0, 0),
520                block_id: (0, 0, 0),
521                timestamp_ns: 0,
522                message: format!(
523                    "Access at {:#x} (size {}) does not fall within any known allocation",
524                    address, size
525                ),
526            }),
527        }
528    }
529}
530
531// ---------------------------------------------------------------------------
532// NanInfChecker / NanInfLocation
533// ---------------------------------------------------------------------------
534
535/// Location of a NaN or Inf value found in a buffer.
536#[derive(Debug, Clone, PartialEq)]
537pub struct NanInfLocation {
538    /// Index into the buffer.
539    pub index: usize,
540    /// The problematic value (as f64 for uniform reporting).
541    pub value: f64,
542    /// `true` if NaN, `false` if Inf.
543    pub is_nan: bool,
544}
545
546/// Scans host-side buffers for NaN and Inf values.
547#[derive(Debug, Clone, Copy)]
548pub struct NanInfChecker;
549
550impl NanInfChecker {
551    /// Check an `f32` buffer for NaN and Inf values.
552    pub fn check_f32(data: &[f32]) -> Vec<NanInfLocation> {
553        data.iter()
554            .enumerate()
555            .filter_map(|(i, &v)| {
556                if v.is_nan() {
557                    Some(NanInfLocation {
558                        index: i,
559                        value: f64::from(v),
560                        is_nan: true,
561                    })
562                } else if v.is_infinite() {
563                    Some(NanInfLocation {
564                        index: i,
565                        value: f64::from(v),
566                        is_nan: false,
567                    })
568                } else {
569                    None
570                }
571            })
572            .collect()
573    }
574
575    /// Check an `f64` buffer for NaN and Inf values.
576    pub fn check_f64(data: &[f64]) -> Vec<NanInfLocation> {
577        data.iter()
578            .enumerate()
579            .filter_map(|(i, &v)| {
580                if v.is_nan() {
581                    Some(NanInfLocation {
582                        index: i,
583                        value: v,
584                        is_nan: true,
585                    })
586                } else if v.is_infinite() {
587                    Some(NanInfLocation {
588                        index: i,
589                        value: v,
590                        is_nan: false,
591                    })
592                } else {
593                    None
594                }
595            })
596            .collect()
597    }
598}
599
600// ---------------------------------------------------------------------------
601// PrintfBuffer / PrintfEntry / PrintfArg
602// ---------------------------------------------------------------------------
603
604/// A single argument captured from a GPU printf call.
605#[derive(Debug, Clone, PartialEq)]
606pub enum PrintfArg {
607    /// Integer argument.
608    Int(i64),
609    /// Floating-point argument.
610    Float(f64),
611    /// String argument.
612    String(String),
613}
614
615/// A parsed GPU-side printf entry.
616#[derive(Debug, Clone)]
617pub struct PrintfEntry {
618    /// Thread index that issued the printf.
619    pub thread_id: (u32, u32, u32),
620    /// Block index that issued the printf.
621    pub block_id: (u32, u32, u32),
622    /// The format string.
623    pub format_string: String,
624    /// Parsed arguments.
625    pub args: Vec<PrintfArg>,
626}
627
628/// GPU-side printf emulation buffer.
629///
630/// The buffer layout is a simple framed protocol:
631///
632/// ```text
633/// [entry_count: u32_le]
634/// repeated entry_count times:
635///   [thread_x: u32_le] [thread_y: u32_le] [thread_z: u32_le]
636///   [block_x:  u32_le] [block_y:  u32_le] [block_z:  u32_le]
637///   [fmt_len:  u32_le] [fmt_bytes: u8 * fmt_len]
638///   [arg_count: u32_le]
639///   repeated arg_count times:
640///     [tag: u8]  0=Int(i64_le), 1=Float(f64_le), 2=String(u32_le len + bytes)
641/// ```
642#[derive(Debug)]
643pub struct PrintfBuffer {
644    buffer_size: usize,
645}
646
647impl PrintfBuffer {
648    /// Create a printf buffer descriptor with the given maximum size.
649    pub fn new(buffer_size: usize) -> Self {
650        Self { buffer_size }
651    }
652
653    /// Returns the configured buffer size.
654    pub fn buffer_size(&self) -> usize {
655        self.buffer_size
656    }
657
658    /// Parse a raw byte buffer into structured printf entries.
659    ///
660    /// Returns an empty vec if the buffer is too small or malformed.
661    pub fn parse_entries(&self, raw: &[u8]) -> Vec<PrintfEntry> {
662        let mut entries = Vec::new();
663        let mut cursor = 0usize;
664
665        let entry_count = match Self::read_u32(raw, &mut cursor) {
666            Some(n) => n as usize,
667            None => return entries,
668        };
669
670        for _ in 0..entry_count {
671            let Some(entry) = self.parse_single_entry(raw, &mut cursor) else {
672                break;
673            };
674            entries.push(entry);
675        }
676
677        entries
678    }
679
680    fn parse_single_entry(&self, raw: &[u8], cursor: &mut usize) -> Option<PrintfEntry> {
681        let tx = Self::read_u32(raw, cursor)?;
682        let ty = Self::read_u32(raw, cursor)?;
683        let tz = Self::read_u32(raw, cursor)?;
684        let bx = Self::read_u32(raw, cursor)?;
685        let by = Self::read_u32(raw, cursor)?;
686        let bz = Self::read_u32(raw, cursor)?;
687
688        let fmt_len = Self::read_u32(raw, cursor)? as usize;
689        let fmt_bytes = Self::read_bytes(raw, cursor, fmt_len)?;
690        let format_string = String::from_utf8_lossy(fmt_bytes).into_owned();
691
692        let arg_count = Self::read_u32(raw, cursor)? as usize;
693        let mut args = Vec::with_capacity(arg_count);
694        for _ in 0..arg_count {
695            let tag = Self::read_u8(raw, cursor)?;
696            let arg = match tag {
697                0 => {
698                    let val = Self::read_i64(raw, cursor)?;
699                    PrintfArg::Int(val)
700                }
701                1 => {
702                    let val = Self::read_f64(raw, cursor)?;
703                    PrintfArg::Float(val)
704                }
705                2 => {
706                    let slen = Self::read_u32(raw, cursor)? as usize;
707                    let sbytes = Self::read_bytes(raw, cursor, slen)?;
708                    PrintfArg::String(String::from_utf8_lossy(sbytes).into_owned())
709                }
710                _ => return None,
711            };
712            args.push(arg);
713        }
714
715        Some(PrintfEntry {
716            thread_id: (tx, ty, tz),
717            block_id: (bx, by, bz),
718            format_string,
719            args,
720        })
721    }
722
723    // --- Low-level readers ---
724
725    fn read_u8(raw: &[u8], cursor: &mut usize) -> Option<u8> {
726        if *cursor >= raw.len() {
727            return None;
728        }
729        let val = raw[*cursor];
730        *cursor += 1;
731        Some(val)
732    }
733
734    fn read_u32(raw: &[u8], cursor: &mut usize) -> Option<u32> {
735        if *cursor + 4 > raw.len() {
736            return None;
737        }
738        let bytes: [u8; 4] = raw[*cursor..*cursor + 4].try_into().ok()?;
739        *cursor += 4;
740        Some(u32::from_le_bytes(bytes))
741    }
742
743    fn read_i64(raw: &[u8], cursor: &mut usize) -> Option<i64> {
744        if *cursor + 8 > raw.len() {
745            return None;
746        }
747        let bytes: [u8; 8] = raw[*cursor..*cursor + 8].try_into().ok()?;
748        *cursor += 8;
749        Some(i64::from_le_bytes(bytes))
750    }
751
752    fn read_f64(raw: &[u8], cursor: &mut usize) -> Option<f64> {
753        if *cursor + 8 > raw.len() {
754            return None;
755        }
756        let bytes: [u8; 8] = raw[*cursor..*cursor + 8].try_into().ok()?;
757        *cursor += 8;
758        Some(f64::from_le_bytes(bytes))
759    }
760
761    fn read_bytes<'a>(raw: &'a [u8], cursor: &mut usize, len: usize) -> Option<&'a [u8]> {
762        if *cursor + len > raw.len() {
763            return None;
764        }
765        let slice = &raw[*cursor..*cursor + len];
766        *cursor += len;
767        Some(slice)
768    }
769}
770
771// ---------------------------------------------------------------------------
772// KernelAssertions
773// ---------------------------------------------------------------------------
774
775/// Convenience assertion helpers that produce [`DebugEvent`]s instead of
776/// panicking, suitable for GPU kernel emulation / validation.
777#[derive(Debug, Clone, Copy)]
778pub struct KernelAssertions;
779
780impl KernelAssertions {
781    /// Assert that `index < len`. Returns an event if the assertion fails.
782    pub fn assert_bounds(index: usize, len: usize, name: &str) -> Option<DebugEvent> {
783        if index < len {
784            return None;
785        }
786        Some(DebugEvent {
787            event_type: DebugEventType::Assertion {
788                condition: format!("{name}[{index}] < {len}"),
789                file: String::new(),
790                line: 0,
791            },
792            thread_id: (0, 0, 0),
793            block_id: (0, 0, 0),
794            timestamp_ns: 0,
795            message: format!("Bounds check failed: {name}[{index}] out of range (len={len})"),
796        })
797    }
798
799    /// Assert that `value` is not NaN. Returns an event if it is.
800    pub fn assert_not_nan(value: f64, name: &str) -> Option<DebugEvent> {
801        if !value.is_nan() {
802            return None;
803        }
804        Some(DebugEvent {
805            event_type: DebugEventType::NanDetected {
806                register: name.to_owned(),
807                value,
808            },
809            thread_id: (0, 0, 0),
810            block_id: (0, 0, 0),
811            timestamp_ns: 0,
812            message: format!("NaN detected in '{name}'"),
813        })
814    }
815
816    /// Assert that `value` is finite (not NaN and not Inf). Returns an event
817    /// if it is not.
818    pub fn assert_finite(value: f64, name: &str) -> Option<DebugEvent> {
819        if value.is_finite() {
820            return None;
821        }
822        if value.is_nan() {
823            return Some(DebugEvent {
824                event_type: DebugEventType::NanDetected {
825                    register: name.to_owned(),
826                    value,
827                },
828                thread_id: (0, 0, 0),
829                block_id: (0, 0, 0),
830                timestamp_ns: 0,
831                message: format!("Non-finite (NaN) value in '{name}'"),
832            });
833        }
834        Some(DebugEvent {
835            event_type: DebugEventType::InfDetected {
836                register: name.to_owned(),
837            },
838            thread_id: (0, 0, 0),
839            block_id: (0, 0, 0),
840            timestamp_ns: 0,
841            message: format!("Non-finite (Inf) value in '{name}'"),
842        })
843    }
844
845    /// Assert that `value` is strictly positive. Returns an event if it is
846    /// not (including NaN, zero, and negative values).
847    pub fn assert_positive(value: f64, name: &str) -> Option<DebugEvent> {
848        if value > 0.0 {
849            return None;
850        }
851        if value.is_nan() {
852            return Some(DebugEvent {
853                event_type: DebugEventType::NanDetected {
854                    register: name.to_owned(),
855                    value,
856                },
857                thread_id: (0, 0, 0),
858                block_id: (0, 0, 0),
859                timestamp_ns: 0,
860                message: format!("Expected positive value for '{name}', got NaN"),
861            });
862        }
863        Some(DebugEvent {
864            event_type: DebugEventType::Assertion {
865                condition: format!("{name} > 0"),
866                file: String::new(),
867                line: 0,
868            },
869            thread_id: (0, 0, 0),
870            block_id: (0, 0, 0),
871            timestamp_ns: 0,
872            message: format!("Expected positive value for '{name}', got {value}"),
873        })
874    }
875}
876
877// ---------------------------------------------------------------------------
878// DebugPtxInstrumenter
879// ---------------------------------------------------------------------------
880
881/// Instruments PTX source code with debugging checks.
882///
883/// This instrumenter inserts additional PTX instructions for bounds checking,
884/// NaN detection, and printf buffer support. The instrumented code writes
885/// diagnostic data to a designated debug buffer that the host can read back
886/// after kernel execution.
887#[derive(Debug)]
888pub struct DebugPtxInstrumenter {
889    enable_bounds_check: bool,
890    enable_nan_check: bool,
891    enable_printf: bool,
892}
893
894impl DebugPtxInstrumenter {
895    /// Create an instrumenter from a debug configuration.
896    pub fn new(config: &KernelDebugConfig) -> Self {
897        Self {
898            enable_bounds_check: config.enable_bounds_check,
899            enable_nan_check: config.enable_nan_check,
900            enable_printf: config.print_buffer_size > 0,
901        }
902    }
903
904    /// Insert bounds-checking instrumentation into PTX source.
905    ///
906    /// Adds `setp` + `trap` sequences after every `ld.global` / `st.global`
907    /// instruction to validate the address against the allocation size
908    /// parameter.
909    pub fn instrument_bounds_checks(&self, ptx: &str) -> String {
910        if !self.enable_bounds_check {
911            return ptx.to_owned();
912        }
913
914        let mut output = String::with_capacity(ptx.len() + ptx.len() / 4);
915        // Add debug parameter declaration at the top of each kernel entry.
916        let mut added_param = false;
917
918        for line in ptx.lines() {
919            let trimmed = line.trim();
920            // Insert debug buffer param after .entry
921            if trimmed.starts_with(".entry") && !added_param {
922                output.push_str(line);
923                output.push('\n');
924                output.push_str("    // [oxicuda-debug] bounds-check instrumentation\n");
925                output.push_str("    .param .u64 __oxicuda_debug_buf;\n");
926                added_param = true;
927                continue;
928            }
929
930            // Instrument global loads/stores
931            if (trimmed.starts_with("ld.global") || trimmed.starts_with("st.global"))
932                && !trimmed.starts_with("// [oxicuda-debug]")
933            {
934                output.push_str(line);
935                output.push('\n');
936                output.push_str("    // [oxicuda-debug] bounds check for above access\n");
937                output.push_str("    setp.ge.u64 %p_oob, %rd_addr, %rd_alloc_end;\n");
938                output.push_str("    @%p_oob trap;\n");
939            } else {
940                output.push_str(line);
941                output.push('\n');
942            }
943        }
944
945        output
946    }
947
948    /// Insert NaN-detection instrumentation into PTX source.
949    ///
950    /// After every floating-point arithmetic instruction (`add.f32`,
951    /// `mul.f64`, etc.) a `testp.nan` check is inserted.
952    pub fn instrument_nan_checks(&self, ptx: &str) -> String {
953        if !self.enable_nan_check {
954            return ptx.to_owned();
955        }
956
957        let fp_ops = [
958            "add.f32", "add.f64", "sub.f32", "sub.f64", "mul.f32", "mul.f64", "div.f32", "div.f64",
959            "fma.f32", "fma.f64",
960        ];
961
962        let mut output = String::with_capacity(ptx.len() + ptx.len() / 4);
963
964        for line in ptx.lines() {
965            output.push_str(line);
966            output.push('\n');
967
968            let trimmed = line.trim();
969            if fp_ops.iter().any(|op| trimmed.starts_with(op)) {
970                // Extract the destination register (first token after the op).
971                if let Some(dest) = trimmed.split_whitespace().nth(1) {
972                    let dest_clean = dest.trim_end_matches(',');
973                    let width = if trimmed.contains(".f64") {
974                        "f64"
975                    } else {
976                        "f32"
977                    };
978                    output.push_str(&format!(
979                        "    // [oxicuda-debug] NaN check for {dest_clean}\n"
980                    ));
981                    output.push_str(&format!("    testp.nan.{width} %p_nan, {dest_clean};\n"));
982                    output.push_str("    @%p_nan trap;\n");
983                }
984            }
985        }
986
987        output
988    }
989
990    /// Insert printf buffer support into PTX source.
991    ///
992    /// Adds a `.param .u64 __oxicuda_printf_buf` to each entry and inserts
993    /// stub store sequences where `// PRINTF` markers appear.
994    pub fn instrument_printf(&self, ptx: &str) -> String {
995        if !self.enable_printf {
996            return ptx.to_owned();
997        }
998
999        let mut output = String::with_capacity(ptx.len() + ptx.len() / 4);
1000        let mut added_param = false;
1001
1002        for line in ptx.lines() {
1003            let trimmed = line.trim();
1004            if trimmed.starts_with(".entry") && !added_param {
1005                output.push_str(line);
1006                output.push('\n');
1007                output.push_str("    // [oxicuda-debug] printf buffer parameter\n");
1008                output.push_str("    .param .u64 __oxicuda_printf_buf;\n");
1009                added_param = true;
1010                continue;
1011            }
1012
1013            if trimmed.starts_with("// PRINTF") {
1014                output.push_str("    // [oxicuda-debug] printf store sequence\n");
1015                output.push_str("    ld.param.u64 %rd_pbuf, [__oxicuda_printf_buf];\n");
1016                output.push_str("    atom.global.add.u32 %r_poff, [%rd_pbuf], 1;\n");
1017            } else {
1018                output.push_str(line);
1019                output.push('\n');
1020            }
1021        }
1022
1023        output
1024    }
1025
1026    /// Remove all OxiCUDA debug instrumentation from PTX source.
1027    pub fn strip_debug(&self, ptx: &str) -> String {
1028        let mut output = String::with_capacity(ptx.len());
1029        let mut skip_next = false;
1030
1031        for line in ptx.lines() {
1032            if skip_next {
1033                skip_next = false;
1034                continue;
1035            }
1036
1037            let trimmed = line.trim();
1038
1039            // Skip debug comment lines and the instruction immediately after.
1040            if trimmed.starts_with("// [oxicuda-debug]") {
1041                // Also skip the next instrumentation line.
1042                skip_next = true;
1043                continue;
1044            }
1045
1046            // Skip debug parameter declarations.
1047            if trimmed.contains("__oxicuda_debug_buf") || trimmed.contains("__oxicuda_printf_buf") {
1048                continue;
1049            }
1050
1051            output.push_str(line);
1052            output.push('\n');
1053        }
1054
1055        output
1056    }
1057}
1058
1059// ===========================================================================
1060// Tests
1061// ===========================================================================
1062
1063#[cfg(test)]
1064mod tests {
1065    use super::*;
1066
1067    // -- Config defaults --
1068
1069    #[test]
1070    fn config_default_values() {
1071        let cfg = KernelDebugConfig::default();
1072        assert_eq!(cfg.debug_level, DebugLevel::Info);
1073        assert!(cfg.enable_bounds_check);
1074        assert!(cfg.enable_nan_check);
1075        assert!(cfg.enable_inf_check);
1076        assert!(!cfg.enable_race_detection);
1077        assert_eq!(cfg.print_buffer_size, 1024 * 1024);
1078        assert_eq!(cfg.max_print_per_thread, 32);
1079    }
1080
1081    // -- KernelDebugger creation --
1082
1083    #[test]
1084    fn debugger_creation_with_config() {
1085        let cfg = KernelDebugConfig {
1086            debug_level: DebugLevel::Trace,
1087            enable_bounds_check: false,
1088            ..KernelDebugConfig::default()
1089        };
1090        let debugger = KernelDebugger::new(cfg);
1091        assert_eq!(debugger.config().debug_level, DebugLevel::Trace);
1092        assert!(!debugger.config().enable_bounds_check);
1093    }
1094
1095    // -- Session lifecycle --
1096
1097    #[test]
1098    fn debug_session_lifecycle() {
1099        let cfg = KernelDebugConfig::default();
1100        let mut debugger = KernelDebugger::new(cfg);
1101        let session = debugger.attach("test_kernel");
1102        assert!(session.is_ok());
1103        let session = session.expect("session");
1104        assert_eq!(session.kernel_name(), "test_kernel");
1105        assert!(session.events().is_empty());
1106
1107        // Attaching with empty name is an error.
1108        let err = debugger.attach("");
1109        assert!(err.is_err());
1110    }
1111
1112    // -- Breakpoints --
1113
1114    #[test]
1115    fn breakpoint_set_and_remove() {
1116        let mut debugger = KernelDebugger::new(KernelDebugConfig::default());
1117        let bp1 = debugger.set_breakpoint(42);
1118        let bp2 = debugger.set_breakpoint(100);
1119        assert_ne!(bp1, bp2);
1120
1121        assert!(debugger.remove_breakpoint(bp1));
1122        // Removing again should return false.
1123        assert!(!debugger.remove_breakpoint(bp1));
1124        // bp2 should still be present.
1125        assert!(debugger.remove_breakpoint(bp2));
1126    }
1127
1128    // -- Memory checker: valid access --
1129
1130    #[test]
1131    fn memory_checker_valid_access() {
1132        let checker = MemoryChecker::new(vec![MemoryRegion {
1133            base_address: 0x1000,
1134            size: 256,
1135            name: "buf_a".into(),
1136            is_readonly: false,
1137        }]);
1138        // Read within bounds.
1139        assert!(checker.check_access(0x1000, 16, false).is_none());
1140        // Write within bounds.
1141        assert!(checker.check_access(0x1080, 32, true).is_none());
1142    }
1143
1144    // -- Memory checker: OOB detection --
1145
1146    #[test]
1147    fn memory_checker_out_of_bounds() {
1148        let checker = MemoryChecker::new(vec![MemoryRegion {
1149            base_address: 0x1000,
1150            size: 256,
1151            name: "buf_a".into(),
1152            is_readonly: false,
1153        }]);
1154        // Access past the end.
1155        let ev = checker.check_access(0x1100, 16, false);
1156        assert!(ev.is_some());
1157        let ev = ev.expect("oob event");
1158        assert!(matches!(ev.event_type, DebugEventType::OutOfBounds { .. }));
1159
1160        // Completely outside.
1161        let ev2 = checker.check_access(0x5000, 4, true);
1162        assert!(ev2.is_some());
1163    }
1164
1165    // -- NaN detection in f32 --
1166
1167    #[test]
1168    fn nan_detection_f32() {
1169        let data = [1.0_f32, f32::NAN, 3.0, f32::NAN];
1170        let locs = NanInfChecker::check_f32(&data);
1171        assert_eq!(locs.len(), 2);
1172        assert_eq!(locs[0].index, 1);
1173        assert!(locs[0].is_nan);
1174        assert_eq!(locs[1].index, 3);
1175    }
1176
1177    // -- Inf detection in f64 --
1178
1179    #[test]
1180    fn inf_detection_f64() {
1181        let data = [1.0_f64, f64::INFINITY, f64::NEG_INFINITY, 4.0];
1182        let locs = NanInfChecker::check_f64(&data);
1183        assert_eq!(locs.len(), 2);
1184        assert!(!locs[0].is_nan);
1185        assert_eq!(locs[0].index, 1);
1186        assert!(!locs[1].is_nan);
1187        assert_eq!(locs[1].index, 2);
1188    }
1189
1190    // -- Printf buffer parsing --
1191
1192    #[test]
1193    fn printf_buffer_parsing() {
1194        let buf = PrintfBuffer::new(4096);
1195
1196        // Build a raw buffer with one entry containing one Int arg.
1197        let mut raw = Vec::new();
1198        // entry_count = 1
1199        raw.extend_from_slice(&1_u32.to_le_bytes());
1200        // thread_id (1,0,0)
1201        raw.extend_from_slice(&1_u32.to_le_bytes());
1202        raw.extend_from_slice(&0_u32.to_le_bytes());
1203        raw.extend_from_slice(&0_u32.to_le_bytes());
1204        // block_id (0,0,0)
1205        raw.extend_from_slice(&0_u32.to_le_bytes());
1206        raw.extend_from_slice(&0_u32.to_le_bytes());
1207        raw.extend_from_slice(&0_u32.to_le_bytes());
1208        // format string "val=%d"
1209        let fmt = b"val=%d";
1210        raw.extend_from_slice(&(fmt.len() as u32).to_le_bytes());
1211        raw.extend_from_slice(fmt);
1212        // arg_count = 1
1213        raw.extend_from_slice(&1_u32.to_le_bytes());
1214        // tag=0 (Int), value=42
1215        raw.push(0);
1216        raw.extend_from_slice(&42_i64.to_le_bytes());
1217
1218        let entries = buf.parse_entries(&raw);
1219        assert_eq!(entries.len(), 1);
1220        assert_eq!(entries[0].thread_id, (1, 0, 0));
1221        assert_eq!(entries[0].format_string, "val=%d");
1222        assert_eq!(entries[0].args.len(), 1);
1223        assert_eq!(entries[0].args[0], PrintfArg::Int(42));
1224    }
1225
1226    // -- Assertions --
1227
1228    #[test]
1229    fn assertion_checks() {
1230        // bounds: in range => None
1231        assert!(KernelAssertions::assert_bounds(5, 10, "arr").is_none());
1232        // bounds: out of range => Some
1233        let ev = KernelAssertions::assert_bounds(10, 10, "arr");
1234        assert!(ev.is_some());
1235
1236        // NaN
1237        assert!(KernelAssertions::assert_not_nan(1.0, "x").is_none());
1238        assert!(KernelAssertions::assert_not_nan(f64::NAN, "x").is_some());
1239
1240        // finite
1241        assert!(KernelAssertions::assert_finite(1.0, "x").is_none());
1242        assert!(KernelAssertions::assert_finite(f64::INFINITY, "x").is_some());
1243        assert!(KernelAssertions::assert_finite(f64::NAN, "x").is_some());
1244
1245        // positive
1246        assert!(KernelAssertions::assert_positive(1.0, "x").is_none());
1247        assert!(KernelAssertions::assert_positive(0.0, "x").is_some());
1248        assert!(KernelAssertions::assert_positive(-1.0, "x").is_some());
1249        assert!(KernelAssertions::assert_positive(f64::NAN, "x").is_some());
1250    }
1251
1252    // -- Event filtering --
1253
1254    #[test]
1255    fn debug_event_filtering() {
1256        let cfg = KernelDebugConfig::default();
1257        let mut debugger = KernelDebugger::new(cfg);
1258        let mut session = debugger.attach("filter_test").expect("session");
1259
1260        session.add_event(DebugEvent {
1261            event_type: DebugEventType::NanDetected {
1262                register: "f0".into(),
1263                value: f64::NAN,
1264            },
1265            thread_id: (0, 0, 0),
1266            block_id: (0, 0, 0),
1267            timestamp_ns: 100,
1268            message: "nan".into(),
1269        });
1270        session.add_event(DebugEvent {
1271            event_type: DebugEventType::OutOfBounds {
1272                address: 0xDEAD,
1273                size: 4,
1274            },
1275            thread_id: (1, 0, 0),
1276            block_id: (0, 0, 0),
1277            timestamp_ns: 200,
1278            message: "oob".into(),
1279        });
1280        session.add_event(DebugEvent {
1281            event_type: DebugEventType::NanDetected {
1282                register: "f1".into(),
1283                value: f64::NAN,
1284            },
1285            thread_id: (2, 0, 0),
1286            block_id: (0, 0, 0),
1287            timestamp_ns: 300,
1288            message: "nan2".into(),
1289        });
1290
1291        let nans = session.filter_events(&DebugEventType::NanDetected {
1292            register: String::new(),
1293            value: 0.0,
1294        });
1295        assert_eq!(nans.len(), 2);
1296
1297        let oobs = session.filter_events(&DebugEventType::OutOfBounds {
1298            address: 0,
1299            size: 0,
1300        });
1301        assert_eq!(oobs.len(), 1);
1302    }
1303
1304    // -- Summary statistics --
1305
1306    #[test]
1307    fn summary_statistics() {
1308        let cfg = KernelDebugConfig::default();
1309        let mut debugger = KernelDebugger::new(cfg);
1310        let mut session = debugger.attach("summary_test").expect("session");
1311
1312        session.add_event(DebugEvent {
1313            event_type: DebugEventType::NanDetected {
1314                register: "f0".into(),
1315                value: f64::NAN,
1316            },
1317            thread_id: (0, 0, 0),
1318            block_id: (0, 0, 0),
1319            timestamp_ns: 0,
1320            message: String::new(),
1321        });
1322        session.add_event(DebugEvent {
1323            event_type: DebugEventType::InfDetected {
1324                register: "f1".into(),
1325            },
1326            thread_id: (0, 0, 0),
1327            block_id: (0, 0, 0),
1328            timestamp_ns: 0,
1329            message: String::new(),
1330        });
1331        session.add_event(DebugEvent {
1332            event_type: DebugEventType::OutOfBounds {
1333                address: 0x100,
1334                size: 4,
1335            },
1336            thread_id: (0, 0, 0),
1337            block_id: (0, 0, 0),
1338            timestamp_ns: 0,
1339            message: String::new(),
1340        });
1341        session.add_event(DebugEvent {
1342            event_type: DebugEventType::RaceCondition { address: 0x200 },
1343            thread_id: (0, 0, 0),
1344            block_id: (0, 0, 0),
1345            timestamp_ns: 0,
1346            message: String::new(),
1347        });
1348
1349        let s = session.summary();
1350        assert_eq!(s.total_events, 4);
1351        assert_eq!(s.errors, 2); // OOB + race
1352        assert_eq!(s.warnings, 2); // NaN + Inf
1353        assert_eq!(s.nan_count, 1);
1354        assert_eq!(s.inf_count, 1);
1355        assert_eq!(s.oob_count, 1);
1356        assert_eq!(s.race_count, 1);
1357    }
1358
1359    // -- Format report --
1360
1361    #[test]
1362    fn format_report_output() {
1363        let cfg = KernelDebugConfig::default();
1364        let mut debugger = KernelDebugger::new(cfg);
1365        let mut session = debugger.attach("report_test").expect("session");
1366
1367        session.add_event(DebugEvent {
1368            event_type: DebugEventType::NanDetected {
1369                register: "f0".into(),
1370                value: f64::NAN,
1371            },
1372            thread_id: (0, 0, 0),
1373            block_id: (0, 0, 0),
1374            timestamp_ns: 42,
1375            message: "NaN found".into(),
1376        });
1377
1378        let report = session.format_report();
1379        assert!(report.contains("report_test"));
1380        assert!(report.contains("Total events: 1"));
1381        assert!(report.contains("NaN detected:  1"));
1382        assert!(report.contains("NaN found"));
1383        assert!(report.contains("=== End Report ==="));
1384    }
1385
1386    // -- PTX instrumentation: bounds checks --
1387
1388    #[test]
1389    fn ptx_instrumentation_bounds_checks() {
1390        let cfg = KernelDebugConfig::default();
1391        let inst = DebugPtxInstrumenter::new(&cfg);
1392
1393        let ptx = ".entry my_kernel {\n    ld.global.f32 %f0, [%rd0];\n    ret;\n}\n";
1394        let result = inst.instrument_bounds_checks(ptx);
1395
1396        assert!(result.contains("__oxicuda_debug_buf"));
1397        assert!(result.contains("setp.ge.u64"));
1398        assert!(result.contains("@%p_oob trap"));
1399    }
1400
1401    // -- PTX strip debug --
1402
1403    #[test]
1404    fn ptx_strip_debug_roundtrip() {
1405        let cfg = KernelDebugConfig::default();
1406        let inst = DebugPtxInstrumenter::new(&cfg);
1407
1408        let original = ".entry kern {\n    add.f32 %f0, %f1, %f2;\n    ret;\n}\n";
1409        let instrumented = inst.instrument_nan_checks(original);
1410        assert!(instrumented.contains("[oxicuda-debug]"));
1411
1412        let stripped = inst.strip_debug(&instrumented);
1413        // After stripping, no debug markers should remain.
1414        assert!(!stripped.contains("[oxicuda-debug]"));
1415        // Original instructions should still be present.
1416        assert!(stripped.contains("add.f32"));
1417        assert!(stripped.contains("ret;"));
1418    }
1419}