Skip to main content

diffusers/
utils.rs

1// A simple wrapper around File::open adding details about the
2// problematic file.
3use std::path::Path;
4use tch::Device;
5
6pub(crate) fn file_open<P: AsRef<Path>>(path: P) -> anyhow::Result<std::fs::File> {
7    std::fs::File::open(path.as_ref()).map_err(|e| {
8        let context = format!("error opening {:?}", path.as_ref().to_string_lossy());
9        anyhow::Error::new(e).context(context)
10    })
11}
12
13pub struct DeviceSetup {
14    accelerator_device: Device,
15    cpu: Vec<String>,
16}
17
18impl DeviceSetup {
19    pub fn new(cpu: Vec<String>) -> Self {
20        let accelerator_device =
21            if tch::utils::has_mps() { Device::Mps } else { Device::cuda_if_available() };
22        Self { accelerator_device, cpu }
23    }
24
25    pub fn get(&self, name: &str) -> Device {
26        if self.cpu.iter().any(|c| c == "all" || c == name) {
27            Device::Cpu
28        } else {
29            self.accelerator_device
30        }
31    }
32}