candle_examples/
lib.rs

1pub mod audio;
2pub mod bs1770;
3pub mod coco_classes;
4pub mod imagenet;
5pub mod token_output_stream;
6pub mod wav;
7use candle::utils::{cuda_is_available, metal_is_available};
8use candle::{Device, Result, Tensor};
9
10pub fn device(cpu: bool) -> Result<Device> {
11    if cpu {
12        Ok(Device::Cpu)
13    } else if cuda_is_available() {
14        Ok(Device::new_cuda(0)?)
15    } else if metal_is_available() {
16        Ok(Device::new_metal(0)?)
17    } else {
18        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
19        {
20            println!(
21                "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
22            );
23        }
24        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
25        {
26            println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
27        }
28        Ok(Device::Cpu)
29    }
30}
31
32pub fn load_image<P: AsRef<std::path::Path>>(
33    p: P,
34    resize_longest: Option<usize>,
35) -> Result<(Tensor, usize, usize)> {
36    let img = image::ImageReader::open(p)?
37        .decode()
38        .map_err(candle::Error::wrap)?;
39    let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
40    let img = match resize_longest {
41        None => img,
42        Some(resize_longest) => {
43            let (height, width) = (img.height(), img.width());
44            let resize_longest = resize_longest as u32;
45            let (height, width) = if height < width {
46                let h = (resize_longest * height) / width;
47                (h, resize_longest)
48            } else {
49                let w = (resize_longest * width) / height;
50                (resize_longest, w)
51            };
52            img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
53        }
54    };
55    let (height, width) = (img.height() as usize, img.width() as usize);
56    let img = img.to_rgb8();
57    let data = img.into_raw();
58    let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
59    Ok((data, initial_h, initial_w))
60}
61
62pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
63    p: P,
64    width: usize,
65    height: usize,
66) -> Result<Tensor> {
67    let img = image::ImageReader::open(p)?
68        .decode()
69        .map_err(candle::Error::wrap)?
70        .resize_to_fill(
71            width as u32,
72            height as u32,
73            image::imageops::FilterType::Triangle,
74        );
75    let img = img.to_rgb8();
76    let data = img.into_raw();
77    Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
78}
79
80/// Saves an image to disk using the image crate, this expects an input with shape
81/// (c, height, width).
82pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
83    let p = p.as_ref();
84    let (channel, height, width) = img.dims3()?;
85    if channel != 3 {
86        candle::bail!("save_image expects an input of shape (3, height, width)")
87    }
88    let img = img.permute((1, 2, 0))?.flatten_all()?;
89    let pixels = img.to_vec1::<u8>()?;
90    let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
91        match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
92            Some(image) => image,
93            None => candle::bail!("error saving image {p:?}"),
94        };
95    image.save(p).map_err(candle::Error::wrap)?;
96    Ok(())
97}
98
99pub fn save_image_resize<P: AsRef<std::path::Path>>(
100    img: &Tensor,
101    p: P,
102    h: usize,
103    w: usize,
104) -> Result<()> {
105    let p = p.as_ref();
106    let (channel, height, width) = img.dims3()?;
107    if channel != 3 {
108        candle::bail!("save_image expects an input of shape (3, height, width)")
109    }
110    let img = img.permute((1, 2, 0))?.flatten_all()?;
111    let pixels = img.to_vec1::<u8>()?;
112    let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
113        match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
114            Some(image) => image,
115            None => candle::bail!("error saving image {p:?}"),
116        };
117    let image = image::DynamicImage::from(image);
118    let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
119    image.save(p).map_err(candle::Error::wrap)?;
120    Ok(())
121}
122
123/// Loads the safetensors files for a model from the hub based on a json index file.
124pub fn hub_load_safetensors(
125    repo: &hf_hub::api::sync::ApiRepo,
126    json_file: &str,
127) -> Result<Vec<std::path::PathBuf>> {
128    let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
129    let json_file = std::fs::File::open(json_file)?;
130    let json: serde_json::Value =
131        serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
132    let weight_map = match json.get("weight_map") {
133        None => candle::bail!("no weight map in {json_file:?}"),
134        Some(serde_json::Value::Object(map)) => map,
135        Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
136    };
137    let mut safetensors_files = std::collections::HashSet::new();
138    for value in weight_map.values() {
139        if let Some(file) = value.as_str() {
140            safetensors_files.insert(file.to_string());
141        }
142    }
143    let safetensors_files = safetensors_files
144        .iter()
145        .map(|v| repo.get(v).map_err(candle::Error::wrap))
146        .collect::<Result<Vec<_>>>()?;
147    Ok(safetensors_files)
148}
149
150pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
151    path: P,
152    json_file: &str,
153) -> Result<Vec<std::path::PathBuf>> {
154    let path = path.as_ref();
155    let jsfile = std::fs::File::open(path.join(json_file))?;
156    let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?;
157    let weight_map = match json.get("weight_map") {
158        None => candle::bail!("no weight map in {json_file:?}"),
159        Some(serde_json::Value::Object(map)) => map,
160        Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
161    };
162    let mut safetensors_files = std::collections::HashSet::new();
163    for value in weight_map.values() {
164        if let Some(file) = value.as_str() {
165            safetensors_files.insert(file);
166        }
167    }
168    let safetensors_files: Vec<_> = safetensors_files
169        .into_iter()
170        .map(|v| path.join(v))
171        .collect();
172    Ok(safetensors_files)
173}