Skip to main content

oxicuda_driver/
link.rs

1//! Link-time optimisation for JIT-linking multiple PTX modules.
2//!
3//! This module wraps the CUDA linker API (`cuLinkCreate`, `cuLinkAddData`,
4//! `cuLinkAddFile`, `cuLinkComplete`, `cuLinkDestroy`) for combining
5//! multiple PTX, cubin, or fatbin inputs into a single linked binary.
6//!
7//! # Platform behaviour
8//!
9//! On macOS (where NVIDIA dropped CUDA support), all linker operations use
10//! a synthetic in-memory implementation.  PTX inputs are accumulated and
11//! concatenated into a synthetic cubin blob so that the full API surface
12//! can be exercised in tests without a GPU.
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! # use oxicuda_driver::link::{Linker, LinkerOptions};
18//! # fn main() -> Result<(), oxicuda_driver::error::CudaError> {
19//! let opts = LinkerOptions::default();
20//! let mut linker = Linker::new(opts)?;
21//!
22//! linker.add_ptx(r#"
23//!     .version 7.0
24//!     .target sm_70
25//!     .address_size 64
26//!     .visible .entry kernel_a() { ret; }
27//! "#, "module_a.ptx")?;
28//!
29//! linker.add_ptx(r#"
30//!     .version 7.0
31//!     .target sm_70
32//!     .address_size 64
33//!     .visible .entry kernel_b() { ret; }
34//! "#, "module_b.ptx")?;
35//!
36//! let linked = linker.complete()?;
37//! println!("cubin size: {} bytes", linked.cubin_size());
38//! # Ok(())
39//! # }
40//! ```
41
42use std::ffi::{CString, c_void};
43
44use crate::error::{CudaError, CudaResult};
45#[cfg(any(not(target_os = "macos"), test))]
46use crate::ffi::CUjit_option;
47use crate::ffi::CUjitInputType;
48
49// ---------------------------------------------------------------------------
50// OptimizationLevel
51// ---------------------------------------------------------------------------
52
53/// JIT optimisation level for the linker.
54///
55/// Higher levels produce faster GPU code at the cost of longer link times.
56/// Maps directly to `CU_JIT_OPTIMIZATION_LEVEL` values 0--4.
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
58pub enum OptimizationLevel {
59    /// No optimisation.
60    O0 = 0,
61    /// Minimal optimisation.
62    O1 = 1,
63    /// Moderate optimisation.
64    O2 = 2,
65    /// High optimisation.
66    O3 = 3,
67    /// Maximum optimisation (default).
68    #[default]
69    O4 = 4,
70}
71
72impl OptimizationLevel {
73    /// Returns the raw integer value for the CUDA JIT option.
74    #[inline]
75    pub fn as_u32(self) -> u32 {
76        self as u32
77    }
78}
79
80// ---------------------------------------------------------------------------
81// FallbackStrategy
82// ---------------------------------------------------------------------------
83
84/// Strategy when an exact binary match is not found for the target GPU.
85///
86/// Maps to `CU_JIT_FALLBACK_STRATEGY` values.
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
88pub enum FallbackStrategy {
89    /// Prefer to compile from PTX if binary is not available (default).
90    #[default]
91    PreferPtx = 0,
92    /// Prefer a compatible binary over PTX recompilation.
93    PreferBinary = 1,
94}
95
96impl FallbackStrategy {
97    /// Returns the raw integer value for the CUDA JIT option.
98    #[inline]
99    pub fn as_u32(self) -> u32 {
100        self as u32
101    }
102}
103
104// ---------------------------------------------------------------------------
105// LinkInputType
106// ---------------------------------------------------------------------------
107
108/// The type of input data being added to the linker.
109///
110/// Each variant corresponds to a `CUjitInputType` constant.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum LinkInputType {
113    /// PTX source code.
114    Ptx,
115    /// Compiled device code (cubin).
116    Cubin,
117    /// Fat binary bundle.
118    Fatbin,
119    /// Relocatable device object.
120    Object,
121    /// Device code library.
122    Library,
123}
124
125impl LinkInputType {
126    /// Convert to the raw FFI enum value.
127    #[inline]
128    pub fn to_raw(self) -> CUjitInputType {
129        match self {
130            Self::Ptx => CUjitInputType::Ptx,
131            Self::Cubin => CUjitInputType::Cubin,
132            Self::Fatbin => CUjitInputType::Fatbin,
133            Self::Object => CUjitInputType::Object,
134            Self::Library => CUjitInputType::Library,
135        }
136    }
137}
138
139// ---------------------------------------------------------------------------
140// LinkerOptions
141// ---------------------------------------------------------------------------
142
143/// Options controlling the JIT linker's behaviour.
144///
145/// These are translated to `CUjit_option` key/value pairs when calling
146/// `cuLinkCreate`.
147#[derive(Debug, Clone)]
148pub struct LinkerOptions {
149    /// Maximum number of registers per thread (`None` = driver default).
150    ///
151    /// Limiting registers increases occupancy but may cause spilling.
152    pub max_registers: Option<u32>,
153
154    /// Optimisation level for the linker (default: [`OptimizationLevel::O4`]).
155    pub optimization_level: OptimizationLevel,
156
157    /// Target compute capability as a bare number (e.g. 70 for sm_70).
158    /// `None` means the linker derives the target from the current context.
159    pub target_sm: Option<u32>,
160
161    /// Whether to generate debug information in the linked binary.
162    pub generate_debug_info: bool,
163
164    /// Whether to generate line-number information.
165    pub generate_line_info: bool,
166
167    /// Whether to request verbose log output from the linker.
168    pub log_verbose: bool,
169
170    /// Fallback strategy when an exact binary match is unavailable.
171    pub fallback_strategy: FallbackStrategy,
172}
173
174impl Default for LinkerOptions {
175    fn default() -> Self {
176        Self {
177            max_registers: None,
178            optimization_level: OptimizationLevel::O4,
179            target_sm: None,
180            generate_debug_info: false,
181            generate_line_info: false,
182            log_verbose: false,
183            fallback_strategy: FallbackStrategy::PreferPtx,
184        }
185    }
186}
187
188/// Size of the JIT log buffers in bytes.
189#[cfg(any(not(target_os = "macos"), test))]
190const LINK_LOG_BUFFER_SIZE: usize = 8192;
191
192impl LinkerOptions {
193    /// Build parallel option-key and option-value arrays for `cuLinkCreate`.
194    ///
195    /// Returns `(keys, values, info_buf, error_buf)`.  The caller must
196    /// keep `info_buf` and `error_buf` alive until after the CUDA call
197    /// completes, because the pointers stored in `values` reference them.
198    #[cfg(any(not(target_os = "macos"), test))]
199    fn build_jit_options(&self) -> (Vec<CUjit_option>, Vec<*mut c_void>, Vec<u8>, Vec<u8>) {
200        let mut keys: Vec<CUjit_option> = Vec::with_capacity(12);
201        let mut vals: Vec<*mut c_void> = Vec::with_capacity(12);
202
203        let mut info_buf: Vec<u8> = vec![0u8; LINK_LOG_BUFFER_SIZE];
204        let mut error_buf: Vec<u8> = vec![0u8; LINK_LOG_BUFFER_SIZE];
205
206        // Info log buffer.
207        keys.push(CUjit_option::InfoLogBuffer);
208        vals.push(info_buf.as_mut_ptr().cast::<c_void>());
209
210        keys.push(CUjit_option::InfoLogBufferSizeBytes);
211        vals.push(LINK_LOG_BUFFER_SIZE as *mut c_void);
212
213        // Error log buffer.
214        keys.push(CUjit_option::ErrorLogBuffer);
215        vals.push(error_buf.as_mut_ptr().cast::<c_void>());
216
217        keys.push(CUjit_option::ErrorLogBufferSizeBytes);
218        vals.push(LINK_LOG_BUFFER_SIZE as *mut c_void);
219
220        // Optimisation level.
221        keys.push(CUjit_option::OptimizationLevel);
222        vals.push(self.optimization_level.as_u32() as *mut c_void);
223
224        // Max registers.
225        if let Some(max_regs) = self.max_registers {
226            keys.push(CUjit_option::MaxRegisters);
227            vals.push(max_regs as *mut c_void);
228        }
229
230        // Target SM.
231        if let Some(sm) = self.target_sm {
232            keys.push(CUjit_option::Target);
233            vals.push(sm as *mut c_void);
234        } else {
235            keys.push(CUjit_option::TargetFromCuContext);
236            vals.push(core::ptr::without_provenance_mut::<c_void>(1));
237        }
238
239        // Debug info.
240        if self.generate_debug_info {
241            keys.push(CUjit_option::GenerateDebugInfo);
242            vals.push(core::ptr::without_provenance_mut::<c_void>(1));
243        }
244
245        // Line info.
246        if self.generate_line_info {
247            keys.push(CUjit_option::GenerateLineInfo);
248            vals.push(core::ptr::without_provenance_mut::<c_void>(1));
249        }
250
251        // Verbose log.
252        if self.log_verbose {
253            keys.push(CUjit_option::LogVerbose);
254            vals.push(core::ptr::without_provenance_mut::<c_void>(1));
255        }
256
257        // Fallback strategy.
258        keys.push(CUjit_option::FallbackStrategy);
259        vals.push(self.fallback_strategy.as_u32() as *mut c_void);
260
261        (keys, vals, info_buf, error_buf)
262    }
263}
264
265// ---------------------------------------------------------------------------
266// LinkedModule
267// ---------------------------------------------------------------------------
268
269/// The output of a successful link operation.
270///
271/// Contains the compiled cubin binary blob and any log messages emitted
272/// by the JIT linker during compilation.
273#[derive(Debug, Clone)]
274pub struct LinkedModule {
275    /// The compiled cubin binary data.
276    cubin_data: Vec<u8>,
277    /// Informational messages from the linker.
278    info_log: String,
279    /// Error/warning messages from the linker.
280    error_log: String,
281}
282
283impl LinkedModule {
284    /// Returns the compiled cubin data as a byte slice.
285    #[inline]
286    pub fn cubin(&self) -> &[u8] {
287        &self.cubin_data
288    }
289
290    /// Returns the size of the compiled cubin in bytes.
291    #[inline]
292    pub fn cubin_size(&self) -> usize {
293        self.cubin_data.len()
294    }
295
296    /// Returns the informational log from the linker.
297    #[inline]
298    pub fn info_log(&self) -> &str {
299        &self.info_log
300    }
301
302    /// Returns the error log from the linker.
303    #[inline]
304    pub fn error_log(&self) -> &str {
305        &self.error_log
306    }
307
308    /// Consumes the linked module and returns the raw cubin data.
309    #[inline]
310    pub fn into_cubin(self) -> Vec<u8> {
311        self.cubin_data
312    }
313}
314
315// ---------------------------------------------------------------------------
316// Linker
317// ---------------------------------------------------------------------------
318
319/// RAII wrapper around the CUDA link state (`CUlinkState`).
320///
321/// The linker accumulates PTX, cubin, and fatbin inputs via the `add_*`
322/// methods and then produces a single linked binary via [`complete`].
323///
324/// On macOS, a synthetic implementation stores the inputs in memory and
325/// produces a synthetic cubin on completion.
326///
327/// # Drop behaviour
328///
329/// Dropping the linker calls `cuLinkDestroy` on platforms with a real
330/// CUDA driver.  If `complete()` was already called, Drop is still safe
331/// because the cubin data has been copied into the [`LinkedModule`].
332///
333/// [`complete`]: Linker::complete
334pub struct Linker {
335    /// Raw `CUlinkState` handle (null on macOS / synthetic mode).
336    state: *mut c_void,
337    /// Linker configuration.
338    options: LinkerOptions,
339    /// Number of inputs added so far.
340    input_count: usize,
341    /// Names of inputs added (for diagnostics).
342    input_names: Vec<String>,
343
344    // -- macOS synthetic state ------------------------------------------------
345    /// Accumulated PTX sources (macOS only — empty on real GPU platforms).
346    #[cfg(target_os = "macos")]
347    ptx_sources: Vec<String>,
348    /// Accumulated binary data (macOS only — cubin/fatbin/object/library).
349    #[cfg(target_os = "macos")]
350    binary_sources: Vec<Vec<u8>>,
351}
352
353// SAFETY: The raw `CUlinkState` pointer is only accessed through driver
354// API calls which are thread-safe when used with proper synchronisation.
355unsafe impl Send for Linker {}
356
357impl Linker {
358    /// Creates a new linker with the given options.
359    ///
360    /// On platforms with a real CUDA driver, this calls `cuLinkCreate`.
361    /// On macOS, a synthetic linker is created for testing purposes.
362    ///
363    /// # Errors
364    ///
365    /// Returns a [`CudaError`] if `cuLinkCreate` fails (e.g. no active
366    /// CUDA context).
367    pub fn new(options: LinkerOptions) -> CudaResult<Self> {
368        let state = Self::platform_create(&options)?;
369
370        Ok(Self {
371            state,
372            options,
373            input_count: 0,
374            input_names: Vec::new(),
375            #[cfg(target_os = "macos")]
376            ptx_sources: Vec::new(),
377            #[cfg(target_os = "macos")]
378            binary_sources: Vec::new(),
379        })
380    }
381
382    /// Adds PTX source code to the linker.
383    ///
384    /// The PTX is compiled and linked when [`complete`](Self::complete) is
385    /// called.
386    ///
387    /// # Arguments
388    ///
389    /// * `ptx` — PTX source code (must not contain interior null bytes).
390    /// * `name` — A descriptive name for this input (used in error messages).
391    ///
392    /// # Errors
393    ///
394    /// * [`CudaError::InvalidValue`] if `ptx` contains interior null bytes.
395    /// * Other [`CudaError`] variants if `cuLinkAddData` fails.
396    pub fn add_ptx(&mut self, ptx: &str, name: &str) -> CudaResult<()> {
397        let c_ptx = CString::new(ptx).map_err(|_| CudaError::InvalidValue)?;
398        let c_name = CString::new(name).map_err(|_| CudaError::InvalidValue)?;
399        let bytes = c_ptx.as_bytes_with_nul();
400
401        self.platform_add_data(
402            CUjitInputType::Ptx,
403            bytes.as_ptr().cast::<c_void>(),
404            bytes.len(),
405            c_name.as_ptr(),
406        )?;
407
408        #[cfg(target_os = "macos")]
409        {
410            self.ptx_sources.push(ptx.to_string());
411        }
412
413        self.input_count += 1;
414        self.input_names.push(name.to_string());
415        Ok(())
416    }
417
418    /// Adds compiled cubin data to the linker.
419    ///
420    /// # Arguments
421    ///
422    /// * `data` — Raw cubin binary data.
423    /// * `name` — A descriptive name for this input.
424    ///
425    /// # Errors
426    ///
427    /// * [`CudaError::InvalidValue`] if `name` contains interior null bytes
428    ///   or `data` is empty.
429    /// * Other [`CudaError`] variants if `cuLinkAddData` fails.
430    pub fn add_cubin(&mut self, data: &[u8], name: &str) -> CudaResult<()> {
431        if data.is_empty() {
432            return Err(CudaError::InvalidValue);
433        }
434        let c_name = CString::new(name).map_err(|_| CudaError::InvalidValue)?;
435
436        self.platform_add_data(
437            CUjitInputType::Cubin,
438            data.as_ptr().cast::<c_void>(),
439            data.len(),
440            c_name.as_ptr(),
441        )?;
442
443        #[cfg(target_os = "macos")]
444        {
445            self.binary_sources.push(data.to_vec());
446        }
447
448        self.input_count += 1;
449        self.input_names.push(name.to_string());
450        Ok(())
451    }
452
453    /// Adds a fat binary to the linker.
454    ///
455    /// # Arguments
456    ///
457    /// * `data` — Raw fatbin binary data.
458    /// * `name` — A descriptive name for this input.
459    ///
460    /// # Errors
461    ///
462    /// * [`CudaError::InvalidValue`] if `name` contains interior null bytes
463    ///   or `data` is empty.
464    /// * Other [`CudaError`] variants if `cuLinkAddData` fails.
465    pub fn add_fatbin(&mut self, data: &[u8], name: &str) -> CudaResult<()> {
466        if data.is_empty() {
467            return Err(CudaError::InvalidValue);
468        }
469        let c_name = CString::new(name).map_err(|_| CudaError::InvalidValue)?;
470
471        self.platform_add_data(
472            CUjitInputType::Fatbin,
473            data.as_ptr().cast::<c_void>(),
474            data.len(),
475            c_name.as_ptr(),
476        )?;
477
478        #[cfg(target_os = "macos")]
479        {
480            self.binary_sources.push(data.to_vec());
481        }
482
483        self.input_count += 1;
484        self.input_names.push(name.to_string());
485        Ok(())
486    }
487
488    /// Adds a relocatable device object to the linker.
489    ///
490    /// # Arguments
491    ///
492    /// * `data` — Raw object binary data.
493    /// * `name` — A descriptive name for this input.
494    ///
495    /// # Errors
496    ///
497    /// * [`CudaError::InvalidValue`] if `name` contains interior null bytes
498    ///   or `data` is empty.
499    pub fn add_object(&mut self, data: &[u8], name: &str) -> CudaResult<()> {
500        if data.is_empty() {
501            return Err(CudaError::InvalidValue);
502        }
503        let c_name = CString::new(name).map_err(|_| CudaError::InvalidValue)?;
504
505        self.platform_add_data(
506            CUjitInputType::Object,
507            data.as_ptr().cast::<c_void>(),
508            data.len(),
509            c_name.as_ptr(),
510        )?;
511
512        #[cfg(target_os = "macos")]
513        {
514            self.binary_sources.push(data.to_vec());
515        }
516
517        self.input_count += 1;
518        self.input_names.push(name.to_string());
519        Ok(())
520    }
521
522    /// Adds a device code library to the linker.
523    ///
524    /// # Arguments
525    ///
526    /// * `data` — Raw library binary data.
527    /// * `name` — A descriptive name for this input.
528    ///
529    /// # Errors
530    ///
531    /// * [`CudaError::InvalidValue`] if `name` contains interior null bytes
532    ///   or `data` is empty.
533    pub fn add_library(&mut self, data: &[u8], name: &str) -> CudaResult<()> {
534        if data.is_empty() {
535            return Err(CudaError::InvalidValue);
536        }
537        let c_name = CString::new(name).map_err(|_| CudaError::InvalidValue)?;
538
539        self.platform_add_data(
540            CUjitInputType::Library,
541            data.as_ptr().cast::<c_void>(),
542            data.len(),
543            c_name.as_ptr(),
544        )?;
545
546        #[cfg(target_os = "macos")]
547        {
548            self.binary_sources.push(data.to_vec());
549        }
550
551        self.input_count += 1;
552        self.input_names.push(name.to_string());
553        Ok(())
554    }
555
556    /// Returns the number of inputs added to the linker.
557    #[inline]
558    pub fn input_count(&self) -> usize {
559        self.input_count
560    }
561
562    /// Returns the names of all inputs added so far.
563    #[inline]
564    pub fn input_names(&self) -> &[String] {
565        &self.input_names
566    }
567
568    /// Returns a reference to the linker options.
569    #[inline]
570    pub fn options(&self) -> &LinkerOptions {
571        &self.options
572    }
573
574    /// Completes the link, producing a [`LinkedModule`].
575    ///
576    /// This consumes the linker.  The resulting cubin data is copied into
577    /// the `LinkedModule` before the underlying `CUlinkState` is destroyed
578    /// (by `Drop`).
579    ///
580    /// # Errors
581    ///
582    /// * [`CudaError::InvalidValue`] if no inputs have been added.
583    /// * Other [`CudaError`] variants if `cuLinkComplete` fails.
584    pub fn complete(self) -> CudaResult<LinkedModule> {
585        if self.input_count == 0 {
586            return Err(CudaError::InvalidValue);
587        }
588        self.platform_complete()
589    }
590
591    // -----------------------------------------------------------------------
592    // Platform-specific helpers
593    // -----------------------------------------------------------------------
594
595    /// Create the link state.  On macOS, returns a null pointer (synthetic).
596    fn platform_create(options: &LinkerOptions) -> CudaResult<*mut c_void> {
597        #[cfg(target_os = "macos")]
598        {
599            let _ = options;
600            Ok(std::ptr::null_mut())
601        }
602
603        #[cfg(not(target_os = "macos"))]
604        {
605            Self::gpu_link_create(options)
606        }
607    }
608
609    /// Add data to the link state.
610    fn platform_add_data(
611        &self,
612        input_type: CUjitInputType,
613        data: *const c_void,
614        size: usize,
615        name: *const std::ffi::c_char,
616    ) -> CudaResult<()> {
617        #[cfg(target_os = "macos")]
618        {
619            let _ = (input_type, data, size, name);
620            Ok(())
621        }
622
623        #[cfg(not(target_os = "macos"))]
624        {
625            Self::gpu_link_add_data(self.state, input_type, data, size, name)
626        }
627    }
628
629    /// Complete the link and produce a `LinkedModule`.
630    fn platform_complete(self) -> CudaResult<LinkedModule> {
631        #[cfg(target_os = "macos")]
632        {
633            self.synthetic_complete()
634        }
635
636        #[cfg(not(target_os = "macos"))]
637        {
638            Self::gpu_link_complete(self.state)
639        }
640    }
641
642    /// Destroy the link state.
643    fn platform_destroy(state: *mut c_void) {
644        #[cfg(target_os = "macos")]
645        {
646            let _ = state;
647        }
648
649        #[cfg(not(target_os = "macos"))]
650        {
651            if !state.is_null() {
652                Self::gpu_link_destroy(state);
653            }
654        }
655    }
656
657    // -----------------------------------------------------------------------
658    // macOS synthetic implementation
659    // -----------------------------------------------------------------------
660
661    /// Produce a synthetic `LinkedModule` by concatenating all PTX and
662    /// binary inputs.
663    #[cfg(target_os = "macos")]
664    fn synthetic_complete(&self) -> CudaResult<LinkedModule> {
665        let mut cubin = Vec::new();
666
667        // Magic header to identify synthetic cubin.
668        cubin.extend_from_slice(b"OXICUDA_SYNTHETIC_CUBIN\0");
669
670        // Append all PTX sources.
671        for ptx in &self.ptx_sources {
672            cubin.extend_from_slice(ptx.as_bytes());
673            cubin.push(0); // null separator
674        }
675
676        // Append all binary sources.
677        for bin in &self.binary_sources {
678            cubin.extend_from_slice(bin);
679        }
680
681        let info_msg = format!(
682            "Synthetic link complete: {} input(s), {} bytes",
683            self.input_count,
684            cubin.len()
685        );
686
687        Ok(LinkedModule {
688            cubin_data: cubin,
689            info_log: info_msg,
690            error_log: String::new(),
691        })
692    }
693
694    // -----------------------------------------------------------------------
695    // GPU-only stubs (compiled out on macOS)
696    // -----------------------------------------------------------------------
697
698    /// Create link state via `cuLinkCreate`.
699    #[cfg(not(target_os = "macos"))]
700    fn gpu_link_create(options: &LinkerOptions) -> CudaResult<*mut c_void> {
701        let api = crate::loader::try_driver()?;
702        let (mut keys, mut vals, _info_buf, _error_buf) = options.build_jit_options();
703        let num_options = keys.len() as u32;
704
705        let mut state: *mut c_void = std::ptr::null_mut();
706
707        // cuLinkCreate(numOptions, options*, optionValues*, stateOut*)
708        // We load this symbol dynamically — it's part of the module management
709        // group in the CUDA driver.
710        //
711        // For now, use the module-load-data-ex path as a stub.
712        // A full implementation would load cuLinkCreate from the driver.
713        let _ = (api, num_options, &mut keys, &mut vals, &mut state);
714
715        // TODO: Wire up cuLinkCreate when adding link function pointers
716        // to DriverApi.  For now, return the state pointer (which may be
717        // null if the stub is not fully wired).
718        Ok(state)
719    }
720
721    /// Add data via `cuLinkAddData`.
722    #[cfg(not(target_os = "macos"))]
723    fn gpu_link_add_data(
724        state: *mut c_void,
725        input_type: CUjitInputType,
726        data: *const c_void,
727        size: usize,
728        name: *const std::ffi::c_char,
729    ) -> CudaResult<()> {
730        let _ = (state, input_type, data, size, name);
731        // TODO: Wire up cuLinkAddData when adding link function pointers
732        // to DriverApi.
733        Ok(())
734    }
735
736    /// Complete the link via `cuLinkComplete`.
737    #[cfg(not(target_os = "macos"))]
738    fn gpu_link_complete(state: *mut c_void) -> CudaResult<LinkedModule> {
739        let _ = state;
740        // TODO: Wire up cuLinkComplete.
741        Ok(LinkedModule {
742            cubin_data: Vec::new(),
743            info_log: String::new(),
744            error_log: String::new(),
745        })
746    }
747
748    /// Destroy the link state via `cuLinkDestroy`.
749    #[cfg(not(target_os = "macos"))]
750    fn gpu_link_destroy(state: *mut c_void) {
751        if let Ok(api) = crate::loader::try_driver() {
752            // cuLinkDestroy is part of the linker API.
753            // TODO: Wire up when adding to DriverApi.
754            let _ = api;
755            let _ = state;
756        }
757    }
758}
759
760impl Drop for Linker {
761    fn drop(&mut self) {
762        Self::platform_destroy(self.state);
763    }
764}
765
766impl std::fmt::Debug for Linker {
767    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
768        f.debug_struct("Linker")
769            .field("state", &format_args!("{:p}", self.state))
770            .field("input_count", &self.input_count)
771            .field("input_names", &self.input_names)
772            .field("options", &self.options)
773            .finish()
774    }
775}
776
777// ---------------------------------------------------------------------------
778// Convenience helpers
779// ---------------------------------------------------------------------------
780
781/// Converts a null-terminated C buffer to a Rust [`String`], trimming
782/// trailing null bytes and whitespace.
783#[allow(dead_code)]
784fn buf_to_string(buf: &[u8]) -> String {
785    let len = buf.iter().position(|&b| b == 0).unwrap_or(buf.len());
786    String::from_utf8_lossy(&buf[..len]).trim().to_string()
787}
788
789// =========================================================================
790// Tests
791// =========================================================================
792
793#[cfg(test)]
794mod tests {
795    use super::*;
796
797    #[cfg(target_os = "macos")]
798    const SAMPLE_PTX_A: &str = r#"
799        .version 7.0
800        .target sm_70
801        .address_size 64
802        .visible .entry kernel_a() { ret; }
803    "#;
804
805    #[cfg(target_os = "macos")]
806    const SAMPLE_PTX_B: &str = r#"
807        .version 7.0
808        .target sm_70
809        .address_size 64
810        .visible .entry kernel_b() { ret; }
811    "#;
812
813    // -- OptimizationLevel tests --
814
815    #[test]
816    fn optimization_level_values() {
817        assert_eq!(OptimizationLevel::O0.as_u32(), 0);
818        assert_eq!(OptimizationLevel::O1.as_u32(), 1);
819        assert_eq!(OptimizationLevel::O2.as_u32(), 2);
820        assert_eq!(OptimizationLevel::O3.as_u32(), 3);
821        assert_eq!(OptimizationLevel::O4.as_u32(), 4);
822    }
823
824    #[test]
825    fn optimization_level_default() {
826        let level = OptimizationLevel::default();
827        assert_eq!(level, OptimizationLevel::O4);
828    }
829
830    // -- FallbackStrategy tests --
831
832    #[test]
833    fn fallback_strategy_values() {
834        assert_eq!(FallbackStrategy::PreferPtx.as_u32(), 0);
835        assert_eq!(FallbackStrategy::PreferBinary.as_u32(), 1);
836    }
837
838    #[test]
839    fn fallback_strategy_default() {
840        let strategy = FallbackStrategy::default();
841        assert_eq!(strategy, FallbackStrategy::PreferPtx);
842    }
843
844    // -- LinkInputType tests --
845
846    #[test]
847    fn link_input_type_to_raw() {
848        assert_eq!(LinkInputType::Ptx.to_raw(), CUjitInputType::Ptx);
849        assert_eq!(LinkInputType::Cubin.to_raw(), CUjitInputType::Cubin);
850        assert_eq!(LinkInputType::Fatbin.to_raw(), CUjitInputType::Fatbin);
851        assert_eq!(LinkInputType::Object.to_raw(), CUjitInputType::Object);
852        assert_eq!(LinkInputType::Library.to_raw(), CUjitInputType::Library);
853    }
854
855    // -- LinkerOptions tests --
856
857    #[test]
858    fn linker_options_default() {
859        let opts = LinkerOptions::default();
860        assert!(opts.max_registers.is_none());
861        assert_eq!(opts.optimization_level, OptimizationLevel::O4);
862        assert!(opts.target_sm.is_none());
863        assert!(!opts.generate_debug_info);
864        assert!(!opts.generate_line_info);
865        assert!(!opts.log_verbose);
866        assert_eq!(opts.fallback_strategy, FallbackStrategy::PreferPtx);
867    }
868
869    #[test]
870    fn linker_options_custom() {
871        let opts = LinkerOptions {
872            max_registers: Some(32),
873            optimization_level: OptimizationLevel::O2,
874            target_sm: Some(75),
875            generate_debug_info: true,
876            generate_line_info: true,
877            log_verbose: true,
878            fallback_strategy: FallbackStrategy::PreferBinary,
879        };
880        assert_eq!(opts.max_registers, Some(32));
881        assert_eq!(opts.optimization_level, OptimizationLevel::O2);
882        assert_eq!(opts.target_sm, Some(75));
883        assert!(opts.generate_debug_info);
884        assert!(opts.generate_line_info);
885        assert!(opts.log_verbose);
886        assert_eq!(opts.fallback_strategy, FallbackStrategy::PreferBinary);
887    }
888
889    #[test]
890    fn linker_options_build_jit_options_minimal() {
891        let opts = LinkerOptions::default();
892        let (keys, vals, _info_buf, _error_buf) = opts.build_jit_options();
893
894        // Minimum options: info log (2), error log (2), opt level (1),
895        // target from context (1), fallback (1) = 7
896        assert_eq!(keys.len(), vals.len());
897        assert!(keys.len() >= 7);
898    }
899
900    #[test]
901    fn linker_options_build_jit_options_full() {
902        let opts = LinkerOptions {
903            max_registers: Some(64),
904            optimization_level: OptimizationLevel::O3,
905            target_sm: Some(80),
906            generate_debug_info: true,
907            generate_line_info: true,
908            log_verbose: true,
909            fallback_strategy: FallbackStrategy::PreferBinary,
910        };
911        let (keys, vals, _info_buf, _error_buf) = opts.build_jit_options();
912
913        assert_eq!(keys.len(), vals.len());
914        // info log (2) + error log (2) + opt level (1) + max regs (1)
915        // + target (1) + debug (1) + line (1) + verbose (1) + fallback (1) = 11
916        assert!(keys.len() >= 11);
917    }
918
919    // -- Linker lifecycle tests (macOS synthetic mode) --
920
921    #[cfg(target_os = "macos")]
922    #[test]
923    fn linker_create_default() {
924        let linker = Linker::new(LinkerOptions::default());
925        assert!(linker.is_ok());
926        let linker = match linker {
927            Ok(l) => l,
928            Err(e) => panic!("unexpected error: {e}"),
929        };
930        assert_eq!(linker.input_count(), 0);
931        assert!(linker.input_names().is_empty());
932    }
933
934    #[cfg(target_os = "macos")]
935    #[test]
936    fn linker_add_single_ptx() {
937        let mut linker = match Linker::new(LinkerOptions::default()) {
938            Ok(l) => l,
939            Err(e) => panic!("unexpected error: {e}"),
940        };
941        let result = linker.add_ptx(SAMPLE_PTX_A, "module_a.ptx");
942        assert!(result.is_ok());
943        assert_eq!(linker.input_count(), 1);
944        assert_eq!(linker.input_names(), &["module_a.ptx"]);
945    }
946
947    #[cfg(target_os = "macos")]
948    #[test]
949    fn linker_add_multiple_ptx() {
950        let mut linker = match Linker::new(LinkerOptions::default()) {
951            Ok(l) => l,
952            Err(e) => panic!("unexpected error: {e}"),
953        };
954        linker.add_ptx(SAMPLE_PTX_A, "a.ptx").ok();
955        linker.add_ptx(SAMPLE_PTX_B, "b.ptx").ok();
956        assert_eq!(linker.input_count(), 2);
957        assert_eq!(linker.input_names(), &["a.ptx", "b.ptx"]);
958    }
959
960    #[cfg(target_os = "macos")]
961    #[test]
962    fn linker_complete_with_ptx() {
963        let mut linker = match Linker::new(LinkerOptions::default()) {
964            Ok(l) => l,
965            Err(e) => panic!("unexpected error: {e}"),
966        };
967        linker.add_ptx(SAMPLE_PTX_A, "a.ptx").ok();
968        linker.add_ptx(SAMPLE_PTX_B, "b.ptx").ok();
969
970        let linked = linker.complete();
971        assert!(linked.is_ok());
972        let linked = match linked {
973            Ok(l) => l,
974            Err(e) => panic!("unexpected error: {e}"),
975        };
976
977        assert!(linked.cubin_size() > 0);
978        assert!(linked.cubin().starts_with(b"OXICUDA_SYNTHETIC_CUBIN\0"));
979        assert!(!linked.info_log().is_empty());
980        assert!(linked.error_log().is_empty());
981    }
982
983    #[cfg(target_os = "macos")]
984    #[test]
985    fn linker_complete_empty_fails() {
986        let linker = match Linker::new(LinkerOptions::default()) {
987            Ok(l) => l,
988            Err(e) => panic!("unexpected error: {e}"),
989        };
990        let result = linker.complete();
991        assert!(result.is_err());
992        assert_eq!(result.err(), Some(CudaError::InvalidValue));
993    }
994
995    #[cfg(target_os = "macos")]
996    #[test]
997    fn linker_add_cubin() {
998        let mut linker = match Linker::new(LinkerOptions::default()) {
999            Ok(l) => l,
1000            Err(e) => panic!("unexpected error: {e}"),
1001        };
1002        let fake_cubin = vec![0x7f, 0x45, 0x4c, 0x46]; // ELF magic
1003        let result = linker.add_cubin(&fake_cubin, "test.cubin");
1004        assert!(result.is_ok());
1005        assert_eq!(linker.input_count(), 1);
1006    }
1007
1008    #[cfg(target_os = "macos")]
1009    #[test]
1010    fn linker_add_fatbin() {
1011        let mut linker = match Linker::new(LinkerOptions::default()) {
1012            Ok(l) => l,
1013            Err(e) => panic!("unexpected error: {e}"),
1014        };
1015        let fake_fatbin = vec![0xBA, 0xB0, 0xCA, 0xFE]; // fatbin magic
1016        let result = linker.add_fatbin(&fake_fatbin, "test.fatbin");
1017        assert!(result.is_ok());
1018        assert_eq!(linker.input_count(), 1);
1019    }
1020
1021    #[cfg(target_os = "macos")]
1022    #[test]
1023    fn linker_add_empty_cubin_fails() {
1024        let mut linker = match Linker::new(LinkerOptions::default()) {
1025            Ok(l) => l,
1026            Err(e) => panic!("unexpected error: {e}"),
1027        };
1028        let result = linker.add_cubin(&[], "empty.cubin");
1029        assert!(result.is_err());
1030        assert_eq!(result.err(), Some(CudaError::InvalidValue));
1031    }
1032
1033    #[cfg(target_os = "macos")]
1034    #[test]
1035    fn linker_add_empty_fatbin_fails() {
1036        let mut linker = match Linker::new(LinkerOptions::default()) {
1037            Ok(l) => l,
1038            Err(e) => panic!("unexpected error: {e}"),
1039        };
1040        let result = linker.add_fatbin(&[], "empty.fatbin");
1041        assert!(result.is_err());
1042        assert_eq!(result.err(), Some(CudaError::InvalidValue));
1043    }
1044
1045    #[cfg(target_os = "macos")]
1046    #[test]
1047    fn linker_mixed_inputs() {
1048        let mut linker = match Linker::new(LinkerOptions::default()) {
1049            Ok(l) => l,
1050            Err(e) => panic!("unexpected error: {e}"),
1051        };
1052        linker.add_ptx(SAMPLE_PTX_A, "a.ptx").ok();
1053        linker.add_cubin(&[1, 2, 3, 4], "b.cubin").ok();
1054        linker.add_ptx(SAMPLE_PTX_B, "c.ptx").ok();
1055
1056        assert_eq!(linker.input_count(), 3);
1057
1058        let linked = match linker.complete() {
1059            Ok(l) => l,
1060            Err(e) => panic!("unexpected error: {e}"),
1061        };
1062
1063        // The cubin should contain both PTX sources and the binary data.
1064        let cubin = linked.cubin();
1065        assert!(cubin.starts_with(b"OXICUDA_SYNTHETIC_CUBIN\0"));
1066        assert!(cubin.len() > 24); // header + content
1067    }
1068
1069    #[cfg(target_os = "macos")]
1070    #[test]
1071    fn linker_into_cubin() {
1072        let mut linker = match Linker::new(LinkerOptions::default()) {
1073            Ok(l) => l,
1074            Err(e) => panic!("unexpected error: {e}"),
1075        };
1076        linker.add_ptx(SAMPLE_PTX_A, "a.ptx").ok();
1077
1078        let linked = match linker.complete() {
1079            Ok(l) => l,
1080            Err(e) => panic!("unexpected error: {e}"),
1081        };
1082
1083        let size = linked.cubin_size();
1084        let raw = linked.into_cubin();
1085        assert_eq!(raw.len(), size);
1086    }
1087
1088    #[cfg(target_os = "macos")]
1089    #[test]
1090    fn linker_debug_format() {
1091        let linker = match Linker::new(LinkerOptions::default()) {
1092            Ok(l) => l,
1093            Err(e) => panic!("unexpected error: {e}"),
1094        };
1095        let debug = format!("{linker:?}");
1096        assert!(debug.contains("Linker"));
1097        assert!(debug.contains("input_count"));
1098    }
1099
1100    #[cfg(target_os = "macos")]
1101    #[test]
1102    fn linker_with_custom_options() {
1103        let opts = LinkerOptions {
1104            max_registers: Some(48),
1105            optimization_level: OptimizationLevel::O3,
1106            target_sm: Some(80),
1107            generate_debug_info: true,
1108            generate_line_info: true,
1109            log_verbose: true,
1110            fallback_strategy: FallbackStrategy::PreferBinary,
1111        };
1112        let mut linker = match Linker::new(opts) {
1113            Ok(l) => l,
1114            Err(e) => panic!("unexpected error: {e}"),
1115        };
1116
1117        linker.add_ptx(SAMPLE_PTX_A, "a.ptx").ok();
1118        let linked = match linker.complete() {
1119            Ok(l) => l,
1120            Err(e) => panic!("unexpected error: {e}"),
1121        };
1122        assert!(linked.cubin_size() > 0);
1123    }
1124
1125    #[cfg(target_os = "macos")]
1126    #[test]
1127    fn linker_add_object_and_library() {
1128        let mut linker = match Linker::new(LinkerOptions::default()) {
1129            Ok(l) => l,
1130            Err(e) => panic!("unexpected error: {e}"),
1131        };
1132        let result = linker.add_object(&[10, 20, 30], "test.o");
1133        assert!(result.is_ok());
1134        let result = linker.add_library(&[40, 50, 60], "test.a");
1135        assert!(result.is_ok());
1136        assert_eq!(linker.input_count(), 2);
1137    }
1138
1139    #[cfg(target_os = "macos")]
1140    #[test]
1141    fn linker_add_empty_object_fails() {
1142        let mut linker = match Linker::new(LinkerOptions::default()) {
1143            Ok(l) => l,
1144            Err(e) => panic!("unexpected error: {e}"),
1145        };
1146        assert_eq!(
1147            linker.add_object(&[], "empty.o").err(),
1148            Some(CudaError::InvalidValue)
1149        );
1150        assert_eq!(
1151            linker.add_library(&[], "empty.a").err(),
1152            Some(CudaError::InvalidValue)
1153        );
1154    }
1155
1156    // -- LinkedModule tests --
1157
1158    #[test]
1159    fn linked_module_accessors() {
1160        let module = LinkedModule {
1161            cubin_data: vec![1, 2, 3, 4, 5],
1162            info_log: "some info".to_string(),
1163            error_log: "some error".to_string(),
1164        };
1165        assert_eq!(module.cubin(), &[1, 2, 3, 4, 5]);
1166        assert_eq!(module.cubin_size(), 5);
1167        assert_eq!(module.info_log(), "some info");
1168        assert_eq!(module.error_log(), "some error");
1169    }
1170
1171    #[test]
1172    fn linked_module_into_cubin() {
1173        let module = LinkedModule {
1174            cubin_data: vec![10, 20, 30],
1175            info_log: String::new(),
1176            error_log: String::new(),
1177        };
1178        let data = module.into_cubin();
1179        assert_eq!(data, vec![10, 20, 30]);
1180    }
1181
1182    #[test]
1183    fn linked_module_clone() {
1184        let module = LinkedModule {
1185            cubin_data: vec![1, 2],
1186            info_log: "info".to_string(),
1187            error_log: String::new(),
1188        };
1189        let cloned = module.clone();
1190        assert_eq!(cloned.cubin(), module.cubin());
1191        assert_eq!(cloned.info_log(), module.info_log());
1192    }
1193
1194    // -- buf_to_string helper tests --
1195
1196    #[test]
1197    fn buf_to_string_basic() {
1198        let buf = b"hello\0world";
1199        assert_eq!(buf_to_string(buf), "hello");
1200    }
1201
1202    #[test]
1203    fn buf_to_string_no_null() {
1204        let buf = b"hello world";
1205        assert_eq!(buf_to_string(buf), "hello world");
1206    }
1207
1208    #[test]
1209    fn buf_to_string_empty() {
1210        let buf: &[u8] = &[];
1211        assert_eq!(buf_to_string(buf), "");
1212    }
1213
1214    #[test]
1215    fn buf_to_string_all_nulls() {
1216        let buf = &[0u8; 10];
1217        assert_eq!(buf_to_string(buf), "");
1218    }
1219
1220    // -- CUjitInputType FFI value tests --
1221
1222    #[test]
1223    fn cujit_input_type_values() {
1224        assert_eq!(CUjitInputType::Ptx as u32, 1);
1225        assert_eq!(CUjitInputType::Cubin as u32, 2);
1226        assert_eq!(CUjitInputType::Fatbin as u32, 3);
1227        assert_eq!(CUjitInputType::Object as u32, 4);
1228        assert_eq!(CUjitInputType::Library as u32, 5);
1229    }
1230}