Skip to main content

oxicuda_driver/
debug.rs

1//! Kernel debugging utilities for OxiCUDA.
2//!
3//! This module provides tools for debugging GPU kernels without traditional
4//! debuggers. It includes memory checking, NaN/Inf detection, printf
5//! emulation, assertion support, and PTX instrumentation for automated
6//! bounds/NaN checking.
7//!
8//! # Architecture
9//!
10//! The debugging system is layered:
11//!
12//! 1. **`KernelDebugger`** — Top-level manager that creates debug sessions and
13//!    manages breakpoints / watchpoints.
14//! 2. **`DebugSession`** — Collects [`DebugEvent`]s for a single kernel launch.
15//! 3. **`MemoryChecker`** — Validates memory accesses against known allocations.
16//! 4. **`NanInfChecker`** — Scans host-side buffers for NaN / Inf values.
17//! 5. **`PrintfBuffer`** — Parses a raw byte buffer that emulates GPU `printf`.
18//! 6. **`KernelAssertions`** — Convenience assertion helpers that produce
19//!    [`DebugEvent`]s instead of panicking.
20//! 7. **`DebugPtxInstrumenter`** — Instruments PTX source for automated
21//!    bounds/NaN checking and printf support.
22//!
23//! # Example
24//!
25//! ```rust
26//! use oxicuda_driver::debug::*;
27//!
28//! let config = KernelDebugConfig::default();
29//! let mut debugger = KernelDebugger::new(config);
30//! let session = debugger.attach("my_kernel").unwrap();
31//! assert_eq!(session.kernel_name(), "my_kernel");
32//! ```
33
34use std::fmt;
35
36use crate::error::{CudaError, CudaResult};
37
38// ---------------------------------------------------------------------------
39// DebugLevel
40// ---------------------------------------------------------------------------
41
42/// Verbosity level for kernel debugging output.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
44pub enum DebugLevel {
45    /// No debug output.
46    Off,
47    /// Only errors.
48    Error,
49    /// Errors and warnings.
50    Warn,
51    /// Errors, warnings, and informational messages.
52    #[default]
53    Info,
54    /// Verbose debugging output.
55    Debug,
56    /// Maximum verbosity — every detail is logged.
57    Trace,
58}
59
60impl fmt::Display for DebugLevel {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        match self {
63            Self::Off => write!(f, "OFF"),
64            Self::Error => write!(f, "ERROR"),
65            Self::Warn => write!(f, "WARN"),
66            Self::Info => write!(f, "INFO"),
67            Self::Debug => write!(f, "DEBUG"),
68            Self::Trace => write!(f, "TRACE"),
69        }
70    }
71}
72
73// ---------------------------------------------------------------------------
74// KernelDebugConfig
75// ---------------------------------------------------------------------------
76
77/// Configuration for a kernel debug session.
78#[derive(Debug, Clone)]
79pub struct KernelDebugConfig {
80    /// The verbosity level for debug output.
81    pub debug_level: DebugLevel,
82    /// Whether to instrument bounds checking on memory accesses.
83    pub enable_bounds_check: bool,
84    /// Whether to detect NaN values in floating-point registers.
85    pub enable_nan_check: bool,
86    /// Whether to detect Inf values in floating-point registers.
87    pub enable_inf_check: bool,
88    /// Whether to detect potential race conditions.
89    pub enable_race_detection: bool,
90    /// Size of the GPU-side printf buffer in bytes.
91    pub print_buffer_size: usize,
92    /// Maximum number of printf calls per thread before truncation.
93    pub max_print_per_thread: usize,
94}
95
96impl Default for KernelDebugConfig {
97    fn default() -> Self {
98        Self {
99            debug_level: DebugLevel::Info,
100            enable_bounds_check: true,
101            enable_nan_check: true,
102            enable_inf_check: true,
103            enable_race_detection: false,
104            print_buffer_size: 1024 * 1024, // 1 MiB
105            max_print_per_thread: 32,
106        }
107    }
108}
109
110// ---------------------------------------------------------------------------
111// DebugEventType
112// ---------------------------------------------------------------------------
113
114/// The kind of debug event captured during kernel execution.
115#[derive(Debug, Clone, PartialEq)]
116pub enum DebugEventType {
117    /// A memory access was out of the allocated bounds.
118    OutOfBounds {
119        /// The faulting address.
120        address: u64,
121        /// The size of the attempted access in bytes.
122        size: usize,
123    },
124    /// A NaN was detected in a floating-point register.
125    NanDetected {
126        /// Register or variable name.
127        register: String,
128        /// The NaN bit-pattern reinterpreted as f64.
129        value: f64,
130    },
131    /// An infinity was detected in a floating-point register.
132    InfDetected {
133        /// Register or variable name.
134        register: String,
135    },
136    /// A potential race condition on a shared memory address.
137    RaceCondition {
138        /// The conflicting address.
139        address: u64,
140    },
141    /// A kernel-side assertion.
142    Assertion {
143        /// The assertion condition expression.
144        condition: String,
145        /// Source file name.
146        file: String,
147        /// Source line number.
148        line: u32,
149    },
150    /// A kernel-side printf invocation.
151    Printf {
152        /// The format string.
153        format: String,
154    },
155    /// A breakpoint was hit.
156    Breakpoint {
157        /// The breakpoint identifier.
158        id: u32,
159    },
160}
161
162impl DebugEventType {
163    /// Returns a short category tag suitable for filtering.
164    fn tag(&self) -> &'static str {
165        match self {
166            Self::OutOfBounds { .. } => "OOB",
167            Self::NanDetected { .. } => "NaN",
168            Self::InfDetected { .. } => "Inf",
169            Self::RaceCondition { .. } => "RACE",
170            Self::Assertion { .. } => "ASSERT",
171            Self::Printf { .. } => "PRINTF",
172            Self::Breakpoint { .. } => "BP",
173        }
174    }
175
176    /// Returns `true` when this variant has the same discriminant as `other`,
177    /// ignoring inner field values. Used by [`DebugSession::filter_events`].
178    fn same_kind(&self, other: &Self) -> bool {
179        std::mem::discriminant(self) == std::mem::discriminant(other)
180    }
181}
182
183// ---------------------------------------------------------------------------
184// DebugEvent
185// ---------------------------------------------------------------------------
186
187/// A single debug event captured during kernel execution.
188#[derive(Debug, Clone)]
189pub struct DebugEvent {
190    /// What kind of event occurred.
191    pub event_type: DebugEventType,
192    /// The CUDA thread index `(x, y, z)` that triggered the event.
193    pub thread_id: (u32, u32, u32),
194    /// The CUDA block index `(x, y, z)` that triggered the event.
195    pub block_id: (u32, u32, u32),
196    /// Timestamp in nanoseconds (monotonic, relative to session start).
197    pub timestamp_ns: u64,
198    /// Free-form human-readable message.
199    pub message: String,
200}
201
202impl fmt::Display for DebugEvent {
203    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204        write!(
205            f,
206            "[{tag}] block({bx},{by},{bz}) thread({tx},{ty},{tz}) @{ts}ns: {msg}",
207            tag = self.event_type.tag(),
208            bx = self.block_id.0,
209            by = self.block_id.1,
210            bz = self.block_id.2,
211            tx = self.thread_id.0,
212            ty = self.thread_id.1,
213            tz = self.thread_id.2,
214            ts = self.timestamp_ns,
215            msg = self.message,
216        )
217    }
218}
219
220// ---------------------------------------------------------------------------
221// WatchType
222// ---------------------------------------------------------------------------
223
224/// The kind of memory access a watchpoint monitors.
225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
226pub enum WatchType {
227    /// Trigger on reads.
228    Read,
229    /// Trigger on writes.
230    Write,
231    /// Trigger on both reads and writes.
232    ReadWrite,
233}
234
235// ---------------------------------------------------------------------------
236// Breakpoint / Watchpoint helpers
237// ---------------------------------------------------------------------------
238
239#[derive(Debug, Clone)]
240#[allow(dead_code)]
241struct Breakpoint {
242    id: u32,
243    line: u32,
244}
245
246#[derive(Debug, Clone)]
247#[allow(dead_code)]
248struct Watchpoint {
249    id: u32,
250    address: u64,
251    size: usize,
252    watch_type: WatchType,
253}
254
255// ---------------------------------------------------------------------------
256// KernelDebugger
257// ---------------------------------------------------------------------------
258
259/// Top-level kernel debugging manager.
260///
261/// Create one per debugging session group. It holds global breakpoints and
262/// watchpoints and spawns [`DebugSession`] instances per kernel launch.
263#[derive(Debug)]
264pub struct KernelDebugger {
265    config: KernelDebugConfig,
266    breakpoints: Vec<Breakpoint>,
267    watchpoints: Vec<Watchpoint>,
268    next_bp_id: u32,
269    next_wp_id: u32,
270}
271
272impl KernelDebugger {
273    /// Create a new kernel debugger with the given configuration.
274    pub fn new(config: KernelDebugConfig) -> Self {
275        Self {
276            config,
277            breakpoints: Vec::new(),
278            watchpoints: Vec::new(),
279            next_bp_id: 1,
280            next_wp_id: 1,
281        }
282    }
283
284    /// Attach the debugger to a kernel launch, returning a new debug session.
285    ///
286    /// On macOS this always succeeds with a synthetic (empty) session because
287    /// no actual GPU driver is available.
288    ///
289    /// # Errors
290    ///
291    /// Returns [`CudaError::InvalidValue`] if `kernel_name` is empty.
292    pub fn attach(&mut self, kernel_name: &str) -> CudaResult<DebugSession> {
293        if kernel_name.is_empty() {
294            return Err(CudaError::InvalidValue);
295        }
296        Ok(DebugSession {
297            kernel_name: kernel_name.to_owned(),
298            events: Vec::new(),
299            config: self.config.clone(),
300        })
301    }
302
303    /// Set a breakpoint at a PTX source line. Returns the breakpoint ID.
304    pub fn set_breakpoint(&mut self, line: u32) -> u32 {
305        let id = self.next_bp_id;
306        self.next_bp_id = self.next_bp_id.saturating_add(1);
307        self.breakpoints.push(Breakpoint { id, line });
308        id
309    }
310
311    /// Remove a breakpoint by ID. Returns `true` if it was found and removed.
312    pub fn remove_breakpoint(&mut self, bp_id: u32) -> bool {
313        let before = self.breakpoints.len();
314        self.breakpoints.retain(|bp| bp.id != bp_id);
315        self.breakpoints.len() < before
316    }
317
318    /// Set a memory watchpoint. Returns the watchpoint ID.
319    pub fn watchpoint(&mut self, address: u64, size: usize, watch_type: WatchType) -> u32 {
320        let id = self.next_wp_id;
321        self.next_wp_id = self.next_wp_id.saturating_add(1);
322        self.watchpoints.push(Watchpoint {
323            id,
324            address,
325            size,
326            watch_type,
327        });
328        id
329    }
330
331    /// Returns a reference to the current debug configuration.
332    pub fn config(&self) -> &KernelDebugConfig {
333        &self.config
334    }
335}
336
337// ---------------------------------------------------------------------------
338// DebugSummary
339// ---------------------------------------------------------------------------
340
341/// Aggregate statistics for a debug session.
342#[derive(Debug, Clone, Default, PartialEq, Eq)]
343pub struct DebugSummary {
344    /// Total number of debug events.
345    pub total_events: usize,
346    /// Number of error-level events (OOB, assertions, races).
347    pub errors: usize,
348    /// Number of warning-level events (NaN, Inf).
349    pub warnings: usize,
350    /// Number of NaN detection events.
351    pub nan_count: usize,
352    /// Number of Inf detection events.
353    pub inf_count: usize,
354    /// Number of out-of-bounds events.
355    pub oob_count: usize,
356    /// Number of race condition events.
357    pub race_count: usize,
358}
359
360// ---------------------------------------------------------------------------
361// DebugSession
362// ---------------------------------------------------------------------------
363
364/// An active debug session for a single kernel launch.
365///
366/// Collects [`DebugEvent`]s and provides analysis / reporting helpers.
367#[derive(Debug)]
368pub struct DebugSession {
369    kernel_name: String,
370    events: Vec<DebugEvent>,
371    #[allow(dead_code)]
372    config: KernelDebugConfig,
373}
374
375impl DebugSession {
376    /// The kernel name this session is attached to.
377    pub fn kernel_name(&self) -> &str {
378        &self.kernel_name
379    }
380
381    /// All events collected so far.
382    pub fn events(&self) -> &[DebugEvent] {
383        &self.events
384    }
385
386    /// Record a new debug event.
387    pub fn add_event(&mut self, event: DebugEvent) {
388        self.events.push(event);
389    }
390
391    /// Return references to events whose type matches the discriminant of
392    /// `event_type` (field values inside variants are ignored for matching).
393    pub fn filter_events(&self, event_type: &DebugEventType) -> Vec<&DebugEvent> {
394        self.events
395            .iter()
396            .filter(|e| e.event_type.same_kind(event_type))
397            .collect()
398    }
399
400    /// Compute aggregate statistics over all collected events.
401    pub fn summary(&self) -> DebugSummary {
402        let mut s = DebugSummary {
403            total_events: self.events.len(),
404            ..DebugSummary::default()
405        };
406        for ev in &self.events {
407            match &ev.event_type {
408                DebugEventType::OutOfBounds { .. } => {
409                    s.oob_count += 1;
410                    s.errors += 1;
411                }
412                DebugEventType::NanDetected { .. } => {
413                    s.nan_count += 1;
414                    s.warnings += 1;
415                }
416                DebugEventType::InfDetected { .. } => {
417                    s.inf_count += 1;
418                    s.warnings += 1;
419                }
420                DebugEventType::RaceCondition { .. } => {
421                    s.race_count += 1;
422                    s.errors += 1;
423                }
424                DebugEventType::Assertion { .. } => {
425                    s.errors += 1;
426                }
427                DebugEventType::Printf { .. } | DebugEventType::Breakpoint { .. } => {}
428            }
429        }
430        s
431    }
432
433    /// Produce a human-readable debug report.
434    pub fn format_report(&self) -> String {
435        let summary = self.summary();
436        let mut out = String::with_capacity(512);
437        out.push_str(&format!("=== Debug Report: {} ===\n", self.kernel_name));
438        out.push_str(&format!(
439            "Total events: {}  (errors: {}, warnings: {})\n",
440            summary.total_events, summary.errors, summary.warnings
441        ));
442        if summary.oob_count > 0 {
443            out.push_str(&format!("  Out-of-bounds: {}\n", summary.oob_count));
444        }
445        if summary.nan_count > 0 {
446            out.push_str(&format!("  NaN detected:  {}\n", summary.nan_count));
447        }
448        if summary.inf_count > 0 {
449            out.push_str(&format!("  Inf detected:  {}\n", summary.inf_count));
450        }
451        if summary.race_count > 0 {
452            out.push_str(&format!("  Race cond.:    {}\n", summary.race_count));
453        }
454        out.push_str("--- Events ---\n");
455        for ev in &self.events {
456            out.push_str(&format!("  {ev}\n"));
457        }
458        out.push_str("=== End Report ===\n");
459        out
460    }
461}
462
463// ---------------------------------------------------------------------------
464// MemoryRegion / MemoryChecker
465// ---------------------------------------------------------------------------
466
467/// A contiguous GPU memory allocation known to the memory checker.
468#[derive(Debug, Clone)]
469pub struct MemoryRegion {
470    /// Base device address of the allocation.
471    pub base_address: u64,
472    /// Size in bytes.
473    pub size: usize,
474    /// Human-readable name for diagnostics.
475    pub name: String,
476    /// Whether the allocation is read-only.
477    pub is_readonly: bool,
478}
479
480/// Validates memory accesses against a set of known [`MemoryRegion`]s.
481#[derive(Debug)]
482pub struct MemoryChecker {
483    allocations: Vec<MemoryRegion>,
484}
485
486impl MemoryChecker {
487    /// Create a memory checker from a list of known allocations.
488    pub fn new(allocations: Vec<MemoryRegion>) -> Self {
489        Self { allocations }
490    }
491
492    /// Check whether a memory access is valid.
493    ///
494    /// Returns `Some(DebugEvent)` if the access is out of bounds or violates
495    /// read-only protections; `None` if the access is valid.
496    pub fn check_access(&self, address: u64, size: usize, is_write: bool) -> Option<DebugEvent> {
497        // Find the allocation that contains this address.
498        let region = self.allocations.iter().find(|r| {
499            address >= r.base_address
500                && address
501                    .checked_add(size as u64)
502                    .is_some_and(|end| end <= r.base_address + r.size as u64)
503        });
504
505        match region {
506            Some(r) if is_write && r.is_readonly => Some(DebugEvent {
507                event_type: DebugEventType::OutOfBounds { address, size },
508                thread_id: (0, 0, 0),
509                block_id: (0, 0, 0),
510                timestamp_ns: 0,
511                message: format!("Write to read-only region '{}' at {:#x}", r.name, address),
512            }),
513            Some(_) => None,
514            None => Some(DebugEvent {
515                event_type: DebugEventType::OutOfBounds { address, size },
516                thread_id: (0, 0, 0),
517                block_id: (0, 0, 0),
518                timestamp_ns: 0,
519                message: format!(
520                    "Access at {:#x} (size {}) does not fall within any known allocation",
521                    address, size
522                ),
523            }),
524        }
525    }
526}
527
528// ---------------------------------------------------------------------------
529// NanInfChecker / NanInfLocation
530// ---------------------------------------------------------------------------
531
532/// Location of a NaN or Inf value found in a buffer.
533#[derive(Debug, Clone, PartialEq)]
534pub struct NanInfLocation {
535    /// Index into the buffer.
536    pub index: usize,
537    /// The problematic value (as f64 for uniform reporting).
538    pub value: f64,
539    /// `true` if NaN, `false` if Inf.
540    pub is_nan: bool,
541}
542
543/// Scans host-side buffers for NaN and Inf values.
544#[derive(Debug, Clone, Copy)]
545pub struct NanInfChecker;
546
547impl NanInfChecker {
548    /// Check an `f32` buffer for NaN and Inf values.
549    pub fn check_f32(data: &[f32]) -> Vec<NanInfLocation> {
550        data.iter()
551            .enumerate()
552            .filter_map(|(i, &v)| {
553                if v.is_nan() {
554                    Some(NanInfLocation {
555                        index: i,
556                        value: f64::from(v),
557                        is_nan: true,
558                    })
559                } else if v.is_infinite() {
560                    Some(NanInfLocation {
561                        index: i,
562                        value: f64::from(v),
563                        is_nan: false,
564                    })
565                } else {
566                    None
567                }
568            })
569            .collect()
570    }
571
572    /// Check an `f64` buffer for NaN and Inf values.
573    pub fn check_f64(data: &[f64]) -> Vec<NanInfLocation> {
574        data.iter()
575            .enumerate()
576            .filter_map(|(i, &v)| {
577                if v.is_nan() {
578                    Some(NanInfLocation {
579                        index: i,
580                        value: v,
581                        is_nan: true,
582                    })
583                } else if v.is_infinite() {
584                    Some(NanInfLocation {
585                        index: i,
586                        value: v,
587                        is_nan: false,
588                    })
589                } else {
590                    None
591                }
592            })
593            .collect()
594    }
595}
596
597// ---------------------------------------------------------------------------
598// PrintfBuffer / PrintfEntry / PrintfArg
599// ---------------------------------------------------------------------------
600
601/// A single argument captured from a GPU printf call.
602#[derive(Debug, Clone, PartialEq)]
603pub enum PrintfArg {
604    /// Integer argument.
605    Int(i64),
606    /// Floating-point argument.
607    Float(f64),
608    /// String argument.
609    String(String),
610}
611
612/// A parsed GPU-side printf entry.
613#[derive(Debug, Clone)]
614pub struct PrintfEntry {
615    /// Thread index that issued the printf.
616    pub thread_id: (u32, u32, u32),
617    /// Block index that issued the printf.
618    pub block_id: (u32, u32, u32),
619    /// The format string.
620    pub format_string: String,
621    /// Parsed arguments.
622    pub args: Vec<PrintfArg>,
623}
624
625/// GPU-side printf emulation buffer.
626///
627/// The buffer layout is a simple framed protocol:
628///
629/// ```text
630/// [entry_count: u32_le]
631/// repeated entry_count times:
632///   [thread_x: u32_le] [thread_y: u32_le] [thread_z: u32_le]
633///   [block_x:  u32_le] [block_y:  u32_le] [block_z:  u32_le]
634///   [fmt_len:  u32_le] [fmt_bytes: u8 * fmt_len]
635///   [arg_count: u32_le]
636///   repeated arg_count times:
637///     [tag: u8]  0=Int(i64_le), 1=Float(f64_le), 2=String(u32_le len + bytes)
638/// ```
639#[derive(Debug)]
640pub struct PrintfBuffer {
641    buffer_size: usize,
642}
643
644impl PrintfBuffer {
645    /// Create a printf buffer descriptor with the given maximum size.
646    pub fn new(buffer_size: usize) -> Self {
647        Self { buffer_size }
648    }
649
650    /// Returns the configured buffer size.
651    pub fn buffer_size(&self) -> usize {
652        self.buffer_size
653    }
654
655    /// Parse a raw byte buffer into structured printf entries.
656    ///
657    /// Returns an empty vec if the buffer is too small or malformed.
658    pub fn parse_entries(&self, raw: &[u8]) -> Vec<PrintfEntry> {
659        let mut entries = Vec::new();
660        let mut cursor = 0usize;
661
662        let entry_count = match Self::read_u32(raw, &mut cursor) {
663            Some(n) => n as usize,
664            None => return entries,
665        };
666
667        for _ in 0..entry_count {
668            let Some(entry) = self.parse_single_entry(raw, &mut cursor) else {
669                break;
670            };
671            entries.push(entry);
672        }
673
674        entries
675    }
676
677    fn parse_single_entry(&self, raw: &[u8], cursor: &mut usize) -> Option<PrintfEntry> {
678        let tx = Self::read_u32(raw, cursor)?;
679        let ty = Self::read_u32(raw, cursor)?;
680        let tz = Self::read_u32(raw, cursor)?;
681        let bx = Self::read_u32(raw, cursor)?;
682        let by = Self::read_u32(raw, cursor)?;
683        let bz = Self::read_u32(raw, cursor)?;
684
685        let fmt_len = Self::read_u32(raw, cursor)? as usize;
686        let fmt_bytes = Self::read_bytes(raw, cursor, fmt_len)?;
687        let format_string = String::from_utf8_lossy(fmt_bytes).into_owned();
688
689        let arg_count = Self::read_u32(raw, cursor)? as usize;
690        let mut args = Vec::with_capacity(arg_count);
691        for _ in 0..arg_count {
692            let tag = Self::read_u8(raw, cursor)?;
693            let arg = match tag {
694                0 => {
695                    let val = Self::read_i64(raw, cursor)?;
696                    PrintfArg::Int(val)
697                }
698                1 => {
699                    let val = Self::read_f64(raw, cursor)?;
700                    PrintfArg::Float(val)
701                }
702                2 => {
703                    let slen = Self::read_u32(raw, cursor)? as usize;
704                    let sbytes = Self::read_bytes(raw, cursor, slen)?;
705                    PrintfArg::String(String::from_utf8_lossy(sbytes).into_owned())
706                }
707                _ => return None,
708            };
709            args.push(arg);
710        }
711
712        Some(PrintfEntry {
713            thread_id: (tx, ty, tz),
714            block_id: (bx, by, bz),
715            format_string,
716            args,
717        })
718    }
719
720    // --- Low-level readers ---
721
722    fn read_u8(raw: &[u8], cursor: &mut usize) -> Option<u8> {
723        if *cursor >= raw.len() {
724            return None;
725        }
726        let val = raw[*cursor];
727        *cursor += 1;
728        Some(val)
729    }
730
731    fn read_u32(raw: &[u8], cursor: &mut usize) -> Option<u32> {
732        if *cursor + 4 > raw.len() {
733            return None;
734        }
735        let bytes: [u8; 4] = raw[*cursor..*cursor + 4].try_into().ok()?;
736        *cursor += 4;
737        Some(u32::from_le_bytes(bytes))
738    }
739
740    fn read_i64(raw: &[u8], cursor: &mut usize) -> Option<i64> {
741        if *cursor + 8 > raw.len() {
742            return None;
743        }
744        let bytes: [u8; 8] = raw[*cursor..*cursor + 8].try_into().ok()?;
745        *cursor += 8;
746        Some(i64::from_le_bytes(bytes))
747    }
748
749    fn read_f64(raw: &[u8], cursor: &mut usize) -> Option<f64> {
750        if *cursor + 8 > raw.len() {
751            return None;
752        }
753        let bytes: [u8; 8] = raw[*cursor..*cursor + 8].try_into().ok()?;
754        *cursor += 8;
755        Some(f64::from_le_bytes(bytes))
756    }
757
758    fn read_bytes<'a>(raw: &'a [u8], cursor: &mut usize, len: usize) -> Option<&'a [u8]> {
759        if *cursor + len > raw.len() {
760            return None;
761        }
762        let slice = &raw[*cursor..*cursor + len];
763        *cursor += len;
764        Some(slice)
765    }
766}
767
768// ---------------------------------------------------------------------------
769// KernelAssertions
770// ---------------------------------------------------------------------------
771
772/// Convenience assertion helpers that produce [`DebugEvent`]s instead of
773/// panicking, suitable for GPU kernel emulation / validation.
774#[derive(Debug, Clone, Copy)]
775pub struct KernelAssertions;
776
777impl KernelAssertions {
778    /// Assert that `index < len`. Returns an event if the assertion fails.
779    pub fn assert_bounds(index: usize, len: usize, name: &str) -> Option<DebugEvent> {
780        if index < len {
781            return None;
782        }
783        Some(DebugEvent {
784            event_type: DebugEventType::Assertion {
785                condition: format!("{name}[{index}] < {len}"),
786                file: String::new(),
787                line: 0,
788            },
789            thread_id: (0, 0, 0),
790            block_id: (0, 0, 0),
791            timestamp_ns: 0,
792            message: format!("Bounds check failed: {name}[{index}] out of range (len={len})"),
793        })
794    }
795
796    /// Assert that `value` is not NaN. Returns an event if it is.
797    pub fn assert_not_nan(value: f64, name: &str) -> Option<DebugEvent> {
798        if !value.is_nan() {
799            return None;
800        }
801        Some(DebugEvent {
802            event_type: DebugEventType::NanDetected {
803                register: name.to_owned(),
804                value,
805            },
806            thread_id: (0, 0, 0),
807            block_id: (0, 0, 0),
808            timestamp_ns: 0,
809            message: format!("NaN detected in '{name}'"),
810        })
811    }
812
813    /// Assert that `value` is finite (not NaN and not Inf). Returns an event
814    /// if it is not.
815    pub fn assert_finite(value: f64, name: &str) -> Option<DebugEvent> {
816        if value.is_finite() {
817            return None;
818        }
819        if value.is_nan() {
820            return Some(DebugEvent {
821                event_type: DebugEventType::NanDetected {
822                    register: name.to_owned(),
823                    value,
824                },
825                thread_id: (0, 0, 0),
826                block_id: (0, 0, 0),
827                timestamp_ns: 0,
828                message: format!("Non-finite (NaN) value in '{name}'"),
829            });
830        }
831        Some(DebugEvent {
832            event_type: DebugEventType::InfDetected {
833                register: name.to_owned(),
834            },
835            thread_id: (0, 0, 0),
836            block_id: (0, 0, 0),
837            timestamp_ns: 0,
838            message: format!("Non-finite (Inf) value in '{name}'"),
839        })
840    }
841
842    /// Assert that `value` is strictly positive. Returns an event if it is
843    /// not (including NaN, zero, and negative values).
844    pub fn assert_positive(value: f64, name: &str) -> Option<DebugEvent> {
845        if value > 0.0 {
846            return None;
847        }
848        if value.is_nan() {
849            return Some(DebugEvent {
850                event_type: DebugEventType::NanDetected {
851                    register: name.to_owned(),
852                    value,
853                },
854                thread_id: (0, 0, 0),
855                block_id: (0, 0, 0),
856                timestamp_ns: 0,
857                message: format!("Expected positive value for '{name}', got NaN"),
858            });
859        }
860        Some(DebugEvent {
861            event_type: DebugEventType::Assertion {
862                condition: format!("{name} > 0"),
863                file: String::new(),
864                line: 0,
865            },
866            thread_id: (0, 0, 0),
867            block_id: (0, 0, 0),
868            timestamp_ns: 0,
869            message: format!("Expected positive value for '{name}', got {value}"),
870        })
871    }
872}
873
874// ---------------------------------------------------------------------------
875// DebugPtxInstrumenter
876// ---------------------------------------------------------------------------
877
878/// Instruments PTX source code with debugging checks.
879///
880/// This instrumenter inserts additional PTX instructions for bounds checking,
881/// NaN detection, and printf buffer support. The instrumented code writes
882/// diagnostic data to a designated debug buffer that the host can read back
883/// after kernel execution.
884#[derive(Debug)]
885pub struct DebugPtxInstrumenter {
886    enable_bounds_check: bool,
887    enable_nan_check: bool,
888    enable_printf: bool,
889}
890
891impl DebugPtxInstrumenter {
892    /// Create an instrumenter from a debug configuration.
893    pub fn new(config: &KernelDebugConfig) -> Self {
894        Self {
895            enable_bounds_check: config.enable_bounds_check,
896            enable_nan_check: config.enable_nan_check,
897            enable_printf: config.print_buffer_size > 0,
898        }
899    }
900
901    /// Insert bounds-checking instrumentation into PTX source.
902    ///
903    /// Adds `setp` + `trap` sequences after every `ld.global` / `st.global`
904    /// instruction to validate the address against the allocation size
905    /// parameter.
906    pub fn instrument_bounds_checks(&self, ptx: &str) -> String {
907        if !self.enable_bounds_check {
908            return ptx.to_owned();
909        }
910
911        let mut output = String::with_capacity(ptx.len() + ptx.len() / 4);
912        // Add debug parameter declaration at the top of each kernel entry.
913        let mut added_param = false;
914
915        for line in ptx.lines() {
916            let trimmed = line.trim();
917            // Insert debug buffer param after .entry
918            if trimmed.starts_with(".entry") && !added_param {
919                output.push_str(line);
920                output.push('\n');
921                output.push_str("    // [oxicuda-debug] bounds-check instrumentation\n");
922                output.push_str("    .param .u64 __oxicuda_debug_buf;\n");
923                added_param = true;
924                continue;
925            }
926
927            // Instrument global loads/stores
928            if (trimmed.starts_with("ld.global") || trimmed.starts_with("st.global"))
929                && !trimmed.starts_with("// [oxicuda-debug]")
930            {
931                output.push_str(line);
932                output.push('\n');
933                output.push_str("    // [oxicuda-debug] bounds check for above access\n");
934                output.push_str("    setp.ge.u64 %p_oob, %rd_addr, %rd_alloc_end;\n");
935                output.push_str("    @%p_oob trap;\n");
936            } else {
937                output.push_str(line);
938                output.push('\n');
939            }
940        }
941
942        output
943    }
944
945    /// Insert NaN-detection instrumentation into PTX source.
946    ///
947    /// After every floating-point arithmetic instruction (`add.f32`,
948    /// `mul.f64`, etc.) a `testp.nan` check is inserted.
949    pub fn instrument_nan_checks(&self, ptx: &str) -> String {
950        if !self.enable_nan_check {
951            return ptx.to_owned();
952        }
953
954        let fp_ops = [
955            "add.f32", "add.f64", "sub.f32", "sub.f64", "mul.f32", "mul.f64", "div.f32", "div.f64",
956            "fma.f32", "fma.f64",
957        ];
958
959        let mut output = String::with_capacity(ptx.len() + ptx.len() / 4);
960
961        for line in ptx.lines() {
962            output.push_str(line);
963            output.push('\n');
964
965            let trimmed = line.trim();
966            if fp_ops.iter().any(|op| trimmed.starts_with(op)) {
967                // Extract the destination register (first token after the op).
968                if let Some(dest) = trimmed.split_whitespace().nth(1) {
969                    let dest_clean = dest.trim_end_matches(',');
970                    let width = if trimmed.contains(".f64") {
971                        "f64"
972                    } else {
973                        "f32"
974                    };
975                    output.push_str(&format!(
976                        "    // [oxicuda-debug] NaN check for {dest_clean}\n"
977                    ));
978                    output.push_str(&format!("    testp.nan.{width} %p_nan, {dest_clean};\n"));
979                    output.push_str("    @%p_nan trap;\n");
980                }
981            }
982        }
983
984        output
985    }
986
987    /// Insert printf buffer support into PTX source.
988    ///
989    /// Adds a `.param .u64 __oxicuda_printf_buf` to each entry and inserts
990    /// stub store sequences where `// PRINTF` markers appear.
991    pub fn instrument_printf(&self, ptx: &str) -> String {
992        if !self.enable_printf {
993            return ptx.to_owned();
994        }
995
996        let mut output = String::with_capacity(ptx.len() + ptx.len() / 4);
997        let mut added_param = false;
998
999        for line in ptx.lines() {
1000            let trimmed = line.trim();
1001            if trimmed.starts_with(".entry") && !added_param {
1002                output.push_str(line);
1003                output.push('\n');
1004                output.push_str("    // [oxicuda-debug] printf buffer parameter\n");
1005                output.push_str("    .param .u64 __oxicuda_printf_buf;\n");
1006                added_param = true;
1007                continue;
1008            }
1009
1010            if trimmed.starts_with("// PRINTF") {
1011                output.push_str("    // [oxicuda-debug] printf store sequence\n");
1012                output.push_str("    ld.param.u64 %rd_pbuf, [__oxicuda_printf_buf];\n");
1013                output.push_str("    atom.global.add.u32 %r_poff, [%rd_pbuf], 1;\n");
1014            } else {
1015                output.push_str(line);
1016                output.push('\n');
1017            }
1018        }
1019
1020        output
1021    }
1022
1023    /// Remove all OxiCUDA debug instrumentation from PTX source.
1024    pub fn strip_debug(&self, ptx: &str) -> String {
1025        let mut output = String::with_capacity(ptx.len());
1026        let mut skip_next = false;
1027
1028        for line in ptx.lines() {
1029            if skip_next {
1030                skip_next = false;
1031                continue;
1032            }
1033
1034            let trimmed = line.trim();
1035
1036            // Skip debug comment lines and the instruction immediately after.
1037            if trimmed.starts_with("// [oxicuda-debug]") {
1038                // Also skip the next instrumentation line.
1039                skip_next = true;
1040                continue;
1041            }
1042
1043            // Skip debug parameter declarations.
1044            if trimmed.contains("__oxicuda_debug_buf") || trimmed.contains("__oxicuda_printf_buf") {
1045                continue;
1046            }
1047
1048            output.push_str(line);
1049            output.push('\n');
1050        }
1051
1052        output
1053    }
1054}
1055
1056// ===========================================================================
1057// Tests
1058// ===========================================================================
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063
1064    // -- Config defaults --
1065
1066    #[test]
1067    fn config_default_values() {
1068        let cfg = KernelDebugConfig::default();
1069        assert_eq!(cfg.debug_level, DebugLevel::Info);
1070        assert!(cfg.enable_bounds_check);
1071        assert!(cfg.enable_nan_check);
1072        assert!(cfg.enable_inf_check);
1073        assert!(!cfg.enable_race_detection);
1074        assert_eq!(cfg.print_buffer_size, 1024 * 1024);
1075        assert_eq!(cfg.max_print_per_thread, 32);
1076    }
1077
1078    // -- KernelDebugger creation --
1079
1080    #[test]
1081    fn debugger_creation_with_config() {
1082        let cfg = KernelDebugConfig {
1083            debug_level: DebugLevel::Trace,
1084            enable_bounds_check: false,
1085            ..KernelDebugConfig::default()
1086        };
1087        let debugger = KernelDebugger::new(cfg);
1088        assert_eq!(debugger.config().debug_level, DebugLevel::Trace);
1089        assert!(!debugger.config().enable_bounds_check);
1090    }
1091
1092    // -- Session lifecycle --
1093
1094    #[test]
1095    fn debug_session_lifecycle() {
1096        let cfg = KernelDebugConfig::default();
1097        let mut debugger = KernelDebugger::new(cfg);
1098        let session = debugger.attach("test_kernel");
1099        assert!(session.is_ok());
1100        let session = session.expect("session");
1101        assert_eq!(session.kernel_name(), "test_kernel");
1102        assert!(session.events().is_empty());
1103
1104        // Attaching with empty name is an error.
1105        let err = debugger.attach("");
1106        assert!(err.is_err());
1107    }
1108
1109    // -- Breakpoints --
1110
1111    #[test]
1112    fn breakpoint_set_and_remove() {
1113        let mut debugger = KernelDebugger::new(KernelDebugConfig::default());
1114        let bp1 = debugger.set_breakpoint(42);
1115        let bp2 = debugger.set_breakpoint(100);
1116        assert_ne!(bp1, bp2);
1117
1118        assert!(debugger.remove_breakpoint(bp1));
1119        // Removing again should return false.
1120        assert!(!debugger.remove_breakpoint(bp1));
1121        // bp2 should still be present.
1122        assert!(debugger.remove_breakpoint(bp2));
1123    }
1124
1125    // -- Memory checker: valid access --
1126
1127    #[test]
1128    fn memory_checker_valid_access() {
1129        let checker = MemoryChecker::new(vec![MemoryRegion {
1130            base_address: 0x1000,
1131            size: 256,
1132            name: "buf_a".into(),
1133            is_readonly: false,
1134        }]);
1135        // Read within bounds.
1136        assert!(checker.check_access(0x1000, 16, false).is_none());
1137        // Write within bounds.
1138        assert!(checker.check_access(0x1080, 32, true).is_none());
1139    }
1140
1141    // -- Memory checker: OOB detection --
1142
1143    #[test]
1144    fn memory_checker_out_of_bounds() {
1145        let checker = MemoryChecker::new(vec![MemoryRegion {
1146            base_address: 0x1000,
1147            size: 256,
1148            name: "buf_a".into(),
1149            is_readonly: false,
1150        }]);
1151        // Access past the end.
1152        let ev = checker.check_access(0x1100, 16, false);
1153        assert!(ev.is_some());
1154        let ev = ev.expect("oob event");
1155        assert!(matches!(ev.event_type, DebugEventType::OutOfBounds { .. }));
1156
1157        // Completely outside.
1158        let ev2 = checker.check_access(0x5000, 4, true);
1159        assert!(ev2.is_some());
1160    }
1161
1162    // -- NaN detection in f32 --
1163
1164    #[test]
1165    fn nan_detection_f32() {
1166        let data = [1.0_f32, f32::NAN, 3.0, f32::NAN];
1167        let locs = NanInfChecker::check_f32(&data);
1168        assert_eq!(locs.len(), 2);
1169        assert_eq!(locs[0].index, 1);
1170        assert!(locs[0].is_nan);
1171        assert_eq!(locs[1].index, 3);
1172    }
1173
1174    // -- Inf detection in f64 --
1175
1176    #[test]
1177    fn inf_detection_f64() {
1178        let data = [1.0_f64, f64::INFINITY, f64::NEG_INFINITY, 4.0];
1179        let locs = NanInfChecker::check_f64(&data);
1180        assert_eq!(locs.len(), 2);
1181        assert!(!locs[0].is_nan);
1182        assert_eq!(locs[0].index, 1);
1183        assert!(!locs[1].is_nan);
1184        assert_eq!(locs[1].index, 2);
1185    }
1186
1187    // -- Printf buffer parsing --
1188
1189    #[test]
1190    fn printf_buffer_parsing() {
1191        let buf = PrintfBuffer::new(4096);
1192
1193        // Build a raw buffer with one entry containing one Int arg.
1194        let mut raw = Vec::new();
1195        // entry_count = 1
1196        raw.extend_from_slice(&1_u32.to_le_bytes());
1197        // thread_id (1,0,0)
1198        raw.extend_from_slice(&1_u32.to_le_bytes());
1199        raw.extend_from_slice(&0_u32.to_le_bytes());
1200        raw.extend_from_slice(&0_u32.to_le_bytes());
1201        // block_id (0,0,0)
1202        raw.extend_from_slice(&0_u32.to_le_bytes());
1203        raw.extend_from_slice(&0_u32.to_le_bytes());
1204        raw.extend_from_slice(&0_u32.to_le_bytes());
1205        // format string "val=%d"
1206        let fmt = b"val=%d";
1207        raw.extend_from_slice(&(fmt.len() as u32).to_le_bytes());
1208        raw.extend_from_slice(fmt);
1209        // arg_count = 1
1210        raw.extend_from_slice(&1_u32.to_le_bytes());
1211        // tag=0 (Int), value=42
1212        raw.push(0);
1213        raw.extend_from_slice(&42_i64.to_le_bytes());
1214
1215        let entries = buf.parse_entries(&raw);
1216        assert_eq!(entries.len(), 1);
1217        assert_eq!(entries[0].thread_id, (1, 0, 0));
1218        assert_eq!(entries[0].format_string, "val=%d");
1219        assert_eq!(entries[0].args.len(), 1);
1220        assert_eq!(entries[0].args[0], PrintfArg::Int(42));
1221    }
1222
1223    // -- Assertions --
1224
1225    #[test]
1226    fn assertion_checks() {
1227        // bounds: in range => None
1228        assert!(KernelAssertions::assert_bounds(5, 10, "arr").is_none());
1229        // bounds: out of range => Some
1230        let ev = KernelAssertions::assert_bounds(10, 10, "arr");
1231        assert!(ev.is_some());
1232
1233        // NaN
1234        assert!(KernelAssertions::assert_not_nan(1.0, "x").is_none());
1235        assert!(KernelAssertions::assert_not_nan(f64::NAN, "x").is_some());
1236
1237        // finite
1238        assert!(KernelAssertions::assert_finite(1.0, "x").is_none());
1239        assert!(KernelAssertions::assert_finite(f64::INFINITY, "x").is_some());
1240        assert!(KernelAssertions::assert_finite(f64::NAN, "x").is_some());
1241
1242        // positive
1243        assert!(KernelAssertions::assert_positive(1.0, "x").is_none());
1244        assert!(KernelAssertions::assert_positive(0.0, "x").is_some());
1245        assert!(KernelAssertions::assert_positive(-1.0, "x").is_some());
1246        assert!(KernelAssertions::assert_positive(f64::NAN, "x").is_some());
1247    }
1248
1249    // -- Event filtering --
1250
1251    #[test]
1252    fn debug_event_filtering() {
1253        let cfg = KernelDebugConfig::default();
1254        let mut debugger = KernelDebugger::new(cfg);
1255        let mut session = debugger.attach("filter_test").expect("session");
1256
1257        session.add_event(DebugEvent {
1258            event_type: DebugEventType::NanDetected {
1259                register: "f0".into(),
1260                value: f64::NAN,
1261            },
1262            thread_id: (0, 0, 0),
1263            block_id: (0, 0, 0),
1264            timestamp_ns: 100,
1265            message: "nan".into(),
1266        });
1267        session.add_event(DebugEvent {
1268            event_type: DebugEventType::OutOfBounds {
1269                address: 0xDEAD,
1270                size: 4,
1271            },
1272            thread_id: (1, 0, 0),
1273            block_id: (0, 0, 0),
1274            timestamp_ns: 200,
1275            message: "oob".into(),
1276        });
1277        session.add_event(DebugEvent {
1278            event_type: DebugEventType::NanDetected {
1279                register: "f1".into(),
1280                value: f64::NAN,
1281            },
1282            thread_id: (2, 0, 0),
1283            block_id: (0, 0, 0),
1284            timestamp_ns: 300,
1285            message: "nan2".into(),
1286        });
1287
1288        let nans = session.filter_events(&DebugEventType::NanDetected {
1289            register: String::new(),
1290            value: 0.0,
1291        });
1292        assert_eq!(nans.len(), 2);
1293
1294        let oobs = session.filter_events(&DebugEventType::OutOfBounds {
1295            address: 0,
1296            size: 0,
1297        });
1298        assert_eq!(oobs.len(), 1);
1299    }
1300
1301    // -- Summary statistics --
1302
1303    #[test]
1304    fn summary_statistics() {
1305        let cfg = KernelDebugConfig::default();
1306        let mut debugger = KernelDebugger::new(cfg);
1307        let mut session = debugger.attach("summary_test").expect("session");
1308
1309        session.add_event(DebugEvent {
1310            event_type: DebugEventType::NanDetected {
1311                register: "f0".into(),
1312                value: f64::NAN,
1313            },
1314            thread_id: (0, 0, 0),
1315            block_id: (0, 0, 0),
1316            timestamp_ns: 0,
1317            message: String::new(),
1318        });
1319        session.add_event(DebugEvent {
1320            event_type: DebugEventType::InfDetected {
1321                register: "f1".into(),
1322            },
1323            thread_id: (0, 0, 0),
1324            block_id: (0, 0, 0),
1325            timestamp_ns: 0,
1326            message: String::new(),
1327        });
1328        session.add_event(DebugEvent {
1329            event_type: DebugEventType::OutOfBounds {
1330                address: 0x100,
1331                size: 4,
1332            },
1333            thread_id: (0, 0, 0),
1334            block_id: (0, 0, 0),
1335            timestamp_ns: 0,
1336            message: String::new(),
1337        });
1338        session.add_event(DebugEvent {
1339            event_type: DebugEventType::RaceCondition { address: 0x200 },
1340            thread_id: (0, 0, 0),
1341            block_id: (0, 0, 0),
1342            timestamp_ns: 0,
1343            message: String::new(),
1344        });
1345
1346        let s = session.summary();
1347        assert_eq!(s.total_events, 4);
1348        assert_eq!(s.errors, 2); // OOB + race
1349        assert_eq!(s.warnings, 2); // NaN + Inf
1350        assert_eq!(s.nan_count, 1);
1351        assert_eq!(s.inf_count, 1);
1352        assert_eq!(s.oob_count, 1);
1353        assert_eq!(s.race_count, 1);
1354    }
1355
1356    // -- Format report --
1357
1358    #[test]
1359    fn format_report_output() {
1360        let cfg = KernelDebugConfig::default();
1361        let mut debugger = KernelDebugger::new(cfg);
1362        let mut session = debugger.attach("report_test").expect("session");
1363
1364        session.add_event(DebugEvent {
1365            event_type: DebugEventType::NanDetected {
1366                register: "f0".into(),
1367                value: f64::NAN,
1368            },
1369            thread_id: (0, 0, 0),
1370            block_id: (0, 0, 0),
1371            timestamp_ns: 42,
1372            message: "NaN found".into(),
1373        });
1374
1375        let report = session.format_report();
1376        assert!(report.contains("report_test"));
1377        assert!(report.contains("Total events: 1"));
1378        assert!(report.contains("NaN detected:  1"));
1379        assert!(report.contains("NaN found"));
1380        assert!(report.contains("=== End Report ==="));
1381    }
1382
1383    // -- PTX instrumentation: bounds checks --
1384
1385    #[test]
1386    fn ptx_instrumentation_bounds_checks() {
1387        let cfg = KernelDebugConfig::default();
1388        let inst = DebugPtxInstrumenter::new(&cfg);
1389
1390        let ptx = ".entry my_kernel {\n    ld.global.f32 %f0, [%rd0];\n    ret;\n}\n";
1391        let result = inst.instrument_bounds_checks(ptx);
1392
1393        assert!(result.contains("__oxicuda_debug_buf"));
1394        assert!(result.contains("setp.ge.u64"));
1395        assert!(result.contains("@%p_oob trap"));
1396    }
1397
1398    // -- PTX strip debug --
1399
1400    #[test]
1401    fn ptx_strip_debug_roundtrip() {
1402        let cfg = KernelDebugConfig::default();
1403        let inst = DebugPtxInstrumenter::new(&cfg);
1404
1405        let original = ".entry kern {\n    add.f32 %f0, %f1, %f2;\n    ret;\n}\n";
1406        let instrumented = inst.instrument_nan_checks(original);
1407        assert!(instrumented.contains("[oxicuda-debug]"));
1408
1409        let stripped = inst.strip_debug(&instrumented);
1410        // After stripping, no debug markers should remain.
1411        assert!(!stripped.contains("[oxicuda-debug]"));
1412        // Original instructions should still be present.
1413        assert!(stripped.contains("add.f32"));
1414        assert!(stripped.contains("ret;"));
1415    }
1416}