Skip to main content

oxicuda_launch/
arg_serialize.rs

1//! Kernel argument serialization, Debug/Display formatting, and launch logging.
2//!
3//! This module provides infrastructure for serializing kernel arguments
4//! into a human-readable format, logging kernel launches, and producing
5//! aggregate launch summaries. It is useful for debugging, profiling,
6//! and tracing GPU kernel invocations.
7//!
8//! # Overview
9//!
10//! - [`ArgType`] — describes the data type of a serialized kernel argument.
11//! - [`SerializedArg`] — a single kernel argument with its type, name, and
12//!   string representation of its value.
13//! - [`LaunchLog`] — a complete record of a single kernel launch including
14//!   the kernel name, grid/block dimensions, shared memory, and arguments.
15//! - [`LaunchLogger`] — collects [`LaunchLog`] entries for analysis.
16//! - [`LaunchSummary`] — aggregate statistics (per-kernel launch counts, etc.).
17//! - [`SerializableKernelArgs`] — extends [`KernelArgs`]
18//!   with the ability to serialize arguments into [`SerializedArg`] form.
19//!
20//! # Example
21//!
22//! ```rust
23//! use oxicuda_launch::arg_serialize::*;
24//! use oxicuda_launch::{LaunchParams, Dim3};
25//!
26//! let arg = SerializedArg::new(Some("n".to_string()), ArgType::U32, "1024".to_string(), 4);
27//! assert_eq!(arg.name(), Some("n"));
28//! assert_eq!(arg.value_repr(), "1024");
29//!
30//! let params = LaunchParams::new(4u32, 256u32);
31//! let formatted = format_launch_params(&params);
32//! assert!(formatted.contains("grid"));
33//! ```
34
35use std::collections::HashMap;
36use std::fmt;
37use std::time::Instant;
38
39use crate::grid::Dim3;
40use crate::kernel::KernelArgs;
41use crate::params::LaunchParams;
42
43// ---------------------------------------------------------------------------
44// ArgType
45// ---------------------------------------------------------------------------
46
47/// Describes the data type of a serialized kernel argument.
48///
49/// Covers the common scalar types used in GPU kernels, a generic
50/// pointer type, and a [`Custom`](ArgType::Custom) variant for
51/// user-defined or composite types.
52#[derive(Debug, Clone, PartialEq, Eq, Hash)]
53pub enum ArgType {
54    /// Unsigned 8-bit integer (`u8`).
55    U8,
56    /// Unsigned 16-bit integer (`u16`).
57    U16,
58    /// Unsigned 32-bit integer (`u32`).
59    U32,
60    /// Unsigned 64-bit integer (`u64`).
61    U64,
62    /// Signed 8-bit integer (`i8`).
63    I8,
64    /// Signed 16-bit integer (`i16`).
65    I16,
66    /// Signed 32-bit integer (`i32`).
67    I32,
68    /// Signed 64-bit integer (`i64`).
69    I64,
70    /// 32-bit floating point (`f32`).
71    F32,
72    /// 64-bit floating point (`f64`).
73    F64,
74    /// A raw pointer (device or host).
75    Ptr,
76    /// A user-defined or composite type with a descriptive name.
77    Custom(String),
78}
79
80impl fmt::Display for ArgType {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        match self {
83            Self::U8 => write!(f, "u8"),
84            Self::U16 => write!(f, "u16"),
85            Self::U32 => write!(f, "u32"),
86            Self::U64 => write!(f, "u64"),
87            Self::I8 => write!(f, "i8"),
88            Self::I16 => write!(f, "i16"),
89            Self::I32 => write!(f, "i32"),
90            Self::I64 => write!(f, "i64"),
91            Self::F32 => write!(f, "f32"),
92            Self::F64 => write!(f, "f64"),
93            Self::Ptr => write!(f, "ptr"),
94            Self::Custom(name) => write!(f, "{name}"),
95        }
96    }
97}
98
99// ---------------------------------------------------------------------------
100// SerializedArg
101// ---------------------------------------------------------------------------
102
103/// A serialized representation of a single kernel argument.
104///
105/// Captures the argument's optional name, data type, a human-readable
106/// string representation of its value, and its size in bytes.
107#[derive(Debug, Clone)]
108pub struct SerializedArg {
109    /// Optional human-readable name for the argument (e.g., parameter name).
110    name: Option<String>,
111    /// The data type of the argument.
112    arg_type: ArgType,
113    /// A string representation of the argument's value.
114    value_repr: String,
115    /// Size of the argument in bytes.
116    size_bytes: usize,
117}
118
119impl SerializedArg {
120    /// Creates a new `SerializedArg`.
121    #[inline]
122    pub fn new(
123        name: Option<String>,
124        arg_type: ArgType,
125        value_repr: String,
126        size_bytes: usize,
127    ) -> Self {
128        Self {
129            name,
130            arg_type,
131            value_repr,
132            size_bytes,
133        }
134    }
135
136    /// Returns the optional name of this argument.
137    #[inline]
138    pub fn name(&self) -> Option<&str> {
139        self.name.as_deref()
140    }
141
142    /// Returns the data type of this argument.
143    #[inline]
144    pub fn arg_type(&self) -> &ArgType {
145        &self.arg_type
146    }
147
148    /// Returns the string representation of the argument value.
149    #[inline]
150    pub fn value_repr(&self) -> &str {
151        &self.value_repr
152    }
153
154    /// Returns the size of this argument in bytes.
155    #[inline]
156    pub fn size_bytes(&self) -> usize {
157        self.size_bytes
158    }
159
160    /// Returns the total size of all arguments in a slice.
161    pub fn total_size(args: &[Self]) -> usize {
162        args.iter().map(|a| a.size_bytes).sum()
163    }
164}
165
166impl fmt::Display for SerializedArg {
167    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168        match &self.name {
169            Some(name) => write!(f, "{name}: {} = {}", self.arg_type, self.value_repr),
170            None => write!(f, "{}: {}", self.arg_type, self.value_repr),
171        }
172    }
173}
174
175// ---------------------------------------------------------------------------
176// LaunchLog
177// ---------------------------------------------------------------------------
178
179/// A complete record of a single kernel launch.
180///
181/// Captures the kernel name, launch configuration (grid, block, shared
182/// memory), serialized arguments, and a timestamp. Named `LaunchLog`
183/// to avoid conflicting with [`LaunchRecord`](crate::LaunchRecord) in
184/// the `graph_launch` module.
185pub struct LaunchLog {
186    /// Name of the kernel function.
187    kernel_name: String,
188    /// Grid dimensions (number of thread blocks).
189    grid: Dim3,
190    /// Block dimensions (threads per block).
191    block: Dim3,
192    /// Dynamic shared memory in bytes.
193    shared_mem: u32,
194    /// Serialized kernel arguments.
195    args: Vec<SerializedArg>,
196    /// Timestamp when this launch was recorded.
197    timestamp: Instant,
198}
199
200impl LaunchLog {
201    /// Creates a new `LaunchLog` entry.
202    ///
203    /// The timestamp is set to the current instant.
204    pub fn new(
205        kernel_name: String,
206        grid: Dim3,
207        block: Dim3,
208        shared_mem: u32,
209        args: Vec<SerializedArg>,
210    ) -> Self {
211        Self {
212            kernel_name,
213            grid,
214            block,
215            shared_mem,
216            args,
217            timestamp: Instant::now(),
218        }
219    }
220
221    /// Creates a new `LaunchLog` from a kernel name, [`LaunchParams`], and args.
222    pub fn from_params(
223        kernel_name: String,
224        params: &LaunchParams,
225        args: Vec<SerializedArg>,
226    ) -> Self {
227        Self::new(
228            kernel_name,
229            params.grid,
230            params.block,
231            params.shared_mem_bytes,
232            args,
233        )
234    }
235
236    /// Returns the kernel function name.
237    #[inline]
238    pub fn kernel_name(&self) -> &str {
239        &self.kernel_name
240    }
241
242    /// Returns the grid dimensions.
243    #[inline]
244    pub fn grid(&self) -> Dim3 {
245        self.grid
246    }
247
248    /// Returns the block dimensions.
249    #[inline]
250    pub fn block(&self) -> Dim3 {
251        self.block
252    }
253
254    /// Returns the shared memory size in bytes.
255    #[inline]
256    pub fn shared_mem(&self) -> u32 {
257        self.shared_mem
258    }
259
260    /// Returns the serialized arguments.
261    #[inline]
262    pub fn args(&self) -> &[SerializedArg] {
263        &self.args
264    }
265
266    /// Returns the timestamp when this launch was recorded.
267    #[inline]
268    pub fn timestamp(&self) -> Instant {
269        self.timestamp
270    }
271
272    /// Returns the total number of threads in this launch.
273    #[inline]
274    pub fn total_threads(&self) -> u64 {
275        self.grid.total() as u64 * self.block.total() as u64
276    }
277}
278
279impl fmt::Display for LaunchLog {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        let grid_str = format!("({},{},{})", self.grid.x, self.grid.y, self.grid.z);
282        let block_str = format!("({},{},{})", self.block.x, self.block.y, self.block.z);
283        let args_str = format_args_inner(&self.args);
284        write!(
285            f,
286            "{}<<<{}, {}, {}>>>( {} )",
287            self.kernel_name, grid_str, block_str, self.shared_mem, args_str
288        )
289    }
290}
291
292impl fmt::Debug for LaunchLog {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        f.debug_struct("LaunchLog")
295            .field("kernel_name", &self.kernel_name)
296            .field("grid", &self.grid)
297            .field("block", &self.block)
298            .field("shared_mem", &self.shared_mem)
299            .field("args_count", &self.args.len())
300            .finish()
301    }
302}
303
304// ---------------------------------------------------------------------------
305// LaunchLogger
306// ---------------------------------------------------------------------------
307
308/// Collects [`LaunchLog`] entries for inspection and analysis.
309///
310/// Provides append-only storage of launch records with methods to
311/// retrieve entries, clear the log, and produce aggregate summaries.
312///
313/// # Example
314///
315/// ```rust
316/// use oxicuda_launch::arg_serialize::*;
317/// use oxicuda_launch::Dim3;
318///
319/// let mut logger = LaunchLogger::new();
320/// logger.log(LaunchLog::new("kern_a".into(), Dim3::x(4), Dim3::x(256), 0, vec![]));
321/// logger.log(LaunchLog::new("kern_a".into(), Dim3::x(8), Dim3::x(256), 0, vec![]));
322/// logger.log(LaunchLog::new("kern_b".into(), Dim3::x(1), Dim3::x(128), 0, vec![]));
323/// let summary = logger.summary();
324/// assert_eq!(summary.total_launches(), 3);
325/// ```
326#[derive(Debug)]
327pub struct LaunchLogger {
328    /// Stored launch log entries.
329    entries: Vec<LaunchLog>,
330}
331
332impl LaunchLogger {
333    /// Creates a new empty `LaunchLogger`.
334    #[inline]
335    pub fn new() -> Self {
336        Self {
337            entries: Vec::new(),
338        }
339    }
340
341    /// Appends a [`LaunchLog`] entry to the logger.
342    #[inline]
343    pub fn log(&mut self, record: LaunchLog) {
344        self.entries.push(record);
345    }
346
347    /// Returns a slice of all recorded launch log entries.
348    #[inline]
349    pub fn entries(&self) -> &[LaunchLog] {
350        &self.entries
351    }
352
353    /// Clears all recorded entries.
354    #[inline]
355    pub fn clear(&mut self) {
356        self.entries.clear();
357    }
358
359    /// Returns the number of recorded entries.
360    #[inline]
361    pub fn len(&self) -> usize {
362        self.entries.len()
363    }
364
365    /// Returns `true` if no entries have been recorded.
366    #[inline]
367    pub fn is_empty(&self) -> bool {
368        self.entries.is_empty()
369    }
370
371    /// Produces a [`LaunchSummary`] from all recorded entries.
372    ///
373    /// The summary aggregates per-kernel launch counts and provides
374    /// the total number of launches.
375    pub fn summary(&self) -> LaunchSummary {
376        let mut per_kernel: HashMap<String, KernelLaunchStats> = HashMap::new();
377        for entry in &self.entries {
378            let stats = per_kernel
379                .entry(entry.kernel_name.clone())
380                .or_insert_with(|| KernelLaunchStats {
381                    kernel_name: entry.kernel_name.clone(),
382                    launch_count: 0,
383                    total_threads: 0,
384                    total_shared_mem: 0,
385                });
386            stats.launch_count += 1;
387            stats.total_threads += entry.total_threads();
388            stats.total_shared_mem += u64::from(entry.shared_mem);
389        }
390        LaunchSummary {
391            total_launches: self.entries.len(),
392            per_kernel,
393        }
394    }
395}
396
397impl Default for LaunchLogger {
398    #[inline]
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404// ---------------------------------------------------------------------------
405// KernelLaunchStats
406// ---------------------------------------------------------------------------
407
408/// Per-kernel aggregate statistics within a [`LaunchSummary`].
409#[derive(Debug, Clone)]
410pub struct KernelLaunchStats {
411    /// The kernel function name.
412    kernel_name: String,
413    /// Number of times this kernel was launched.
414    launch_count: usize,
415    /// Total threads across all launches of this kernel.
416    total_threads: u64,
417    /// Total shared memory bytes requested across all launches.
418    total_shared_mem: u64,
419}
420
421impl KernelLaunchStats {
422    /// Returns the kernel function name.
423    #[inline]
424    pub fn kernel_name(&self) -> &str {
425        &self.kernel_name
426    }
427
428    /// Returns the number of launches recorded for this kernel.
429    #[inline]
430    pub fn launch_count(&self) -> usize {
431        self.launch_count
432    }
433
434    /// Returns the total number of threads across all launches.
435    #[inline]
436    pub fn total_threads(&self) -> u64 {
437        self.total_threads
438    }
439
440    /// Returns the total shared memory bytes across all launches.
441    #[inline]
442    pub fn total_shared_mem(&self) -> u64 {
443        self.total_shared_mem
444    }
445}
446
447impl fmt::Display for KernelLaunchStats {
448    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449        write!(
450            f,
451            "{}: {} launches, {} total threads, {} bytes shared mem",
452            self.kernel_name, self.launch_count, self.total_threads, self.total_shared_mem
453        )
454    }
455}
456
457// ---------------------------------------------------------------------------
458// LaunchSummary
459// ---------------------------------------------------------------------------
460
461/// Aggregate statistics over all recorded kernel launches.
462///
463/// Produced by [`LaunchLogger::summary`], this provides per-kernel
464/// launch counts and total launch counts for analysis and debugging.
465#[derive(Debug)]
466pub struct LaunchSummary {
467    /// Total number of kernel launches recorded.
468    total_launches: usize,
469    /// Per-kernel statistics, keyed by kernel function name.
470    per_kernel: HashMap<String, KernelLaunchStats>,
471}
472
473impl LaunchSummary {
474    /// Returns the total number of kernel launches across all kernels.
475    #[inline]
476    pub fn total_launches(&self) -> usize {
477        self.total_launches
478    }
479
480    /// Returns per-kernel statistics as a map keyed by kernel name.
481    #[inline]
482    pub fn per_kernel(&self) -> &HashMap<String, KernelLaunchStats> {
483        &self.per_kernel
484    }
485
486    /// Returns the number of distinct kernels that were launched.
487    #[inline]
488    pub fn unique_kernels(&self) -> usize {
489        self.per_kernel.len()
490    }
491
492    /// Returns the statistics for a specific kernel by name, if present.
493    #[inline]
494    pub fn kernel_stats(&self, name: &str) -> Option<&KernelLaunchStats> {
495        self.per_kernel.get(name)
496    }
497}
498
499impl fmt::Display for LaunchSummary {
500    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
501        writeln!(f, "LaunchSummary: {} total launches", self.total_launches)?;
502        let mut names: Vec<&String> = self.per_kernel.keys().collect();
503        names.sort();
504        for name in names {
505            if let Some(stats) = self.per_kernel.get(name) {
506                writeln!(f, "  {stats}")?;
507            }
508        }
509        Ok(())
510    }
511}
512
513// ---------------------------------------------------------------------------
514// SerializableKernelArgs trait
515// ---------------------------------------------------------------------------
516
517/// Extension trait for [`KernelArgs`] that can serialize arguments
518/// into [`SerializedArg`] form for logging and debugging.
519///
520/// # Safety
521///
522/// Implementors must uphold the same invariants as [`KernelArgs`].
523/// The serialized arguments must correspond one-to-one with the
524/// pointers returned by `as_param_ptrs`.
525pub unsafe trait SerializableKernelArgs: KernelArgs {
526    /// Serializes the kernel arguments into a vector of [`SerializedArg`].
527    fn serialize_args(&self) -> Vec<SerializedArg>;
528}
529
530// SerializableKernelArgs for () (no arguments)
531unsafe impl SerializableKernelArgs for () {
532    fn serialize_args(&self) -> Vec<SerializedArg> {
533        Vec::new()
534    }
535}
536
537// ---------------------------------------------------------------------------
538// Helper trait for individual argument serialization
539// ---------------------------------------------------------------------------
540
541/// Helper trait to serialize a single value into a [`SerializedArg`].
542///
543/// Implemented for all common scalar types used in GPU kernels.
544pub trait SerializeArg: Copy {
545    /// Returns the [`ArgType`] for this value.
546    fn arg_type() -> ArgType;
547
548    /// Returns a string representation of this value.
549    fn value_repr(&self) -> String;
550
551    /// Returns the size of this type in bytes.
552    fn size_bytes() -> usize;
553
554    /// Produces a [`SerializedArg`] with an optional name.
555    fn to_serialized(&self, name: Option<String>) -> SerializedArg {
556        SerializedArg::new(
557            name,
558            Self::arg_type(),
559            self.value_repr(),
560            Self::size_bytes(),
561        )
562    }
563}
564
565macro_rules! impl_serialize_arg_int {
566    ($ty:ty, $variant:ident) => {
567        impl SerializeArg for $ty {
568            #[inline]
569            fn arg_type() -> ArgType {
570                ArgType::$variant
571            }
572            #[inline]
573            fn value_repr(&self) -> String {
574                self.to_string()
575            }
576            #[inline]
577            fn size_bytes() -> usize {
578                std::mem::size_of::<$ty>()
579            }
580        }
581    };
582}
583
584impl_serialize_arg_int!(u8, U8);
585impl_serialize_arg_int!(u16, U16);
586impl_serialize_arg_int!(u32, U32);
587impl_serialize_arg_int!(u64, U64);
588impl_serialize_arg_int!(i8, I8);
589impl_serialize_arg_int!(i16, I16);
590impl_serialize_arg_int!(i32, I32);
591impl_serialize_arg_int!(i64, I64);
592
593impl SerializeArg for f32 {
594    #[inline]
595    fn arg_type() -> ArgType {
596        ArgType::F32
597    }
598    #[inline]
599    fn value_repr(&self) -> String {
600        if self.fract() == 0.0 && self.is_finite() {
601            format!("{self:.1}")
602        } else {
603            format!("{self}")
604        }
605    }
606    #[inline]
607    fn size_bytes() -> usize {
608        4
609    }
610}
611
612impl SerializeArg for f64 {
613    #[inline]
614    fn arg_type() -> ArgType {
615        ArgType::F64
616    }
617    #[inline]
618    fn value_repr(&self) -> String {
619        if self.fract() == 0.0 && self.is_finite() {
620            format!("{self:.1}")
621        } else {
622            format!("{self}")
623        }
624    }
625    #[inline]
626    fn size_bytes() -> usize {
627        8
628    }
629}
630
631impl SerializeArg for usize {
632    #[inline]
633    fn arg_type() -> ArgType {
634        ArgType::Ptr
635    }
636    #[inline]
637    fn value_repr(&self) -> String {
638        format!("0x{self:x}")
639    }
640    #[inline]
641    fn size_bytes() -> usize {
642        std::mem::size_of::<usize>()
643    }
644}
645
646impl SerializeArg for isize {
647    #[inline]
648    fn arg_type() -> ArgType {
649        ArgType::Ptr
650    }
651    #[inline]
652    fn value_repr(&self) -> String {
653        format!("0x{self:x}")
654    }
655    #[inline]
656    fn size_bytes() -> usize {
657        std::mem::size_of::<isize>()
658    }
659}
660
661// ---------------------------------------------------------------------------
662// Macro-generated SerializableKernelArgs for tuples
663// ---------------------------------------------------------------------------
664
665macro_rules! impl_serializable_kernel_args_tuple {
666    ($($idx:tt: $T:ident),+) => {
667        /// # Safety
668        ///
669        /// The serialized arguments correspond one-to-one with the pointers
670        /// from `as_param_ptrs`.
671        unsafe impl<$($T: Copy + SerializeArg),+> SerializableKernelArgs for ($($T,)+) {
672            fn serialize_args(&self) -> Vec<SerializedArg> {
673                vec![
674                    $(self.$idx.to_serialized(Some(format!("arg{}", $idx))),)+
675                ]
676            }
677        }
678    };
679}
680
681impl_serializable_kernel_args_tuple!(0: A);
682impl_serializable_kernel_args_tuple!(0: A, 1: B);
683impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C);
684impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
685impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
686impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
687impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
688impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
689impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
690impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
691impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
692impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
693
694// ---------------------------------------------------------------------------
695// Formatting helpers
696// ---------------------------------------------------------------------------
697
698/// Pretty-prints a [`LaunchParams`] configuration.
699///
700/// Produces a string like `"grid=(4,1,1) block=(256,1,1) smem=0"`.
701pub fn format_launch_params(params: &LaunchParams) -> String {
702    format!(
703        "grid=({},{},{}) block=({},{},{}) smem={}",
704        params.grid.x,
705        params.grid.y,
706        params.grid.z,
707        params.block.x,
708        params.block.y,
709        params.block.z,
710        params.shared_mem_bytes,
711    )
712}
713
714/// Pretty-prints a slice of [`SerializedArg`] values.
715///
716/// Produces a comma-separated string of argument representations.
717/// Each argument is formatted using its [`Display`](fmt::Display) impl.
718pub fn format_args(args: &[SerializedArg]) -> String {
719    format_args_inner(args)
720}
721
722/// Internal formatting helper shared by `format_args` and `LaunchLog::Display`.
723fn format_args_inner(args: &[SerializedArg]) -> String {
724    if args.is_empty() {
725        return String::new();
726    }
727    let parts: Vec<String> = args.iter().map(|a| a.to_string()).collect();
728    parts.join(", ")
729}
730
731// ---------------------------------------------------------------------------
732// Tests
733// ---------------------------------------------------------------------------
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738    use crate::params::LaunchParams;
739
740    #[test]
741    fn arg_type_display() {
742        assert_eq!(format!("{}", ArgType::U32), "u32");
743        assert_eq!(format!("{}", ArgType::F64), "f64");
744        assert_eq!(format!("{}", ArgType::Ptr), "ptr");
745        assert_eq!(format!("{}", ArgType::Custom("MyType".into())), "MyType");
746    }
747
748    #[test]
749    fn arg_type_equality() {
750        assert_eq!(ArgType::U32, ArgType::U32);
751        assert_ne!(ArgType::U32, ArgType::U64);
752        assert_eq!(ArgType::Custom("Foo".into()), ArgType::Custom("Foo".into()));
753    }
754
755    #[test]
756    fn serialized_arg_new_and_accessors() {
757        let arg = SerializedArg::new(Some("count".into()), ArgType::U32, "42".into(), 4);
758        assert_eq!(arg.name(), Some("count"));
759        assert_eq!(*arg.arg_type(), ArgType::U32);
760        assert_eq!(arg.value_repr(), "42");
761        assert_eq!(arg.size_bytes(), 4);
762    }
763
764    #[test]
765    fn serialized_arg_no_name() {
766        let arg = SerializedArg::new(None, ArgType::F32, "3.14".into(), 4);
767        assert_eq!(arg.name(), None);
768        assert_eq!(format!("{arg}"), "f32: 3.14");
769    }
770
771    #[test]
772    fn serialized_arg_with_name_display() {
773        let arg = SerializedArg::new(Some("x".into()), ArgType::I64, "-100".into(), 8);
774        assert_eq!(format!("{arg}"), "x: i64 = -100");
775    }
776
777    #[test]
778    fn serialized_arg_total_size() {
779        let args = vec![
780            SerializedArg::new(None, ArgType::U32, "1".into(), 4),
781            SerializedArg::new(None, ArgType::U64, "2".into(), 8),
782            SerializedArg::new(None, ArgType::F32, "3.0".into(), 4),
783        ];
784        assert_eq!(SerializedArg::total_size(&args), 16);
785    }
786
787    #[test]
788    fn launch_log_creation_and_accessors() {
789        let log = LaunchLog::new(
790            "vector_add".into(),
791            Dim3::x(4),
792            Dim3::x(256),
793            1024,
794            vec![SerializedArg::new(None, ArgType::U32, "42".into(), 4)],
795        );
796        assert_eq!(log.kernel_name(), "vector_add");
797        assert_eq!(log.grid(), Dim3::x(4));
798        assert_eq!(log.block(), Dim3::x(256));
799        assert_eq!(log.shared_mem(), 1024);
800        assert_eq!(log.args().len(), 1);
801        assert_eq!(log.total_threads(), 1024);
802    }
803
804    #[test]
805    fn launch_log_from_params() {
806        let params = LaunchParams::new(Dim3::xy(2, 2), Dim3::x(128)).with_shared_mem(512);
807        let log = LaunchLog::from_params("matmul".into(), &params, vec![]);
808        assert_eq!(log.kernel_name(), "matmul");
809        assert_eq!(log.grid(), Dim3::xy(2, 2));
810        assert_eq!(log.shared_mem(), 512);
811    }
812
813    #[test]
814    fn launch_log_display() {
815        let log = LaunchLog::new(
816            "my_kernel".into(),
817            Dim3::x(4),
818            Dim3::x(256),
819            0,
820            vec![
821                SerializedArg::new(Some("a".into()), ArgType::U64, "0x1000".into(), 8),
822                SerializedArg::new(Some("n".into()), ArgType::U32, "1024".into(), 4),
823            ],
824        );
825        let s = format!("{log}");
826        assert!(s.contains("my_kernel<<<"));
827        assert!(s.contains("(4,1,1)"));
828        assert!(s.contains("(256,1,1)"));
829        assert!(s.contains("a: u64 = 0x1000"));
830        assert!(s.contains("n: u32 = 1024"));
831    }
832
833    #[test]
834    fn launch_log_debug() {
835        let log = LaunchLog::new("kern".into(), Dim3::x(1), Dim3::x(1), 0, vec![]);
836        let dbg = format!("{log:?}");
837        assert!(dbg.contains("LaunchLog"));
838        assert!(dbg.contains("kern"));
839    }
840
841    #[test]
842    fn launch_logger_basic_workflow() {
843        let mut logger = LaunchLogger::new();
844        assert!(logger.is_empty());
845        assert_eq!(logger.len(), 0);
846
847        logger.log(LaunchLog::new(
848            "kern_a".into(),
849            Dim3::x(4),
850            Dim3::x(256),
851            0,
852            vec![],
853        ));
854        logger.log(LaunchLog::new(
855            "kern_b".into(),
856            Dim3::x(8),
857            Dim3::x(128),
858            512,
859            vec![],
860        ));
861        assert_eq!(logger.len(), 2);
862        assert!(!logger.is_empty());
863        assert_eq!(logger.entries()[0].kernel_name(), "kern_a");
864        assert_eq!(logger.entries()[1].kernel_name(), "kern_b");
865
866        logger.clear();
867        assert!(logger.is_empty());
868    }
869
870    #[test]
871    fn launch_logger_default() {
872        let logger = LaunchLogger::default();
873        assert!(logger.is_empty());
874    }
875
876    #[test]
877    fn launch_summary_aggregation() {
878        let mut logger = LaunchLogger::new();
879        logger.log(LaunchLog::new(
880            "kern_a".into(),
881            Dim3::x(4),
882            Dim3::x(256),
883            0,
884            vec![],
885        ));
886        logger.log(LaunchLog::new(
887            "kern_a".into(),
888            Dim3::x(8),
889            Dim3::x(256),
890            1024,
891            vec![],
892        ));
893        logger.log(LaunchLog::new(
894            "kern_b".into(),
895            Dim3::x(1),
896            Dim3::x(128),
897            0,
898            vec![],
899        ));
900
901        let summary = logger.summary();
902        assert_eq!(summary.total_launches(), 3);
903        assert_eq!(summary.unique_kernels(), 2);
904
905        let a_stats = summary.kernel_stats("kern_a");
906        assert!(a_stats.is_some());
907        let a_stats = a_stats.expect("kern_a stats should exist in test");
908        assert_eq!(a_stats.launch_count(), 2);
909        assert_eq!(a_stats.total_threads(), 4 * 256 + 8 * 256);
910        assert_eq!(a_stats.total_shared_mem(), 1024);
911
912        let b_stats = summary.kernel_stats("kern_b");
913        assert!(b_stats.is_some());
914        let b_stats = b_stats.expect("kern_b stats should exist in test");
915        assert_eq!(b_stats.launch_count(), 1);
916    }
917
918    #[test]
919    fn launch_summary_display() {
920        let mut logger = LaunchLogger::new();
921        logger.log(LaunchLog::new(
922            "kern".into(),
923            Dim3::x(1),
924            Dim3::x(1),
925            0,
926            vec![],
927        ));
928        let summary = logger.summary();
929        let s = format!("{summary}");
930        assert!(s.contains("LaunchSummary"));
931        assert!(s.contains("1 total launches"));
932        assert!(s.contains("kern"));
933    }
934
935    #[test]
936    fn serialize_arg_trait_scalars() {
937        let v: u32 = 42;
938        let sa = v.to_serialized(Some("n".into()));
939        assert_eq!(*sa.arg_type(), ArgType::U32);
940        assert_eq!(sa.value_repr(), "42");
941        assert_eq!(sa.size_bytes(), 4);
942
943        let v: f64 = 3.15;
944        let sa = v.to_serialized(None);
945        assert_eq!(*sa.arg_type(), ArgType::F64);
946        assert_eq!(sa.value_repr(), "3.15");
947        assert_eq!(sa.size_bytes(), 8);
948
949        let v: f32 = 1.0;
950        let sa = v.to_serialized(None);
951        assert_eq!(sa.value_repr(), "1.0");
952    }
953
954    #[test]
955    fn serializable_kernel_args_unit() {
956        let args = ();
957        let serialized = args.serialize_args();
958        assert!(serialized.is_empty());
959    }
960
961    #[test]
962    fn serializable_kernel_args_tuple() {
963        let args = (42u32, 3.15f64);
964        let serialized = args.serialize_args();
965        assert_eq!(serialized.len(), 2);
966        assert_eq!(serialized[0].name(), Some("arg0"));
967        assert_eq!(*serialized[0].arg_type(), ArgType::U32);
968        assert_eq!(serialized[0].value_repr(), "42");
969        assert_eq!(serialized[1].name(), Some("arg1"));
970        assert_eq!(*serialized[1].arg_type(), ArgType::F64);
971        assert_eq!(serialized[1].value_repr(), "3.15");
972    }
973
974    #[test]
975    fn format_launch_params_output() {
976        let params = LaunchParams::new(Dim3::xy(4, 2), Dim3::x(256)).with_shared_mem(4096);
977        let s = format_launch_params(&params);
978        assert!(s.contains("grid=(4,2,1)"));
979        assert!(s.contains("block=(256,1,1)"));
980        assert!(s.contains("smem=4096"));
981    }
982
983    #[test]
984    fn format_args_output() {
985        let args = vec![
986            SerializedArg::new(Some("a".into()), ArgType::U64, "0x1000".into(), 8),
987            SerializedArg::new(Some("n".into()), ArgType::U32, "1024".into(), 4),
988        ];
989        let s = format_args(&args);
990        assert!(s.contains("a: u64 = 0x1000"));
991        assert!(s.contains("n: u32 = 1024"));
992    }
993
994    #[test]
995    fn format_args_empty() {
996        let s = format_args(&[]);
997        assert!(s.is_empty());
998    }
999
1000    #[test]
1001    fn kernel_launch_stats_display() {
1002        let stats = KernelLaunchStats {
1003            kernel_name: "matmul".into(),
1004            launch_count: 5,
1005            total_threads: 1_000_000,
1006            total_shared_mem: 4096,
1007        };
1008        let s = format!("{stats}");
1009        assert!(s.contains("matmul"));
1010        assert!(s.contains("5 launches"));
1011        assert!(s.contains("1000000 total threads"));
1012    }
1013}