content_extractor_rl/
device.rs1use candle_core::Device;
7use tracing::info;
8
9pub fn get_device() -> Device {
11 if std::env::var("CONTENT_EXTRACTOR_RL_FORCE_CPU").is_ok() {
13 info!("CONTENT_EXTRACTOR_RL_FORCE_CPU set, using CPU");
14 return Device::Cpu;
15 }
16
17 #[cfg(feature = "cuda")]
19 {
20 if candle_core::utils::cuda_is_available() {
21 match Device::new_cuda(0) {
22 Ok(device) => {
23 info!("Using CUDA device (GPU)");
24 info!("Training will use GPU acceleration");
25 return device;
26 }
27 Err(e) => {
28 tracing::warn!("CUDA available but failed to initialize: {}. Falling back to CPU", e);
29 }
30 }
31 } else {
32 info!("CUDA not available, using CPU");
33 }
34 }
35
36 #[cfg(not(feature = "cuda"))]
37 {
38 info!("Using CPU (built without CUDA support)");
39 }
40
41 Device::Cpu
42}
43
44pub fn get_device_with_preference(prefer_cpu: bool) -> Device {
46 if prefer_cpu {
47 info!("Using CPU (forced)");
48 return Device::Cpu;
49 }
50
51 get_device()
52}
53
54pub fn cuda_is_available() -> bool {
56 #[cfg(feature = "cuda")]
57 {
58 candle_core::utils::cuda_is_available()
59 }
60
61 #[cfg(not(feature = "cuda"))]
62 {
63 false
64 }
65}
66
67pub fn get_device_info(device: &Device) -> String {
69 match device {
70 Device::Cpu => "CPU".to_string(),
71 Device::Cuda(_) => {
72 "CUDA GPU".to_string()
75 }
76 Device::Metal(_) => "Metal GPU".to_string(),
77 }
78}
79
80pub fn get_device_info_string(device: &Device) -> String {
82 let build_info = if cfg!(feature = "cuda") {
83 "CUDA support enabled"
84 } else {
85 "CUDA support disabled"
86 };
87
88 let runtime_info = match device {
89 Device::Cuda(_) => "CUDA GPU",
90 Device::Cpu => "CPU",
91 _ => "Other device",
92 };
93
94 let status = match device {
95 Device::Cuda(_) => "GPU acceleration active",
96 Device::Cpu => "Running on CPU",
97 _ => "Unknown device",
98 };
99
100 format!(
101 "\n\
102 ╔════════════════════════════════════════╗\n\
103 ║ Content Extractor RL - Device Info ║\n\
104 ╠════════════════════════════════════════╣\n\
105 ║ Build: {:<31} ║\n\
106 ║ Runtime: {:<29} ║\n\
107 ║ Status: {:<30} ║\n\
108 ╚════════════════════════════════════════╝",
109 build_info, runtime_info, status
110 )
111}
112
113pub fn print_device_info() {
114 let device = get_device();
115 println!("{}", get_device_info_string(&device));
116}
117
118pub fn log_device_info() {
120 let device = get_device();
121 tracing::info!("{}", get_device_info_string(&device));
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_device_selection() {
130 let device = get_device();
131 println!("Selected device: {:?}", device);
132
133 let info = get_device_info(&device);
134 println!("Device info: {}", info);
135
136 println!("CUDA available: {}", cuda_is_available());
137 }
138
139 #[test]
140 fn test_force_cpu() {
141 std::env::set_var("CONTENT_EXTRACTOR_RL_FORCE_CPU", "1");
142 let device = get_device();
143 assert!(matches!(device, Device::Cpu));
144 std::env::remove_var("CONTENT_EXTRACTOR_RL_FORCE_CPU");
145 }
146
147 #[test]
148 fn test_device_info_cpu() {
149 let device = Device::Cpu;
150 let info = get_device_info(&device);
151 assert_eq!(info, "CPU");
152 }
153}