Skip to main content

content_extractor_rl/
device.rs

1//! Device selection for CPU/CUDA
2// ============================================================================
3// FILE: crates/content-extractor-rl/src/device.rs
4// ============================================================================
5
6use candle_core::Device;
7use tracing::info;
8
9/// Get best available device (CUDA if available, otherwise CPU)
10pub fn get_device() -> Device {
11    // Check environment variable for forcing CPU
12    if std::env::var("CONTENT_EXTRACTOR_RL_FORCE_CPU").is_ok() {
13        info!("CONTENT_EXTRACTOR_RL_FORCE_CPU set, using CPU");
14        return Device::Cpu;
15    }
16
17    // Try to use CUDA if available
18    #[cfg(feature = "cuda")]
19    {
20        if candle_core::utils::cuda_is_available() {
21            match Device::new_cuda(0) {
22                Ok(device) => {
23                    info!("Using CUDA device (GPU)");
24                    info!("Training will use GPU acceleration");
25                    return device;
26                }
27                Err(e) => {
28                    tracing::warn!("CUDA available but failed to initialize: {}. Falling back to CPU", e);
29                }
30            }
31        } else {
32            info!("CUDA not available, using CPU");
33        }
34    }
35
36    #[cfg(not(feature = "cuda"))]
37    {
38        info!("Using CPU (built without CUDA support)");
39    }
40
41    Device::Cpu
42}
43
44/// Get device with preference (for testing/debugging)
45pub fn get_device_with_preference(prefer_cpu: bool) -> Device {
46    if prefer_cpu {
47        info!("Using CPU (forced)");
48        return Device::Cpu;
49    }
50
51    get_device()
52}
53
54/// Check if CUDA is available
55pub fn cuda_is_available() -> bool {
56    #[cfg(feature = "cuda")]
57    {
58        candle_core::utils::cuda_is_available()
59    }
60
61    #[cfg(not(feature = "cuda"))]
62    {
63        false
64    }
65}
66
67/// Get device info string
68pub fn get_device_info(device: &Device) -> String {
69    match device {
70        Device::Cpu => "CPU".to_string(),
71        Device::Cuda(_) => {
72            // CudaDevice in candle doesn't expose device ID directly
73            // We just indicate it's using CUDA
74            "CUDA GPU".to_string()
75        }
76        Device::Metal(_) => "Metal GPU".to_string(),
77    }
78}
79
80/// get device information
81pub fn get_device_info_string(device: &Device) -> String {
82    let build_info = if cfg!(feature = "cuda") {
83        "CUDA support enabled"
84    } else {
85        "CUDA support disabled"
86    };
87
88    let runtime_info = match device {
89        Device::Cuda(_) => "CUDA GPU",
90        Device::Cpu => "CPU",
91        _ => "Other device",
92    };
93
94    let status = match device {
95        Device::Cuda(_) => "GPU acceleration active",
96        Device::Cpu => "Running on CPU",
97        _ => "Unknown device",
98    };
99
100    format!(
101        "\n\
102         ╔════════════════════════════════════════╗\n\
103         ║   Content Extractor RL - Device Info      ║\n\
104         ╠════════════════════════════════════════╣\n\
105         ║ Build: {:<31} ║\n\
106         ║ Runtime: {:<29} ║\n\
107         ║ Status: {:<30} ║\n\
108         ╚════════════════════════════════════════╝",
109        build_info, runtime_info, status
110    )
111}
112
113pub fn print_device_info() {
114    let device = get_device();
115    println!("{}", get_device_info_string(&device));
116}
117
118// Add new function for logging only
119pub fn log_device_info() {
120    let device = get_device();
121    tracing::info!("{}", get_device_info_string(&device));
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_device_selection() {
130        let device = get_device();
131        println!("Selected device: {:?}", device);
132
133        let info = get_device_info(&device);
134        println!("Device info: {}", info);
135
136        println!("CUDA available: {}", cuda_is_available());
137    }
138
139    #[test]
140    fn test_force_cpu() {
141        std::env::set_var("CONTENT_EXTRACTOR_RL_FORCE_CPU", "1");
142        let device = get_device();
143        assert!(matches!(device, Device::Cpu));
144        std::env::remove_var("CONTENT_EXTRACTOR_RL_FORCE_CPU");
145    }
146
147    #[test]
148    fn test_device_info_cpu() {
149        let device = Device::Cpu;
150        let info = get_device_info(&device);
151        assert_eq!(info, "CPU");
152    }
153}