Skip to main content

spark_bert/
util.rs

1use candle_core::utils::{cuda_is_available, metal_is_available};
2use candle_core::{Device, Result};
3use indicatif::{ProgressBar, ProgressStyle};
4
5pub fn device(cpu: bool) -> Result<Device> {
6    if cpu {
7        Ok(Device::Cpu)
8    } else if cuda_is_available() {
9        Ok(Device::new_cuda(0)?)
10    } else if metal_is_available() {
11        Ok(Device::new_metal(0)?)
12    } else {
13        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
14        {
15            println!(
16                "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
17            );
18        }
19        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
20        {
21            println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
22        }
23        Ok(Device::Cpu)
24    }
25}
26
27pub fn get_progress_bar(len: u64) -> anyhow::Result<ProgressBar> {
28    let pb = ProgressBar::new(len);
29    pb.set_style(
30        ProgressStyle::default_bar()
31            .template(
32                "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {pos}/{len} ({eta})",
33            )
34            .unwrap()
35            .progress_chars("#>-"),
36    );
37    anyhow::Ok(pb)
38}