1pub mod audio;
2pub mod bs1770;
3pub mod coco_classes;
4pub mod imagenet;
5pub mod token_output_stream;
6pub mod wav;
7use candle::utils::{cuda_is_available, metal_is_available};
8use candle::{Device, Result, Tensor};
9
10pub fn device(cpu: bool) -> Result<Device> {
11 if cpu {
12 Ok(Device::Cpu)
13 } else if cuda_is_available() {
14 Ok(Device::new_cuda(0)?)
15 } else if metal_is_available() {
16 Ok(Device::new_metal(0)?)
17 } else {
18 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
19 {
20 println!(
21 "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
22 );
23 }
24 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
25 {
26 println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
27 }
28 Ok(Device::Cpu)
29 }
30}
31
32pub fn load_image<P: AsRef<std::path::Path>>(
33 p: P,
34 resize_longest: Option<usize>,
35) -> Result<(Tensor, usize, usize)> {
36 let img = image::ImageReader::open(p)?
37 .decode()
38 .map_err(candle::Error::wrap)?;
39 let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
40 let img = match resize_longest {
41 None => img,
42 Some(resize_longest) => {
43 let (height, width) = (img.height(), img.width());
44 let resize_longest = resize_longest as u32;
45 let (height, width) = if height < width {
46 let h = (resize_longest * height) / width;
47 (h, resize_longest)
48 } else {
49 let w = (resize_longest * width) / height;
50 (resize_longest, w)
51 };
52 img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
53 }
54 };
55 let (height, width) = (img.height() as usize, img.width() as usize);
56 let img = img.to_rgb8();
57 let data = img.into_raw();
58 let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
59 Ok((data, initial_h, initial_w))
60}
61
62pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
63 p: P,
64 width: usize,
65 height: usize,
66) -> Result<Tensor> {
67 let img = image::ImageReader::open(p)?
68 .decode()
69 .map_err(candle::Error::wrap)?
70 .resize_to_fill(
71 width as u32,
72 height as u32,
73 image::imageops::FilterType::Triangle,
74 );
75 let img = img.to_rgb8();
76 let data = img.into_raw();
77 Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
78}
79
80pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
83 let p = p.as_ref();
84 let (channel, height, width) = img.dims3()?;
85 if channel != 3 {
86 candle::bail!("save_image expects an input of shape (3, height, width)")
87 }
88 let img = img.permute((1, 2, 0))?.flatten_all()?;
89 let pixels = img.to_vec1::<u8>()?;
90 let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
91 match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
92 Some(image) => image,
93 None => candle::bail!("error saving image {p:?}"),
94 };
95 image.save(p).map_err(candle::Error::wrap)?;
96 Ok(())
97}
98
99pub fn save_image_resize<P: AsRef<std::path::Path>>(
100 img: &Tensor,
101 p: P,
102 h: usize,
103 w: usize,
104) -> Result<()> {
105 let p = p.as_ref();
106 let (channel, height, width) = img.dims3()?;
107 if channel != 3 {
108 candle::bail!("save_image expects an input of shape (3, height, width)")
109 }
110 let img = img.permute((1, 2, 0))?.flatten_all()?;
111 let pixels = img.to_vec1::<u8>()?;
112 let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
113 match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
114 Some(image) => image,
115 None => candle::bail!("error saving image {p:?}"),
116 };
117 let image = image::DynamicImage::from(image);
118 let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
119 image.save(p).map_err(candle::Error::wrap)?;
120 Ok(())
121}
122
123pub fn hub_load_safetensors(
125 repo: &hf_hub::api::sync::ApiRepo,
126 json_file: &str,
127) -> Result<Vec<std::path::PathBuf>> {
128 let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
129 let json_file = std::fs::File::open(json_file)?;
130 let json: serde_json::Value =
131 serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
132 let weight_map = match json.get("weight_map") {
133 None => candle::bail!("no weight map in {json_file:?}"),
134 Some(serde_json::Value::Object(map)) => map,
135 Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
136 };
137 let mut safetensors_files = std::collections::HashSet::new();
138 for value in weight_map.values() {
139 if let Some(file) = value.as_str() {
140 safetensors_files.insert(file.to_string());
141 }
142 }
143 let safetensors_files = safetensors_files
144 .iter()
145 .map(|v| repo.get(v).map_err(candle::Error::wrap))
146 .collect::<Result<Vec<_>>>()?;
147 Ok(safetensors_files)
148}
149
150pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
151 path: P,
152 json_file: &str,
153) -> Result<Vec<std::path::PathBuf>> {
154 let path = path.as_ref();
155 let jsfile = std::fs::File::open(path.join(json_file))?;
156 let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?;
157 let weight_map = match json.get("weight_map") {
158 None => candle::bail!("no weight map in {json_file:?}"),
159 Some(serde_json::Value::Object(map)) => map,
160 Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
161 };
162 let mut safetensors_files = std::collections::HashSet::new();
163 for value in weight_map.values() {
164 if let Some(file) = value.as_str() {
165 safetensors_files.insert(file);
166 }
167 }
168 let safetensors_files: Vec<_> = safetensors_files
169 .into_iter()
170 .map(|v| path.join(v))
171 .collect();
172 Ok(safetensors_files)
173}