train_station/device.rs
1//! Device management system for Train Station ML Library
2//!
3//! This module provides a unified device abstraction for CPU and CUDA operations with thread-safe
4//! context management. The device system follows PyTorch's device API design while maintaining
5//! zero dependencies for CPU operations and feature-gated CUDA support.
6//!
7//! # Design Philosophy
8//!
9//! The device management system is designed for:
10//! - **Thread Safety**: Thread-local device contexts with automatic restoration
11//! - **Zero Dependencies**: CPU operations require no external dependencies
12//! - **Feature Isolation**: CUDA support is completely optional and feature-gated
13//! - **PyTorch Compatibility**: Familiar API design for users coming from PyTorch
14//! - **Performance**: Minimal overhead for device switching and context management
15//!
16//! # Organization
17//!
18//! The device module is organized into several key components:
19//! - **Device Types**: `DeviceType` enum for CPU and CUDA device types
20//! - **Device Representation**: `Device` struct with type and index information
21//! - **Context Management**: Thread-local device stack with automatic restoration
22//! - **Global Default**: Atomic global default device for new tensor creation
23//! - **CUDA Integration**: Feature-gated CUDA availability and device count functions
24//!
25//! # Key Features
26//!
27//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
28//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
29//! - **Global Default Device**: Configurable default device for new tensor creation
30//! - **CUDA Feature Gates**: All CUDA functionality is feature-gated and optional
31//! - **Runtime Validation**: CUDA device indices are validated at runtime
32//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
33//!
34//! # Examples
35//!
36//! ## Basic Device Usage
37//!
38//! ```rust
39//! use train_station::{Device, DeviceType};
40//!
41//! // Create CPU device
42//! let cpu_device = Device::cpu();
43//! assert!(cpu_device.is_cpu());
44//! assert_eq!(cpu_device.index(), 0);
45//! assert_eq!(cpu_device.to_string(), "cpu");
46//!
47//! // Create CUDA device (when feature enabled)
48//! #[cfg(feature = "cuda")]
49//! {
50//! if train_station::cuda_is_available() {
51//! let cuda_device = Device::cuda(0);
52//! assert!(cuda_device.is_cuda());
53//! assert_eq!(cuda_device.index(), 0);
54//! assert_eq!(cuda_device.to_string(), "cuda:0");
55//! }
56//! }
57//! ```
58//!
59//! ## Device Context Management
60//!
61//! ```rust
62//! use train_station::{Device, with_device, current_device, set_default_device};
63//!
64//! // Get current device context
65//! let initial_device = current_device();
66//! assert!(initial_device.is_cpu());
67//!
68//! // Execute code with specific device context
69//! let result = with_device(Device::cpu(), || {
70//! assert_eq!(current_device(), Device::cpu());
71//! // Device is automatically restored when closure exits
72//! 42
73//! });
74//!
75//! assert_eq!(result, 42);
76//! assert_eq!(current_device(), initial_device);
77//!
78//! // Set global default device
79//! set_default_device(Device::cpu());
80//! assert_eq!(train_station::get_default_device(), Device::cpu());
81//! ```
82//!
83//! ## CUDA Availability Checking
84//!
85//! ```rust
86//! use train_station::{cuda_is_available, cuda_device_count, Device};
87//!
88//! // Check CUDA availability
89//! if cuda_is_available() {
90//! let device_count = cuda_device_count();
91//! println!("CUDA available with {} devices", device_count);
92//!
93//! // Create tensors on CUDA devices
94//! for i in 0..device_count {
95//! let device = Device::cuda(i);
96//! // Use device for tensor operations
97//! }
98//! } else {
99//! println!("CUDA not available, using CPU only");
100//! }
101//! ```
102//!
103//! ## Nested Device Contexts
104//!
105//! ```rust
106//! use train_station::{Device, with_device, current_device};
107//!
108//! let original_device = current_device();
109//!
110//! // Nested device contexts are supported
111//! with_device(Device::cpu(), || {
112//! assert_eq!(current_device(), Device::cpu());
113//!
114//! with_device(Device::cpu(), || {
115//! assert_eq!(current_device(), Device::cpu());
116//! // Inner context
117//! });
118//!
119//! assert_eq!(current_device(), Device::cpu());
120//! // Outer context
121//! });
122//!
123//! // Original device is restored
124//! assert_eq!(current_device(), original_device);
125//! ```
126//!
127//! # Thread Safety
128//!
129//! The device management system is designed to be thread-safe:
130//!
131//! - **Thread-Local Contexts**: Each thread maintains its own device context stack
132//! - **Atomic Global Default**: Global default device uses atomic operations for thread safety
133//! - **Context Isolation**: Device contexts are isolated between threads
134//! - **Automatic Cleanup**: Device contexts are automatically cleaned up when threads terminate
135//! - **No Shared State**: No shared mutable state between threads for device contexts
136//!
137//! # Memory Safety
138//!
139//! The device system prioritizes memory safety:
140//!
141//! - **RAII Patterns**: Device contexts use RAII for automatic resource management
142//! - **No Unsafe Code**: All device management code is safe Rust
143//! - **Thread-Local Storage**: Uses thread-local storage for isolation
144//! - **Automatic Restoration**: Device contexts are automatically restored when dropped
145//! - **Feature Gates**: CUDA functionality is completely isolated when not enabled
146//!
147//! # Performance Characteristics
148//!
149//! - **Zero-Cost CPU Operations**: CPU device operations have no runtime overhead
150//! - **Minimal Context Switching**: Device context switching is optimized for performance
151//! - **Thread-Local Access**: Device context access is O(1) thread-local lookup
152//! - **Atomic Global Default**: Global default device access uses relaxed atomic operations
153//! - **Stack-Based Contexts**: Device context stack uses efficient Vec operations
154//!
155//! # Feature Flags
156//!
157//! - **`cuda`**: Enables CUDA device support and related functions
158//! - **No CUDA**: When CUDA feature is disabled, all CUDA functions return safe defaults
159//!
160//! # Error Handling
161//!
162//! - **CUDA Validation**: CUDA device indices are validated at runtime
163//! - **Feature Gates**: CUDA functions panic with clear messages when feature is disabled
164//! - **Device Availability**: CUDA functions check device availability before use
165//! - **Graceful Degradation**: System gracefully falls back to CPU when CUDA is unavailable
166
167use std::cell::RefCell;
168use std::fmt;
169use std::sync::atomic::{AtomicUsize, Ordering};
170
171/// Device types supported by Train Station
172///
173/// This enum represents the different types of devices where tensor operations
174/// can be performed. Currently supports CPU and CUDA GPU devices.
175///
176/// # Variants
177///
178/// * `Cpu` - CPU device for general-purpose computation
179/// * `Cuda` - CUDA GPU device for accelerated computation (feature-gated)
180///
181/// # Examples
182///
183/// ```rust
184/// use train_station::{DeviceType, Device};
185///
186/// let cpu_type = DeviceType::Cpu;
187/// let cpu_device = Device::from(cpu_type);
188/// assert!(cpu_device.is_cpu());
189///
190/// #[cfg(feature = "cuda")]
191/// {
192/// let cuda_type = DeviceType::Cuda;
193/// let cuda_device = Device::from(cuda_type);
194/// assert!(cuda_device.is_cuda());
195/// }
196/// ```
197///
198/// # Thread Safety
199///
200/// This type is thread-safe and can be shared between threads.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
202pub enum DeviceType {
203 /// CPU device for general-purpose computation
204 Cpu,
205 /// CUDA GPU device for accelerated computation (feature-gated)
206 Cuda,
207}
208
209impl fmt::Display for DeviceType {
210 /// Format the device type as a string
211 ///
212 /// # Returns
213 ///
214 /// String representation of the device type:
215 /// - `"cpu"` for CPU devices
216 /// - `"cuda"` for CUDA devices
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 match self {
219 DeviceType::Cpu => write!(f, "cpu"),
220 DeviceType::Cuda => write!(f, "cuda"),
221 }
222 }
223}
224
225/// Device representation for tensor operations
226///
227/// A device specifies where tensors are located and where operations should be performed.
228/// Each device has a type (CPU or CUDA) and an index (0 for CPU, GPU ID for CUDA).
229/// The device system provides thread-safe context management and automatic resource cleanup.
230///
231/// # Fields
232///
233/// * `device_type` - The type of device (CPU or CUDA)
234/// * `index` - Device index (0 for CPU, GPU ID for CUDA)
235///
236/// # Examples
237///
238/// ```rust
239/// use train_station::Device;
240///
241/// // Create CPU device
242/// let cpu = Device::cpu();
243/// assert!(cpu.is_cpu());
244/// assert_eq!(cpu.index(), 0);
245/// assert_eq!(cpu.to_string(), "cpu");
246///
247/// // Create CUDA device (when feature enabled)
248/// #[cfg(feature = "cuda")]
249/// {
250/// if train_station::cuda_is_available() {
251/// let cuda = Device::cuda(0);
252/// assert!(cuda.is_cuda());
253/// assert_eq!(cuda.index(), 0);
254/// assert_eq!(cuda.to_string(), "cuda:0");
255/// }
256/// }
257/// ```
258///
259/// # Thread Safety
260///
261/// This type is thread-safe and can be shared between threads. Device contexts
262/// are managed per-thread using thread-local storage.
263///
264/// # Memory Layout
265///
266/// The device struct is small and efficient:
267/// - Size: 16 bytes (8 bytes for enum + 8 bytes for index)
268/// - Alignment: 8 bytes
269/// - Copy semantics: Implements Copy for efficient passing
270#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
271pub struct Device {
272 device_type: DeviceType,
273 index: usize,
274}
275
276impl Device {
277 /// Create a CPU device
278 ///
279 /// CPU devices always have index 0 and are always available regardless
280 /// of feature flags or system configuration.
281 ///
282 /// # Returns
283 ///
284 /// A Device representing the CPU (always index 0)
285 ///
286 /// # Examples
287 ///
288 /// ```rust
289 /// use train_station::Device;
290 ///
291 /// let device = Device::cpu();
292 /// assert!(device.is_cpu());
293 /// assert_eq!(device.index(), 0);
294 /// assert_eq!(device.device_type(), train_station::DeviceType::Cpu);
295 /// ```
296 pub fn cpu() -> Self {
297 Device {
298 device_type: DeviceType::Cpu,
299 index: 0,
300 }
301 }
302
303 /// Create a CUDA device
304 ///
305 /// Creates a device representing a specific CUDA GPU. The device index
306 /// must be valid for the current system configuration.
307 ///
308 /// # Arguments
309 ///
310 /// * `index` - CUDA device index (0-based)
311 ///
312 /// # Returns
313 ///
314 /// A Device representing the specified CUDA device
315 ///
316 /// # Panics
317 ///
318 /// Panics in the following cases:
319 /// - CUDA feature is not enabled (`--features cuda` not specified)
320 /// - CUDA is not available on the system
321 /// - Device index is out of range (>= number of available devices)
322 ///
323 /// # Examples
324 ///
325 /// ```rust
326 /// use train_station::Device;
327 ///
328 /// // CPU device is always available
329 /// let cpu = Device::cpu();
330 ///
331 /// // CUDA device (when feature enabled and available)
332 /// #[cfg(feature = "cuda")]
333 /// {
334 /// if train_station::cuda_is_available() {
335 /// let device_count = train_station::cuda_device_count();
336 /// if device_count > 0 {
337 /// let cuda = Device::cuda(0);
338 /// assert!(cuda.is_cuda());
339 /// assert_eq!(cuda.index(), 0);
340 /// }
341 /// }
342 /// }
343 /// ```
344 pub fn cuda(index: usize) -> Self {
345 #[cfg(feature = "cuda")]
346 {
347 use crate::cuda;
348
349 // Check if CUDA is available
350 if !cuda::cuda_is_available() {
351 panic!("CUDA is not available on this system");
352 }
353
354 // Check if device index is valid
355 let device_count = cuda::cuda_device_count();
356 if index >= device_count as usize {
357 panic!(
358 "CUDA device index {} out of range (0-{})",
359 index,
360 device_count - 1
361 );
362 }
363
364 Device {
365 device_type: DeviceType::Cuda,
366 index,
367 }
368 }
369
370 #[cfg(not(feature = "cuda"))]
371 {
372 let _ = index;
373 panic!("CUDA support not enabled. Enable with --features cuda");
374 }
375 }
376
377 /// Get the device type
378 ///
379 /// # Returns
380 ///
381 /// The `DeviceType` enum variant representing this device's type
382 ///
383 /// # Examples
384 ///
385 /// ```rust
386 /// use train_station::{Device, DeviceType};
387 ///
388 /// let cpu = Device::cpu();
389 /// assert_eq!(cpu.device_type(), DeviceType::Cpu);
390 /// ```
391 pub fn device_type(&self) -> DeviceType {
392 self.device_type
393 }
394
395 /// Get the device index
396 ///
397 /// # Returns
398 ///
399 /// The device index (0 for CPU, GPU ID for CUDA)
400 ///
401 /// # Examples
402 ///
403 /// ```rust
404 /// use train_station::Device;
405 ///
406 /// let cpu = Device::cpu();
407 /// assert_eq!(cpu.index(), 0);
408 ///
409 /// #[cfg(feature = "cuda")]
410 /// {
411 /// if train_station::cuda_is_available() {
412 /// let cuda = Device::cuda(0);
413 /// assert_eq!(cuda.index(), 0);
414 /// }
415 /// }
416 /// ```
417 pub fn index(&self) -> usize {
418 self.index
419 }
420
421 /// Check if this is a CPU device
422 ///
423 /// # Returns
424 ///
425 /// `true` if this device represents a CPU, `false` otherwise
426 ///
427 /// # Examples
428 ///
429 /// ```rust
430 /// use train_station::Device;
431 ///
432 /// let cpu = Device::cpu();
433 /// assert!(cpu.is_cpu());
434 /// assert!(!cpu.is_cuda());
435 /// ```
436 pub fn is_cpu(&self) -> bool {
437 self.device_type == DeviceType::Cpu
438 }
439
440 /// Check if this is a CUDA device
441 ///
442 /// # Returns
443 ///
444 /// `true` if this device represents a CUDA GPU, `false` otherwise
445 ///
446 /// # Examples
447 ///
448 /// ```rust
449 /// use train_station::Device;
450 ///
451 /// let cpu = Device::cpu();
452 /// assert!(!cpu.is_cuda());
453 /// assert!(cpu.is_cpu());
454 /// ```
455 pub fn is_cuda(&self) -> bool {
456 self.device_type == DeviceType::Cuda
457 }
458}
459
460impl Default for Device {
461 /// Create the default device (CPU)
462 ///
463 /// # Returns
464 ///
465 /// A CPU device (same as `Device::cpu()`)
466 ///
467 /// # Examples
468 ///
469 /// ```rust
470 /// use train_station::Device;
471 ///
472 /// let device = Device::default();
473 /// assert!(device.is_cpu());
474 /// assert_eq!(device, Device::cpu());
475 /// ```
476 fn default() -> Self {
477 Device::cpu()
478 }
479}
480
481impl fmt::Display for Device {
482 /// Format the device as a string
483 ///
484 /// # Returns
485 ///
486 /// String representation of the device:
487 /// - `"cpu"` for CPU devices
488 /// - `"cuda:{index}"` for CUDA devices
489 ///
490 /// # Examples
491 ///
492 /// ```rust
493 /// use train_station::Device;
494 ///
495 /// let cpu = Device::cpu();
496 /// assert_eq!(cpu.to_string(), "cpu");
497 ///
498 /// #[cfg(feature = "cuda")]
499 /// {
500 /// if train_station::cuda_is_available() {
501 /// let cuda = Device::cuda(0);
502 /// assert_eq!(cuda.to_string(), "cuda:0");
503 /// }
504 /// }
505 /// ```
506 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507 match self.device_type {
508 DeviceType::Cpu => write!(f, "cpu"),
509 DeviceType::Cuda => write!(f, "cuda:{}", self.index),
510 }
511 }
512}
513
514impl From<DeviceType> for Device {
515 /// Convert DeviceType to Device with index 0
516 ///
517 /// # Arguments
518 ///
519 /// * `device_type` - The device type to convert
520 ///
521 /// # Returns
522 ///
523 /// A Device with the specified type and index 0
524 ///
525 /// # Panics
526 ///
527 /// Panics if `device_type` is `DeviceType::Cuda` and CUDA is not available
528 /// or the feature is not enabled.
529 ///
530 /// # Examples
531 ///
532 /// ```rust
533 /// use train_station::{Device, DeviceType};
534 ///
535 /// let cpu_type = DeviceType::Cpu;
536 /// let cpu_device = Device::from(cpu_type);
537 /// assert!(cpu_device.is_cpu());
538 /// assert_eq!(cpu_device.index(), 0);
539 /// ```
540 fn from(device_type: DeviceType) -> Self {
541 match device_type {
542 DeviceType::Cpu => Device::cpu(),
543 DeviceType::Cuda => {
544 // Call Device::cuda(0) which handles the proper feature flag checking
545 Device::cuda(0)
546 }
547 }
548 }
549}
550
551// ================================================================================================
552// Device Context Management
553// ================================================================================================
554
555thread_local! {
556 /// Thread-local storage for device context stack
557 ///
558 /// Each thread maintains its own stack of device contexts. The top of the stack
559 /// represents the current device context for that thread. When a new context
560 /// is pushed, it becomes the current device. When a context is popped, the
561 /// previous device is restored.
562 ///
563 /// # Thread Safety
564 ///
565 /// This is thread-local storage, so each thread has its own isolated stack.
566 /// No synchronization is required for access within a single thread.
567 static DEVICE_STACK: RefCell<Vec<Device>> = RefCell::new(vec![Device::cpu()]);
568}
569
570/// Global default device (starts as CPU)
571///
572/// This atomic variable stores the global default device that is used when
573/// creating new tensors without an explicit device specification. The device
574/// is stored as an ID for efficient atomic operations.
575///
576/// # Thread Safety
577///
578/// Uses atomic operations for thread-safe access. Multiple threads can read
579/// and write the default device concurrently without data races.
580static GLOBAL_DEFAULT_DEVICE: AtomicUsize = AtomicUsize::new(0); // 0 = CPU
581
582/// Device context guard for RAII-style device switching
583///
584/// This struct provides automatic restoration of the previous device context
585/// when it goes out of scope, similar to PyTorch's device context manager.
586/// The guard ensures that device contexts are properly cleaned up even if
587/// exceptions occur.
588///
589/// # Thread Safety
590///
591/// This type is not thread-safe and should not be shared between threads.
592/// Each thread should create its own device context guards.
593///
594/// # Examples
595///
596/// ```rust
597/// use train_station::{Device, with_device, current_device};
598///
599/// let original_device = current_device();
600///
601/// // Use with_device instead of DeviceContext::new for public API
602/// with_device(Device::cpu(), || {
603/// assert_eq!(current_device(), Device::cpu());
604/// // Device context is automatically restored when closure exits
605/// });
606///
607/// assert_eq!(current_device(), original_device);
608/// ```
609pub struct DeviceContext {
610 previous_device: Device,
611}
612
613impl DeviceContext {
614 /// Create a new device context guard
615 ///
616 /// This function switches to the specified device and creates a guard
617 /// that will automatically restore the previous device when dropped.
618 ///
619 /// # Arguments
620 ///
621 /// * `device` - The device to switch to
622 ///
623 /// # Returns
624 ///
625 /// A `DeviceContext` guard that will restore the previous device when dropped
626 ///
627 /// # Side Effects
628 ///
629 /// Changes the current thread's device context to the specified device.
630 fn new(device: Device) -> Self {
631 let previous_device = current_device();
632 set_current_device(device);
633
634 DeviceContext { previous_device }
635 }
636}
637
638impl Drop for DeviceContext {
639 /// Restore the previous device context when the guard is dropped
640 ///
641 /// This ensures that device contexts are properly cleaned up even if
642 /// exceptions occur or the guard is dropped early.
643 fn drop(&mut self) {
644 set_current_device(self.previous_device);
645 }
646}
647
648/// Set the global default device
649///
650/// This affects the default device for new tensors created without an explicit device.
651/// It does not affect the current thread's device context.
652///
653/// # Arguments
654///
655/// * `device` - The device to set as the global default
656///
657/// # Thread Safety
658///
659/// This function is thread-safe and uses atomic operations to update the global default.
660///
661/// # Examples
662///
663/// ```rust
664/// use train_station::{Device, set_default_device, get_default_device};
665///
666/// // Set global default to CPU
667/// set_default_device(Device::cpu());
668/// assert_eq!(get_default_device(), Device::cpu());
669///
670/// // The global default affects new tensor creation
671/// // (tensor creation would use this default device)
672/// ```
673pub fn set_default_device(device: Device) {
674 let device_id = device_to_id(device);
675 GLOBAL_DEFAULT_DEVICE.store(device_id, Ordering::Relaxed);
676}
677
678/// Get the global default device
679///
680/// # Returns
681///
682/// The current global default device
683///
684/// # Thread Safety
685///
686/// This function is thread-safe and uses atomic operations to read the global default.
687///
688/// # Examples
689///
690/// ```rust
691/// use train_station::{Device, get_default_device, set_default_device};
692///
693/// let initial_default = get_default_device();
694/// assert!(initial_default.is_cpu());
695///
696/// set_default_device(Device::cpu());
697/// assert_eq!(get_default_device(), Device::cpu());
698/// ```
699pub fn get_default_device() -> Device {
700 let device_id = GLOBAL_DEFAULT_DEVICE.load(Ordering::Relaxed);
701 id_to_device(device_id)
702}
703
704/// Get the current thread's device context
705///
706/// # Returns
707///
708/// The current device context for this thread
709///
710/// # Thread Safety
711///
712/// This function is thread-safe and returns the device context for the current thread only.
713///
714/// # Examples
715///
716/// ```rust
717/// use train_station::{Device, current_device, with_device};
718///
719/// let initial_device = current_device();
720/// assert!(initial_device.is_cpu());
721///
722/// with_device(Device::cpu(), || {
723/// assert_eq!(current_device(), Device::cpu());
724/// });
725///
726/// assert_eq!(current_device(), initial_device);
727/// ```
728pub fn current_device() -> Device {
729 DEVICE_STACK.with(|stack| stack.borrow().last().copied().unwrap_or_else(Device::cpu))
730}
731
732/// Set the current thread's device context
733///
734/// This function updates the current thread's device context. It modifies the
735/// top of the thread-local device stack.
736///
737/// # Arguments
738///
739/// * `device` - The device to set as the current context
740///
741/// # Thread Safety
742///
743/// This function is thread-safe and only affects the current thread's context.
744///
745/// # Side Effects
746///
747/// Changes the current thread's device context to the specified device.
748fn set_current_device(device: Device) {
749 DEVICE_STACK.with(|stack| {
750 let mut stack = stack.borrow_mut();
751 if stack.is_empty() {
752 stack.push(device);
753 } else {
754 // Replace the top of the stack
755 if let Some(last) = stack.last_mut() {
756 *last = device;
757 }
758 }
759 });
760}
761
762/// Execute a closure with a specific device context
763///
764/// This function temporarily switches to the specified device for the duration
765/// of the closure, then automatically restores the previous device. This is
766/// the recommended way to execute code with a specific device context.
767///
768/// # Arguments
769///
770/// * `device` - The device to use for the closure
771/// * `f` - The closure to execute
772///
773/// # Returns
774///
775/// The result of the closure
776///
777/// # Thread Safety
778///
779/// This function is thread-safe and only affects the current thread's context.
780///
781/// # Examples
782///
783/// ```rust
784/// use train_station::{Device, with_device, current_device};
785///
786/// let original_device = current_device();
787///
788/// let result = with_device(Device::cpu(), || {
789/// assert_eq!(current_device(), Device::cpu());
790/// // Perform operations with CPU device
791/// 42
792/// });
793///
794/// assert_eq!(result, 42);
795/// assert_eq!(current_device(), original_device);
796/// ```
797pub fn with_device<F, R>(device: Device, f: F) -> R
798where
799 F: FnOnce() -> R,
800{
801 let _context = DeviceContext::new(device);
802 f()
803}
804
805// Helper functions for device ID conversion
806/// Convert a device to a numeric ID for storage
807///
808/// # Arguments
809///
810/// * `device` - The device to convert
811///
812/// # Returns
813///
814/// A numeric ID representing the device:
815/// - 0 for CPU devices
816/// - 1000 + index for CUDA devices
817///
818/// # Thread Safety
819///
820/// This function is thread-safe and has no side effects.
821fn device_to_id(device: Device) -> usize {
822 match device.device_type {
823 DeviceType::Cpu => 0,
824 DeviceType::Cuda => 1000 + device.index, // Offset CUDA devices by 1000
825 }
826}
827
828/// Convert a numeric ID back to a device
829///
830/// # Arguments
831///
832/// * `id` - The numeric ID to convert
833///
834/// # Returns
835///
836/// A device representing the ID:
837/// - ID 0 → CPU device
838/// - ID >= 1000 → CUDA device with index (ID - 1000)
839/// - Invalid IDs → CPU device (fallback)
840///
841/// # Thread Safety
842///
843/// This function is thread-safe and has no side effects.
844fn id_to_device(id: usize) -> Device {
845 if id == 0 {
846 Device::cpu()
847 } else if id >= 1000 {
848 Device::cuda(id - 1000)
849 } else {
850 Device::cpu() // Fallback to CPU for invalid IDs
851 }
852}
853
854// ================================================================================================
855// CUDA Availability Functions (Direct delegation to cuda_ffi)
856// ================================================================================================
857
858/// Check if CUDA is available
859///
860/// This function checks if CUDA is available on the current system and
861/// at least one CUDA device is found. The result depends on the CUDA
862/// feature flag and system configuration.
863///
864/// # Returns
865///
866/// - `true` if CUDA feature is enabled and at least one CUDA device is available
867/// - `false` if CUDA feature is disabled or no CUDA devices are found
868///
869/// # Thread Safety
870///
871/// This function is thread-safe and can be called from multiple threads.
872///
873/// # Examples
874///
875/// ```rust
876/// use train_station::cuda_is_available;
877///
878/// if cuda_is_available() {
879/// println!("CUDA is available");
880/// // Create CUDA tensors and perform GPU operations
881/// } else {
882/// println!("CUDA is not available, using CPU only");
883/// // Fall back to CPU operations
884/// }
885/// ```
886pub fn cuda_is_available() -> bool {
887 #[cfg(feature = "cuda")]
888 {
889 crate::cuda::cuda_is_available()
890 }
891
892 #[cfg(not(feature = "cuda"))]
893 {
894 false
895 }
896}
897
898/// Get the number of CUDA devices available
899///
900/// This function returns the number of CUDA devices available on the system.
901/// The result depends on the CUDA feature flag and system configuration.
902///
903/// # Returns
904///
905/// Number of CUDA devices available:
906/// - 0 if CUDA feature is disabled
907/// - 0 if CUDA is not available on the system
908/// - Number of available CUDA devices if CUDA is available
909///
910/// # Thread Safety
911///
912/// This function is thread-safe and can be called from multiple threads.
913///
914/// # Examples
915///
916/// ```rust
917/// use train_station::{cuda_device_count, Device};
918///
919/// let device_count = cuda_device_count();
920/// println!("Found {} CUDA devices", device_count);
921///
922/// for i in 0..device_count {
923/// let device = Device::cuda(i);
924/// println!("CUDA device {}: {}", i, device);
925/// }
926/// ```
927#[allow(unused)]
928pub fn cuda_device_count() -> usize {
929 #[cfg(feature = "cuda")]
930 {
931 crate::cuda::cuda_device_count() as usize
932 }
933
934 #[cfg(not(feature = "cuda"))]
935 {
936 0
937 }
938}
939
940// ================================================================================================
941// Tests
942// ================================================================================================
943
944#[cfg(test)]
945mod tests {
946 use super::*;
947
948 #[test]
949 fn test_cpu_device() {
950 let device = Device::cpu();
951 assert_eq!(device.device_type(), DeviceType::Cpu);
952 assert_eq!(device.index(), 0);
953 assert!(device.is_cpu());
954 assert!(!device.is_cuda());
955 assert_eq!(device.to_string(), "cpu");
956 }
957
958 #[test]
959 fn test_device_default() {
960 let device = Device::default();
961 assert_eq!(device.device_type(), DeviceType::Cpu);
962 assert!(device.is_cpu());
963 }
964
965 #[test]
966 fn test_device_type_display() {
967 assert_eq!(DeviceType::Cpu.to_string(), "cpu");
968 assert_eq!(DeviceType::Cuda.to_string(), "cuda");
969 }
970
971 #[test]
972 fn test_device_from_device_type() {
973 let device = Device::from(DeviceType::Cpu);
974 assert!(device.is_cpu());
975 assert_eq!(device.index(), 0);
976 }
977
978 #[test]
979 #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
980 fn test_cuda_device_panics() {
981 Device::cuda(0);
982 }
983
984 #[test]
985 #[should_panic(expected = "CUDA support not enabled. Enable with --features cuda")]
986 fn test_device_from_cuda_type_panics() {
987 let _ = Device::from(DeviceType::Cuda);
988 }
989
990 #[test]
991 fn test_device_equality() {
992 let cpu1 = Device::cpu();
993 let cpu2 = Device::cpu();
994 assert_eq!(cpu1, cpu2);
995 }
996
997 // Context management tests
998 #[test]
999 fn test_current_device() {
1000 assert_eq!(current_device(), Device::cpu());
1001 }
1002
1003 #[test]
1004 fn test_default_device() {
1005 let initial_default = get_default_device();
1006 assert_eq!(initial_default, Device::cpu());
1007
1008 // Should still be CPU after setting it explicitly
1009 set_default_device(Device::cpu());
1010 assert_eq!(get_default_device(), Device::cpu());
1011 }
1012
1013 #[test]
1014 fn test_device_context_guard() {
1015 let original_device = current_device();
1016
1017 {
1018 let _guard = DeviceContext::new(Device::cpu());
1019 assert_eq!(current_device(), Device::cpu());
1020 }
1021
1022 // Device should be restored after guard is dropped
1023 assert_eq!(current_device(), original_device);
1024 }
1025
1026 #[test]
1027 fn test_with_device() {
1028 let original_device = current_device();
1029
1030 let result = with_device(Device::cpu(), || {
1031 assert_eq!(current_device(), Device::cpu());
1032 42
1033 });
1034
1035 assert_eq!(result, 42);
1036 assert_eq!(current_device(), original_device);
1037 }
1038
1039 #[test]
1040 fn test_nested_device_contexts() {
1041 let original = current_device();
1042
1043 with_device(Device::cpu(), || {
1044 assert_eq!(current_device(), Device::cpu());
1045
1046 with_device(Device::cpu(), || {
1047 assert_eq!(current_device(), Device::cpu());
1048 });
1049
1050 assert_eq!(current_device(), Device::cpu());
1051 });
1052
1053 assert_eq!(current_device(), original);
1054 }
1055
1056 #[test]
1057 fn test_device_id_conversion() {
1058 assert_eq!(device_to_id(Device::cpu()), 0);
1059 assert_eq!(id_to_device(0), Device::cpu());
1060
1061 // Test invalid ID fallback
1062 assert_eq!(id_to_device(999), Device::cpu());
1063 }
1064
1065 #[test]
1066 fn test_cuda_availability_check() {
1067 // These functions should be callable regardless of CUDA availability
1068 let available = cuda_is_available();
1069 let device_count = cuda_device_count();
1070
1071 if available {
1072 assert!(device_count > 0, "CUDA available but no devices found");
1073 } else {
1074 assert_eq!(device_count, 0, "CUDA not available but devices reported");
1075 }
1076 }
1077}