1use std::path::Path;
4use tch::Device;
5
6pub(crate) fn file_open<P: AsRef<Path>>(path: P) -> anyhow::Result<std::fs::File> {
7 std::fs::File::open(path.as_ref()).map_err(|e| {
8 let context = format!("error opening {:?}", path.as_ref().to_string_lossy());
9 anyhow::Error::new(e).context(context)
10 })
11}
12
13pub struct DeviceSetup {
14 accelerator_device: Device,
15 cpu: Vec<String>,
16}
17
18impl DeviceSetup {
19 pub fn new(cpu: Vec<String>) -> Self {
20 let accelerator_device =
21 if tch::utils::has_mps() { Device::Mps } else { Device::cuda_if_available() };
22 Self { accelerator_device, cpu }
23 }
24
25 pub fn get(&self, name: &str) -> Device {
26 if self.cpu.iter().any(|c| c == "all" || c == name) {
27 Device::Cpu
28 } else {
29 self.accelerator_device
30 }
31 }
32}