train_station/
device.rs

1//! Device management system for Train Station ML Library
2//!
3//! This module provides a unified device abstraction for CPU and CUDA operations with thread-safe
4//! context management. The device system follows PyTorch's device API design while maintaining
5//! zero dependencies for CPU operations and feature-gated CUDA support.
6//!
7//! # Design Philosophy
8//!
9//! The device management system is designed for:
10//! - **Thread Safety**: Thread-local device contexts with automatic restoration
11//! - **Zero Dependencies**: CPU operations require no external dependencies
12//! - **Feature Isolation**: CUDA support is completely optional and feature-gated
13//! - **PyTorch Compatibility**: Familiar API design for users coming from PyTorch
14//! - **Performance**: Minimal overhead for device switching and context management
15//!
16//! # Organization
17//!
18//! The device module is organized into several key components:
19//! - **Device Types**: `DeviceType` enum for CPU and CUDA device types
20//! - **Device Representation**: `Device` struct with type and index information
21//! - **Context Management**: Thread-local device stack with automatic restoration
22//! - **Global Default**: Atomic global default device for new tensor creation
23//! - **CUDA Integration**: Feature-gated CUDA availability and device count functions
24//!
25//! # Key Features
26//!
27//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
28//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
29//! - **Global Default Device**: Configurable default device for new tensor creation
30//! - **CUDA Feature Gates**: All CUDA functionality is feature-gated and optional
31//! - **Runtime Validation**: CUDA device indices are validated at runtime
32//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
33//!
34//! # Examples
35//!
36//! ## Basic Device Usage
37//!
38//! ```rust
39//! use train_station::{Device, DeviceType};
40//!
41//! // Create CPU device
42//! let cpu_device = Device::cpu();
43//! assert!(cpu_device.is_cpu());
44//! assert_eq!(cpu_device.index(), 0);
45//! assert_eq!(cpu_device.to_string(), "cpu");
46//!
47//! // Create CUDA device (when feature enabled)
48//! #[cfg(feature = "cuda")]
49//! {
50//!     if train_station::cuda_is_available() {
51//!         let cuda_device = Device::cuda(0);
52//!         assert!(cuda_device.is_cuda());
53//!         assert_eq!(cuda_device.index(), 0);
54//!         assert_eq!(cuda_device.to_string(), "cuda:0");
55//!     }
56//! }
57//! ```
58//!
59//! ## Device Context Management
60//!
61//! ```rust
62//! use train_station::{Device, with_device, current_device, set_default_device};
63//!
64//! // Get current device context
65//! let initial_device = current_device();
66//! assert!(initial_device.is_cpu());
67//!
68//! // Execute code with specific device context
69//! let result = with_device(Device::cpu(), || {
70//!     assert_eq!(current_device(), Device::cpu());
71//!     // Device is automatically restored when closure exits
72//!     42
73//! });
74//!
75//! assert_eq!(result, 42);
76//! assert_eq!(current_device(), initial_device);
77//!
78//! // Set global default device
79//! set_default_device(Device::cpu());
80//! assert_eq!(train_station::get_default_device(), Device::cpu());
81//! ```
82//!
83//! ## CUDA Availability Checking
84//!
85//! ```rust
86//! use train_station::{cuda_is_available, cuda_device_count, Device};
87//!
88//! // Check CUDA availability
89//! if cuda_is_available() {
90//!     let device_count = cuda_device_count();
91//!     println!("CUDA available with {} devices", device_count);
92//!     
93//!     // Create tensors on CUDA devices
94//!     for i in 0..device_count {
95//!         let device = Device::cuda(i);
96//!         // Use device for tensor operations
97//!     }
98//! } else {
99//!     println!("CUDA not available, using CPU only");
100//! }
101//! ```
102//!
103//! ## Nested Device Contexts
104//!
105//! ```rust
106//! use train_station::{Device, with_device, current_device};
107//!
108//! let original_device = current_device();
109//!
110//! // Nested device contexts are supported
111//! with_device(Device::cpu(), || {
112//!     assert_eq!(current_device(), Device::cpu());
113//!     
114//!     with_device(Device::cpu(), || {
115//!         assert_eq!(current_device(), Device::cpu());
116//!         // Inner context
117//!     });
118//!     
119//!     assert_eq!(current_device(), Device::cpu());
120//!     // Outer context
121//! });
122//!
123//! // Original device is restored
124//! assert_eq!(current_device(), original_device);
125//! ```
126//!
127//! # Thread Safety
128//!
129//! The device management system is designed to be thread-safe:
130//!
131//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
132//! - **Atomic Global Default**: Global default device uses atomic operations for thread safety
133//! - **Context Isolation**: Device contexts are isolated between threads
134//! - **Automatic Cleanup**: Device contexts are automatically cleaned up when threads terminate
135//! - **No Shared State**: No shared mutable state between threads for device contexts
136//!
137//! # Memory Safety
138//!
139//! The device system prioritizes memory safety:
140//!
141//! - **RAII Patterns**: Device contexts use RAII for automatic resource management
142//! - **No Unsafe Code**: All device management code is safe Rust
143//! - **Thread-Local Storage**: Uses thread-local storage for isolation
144//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
145//! - **Feature Gates**: CUDA functionality is completely isolated when not enabled
146//!
147//! # Performance Characteristics
148//!
149//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
150//! - **Minimal Context Switching**: Device context switching is optimized for performance
151//! - **Thread-Local Access**: Device context access is O(1) thread-local lookup
152//! - **Atomic Global Default**: Global default device access uses relaxed atomic operations
153//! - **Stack-Based Contexts**: Device context stack uses efficient Vec operations
154//!
155//! # Feature Flags
156//!
157//! - **`cuda`**: Enables CUDA device support and related functions
158//! - **No CUDA**: When CUDA feature is disabled, all CUDA functions return safe defaults
159//!
160//! # Error Handling
161//!
162//! - **CUDA Validation**: CUDA device indices are validated at runtime
163//! - **Feature Gates**: CUDA functions panic with clear messages when feature is disabled
164//! - **Device Availability**: CUDA functions check device availability before use
165//! - **Graceful Degradation**: System gracefully falls back to CPU when CUDA is unavailable
166
167use std::cell::RefCell;
168use std::fmt;
169use std::sync::atomic::{AtomicUsize, Ordering};
170
171/// Device types supported by Train Station
172///
173/// This enum represents the different types of devices where tensor operations
174/// can be performed. Currently supports CPU and CUDA GPU devices.
175///
176/// # Variants
177///
178/// * `Cpu` - CPU device for general-purpose computation
179/// * `Cuda` - CUDA GPU device for accelerated computation (feature-gated)
180///
181/// # Examples
182///
183/// ```rust
184/// use train_station::{DeviceType, Device};
185///
186/// let cpu_type = DeviceType::Cpu;
187/// let cpu_device = Device::from(cpu_type);
188/// assert!(cpu_device.is_cpu());
189///
190/// #[cfg(feature = "cuda")]
191/// {
192///     let cuda_type = DeviceType::Cuda;
193///     let cuda_device = Device::from(cuda_type);
194///     assert!(cuda_device.is_cuda());
195/// }
196/// ```
197///
198/// # Thread Safety
199///
200/// This type is thread-safe and can be shared between threads.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub enum DeviceType {
203    /// CPU device for general-purpose computation
204    Cpu,
205    /// CUDA GPU device for accelerated computation (feature-gated)
206    Cuda,
207}
208
209impl fmt::Display for DeviceType {
210    /// Format the device type as a string
211    ///
212    /// # Returns
213    ///
214    /// String representation of the device type:
215    /// - `"cpu"` for CPU devices
216    /// - `"cuda"` for CUDA devices
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        match self {
219            DeviceType::Cpu => write!(f, "cpu"),
220            DeviceType::Cuda => write!(f, "cuda"),
221        }
222    }
223}
224
225/// Device representation for tensor operations
226///
227/// A device specifies where tensors are located and where operations should be performed.
228/// Each device has a type (CPU or CUDA) and an index (0 for CPU, GPU ID for CUDA).
229/// The device system provides thread-safe context management and automatic resource cleanup.
230///
231/// # Fields
232///
233/// * `device_type` - The type of device (CPU or CUDA)
234/// * `index` - Device index (0 for CPU, GPU ID for CUDA)
235///
236/// # Examples
237///
238/// ```rust
239/// use train_station::Device;
240///
241/// // Create CPU device
242/// let cpu = Device::cpu();
243/// assert!(cpu.is_cpu());
244/// assert_eq!(cpu.index(), 0);
245/// assert_eq!(cpu.to_string(), "cpu");
246///
247/// // Create CUDA device (when feature enabled)
248/// #[cfg(feature = "cuda")]
249/// {
250///     if train_station::cuda_is_available() {
251///         let cuda = Device::cuda(0);
252///         assert!(cuda.is_cuda());
253///         assert_eq!(cuda.index(), 0);
254///         assert_eq!(cuda.to_string(), "cuda:0");
255///     }
256/// }
257/// ```
258///
259/// # Thread Safety
260///
261/// This type is thread-safe and can be shared between threads. Device contexts
262/// are managed per-thread using thread-local storage.
263///
264/// # Memory Layout
265///
266/// The device struct is small and efficient:
267/// - Size: 16 bytes (8 bytes for enum + 8 bytes for index)
268/// - Alignment: 8 bytes
269/// - Copy semantics: Implements Copy for efficient passing
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
271pub struct Device {
272    device_type: DeviceType,
273    index: usize,
274}
275
276impl Device {
277    /// Create a CPU device
278    ///
279    /// CPU devices always have index 0 and are always available regardless
280    /// of feature flags or system configuration.
281    ///
282    /// # Returns
283    ///
284    /// A Device representing the CPU (always index 0)
285    ///
286    /// # Examples
287    ///
288    /// ```rust
289    /// use train_station::Device;
290    ///
291    /// let device = Device::cpu();
292    /// assert!(device.is_cpu());
293    /// assert_eq!(device.index(), 0);
294    /// assert_eq!(device.device_type(), train_station::DeviceType::Cpu);
295    /// ```
296    pub fn cpu() -> Self {
297        Device {
298            device_type: DeviceType::Cpu,
299            index: 0,
300        }
301    }
302
303    /// Create a CUDA device
304    ///
305    /// Creates a device representing a specific CUDA GPU. The device index
306    /// must be valid for the current system configuration.
307    ///
308    /// # Arguments
309    ///
310    /// * `index` - CUDA device index (0-based)
311    ///
312    /// # Returns
313    ///
314    /// A Device representing the specified CUDA device
315    ///
316    /// # Panics
317    ///
318    /// Panics in the following cases:
319    /// - CUDA feature is not enabled (`--features cuda` not specified)
320    /// - CUDA is not available on the system
321    /// - Device index is out of range (>= number of available devices)
322    ///
323    /// # Examples
324    ///
325    /// ```rust
326    /// use train_station::Device;
327    ///
328    /// // CPU device is always available
329    /// let cpu = Device::cpu();
330    ///
331    /// // CUDA device (when feature enabled and available)
332    /// #[cfg(feature = "cuda")]
333    /// {
334    ///     if train_station::cuda_is_available() {
335    ///         let device_count = train_station::cuda_device_count();
336    ///         if device_count > 0 {
337    ///             let cuda = Device::cuda(0);
338    ///             assert!(cuda.is_cuda());
339    ///             assert_eq!(cuda.index(), 0);
340    ///         }
341    ///     }
342    /// }
343    /// ```
344    pub fn cuda(index: usize) -> Self {
345        #[cfg(feature = "cuda")]
346        {
347            use crate::cuda;
348
349            // Check if CUDA is available
350            if !cuda::cuda_is_available() {
351                panic!("CUDA is not available on this system");
352            }
353
354            // Check if device index is valid
355            let device_count = cuda::cuda_device_count();
356            if index >= device_count as usize {
357                panic!(
358                    "CUDA device index {} out of range (0-{})",
359                    index,
360                    device_count - 1
361                );
362            }
363
364            Device {
365                device_type: DeviceType::Cuda,
366                index,
367            }
368        }
369
370        #[cfg(not(feature = "cuda"))]
371        {
372            let _ = index;
373            panic!("CUDA support not enabled. Enable with --features cuda");
374        }
375    }
376
377    /// Get the device type
378    ///
379    /// # Returns
380    ///
381    /// The `DeviceType` enum variant representing this device's type
382    ///
383    /// # Examples
384    ///
385    /// ```rust
386    /// use train_station::{Device, DeviceType};
387    ///
388    /// let cpu = Device::cpu();
389    /// assert_eq!(cpu.device_type(), DeviceType::Cpu);
390    /// ```
391    pub fn device_type(&self) -> DeviceType {
392        self.device_type
393    }
394
395    /// Get the device index
396    ///
397    /// # Returns
398    ///
399    /// The device index (0 for CPU, GPU ID for CUDA)
400    ///
401    /// # Examples
402    ///
403    /// ```rust
404    /// use train_station::Device;
405    ///
406    /// let cpu = Device::cpu();
407    /// assert_eq!(cpu.index(), 0);
408    ///
409    /// #[cfg(feature = "cuda")]
410    /// {
411    ///     if train_station::cuda_is_available() {
412    ///         let cuda = Device::cuda(0);
413    ///         assert_eq!(cuda.index(), 0);
414    ///     }
415    /// }
416    /// ```
417    pub fn index(&self) -> usize {
418        self.index
419    }
420
421    /// Check if this is a CPU device
422    ///
423    /// # Returns
424    ///
425    /// `true` if this device represents a CPU, `false` otherwise
426    ///
427    /// # Examples
428    ///
429    /// ```rust
430    /// use train_station::Device;
431    ///
432    /// let cpu = Device::cpu();
433    /// assert!(cpu.is_cpu());
434    /// assert!(!cpu.is_cuda());
435    /// ```
436    pub fn is_cpu(&self) -> bool {
437        self.device_type == DeviceType::Cpu
438    }
439
440    /// Check if this is a CUDA device
441    ///
442    /// # Returns
443    ///
444    /// `true` if this device represents a CUDA GPU, `false` otherwise
445    ///
446    /// # Examples
447    ///
448    /// ```rust
449    /// use train_station::Device;
450    ///
451    /// let cpu = Device::cpu();
452    /// assert!(!cpu.is_cuda());
453    /// assert!(cpu.is_cpu());
454    /// ```
455    pub fn is_cuda(&self) -> bool {
456        self.device_type == DeviceType::Cuda
457    }
458}
459
460impl Default for Device {
461    /// Create the default device (CPU)
462    ///
463    /// # Returns
464    ///
465    /// A CPU device (same as `Device::cpu()`)
466    ///
467    /// # Examples
468    ///
469    /// ```rust
470    /// use train_station::Device;
471    ///
472    /// let device = Device::default();
473    /// assert!(device.is_cpu());
474    /// assert_eq!(device, Device::cpu());
475    /// ```
476    fn default() -> Self {
477        Device::cpu()
478    }
479}
480
481impl fmt::Display for Device {
482    /// Format the device as a string
483    ///
484    /// # Returns
485    ///
486    /// String representation of the device:
487    /// - `"cpu"` for CPU devices
488    /// - `"cuda:{index}"` for CUDA devices
489    ///
490    /// # Examples
491    ///
492    /// ```rust
493    /// use train_station::Device;
494    ///
495    /// let cpu = Device::cpu();
496    /// assert_eq!(cpu.to_string(), "cpu");
497    ///
498    /// #[cfg(feature = "cuda")]
499    /// {
500    ///     if train_station::cuda_is_available() {
501    ///         let cuda = Device::cuda(0);
502    ///         assert_eq!(cuda.to_string(), "cuda:0");
503    ///     }
504    /// }
505    /// ```
506    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507        match self.device_type {
508            DeviceType::Cpu => write!(f, "cpu"),
509            DeviceType::Cuda => write!(f, "cuda:{}", self.index),
510        }
511    }
512}
513
514impl From<DeviceType> for Device {
515    /// Convert DeviceType to Device with index 0
516    ///
517    /// # Arguments
518    ///
519    /// * `device_type` - The device type to convert
520    ///
521    /// # Returns
522    ///
523    /// A Device with the specified type and index 0
524    ///
525    /// # Panics
526    ///
527    /// Panics if `device_type` is `DeviceType::Cuda` and CUDA is not available
528    /// or the feature is not enabled.
529    ///
530    /// # Examples
531    ///
532    /// ```rust
533    /// use train_station::{Device, DeviceType};
534    ///
535    /// let cpu_type = DeviceType::Cpu;
536    /// let cpu_device = Device::from(cpu_type);
537    /// assert!(cpu_device.is_cpu());
538    /// assert_eq!(cpu_device.index(), 0);
539    /// ```
540    fn from(device_type: DeviceType) -> Self {
541        match device_type {
542            DeviceType::Cpu => Device::cpu(),
543            DeviceType::Cuda => {
544                // Call Device::cuda(0) which handles the proper feature flag checking
545                Device::cuda(0)
546            }
547        }
548    }
549}
550
551// ================================================================================================
552// Device Context Management
553// ================================================================================================
554
555thread_local! {
556    /// Thread-local storage for device context stack
557    ///
558    /// Each thread maintains its own stack of device contexts. The top of the stack
559    /// represents the current device context for that thread. When a new context
560    /// is pushed, it becomes the current device. When a context is popped, the
561    /// previous device is restored.
562    ///
563    /// # Thread Safety
564    ///
565    /// This is thread-local storage, so each thread has its own isolated stack.
566    /// No synchronization is required for access within a single thread.
567    static DEVICE_STACK: RefCell<Vec<Device>> = RefCell::new(vec![Device::cpu()]);
568}
569
570/// Global default device (starts as CPU)
571///
572/// This atomic variable stores the global default device that is used when
573/// creating new tensors without an explicit device specification. The device
574/// is stored as an ID for efficient atomic operations.
575///
576/// # Thread Safety
577///
578/// Uses atomic operations for thread-safe access. Multiple threads can read
579/// and write the default device concurrently without data races.
580static GLOBAL_DEFAULT_DEVICE: AtomicUsize = AtomicUsize::new(0); // 0 = CPU
581
582/// Device context guard for RAII-style device switching
583///
584/// This struct provides automatic restoration of the previous device context
585/// when it goes out of scope, similar to PyTorch's device context manager.
586/// The guard ensures that device contexts are properly cleaned up even if
587/// exceptions occur.
588///
589/// # Thread Safety
590///
591/// This type is not thread-safe and should not be shared between threads.
592/// Each thread should create its own device context guards.
593///
594/// # Examples
595///
596/// ```rust
597/// use train_station::{Device, with_device, current_device};
598///
599/// let original_device = current_device();
600///
601/// // Use with_device instead of DeviceContext::new for public API
602/// with_device(Device::cpu(), || {
603///     assert_eq!(current_device(), Device::cpu());
604///     // Device context is automatically restored when closure exits
605/// });
606///
607/// assert_eq!(current_device(), original_device);
608/// ```
609pub struct DeviceContext {
610    previous_device: Device,
611}
612
613impl DeviceContext {
614    /// Create a new device context guard
615    ///
616    /// This function switches to the specified device and creates a guard
617    /// that will automatically restore the previous device when dropped.
618    ///
619    /// # Arguments
620    ///
621    /// * `device` - The device to switch to
622    ///
623    /// # Returns
624    ///
625    /// A `DeviceContext` guard that will restore the previous device when dropped
626    ///
627    /// # Side Effects
628    ///
629    /// Changes the current thread's device context to the specified device.
630    fn new(device: Device) -> Self {
631        let previous_device = current_device();
632        set_current_device(device);
633
634        DeviceContext { previous_device }
635    }
636}
637
638impl Drop for DeviceContext {
639    /// Restore the previous device context when the guard is dropped
640    ///
641    /// This ensures that device contexts are properly cleaned up even if
642    /// exceptions occur or the guard is dropped early.
643    fn drop(&mut self) {
644        set_current_device(self.previous_device);
645    }
646}
647
648/// Set the global default device
649///
650/// This affects the default device for new tensors created without an explicit device.
651/// It does not affect the current thread's device context.
652///
653/// # Arguments
654///
655/// * `device` - The device to set as the global default
656///
657/// # Thread Safety
658///
659/// This function is thread-safe and uses atomic operations to update the global default.
660///
661/// # Examples
662///
663/// ```rust
664/// use train_station::{Device, set_default_device, get_default_device};
665///
666/// // Set global default to CPU
667/// set_default_device(Device::cpu());
668/// assert_eq!(get_default_device(), Device::cpu());
669///
670/// // The global default affects new tensor creation
671/// // (tensor creation would use this default device)
672/// ```
673pub fn set_default_device(device: Device) {
674    let device_id = device_to_id(device);
675    GLOBAL_DEFAULT_DEVICE.store(device_id, Ordering::Relaxed);
676}
677
678/// Get the global default device
679///
680/// # Returns
681///
682/// The current global default device
683///
684/// # Thread Safety
685///
686/// This function is thread-safe and uses atomic operations to read the global default.
687///
688/// # Examples
689///
690/// ```rust
691/// use train_station::{Device, get_default_device, set_default_device};
692///
693/// let initial_default = get_default_device();
694/// assert!(initial_default.is_cpu());
695///
696/// set_default_device(Device::cpu());
697/// assert_eq!(get_default_device(), Device::cpu());
698/// ```
699pub fn get_default_device() -> Device {
700    let device_id = GLOBAL_DEFAULT_DEVICE.load(Ordering::Relaxed);
701    id_to_device(device_id)
702}
703
704/// Get the current thread's device context
705///
706/// # Returns
707///
708/// The current device context for this thread
709///
710/// # Thread Safety
711///
712/// This function is thread-safe and returns the device context for the current thread only.
713///
714/// # Examples
715///
716/// ```rust
717/// use train_station::{Device, current_device, with_device};
718///
719/// let initial_device = current_device();
720/// assert!(initial_device.is_cpu());
721///
722/// with_device(Device::cpu(), || {
723///     assert_eq!(current_device(), Device::cpu());
724/// });
725///
726/// assert_eq!(current_device(), initial_device);
727/// ```
728pub fn current_device() -> Device {
729    DEVICE_STACK.with(|stack| stack.borrow().last().copied().unwrap_or_else(Device::cpu))
730}
731
732/// Set the current thread's device context
733///
734/// This function updates the current thread's device context. It modifies the
735/// top of the thread-local device stack.
736///
737/// # Arguments
738///
739/// * `device` - The device to set as the current context
740///
741/// # Thread Safety
742///
743/// This function is thread-safe and only affects the current thread's context.
744///
745/// # Side Effects
746///
747/// Changes the current thread's device context to the specified device.
748fn set_current_device(device: Device) {
749    DEVICE_STACK.with(|stack| {
750        let mut stack = stack.borrow_mut();
751        if stack.is_empty() {
752            stack.push(device);
753        } else {
754            // Replace the top of the stack
755            if let Some(last) = stack.last_mut() {
756                *last = device;
757            }
758        }
759    });
760}
761
762/// Execute a closure with a specific device context
763///
764/// This function temporarily switches to the specified device for the duration
765/// of the closure, then automatically restores the previous device. This is
766/// the recommended way to execute code with a specific device context.
767///
768/// # Arguments
769///
770/// * `device` - The device to use for the closure
771/// * `f` - The closure to execute
772///
773/// # Returns
774///
775/// The result of the closure
776///
777/// # Thread Safety
778///
779/// This function is thread-safe and only affects the current thread's context.
780///
781/// # Examples
782///
783/// ```rust
784/// use train_station::{Device, with_device, current_device};
785///
786/// let original_device = current_device();
787///
788/// let result = with_device(Device::cpu(), || {
789///     assert_eq!(current_device(), Device::cpu());
790///     // Perform operations with CPU device
791///     42
792/// });
793///
794/// assert_eq!(result, 42);
795/// assert_eq!(current_device(), original_device);
796/// ```
797pub fn with_device<F, R>(device: Device, f: F) -> R
798where
799    F: FnOnce() -> R,
800{
801    let _context = DeviceContext::new(device);
802    f()
803}
804
805// Helper functions for device ID conversion
806/// Convert a device to a numeric ID for storage
807///
808/// # Arguments
809///
810/// * `device` - The device to convert
811///
812/// # Returns
813///
814/// A numeric ID representing the device:
815/// - 0 for CPU devices
816/// - 1000 + index for CUDA devices
817///
818/// # Thread Safety
819///
820/// This function is thread-safe and has no side effects.
821fn device_to_id(device: Device) -> usize {
822    match device.device_type {
823        DeviceType::Cpu => 0,
824        DeviceType::Cuda => 1000 + device.index, // Offset CUDA devices by 1000
825    }
826}
827
828/// Convert a numeric ID back to a device
829///
830/// # Arguments
831///
832/// * `id` - The numeric ID to convert
833///
834/// # Returns
835///
836/// A device representing the ID:
837/// - ID 0 → CPU device
838/// - ID >= 1000 → CUDA device with index (ID - 1000)
839/// - Invalid IDs → CPU device (fallback)
840///
841/// # Thread Safety
842///
843/// This function is thread-safe and has no side effects.
844fn id_to_device(id: usize) -> Device {
845    if id == 0 {
846        Device::cpu()
847    } else if id >= 1000 {
848        Device::cuda(id - 1000)
849    } else {
850        Device::cpu() // Fallback to CPU for invalid IDs
851    }
852}
853
854// ================================================================================================
855// CUDA Availability Functions (Direct delegation to cuda_ffi)
856// ================================================================================================
857
858/// Check if CUDA is available
859///
860/// This function checks if CUDA is available on the current system and
861/// at least one CUDA device is found. The result depends on the CUDA
862/// feature flag and system configuration.
863///
864/// # Returns
865///
866/// - `true` if CUDA feature is enabled and at least one CUDA device is available
867/// - `false` if CUDA feature is disabled or no CUDA devices are found
868///
869/// # Thread Safety
870///
871/// This function is thread-safe and can be called from multiple threads.
872///
873/// # Examples
874///
875/// ```rust
876/// use train_station::cuda_is_available;
877///
878/// if cuda_is_available() {
879///     println!("CUDA is available");
880///     // Create CUDA tensors and perform GPU operations
881/// } else {
882///     println!("CUDA is not available, using CPU only");
883///     // Fall back to CPU operations
884/// }
885/// ```
886pub fn cuda_is_available() -> bool {
887    #[cfg(feature = "cuda")]
888    {
889        crate::cuda::cuda_is_available()
890    }
891
892    #[cfg(not(feature = "cuda"))]
893    {
894        false
895    }
896}
897
898/// Get the number of CUDA devices available
899///
900/// This function returns the number of CUDA devices available on the system.
901/// The result depends on the CUDA feature flag and system configuration.
902///
903/// # Returns
904///
905/// Number of CUDA devices available:
906/// - 0 if CUDA feature is disabled
907/// - 0 if CUDA is not available on the system
908/// - Number of available CUDA devices if CUDA is available
909///
910/// # Thread Safety
911///
912/// This function is thread-safe and can be called from multiple threads.
913///
914/// # Examples
915///
916/// ```rust
917/// use train_station::{cuda_device_count, Device};
918///
919/// let device_count = cuda_device_count();
920/// println!("Found {} CUDA devices", device_count);
921///
922/// for i in 0..device_count {
923///     let device = Device::cuda(i);
924///     println!("CUDA device {}: {}", i, device);
925/// }
926/// ```
927#[allow(unused)]
928pub fn cuda_device_count() -> usize {
929    #[cfg(feature = "cuda")]
930    {
931        crate::cuda::cuda_device_count() as usize
932    }
933
934    #[cfg(not(feature = "cuda"))]
935    {
936        0
937    }
938}
939
940// ================================================================================================
941// Tests
942// ================================================================================================
943
944#[cfg(test)]
945mod tests {
946    use super::*;
947
948    #[test]
949    fn test_cpu_device() {
950        let device = Device::cpu();
951        assert_eq!(device.device_type(), DeviceType::Cpu);
952        assert_eq!(device.index(), 0);
953        assert!(device.is_cpu());
954        assert!(!device.is_cuda());
955        assert_eq!(device.to_string(), "cpu");
956    }
957
958    #[test]
959    fn test_device_default() {
960        let device = Device::default();
961        assert_eq!(device.device_type(), DeviceType::Cpu);
962        assert!(device.is_cpu());
963    }
964
965    #[test]
966    fn test_device_type_display() {
967        assert_eq!(DeviceType::Cpu.to_string(), "cpu");
968        assert_eq!(DeviceType::Cuda.to_string(), "cuda");
969    }
970
971    #[test]
972    fn test_device_from_device_type() {
973        let device = Device::from(DeviceType::Cpu);
974        assert!(device.is_cpu());
975        assert_eq!(device.index(), 0);
976    }
977
978    #[test]
979    #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
980    fn test_cuda_device_panics() {
981        Device::cuda(0);
982    }
983
984    #[test]
985    #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
986    fn test_device_from_cuda_type_panics() {
987        let _ = Device::from(DeviceType::Cuda);
988    }
989
990    #[test]
991    fn test_device_equality() {
992        let cpu1 = Device::cpu();
993        let cpu2 = Device::cpu();
994        assert_eq!(cpu1, cpu2);
995    }
996
997    // Context management tests
998    #[test]
999    fn test_current_device() {
1000        assert_eq!(current_device(), Device::cpu());
1001    }
1002
1003    #[test]
1004    fn test_default_device() {
1005        let initial_default = get_default_device();
1006        assert_eq!(initial_default, Device::cpu());
1007
1008        // Should still be CPU after setting it explicitly
1009        set_default_device(Device::cpu());
1010        assert_eq!(get_default_device(), Device::cpu());
1011    }
1012
1013    #[test]
1014    fn test_device_context_guard() {
1015        let original_device = current_device();
1016
1017        {
1018            let _guard = DeviceContext::new(Device::cpu());
1019            assert_eq!(current_device(), Device::cpu());
1020        }
1021
1022        // Device should be restored after guard is dropped
1023        assert_eq!(current_device(), original_device);
1024    }
1025
1026    #[test]
1027    fn test_with_device() {
1028        let original_device = current_device();
1029
1030        let result = with_device(Device::cpu(), || {
1031            assert_eq!(current_device(), Device::cpu());
1032            42
1033        });
1034
1035        assert_eq!(result, 42);
1036        assert_eq!(current_device(), original_device);
1037    }
1038
1039    #[test]
1040    fn test_nested_device_contexts() {
1041        let original = current_device();
1042
1043        with_device(Device::cpu(), || {
1044            assert_eq!(current_device(), Device::cpu());
1045
1046            with_device(Device::cpu(), || {
1047                assert_eq!(current_device(), Device::cpu());
1048            });
1049
1050            assert_eq!(current_device(), Device::cpu());
1051        });
1052
1053        assert_eq!(current_device(), original);
1054    }
1055
1056    #[test]
1057    fn test_device_id_conversion() {
1058        assert_eq!(device_to_id(Device::cpu()), 0);
1059        assert_eq!(id_to_device(0), Device::cpu());
1060
1061        // Test invalid ID fallback
1062        assert_eq!(id_to_device(999), Device::cpu());
1063    }
1064
1065    #[test]
1066    fn test_cuda_availability_check() {
1067        // These functions should be callable regardless of CUDA availability
1068        let available = cuda_is_available();
1069        let device_count = cuda_device_count();
1070
1071        if available {
1072            assert!(device_count > 0, "CUDA available but no devices found");
1073        } else {
1074            assert_eq!(device_count, 0, "CUDA not available but devices reported");
1075        }
1076    }
1077}