1pub mod audio;
2pub mod bs1770;
3pub mod chat_template;
4pub mod coco_classes;
5pub mod imagenet;
6pub mod token_output_stream;
7pub mod wav;
8use candle::utils::{cuda_is_available, metal_is_available};
9use candle::{Device, Result, Tensor};
10
11pub fn device(cpu: bool) -> Result<Device> {
12 if cpu {
13 Ok(Device::Cpu)
14 } else if cuda_is_available() {
15 Ok(Device::new_cuda(0)?)
16 } else if metal_is_available() {
17 Ok(Device::new_metal(0)?)
18 } else {
19 #[cfg(all(target_os = "macos", target_arch = "aarch64"))]
20 {
21 println!(
22 "Running on CPU, to run on GPU(metal), build this example with `--features metal`"
23 );
24 }
25 #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
26 {
27 println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
28 }
29 Ok(Device::Cpu)
30 }
31}
32
33pub fn load_image<P: AsRef<std::path::Path>>(
34 p: P,
35 resize_longest: Option<usize>,
36) -> Result<(Tensor, usize, usize)> {
37 let img = image::ImageReader::open(p)?
38 .decode()
39 .map_err(candle::Error::wrap)?;
40 let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
41 let img = match resize_longest {
42 None => img,
43 Some(resize_longest) => {
44 let (height, width) = (img.height(), img.width());
45 let resize_longest = resize_longest as u32;
46 let (height, width) = if height < width {
47 let h = (resize_longest * height) / width;
48 (h, resize_longest)
49 } else {
50 let w = (resize_longest * width) / height;
51 (resize_longest, w)
52 };
53 img.resize_exact(width, height, image::imageops::FilterType::CatmullRom)
54 }
55 };
56 let (height, width) = (img.height() as usize, img.width() as usize);
57 let img = img.to_rgb8();
58 let data = img.into_raw();
59 let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
60 Ok((data, initial_h, initial_w))
61}
62
63pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
64 p: P,
65 width: usize,
66 height: usize,
67) -> Result<Tensor> {
68 let img = image::ImageReader::open(p)?
69 .decode()
70 .map_err(candle::Error::wrap)?
71 .resize_to_fill(
72 width as u32,
73 height as u32,
74 image::imageops::FilterType::Triangle,
75 );
76 let img = img.to_rgb8();
77 let data = img.into_raw();
78 Tensor::from_vec(data, (width, height, 3), &Device::Cpu)?.permute((2, 0, 1))
79}
80
81pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
84 let p = p.as_ref();
85 let (channel, height, width) = img.dims3()?;
86 if channel != 3 {
87 candle::bail!("save_image expects an input of shape (3, height, width)")
88 }
89 let img = img.permute((1, 2, 0))?.flatten_all()?;
90 let pixels = img.to_vec1::<u8>()?;
91 let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
92 match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
93 Some(image) => image,
94 None => candle::bail!("error saving image {p:?}"),
95 };
96 image.save(p).map_err(candle::Error::wrap)?;
97 Ok(())
98}
99
100pub fn save_image_resize<P: AsRef<std::path::Path>>(
101 img: &Tensor,
102 p: P,
103 h: usize,
104 w: usize,
105) -> Result<()> {
106 let p = p.as_ref();
107 let (channel, height, width) = img.dims3()?;
108 if channel != 3 {
109 candle::bail!("save_image expects an input of shape (3, height, width)")
110 }
111 let img = img.permute((1, 2, 0))?.flatten_all()?;
112 let pixels = img.to_vec1::<u8>()?;
113 let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
114 match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
115 Some(image) => image,
116 None => candle::bail!("error saving image {p:?}"),
117 };
118 let image = image::DynamicImage::from(image);
119 let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
120 image.save(p).map_err(candle::Error::wrap)?;
121 Ok(())
122}
123
124pub fn hub_load_safetensors(
126 repo: &hf_hub::api::sync::ApiRepo,
127 json_file: &str,
128) -> Result<Vec<std::path::PathBuf>> {
129 let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
130 let json_file = std::fs::File::open(json_file)?;
131 let json: serde_json::Value =
132 serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
133 let weight_map = match json.get("weight_map") {
134 None => candle::bail!("no weight map in {json_file:?}"),
135 Some(serde_json::Value::Object(map)) => map,
136 Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
137 };
138 let mut safetensors_files = std::collections::HashSet::new();
139 for value in weight_map.values() {
140 if let Some(file) = value.as_str() {
141 safetensors_files.insert(file.to_string());
142 }
143 }
144 let safetensors_files = safetensors_files
145 .iter()
146 .map(|v| repo.get(v).map_err(candle::Error::wrap))
147 .collect::<Result<Vec<_>>>()?;
148 Ok(safetensors_files)
149}
150
151pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
152 path: P,
153 json_file: &str,
154) -> Result<Vec<std::path::PathBuf>> {
155 let path = path.as_ref();
156 let jsfile = std::fs::File::open(path.join(json_file))?;
157 let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?;
158 let weight_map = match json.get("weight_map") {
159 None => candle::bail!("no weight map in {json_file:?}"),
160 Some(serde_json::Value::Object(map)) => map,
161 Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
162 };
163 let mut safetensors_files = std::collections::HashSet::new();
164 for value in weight_map.values() {
165 if let Some(file) = value.as_str() {
166 safetensors_files.insert(file);
167 }
168 }
169 let safetensors_files: Vec<_> = safetensors_files
170 .into_iter()
171 .map(|v| path.join(v))
172 .collect();
173 Ok(safetensors_files)
174}