1use std::fs::File;
5use std::io;
6use std::path::Path;
7
8use npyz::{self, WriterBuilder};
9use npyz::{TypeStr, npz};
10use zip::write::ExtendedFileOptions;
11
12use crate::error::PineappleError;
13
14pub fn write_numpy<T, P: AsRef<Path>>(
22 path: P,
23 data: Vec<T>,
24 shape: Vec<u64>,
25) -> Result<(), PineappleError>
26where
27 T: npyz::Serialize + npyz::AutoSerialize,
28{
29 let mut buffer = vec![];
30 let mut writer = npyz::WriteOptions::<T>::new()
31 .default_dtype()
32 .shape(&shape)
33 .writer(&mut buffer)
34 .begin_nd()
35 .map_err(|_| PineappleError::ImageWriteError)?;
36
37 for d in data {
38 let _ = writer.push(&d);
39 }
40
41 writer.finish().map_err(|_| PineappleError::ImageWriteError)?;
42 std::fs::write(path, buffer).map_err(|_| PineappleError::ImageWriteError)?;
43 Ok(())
44}
45
46pub fn write_embeddings_npz<P: AsRef<Path>>(
62 images: Vec<String>,
63 ids: Vec<u32>,
64 centroids: Vec<[f32; 2]>,
65 embeddings: Vec<Vec<f32>>,
66 output: &P,
67) -> Result<(), PineappleError> {
68 let file = io::BufWriter::new(
69 File::create(output)
70 .map_err(|_| PineappleError::OtherError("Failed to create .npz file".to_string()))?,
71 );
72
73 let mut zip = zip::ZipWriter::new(file);
74
75 if images.len() != embeddings.len() {
76 return Err(PineappleError::OtherError(
77 "Image names and embeddings must have same length when saving .npz.".to_string(),
78 ));
79 }
80
81 if !ids.is_empty() && ids.len() != embeddings.len() {
82 return Err(PineappleError::OtherError(
83 "Object identifiers and embeddings must have same length when saving .npz.".to_string(),
84 ));
85 }
86
87 if !centroids.is_empty() && centroids.len() != embeddings.len() {
88 return Err(PineappleError::OtherError(
89 "Object centroids and embeddings must have same length when saving .npz.".to_string(),
90 ));
91 }
92
93 let n = embeddings.len() as u64;
94 let m = embeddings[0].len() as u64;
95
96 zip.start_file::<_, ExtendedFileOptions>(
99 npz::file_name_from_array_name("image"),
100 Default::default(),
101 )
102 .map_err(|_| {
103 PineappleError::OtherError(
104 "Failed to initiailize zip file for image names in .npz file".to_string(),
105 )
106 })?;
107
108 let mut writer = npyz::WriteOptions::new()
109 .dtype(npyz::DType::Plain("<U53".parse::<TypeStr>().unwrap()))
110 .shape(&[n])
111 .writer(&mut zip)
112 .begin_nd()
113 .map_err(|_| {
114 PineappleError::OtherError(
115 "Failed to initiailize writer for image names in .npz file".to_string(),
116 )
117 })?;
118
119 writer
120 .extend(images.iter().map(|image| image.as_str()))
121 .map_err(|_| {
122 PineappleError::OtherError("Failed to add image names to .npz file".to_string())
123 })?;
124
125 writer.finish().map_err(|_| {
126 PineappleError::OtherError("Failed to write image names to .npz file".to_string())
127 })?;
128
129 if !ids.is_empty() {
132 zip.start_file::<_, ExtendedFileOptions>(
133 npz::file_name_from_array_name("id"),
134 Default::default(),
135 )
136 .map_err(|_| {
137 PineappleError::OtherError(
138 "Failed to initiailize zip file for identifiers in .npz file".to_string(),
139 )
140 })?;
141
142 let mut writer = npyz::WriteOptions::new()
143 .default_dtype()
144 .shape(&[n])
145 .writer(&mut zip)
146 .begin_nd()
147 .map_err(|_| {
148 PineappleError::OtherError(
149 "Failed to initialize writer for identifiers in .npz file".to_string(),
150 )
151 })?;
152
153 writer.extend(ids).map_err(|_| {
154 PineappleError::OtherError("Failed to add identifiers to .npz file".to_string())
155 })?;
156
157 writer.finish().map_err(|_| {
158 PineappleError::OtherError("Failed to write identifiers to .npz file".to_string())
159 })?;
160 }
161
162 if !centroids.is_empty() {
165 zip.start_file::<_, ExtendedFileOptions>(
166 npz::file_name_from_array_name("centroid"),
167 Default::default(),
168 )
169 .map_err(|_| {
170 PineappleError::OtherError(
171 "Failed to initiailize zip file for centroids in .npz file".to_string(),
172 )
173 })?;
174
175 let mut writer = npyz::WriteOptions::new()
176 .default_dtype()
177 .shape(&[n, 2])
178 .writer(&mut zip)
179 .begin_nd()
180 .map_err(|_| {
181 PineappleError::OtherError(
182 "Failed to initialize writer for centroids in .npz file".to_string(),
183 )
184 })?;
185
186 writer
187 .extend(centroids.iter().flat_map(|r| r.iter().cloned()))
188 .map_err(|_| {
189 PineappleError::OtherError("Failed to add centroids to .npz file".to_string())
190 })?;
191
192 writer.finish().map_err(|_| {
193 PineappleError::OtherError("Failed to write centroids to .npz file".to_string())
194 })?;
195 }
196
197 zip.start_file::<_, ExtendedFileOptions>(
200 npz::file_name_from_array_name("embedding"),
201 Default::default(),
202 )
203 .map_err(|_| {
204 PineappleError::OtherError(
205 "Failed to initiailize zip file for embeddings in .npz file".to_string(),
206 )
207 })?;
208
209 let mut writer = npyz::WriteOptions::new()
210 .default_dtype()
211 .shape(&[n, m])
212 .writer(&mut zip)
213 .begin_nd()
214 .map_err(|_| {
215 PineappleError::OtherError(
216 "Failed to initiailize writer for embeddings in .npz file".to_string(),
217 )
218 })?;
219
220 writer
221 .extend(embeddings.iter().flat_map(|r| r.iter().cloned()))
222 .map_err(|_| PineappleError::OtherError("Failed to add embeddings to .npz file".to_string()))?;
223
224 writer.finish().map_err(|_| {
225 PineappleError::OtherError("Failed to write image names to .npz file".to_string())
226 })?;
227
228 zip.finish()
229 .map_err(|_| PineappleError::OtherError("Failed to zip .npz file".to_string()))?;
230
231 Ok(())
232}