1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use ad_core::error::{ADError, ADResult};
5use ad_core::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
6use ad_core::ndarray_pool::NDArrayPool;
7use ad_core::plugin::file_base::{NDFileMode, NDFileWriter, NDPluginFileBase};
8use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
9
10use jpeg_encoder::{Encoder as JpegEncoder, ColorType as JpegColorType};
11
12pub struct JpegWriter {
14 current_path: Option<PathBuf>,
15 quality: u8,
16}
17
18impl JpegWriter {
19 pub fn new(quality: u8) -> Self {
20 Self {
21 current_path: None,
22 quality,
23 }
24 }
25
26 pub fn set_quality(&mut self, quality: u8) {
27 self.quality = quality;
28 }
29}
30
31impl NDFileWriter for JpegWriter {
32 fn open_file(&mut self, path: &Path, _mode: NDFileMode, array: &NDArray) -> ADResult<()> {
33 if array.data.data_type() != NDDataType::UInt8 {
34 return Err(ADError::UnsupportedConversion(
35 "JPEG only supports UInt8 data".into(),
36 ));
37 }
38 self.current_path = Some(path.to_path_buf());
39 Ok(())
40 }
41
42 fn write_file(&mut self, array: &NDArray) -> ADResult<()> {
43 let path = self.current_path.as_ref()
44 .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
45
46 let info = array.info();
47 let width = info.x_size;
48 let height = info.y_size;
49
50 let data = match &array.data {
51 NDDataBuffer::U8(v) => v.as_slice(),
52 _ => {
53 return Err(ADError::UnsupportedConversion(
54 "JPEG only supports UInt8".into(),
55 ))
56 }
57 };
58
59 let color_type = if info.color_size == 3 {
60 JpegColorType::Rgb
61 } else {
62 JpegColorType::Luma
63 };
64
65 let mut buf = Vec::new();
66 let encoder = JpegEncoder::new(&mut buf, self.quality);
67 encoder
68 .encode(data, width as u16, height as u16, color_type)
69 .map_err(|e| ADError::UnsupportedConversion(format!("JPEG encode error: {}", e)))?;
70
71 std::fs::write(path, &buf)?;
72 Ok(())
73 }
74
75 fn read_file(&mut self) -> ADResult<NDArray> {
76 let path = self.current_path.as_ref()
77 .ok_or_else(|| ADError::UnsupportedConversion("no file open".into()))?;
78
79 let file_data = std::fs::read(path)?;
80 let mut decoder = jpeg_decoder::Decoder::new(&file_data[..]);
81 let pixels = decoder
82 .decode()
83 .map_err(|e| ADError::UnsupportedConversion(format!("JPEG decode error: {}", e)))?;
84 let info = decoder.info().unwrap();
85
86 let (width, height) = (info.width as usize, info.height as usize);
87
88 let dims = match info.pixel_format {
89 jpeg_decoder::PixelFormat::L8 => {
90 vec![NDDimension::new(width), NDDimension::new(height)]
91 }
92 jpeg_decoder::PixelFormat::RGB24 => {
93 vec![
94 NDDimension::new(3),
95 NDDimension::new(width),
96 NDDimension::new(height),
97 ]
98 }
99 _ => {
100 return Err(ADError::UnsupportedConversion(
101 "unsupported JPEG pixel format".into(),
102 ))
103 }
104 };
105
106 let mut arr = NDArray::new(dims, NDDataType::UInt8);
107 arr.data = NDDataBuffer::U8(pixels);
108 Ok(arr)
109 }
110
111 fn close_file(&mut self) -> ADResult<()> {
112 self.current_path = None;
113 Ok(())
114 }
115
116 fn supports_multiple_arrays(&self) -> bool {
117 false
118 }
119}
120
121pub struct JpegFileProcessor {
123 file_base: NDPluginFileBase,
124 writer: JpegWriter,
125}
126
127impl JpegFileProcessor {
128 pub fn new(quality: u8) -> Self {
129 Self {
130 file_base: NDPluginFileBase::new(),
131 writer: JpegWriter::new(quality),
132 }
133 }
134
135 pub fn file_base_mut(&mut self) -> &mut NDPluginFileBase {
136 &mut self.file_base
137 }
138}
139
140impl Default for JpegFileProcessor {
141 fn default() -> Self {
142 Self::new(90)
143 }
144}
145
146impl NDPluginProcess for JpegFileProcessor {
147 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
148 let _ = self
149 .file_base
150 .process_array(Arc::new(array.clone()), &mut self.writer);
151 ProcessResult::empty() }
153
154 fn plugin_type(&self) -> &str {
155 "NDFileJPEG"
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use ad_core::ndarray::{NDDataBuffer, NDDimension};
163 use std::sync::atomic::{AtomicU32, Ordering};
164
165 static TEST_COUNTER: AtomicU32 = AtomicU32::new(0);
166
167 fn temp_path(prefix: &str) -> PathBuf {
168 let n = TEST_COUNTER.fetch_add(1, Ordering::Relaxed);
169 std::env::temp_dir().join(format!("adcore_test_{}_{}.jpg", prefix, n))
170 }
171
172 #[test]
173 fn test_write_u8() {
174 let path = temp_path("jpeg");
175 let mut writer = JpegWriter::new(90);
176
177 let mut arr = NDArray::new(
178 vec![NDDimension::new(8), NDDimension::new(8)],
179 NDDataType::UInt8,
180 );
181 if let NDDataBuffer::U8(ref mut v) = arr.data {
182 for i in 0..64 { v[i] = (i * 4) as u8; }
183 }
184
185 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
186 writer.write_file(&arr).unwrap();
187 writer.close_file().unwrap();
188
189 let data = std::fs::read(&path).unwrap();
190 assert_eq!(&data[0..2], &[0xFF, 0xD8]);
192 assert_eq!(&data[data.len() - 2..], &[0xFF, 0xD9]);
194
195 std::fs::remove_file(&path).ok();
196 }
197
198 #[test]
199 fn test_rejects_non_u8() {
200 let path = temp_path("jpeg_u16");
201 let mut writer = JpegWriter::new(90);
202
203 let arr = NDArray::new(
204 vec![NDDimension::new(4), NDDimension::new(4)],
205 NDDataType::UInt16,
206 );
207
208 let result = writer.open_file(&path, NDFileMode::Single, &arr);
209 assert!(result.is_err());
210 }
211
212 #[test]
213 fn test_quality_affects_size() {
214 let path_high = temp_path("jpeg_hi");
215 let path_low = temp_path("jpeg_lo");
216
217 let mut arr = NDArray::new(
218 vec![NDDimension::new(32), NDDimension::new(32)],
219 NDDataType::UInt8,
220 );
221 if let NDDataBuffer::U8(ref mut v) = arr.data {
222 for i in 0..v.len() { v[i] = (i % 256) as u8; }
223 }
224
225 let mut writer_high = JpegWriter::new(95);
226 writer_high.open_file(&path_high, NDFileMode::Single, &arr).unwrap();
227 writer_high.write_file(&arr).unwrap();
228 writer_high.close_file().unwrap();
229
230 let mut writer_low = JpegWriter::new(10);
231 writer_low.open_file(&path_low, NDFileMode::Single, &arr).unwrap();
232 writer_low.write_file(&arr).unwrap();
233 writer_low.close_file().unwrap();
234
235 let size_high = std::fs::metadata(&path_high).unwrap().len();
236 let size_low = std::fs::metadata(&path_low).unwrap().len();
237 assert!(size_high > size_low, "high quality ({}) should be larger than low quality ({})", size_high, size_low);
238
239 std::fs::remove_file(&path_high).ok();
240 std::fs::remove_file(&path_low).ok();
241 }
242
243 #[test]
244 fn test_roundtrip_luma() {
245 let path = temp_path("jpeg_rt");
246 let mut writer = JpegWriter::new(100);
247
248 let mut arr = NDArray::new(
249 vec![NDDimension::new(8), NDDimension::new(8)],
250 NDDataType::UInt8,
251 );
252 if let NDDataBuffer::U8(ref mut v) = arr.data {
253 for i in 0..64 { v[i] = 128; }
255 }
256
257 writer.open_file(&path, NDFileMode::Single, &arr).unwrap();
258 writer.write_file(&arr).unwrap();
259
260 let read_back = writer.read_file().unwrap();
261 assert_eq!(read_back.data.data_type(), NDDataType::UInt8);
262 if let NDDataBuffer::U8(ref v) = read_back.data {
263 for &px in v.iter() {
265 assert!((px as i16 - 128).unsigned_abs() < 5, "pixel {} too far from 128", px);
266 }
267 }
268
269 writer.close_file().unwrap();
270 std::fs::remove_file(&path).ok();
271 }
272}