reflex/embedding/
device.rs

1use candle_core::Device;
2use tracing::warn;
3
4#[cfg(any(feature = "metal", feature = "cuda"))]
5use tracing::info;
6
7#[cfg(not(any(feature = "metal", feature = "cuda")))]
8use tracing::debug;
9
10use super::error::EmbeddingError;
11
12/// Selects the compute device based on enabled features (falls back to CPU).
13pub fn select_device() -> Result<Device, EmbeddingError> {
14    #[cfg(any(feature = "metal", feature = "cuda"))]
15    let mut failures: Vec<String> = Vec::new();
16
17    #[cfg(not(any(feature = "metal", feature = "cuda")))]
18    let failures: Vec<String> = Vec::new();
19
20    #[cfg(feature = "metal")]
21    {
22        match Device::new_metal(0) {
23            Ok(device) => {
24                info!("Using Metal GPU acceleration");
25                return Ok(device);
26            }
27            Err(e) => {
28                let msg = e.to_string();
29                if cfg!(feature = "cuda") {
30                    warn!(error = %msg, "Metal device unavailable, trying CUDA");
31                } else {
32                    warn!(error = %msg, "Metal device unavailable");
33                }
34                failures.push(format!("metal failed: {msg}"));
35            }
36        }
37    }
38
39    #[cfg(feature = "cuda")]
40    {
41        match Device::new_cuda(0) {
42            Ok(device) => {
43                info!("Using CUDA GPU acceleration");
44                return Ok(device);
45            }
46            Err(e) => {
47                let msg = e.to_string();
48                warn!(error = %msg, "CUDA device unavailable");
49                failures.push(format!("cuda failed: {msg}"));
50            }
51        }
52    }
53
54    #[cfg(not(any(feature = "metal", feature = "cuda")))]
55    {
56        debug!("No GPU features enabled");
57    }
58
59    let reason = if !cfg!(any(feature = "metal", feature = "cuda")) {
60        "no GPU backend compiled".to_string()
61    } else if failures.is_empty() {
62        "no GPU device available".to_string()
63    } else {
64        failures.join("; ")
65    };
66
67    warn!(reason = %reason, "Falling back to CPU device");
68    Ok(Device::Cpu)
69}