candle_examples/
lib.rs

1pub mod audio;
2pub mod bs1770;
3pub mod chat_template;
4pub mod coco_classes;
5pub mod imagenet;
6pub mod token_output_stream;
7pub mod wav;
8use candle::utils::{cuda_is_available, metal_is_available};
9use candle::{Device, Result, Tensor};
10
11pub fn device(cpu: bool) -> Result<Device> {
12    if cpu {
13        Ok(Device::Cpu)
14    } else if cuda_is_available() {
15        Ok(Device::new_cuda(0)?)
16    } else if metal_is_available() {
17        Ok(Device::new_metal(0)?)
18    } else {
19        #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
20        {
21            println!(
22                "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
23            );
24        }
25        #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
26        {
27            println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
28        }
29        Ok(Device::Cpu)
30    }
31}
32
33pub fn load_image<P: AsRef<std::path::Path>>(
34    p: P,
35    resize_longest: Option<usize>,
36) -> Result<(Tensor, usize, usize)> {
37    let img = image::ImageReader::open(p)?
38        .decode()
39        .map_err(candle::Error::wrap)?;
40    let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
41    let img = match resize_longest {
42        None => img,
43        Some(resize_longest) => {
44            let (height, width) = (img.height(), img.width());
45            let resize_longest = resize_longest as u32;
46            let (height, width) = if height < width {
47                let h = (resize_longest * height) / width;
48                (h, resize_longest)
49            } else {
50                let w = (resize_longest * width) / height;
51                (resize_longest, w)
52            };
53            img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
54        }
55    };
56    let (height, width) = (img.height() as usize, img.width() as usize);
57    let img = img.to_rgb8();
58    let data = img.into_raw();
59    let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
60    Ok((data, initial_h, initial_w))
61}
62
63pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
64    p: P,
65    width: usize,
66    height: usize,
67) -> Result<Tensor> {
68    let img = image::ImageReader::open(p)?
69        .decode()
70        .map_err(candle::Error::wrap)?
71        .resize_to_fill(
72            width as u32,
73            height as u32,
74            image::imageops::FilterType::Triangle,
75        );
76    let img = img.to_rgb8();
77    let data = img.into_raw();
78    Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
79}
80
81/// Saves an image to disk using the image crate, this expects an input with shape
82/// (c, height, width).
83pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
84    let p = p.as_ref();
85    let (channel, height, width) = img.dims3()?;
86    if channel != 3 {
87        candle::bail!("save_image expects an input of shape (3, height, width)")
88    }
89    let img = img.permute((1, 2, 0))?.flatten_all()?;
90    let pixels = img.to_vec1::<u8>()?;
91    let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
92        match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
93            Some(image) => image,
94            None => candle::bail!("error saving image {p:?}"),
95        };
96    image.save(p).map_err(candle::Error::wrap)?;
97    Ok(())
98}
99
100pub fn save_image_resize<P: AsRef<std::path::Path>>(
101    img: &Tensor,
102    p: P,
103    h: usize,
104    w: usize,
105) -> Result<()> {
106    let p = p.as_ref();
107    let (channel, height, width) = img.dims3()?;
108    if channel != 3 {
109        candle::bail!("save_image expects an input of shape (3, height, width)")
110    }
111    let img = img.permute((1, 2, 0))?.flatten_all()?;
112    let pixels = img.to_vec1::<u8>()?;
113    let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
114        match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
115            Some(image) => image,
116            None => candle::bail!("error saving image {p:?}"),
117        };
118    let image = image::DynamicImage::from(image);
119    let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
120    image.save(p).map_err(candle::Error::wrap)?;
121    Ok(())
122}
123
124/// Loads the safetensors files for a model from the hub based on a json index file.
125pub fn hub_load_safetensors(
126    repo: &hf_hub::api::sync::ApiRepo,
127    json_file: &str,
128) -> Result<Vec<std::path::PathBuf>> {
129    let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
130    let json_file = std::fs::File::open(json_file)?;
131    let json: serde_json::Value =
132        serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
133    let weight_map = match json.get("weight_map") {
134        None => candle::bail!("no weight map in {json_file:?}"),
135        Some(serde_json::Value::Object(map)) => map,
136        Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
137    };
138    let mut safetensors_files = std::collections::HashSet::new();
139    for value in weight_map.values() {
140        if let Some(file) = value.as_str() {
141            safetensors_files.insert(file.to_string());
142        }
143    }
144    let safetensors_files = safetensors_files
145        .iter()
146        .map(|v| repo.get(v).map_err(candle::Error::wrap))
147        .collect::<Result<Vec<_>>>()?;
148    Ok(safetensors_files)
149}
150
151pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
152    path: P,
153    json_file: &str,
154) -> Result<Vec<std::path::PathBuf>> {
155    let path = path.as_ref();
156    let jsfile = std::fs::File::open(path.join(json_file))?;
157    let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?;
158    let weight_map = match json.get("weight_map") {
159        None => candle::bail!("no weight map in {json_file:?}"),
160        Some(serde_json::Value::Object(map)) => map,
161        Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
162    };
163    let mut safetensors_files = std::collections::HashSet::new();
164    for value in weight_map.values() {
165        if let Some(file) = value.as_str() {
166            safetensors_files.insert(file);
167        }
168    }
169    let safetensors_files: Vec<_> = safetensors_files
170        .into_iter()
171        .map(|v| path.join(v))
172        .collect();
173    Ok(safetensors_files)
174}