reflex/embedding/
device.rs1use candle_core::Device;
2use tracing::warn;
3
4#[cfg(any(feature = "metal", feature = "cuda"))]
5use tracing::info;
6
7#[cfg(not(any(feature = "metal", feature = "cuda")))]
8use tracing::debug;
9
10use super::error::EmbeddingError;
11
12pub fn select_device() -> Result<Device, EmbeddingError> {
14 #[cfg(any(feature = "metal", feature = "cuda"))]
15 let mut failures: Vec<String> = Vec::new();
16
17 #[cfg(not(any(feature = "metal", feature = "cuda")))]
18 let failures: Vec<String> = Vec::new();
19
20 #[cfg(feature = "metal")]
21 {
22 match Device::new_metal(0) {
23 Ok(device) => {
24 info!("Using Metal GPU acceleration");
25 return Ok(device);
26 }
27 Err(e) => {
28 let msg = e.to_string();
29 if cfg!(feature = "cuda") {
30 warn!(error = %msg, "Metal device unavailable, trying CUDA");
31 } else {
32 warn!(error = %msg, "Metal device unavailable");
33 }
34 failures.push(format!("metal failed: {msg}"));
35 }
36 }
37 }
38
39 #[cfg(feature = "cuda")]
40 {
41 match Device::new_cuda(0) {
42 Ok(device) => {
43 info!("Using CUDA GPU acceleration");
44 return Ok(device);
45 }
46 Err(e) => {
47 let msg = e.to_string();
48 warn!(error = %msg, "CUDA device unavailable");
49 failures.push(format!("cuda failed: {msg}"));
50 }
51 }
52 }
53
54 #[cfg(not(any(feature = "metal", feature = "cuda")))]
55 {
56 debug!("No GPU features enabled");
57 }
58
59 let reason = if !cfg!(any(feature = "metal", feature = "cuda")) {
60 "no GPU backend compiled".to_string()
61 } else if failures.is_empty() {
62 "no GPU device available".to_string()
63 } else {
64 failures.join("; ")
65 };
66
67 warn!(reason = %reason, "Falling back to CPU device");
68 Ok(Device::Cpu)
69}