Skip to main content

ai_image/
hooks.rs

1//! This module provides a way to register decoding hooks for image formats not directly supported
2//! by this crate.
3
4use alloc::{boxed::Box, string::String, vec::Vec};
5use std::{
6    collections::HashMap,
7    ffi::{OsStr, OsString},
8    io::{BufRead, BufReader, Read, Seek},
9    sync::RwLock,
10};
11
12use crate::{ImageDecoder, ImageResult};
13
14pub(crate) trait ReadSeek: Read + Seek {}
15impl<T: Read + Seek> ReadSeek for T {}
16
17/// Stores ascii lowercase extension to hook mapping
18pub(crate) static DECODING_HOOKS: RwLock<Option<HashMap<OsString, DecodingHook>>> =
19    RwLock::new(None);
20
21pub(crate) type DetectionHook = (&'static [u8], &'static [u8], OsString);
22pub(crate) static GUESS_FORMAT_HOOKS: RwLock<Vec<DetectionHook>> = RwLock::new(Vec::new());
23
24/// A wrapper around a type-erased trait object that implements `Read` and `Seek`.
25pub struct GenericReader<'a>(pub(crate) BufReader<Box<dyn ReadSeek + 'a>>);
26impl Read for GenericReader<'_> {
27    fn read(&mut self, buf: &mut [u8]) -> no_std_io::io::Result<usize> {
28        self.0.read(buf)
29    }
30    fn read_vectored(
31        &mut self,
32        bufs: &mut [no_std_io::io::IoSliceMut<'_>],
33    ) -> no_std_io::io::Result<usize> {
34        self.0.read_vectored(bufs)
35    }
36    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> no_std_io::io::Result<usize> {
37        self.0.read_to_end(buf)
38    }
39    fn read_to_string(&mut self, buf: &mut String) -> no_std_io::io::Result<usize> {
40        self.0.read_to_string(buf)
41    }
42    fn read_exact(&mut self, buf: &mut [u8]) -> no_std_io::io::Result<()> {
43        self.0.read_exact(buf)
44    }
45}
46impl BufRead for GenericReader<'_> {
47    fn fill_buf(&mut self) -> no_std_io::io::Result<&[u8]> {
48        self.0.fill_buf()
49    }
50    fn consume(&mut self, amt: usize) {
51        self.0.consume(amt)
52    }
53    fn read_until(&mut self, byte: u8, buf: &mut Vec<u8>) -> no_std_io::io::Result<usize> {
54        self.0.read_until(byte, buf)
55    }
56    fn read_line(&mut self, buf: &mut String) -> no_std_io::io::Result<usize> {
57        self.0.read_line(buf)
58    }
59}
60impl Seek for GenericReader<'_> {
61    fn seek(&mut self, pos: no_std_io::io::SeekFrom) -> no_std_io::io::Result<u64> {
62        self.0.seek(pos)
63    }
64    fn rewind(&mut self) -> no_std_io::io::Result<()> {
65        self.0.rewind()
66    }
67    fn stream_position(&mut self) -> no_std_io::io::Result<u64> {
68        self.0.stream_position()
69    }
70
71    // TODO: Add `seek_relative` once MSRV is at least 1.80.0
72}
73
74/// A function to produce an `ImageDecoder` for a given image format.
75pub type DecodingHook =
76    Box<dyn for<'a> Fn(GenericReader<'a>) -> ImageResult<Box<dyn ImageDecoder + 'a>> + Send + Sync>;
77
78/// Register a new decoding hook or returns false if one already exists for the given format.
79pub fn register_decoding_hook(extension: OsString, hook: DecodingHook) -> bool {
80    let extension = extension.to_ascii_lowercase();
81    let mut hooks = DECODING_HOOKS.write().unwrap();
82    if hooks.is_none() {
83        *hooks = Some(HashMap::new());
84    }
85    match hooks.as_mut().unwrap().entry(extension) {
86        std::collections::hash_map::Entry::Vacant(entry) => {
87            entry.insert(hook);
88            true
89        }
90        std::collections::hash_map::Entry::Occupied(_) => false,
91    }
92}
93
94/// Returns whether a decoding hook has been registered for the given format.
95pub fn decoding_hook_registered(extension: &OsStr) -> bool {
96    let extension = extension.to_ascii_lowercase();
97    DECODING_HOOKS
98        .read()
99        .unwrap()
100        .as_ref()
101        .map(|hooks| hooks.contains_key(&extension))
102        .unwrap_or(false)
103}
104
105/// Registers a format detection hook.
106///
107/// The signature field holds the magic bytes from the start of the file that must be matched to
108/// detect the format. The mask field is optional and can be used to specify which bytes in the
109/// signature should be ignored during the detection.
110///
111/// # Examples
112///
113/// ## Using the mask to ignore some bytes
114///
115/// ```
116/// # use ai_image::hooks::register_format_detection_hook;
117/// // WebP signature is 'riff' followed by 4 bytes of length and then by 'webp'.
118/// // This requires a mask to ignore the length.
119/// register_format_detection_hook("webp".into(),
120///      &[b'r', b'i', b'f', b'f', 0, 0, 0, 0, b'w', b'e', b'b', b'p'],
121/// Some(&[0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff]),
122/// );
123/// ```
124///
125/// ## Multiple signatures
126///
127/// ```
128/// # use ai_image::hooks::register_format_detection_hook;
129/// // JPEG XL has two different signatures: https://en.wikipedia.org/wiki/JPEG_XL
130/// // This function should be called twice to register them both.
131/// register_format_detection_hook("jxl".into(), &[0xff, 0x0a], None);
132/// register_format_detection_hook("jxl".into(),
133///      &[0x00, 0x00, 0x00, 0x0c, 0x4a, 0x58, 0x4c, 0x20, 0x0d, 0x0a, 0x87, 0x0a], None,
134/// );
135/// ```
136///
137pub fn register_format_detection_hook(
138    extension: OsString,
139    signature: &'static [u8],
140    mask: Option<&'static [u8]>,
141) {
142    let extension = extension.to_ascii_lowercase();
143    GUESS_FORMAT_HOOKS
144        .write()
145        .unwrap()
146        .push((signature, mask.unwrap_or(&[]), extension));
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::{load_from_memory, ColorType, DynamicImage, ImageReader};
153    use no_std_io::io::Cursor;
154
155    const MOCK_HOOK_EXTENSION: &str = "MOCKHOOK";
156
157    const MOCK_IMAGE_OUTPUT: [u8; 9] = [255, 0, 0, 0, 255, 0, 0, 0, 255];
158    struct MockDecoder {}
159    impl ImageDecoder for MockDecoder {
160        fn dimensions(&self) -> (u32, u32) {
161            ((&MOCK_IMAGE_OUTPUT.len() / 3) as u32, 1)
162        }
163        fn color_type(&self) -> ColorType {
164            ColorType::Rgb8
165        }
166        fn read_image(self, buf: &mut [u8]) -> ImageResult<()> {
167            buf[..MOCK_IMAGE_OUTPUT.len()].copy_from_slice(&MOCK_IMAGE_OUTPUT);
168            Ok(())
169        }
170        fn read_image_boxed(self: Box<Self>, buf: &mut [u8]) -> ImageResult<()> {
171            (*self).read_image(buf)
172        }
173    }
174    fn is_mock_decoder_output(image: DynamicImage) -> bool {
175        image.as_rgb8().unwrap().as_raw() == &MOCK_IMAGE_OUTPUT
176    }
177
178    #[test]
179    fn decoding_hook() {
180        register_decoding_hook(
181            MOCK_HOOK_EXTENSION.into(),
182            Box::new(|_| Ok(Box::new(MockDecoder {}))),
183        );
184
185        let image = ImageReader::open("tests/images/hook/extension.MoCkHoOk")
186            .unwrap()
187            .decode()
188            .unwrap();
189
190        assert!(is_mock_decoder_output(image));
191    }
192
193    #[test]
194    fn detection_hook() {
195        register_decoding_hook(
196            MOCK_HOOK_EXTENSION.into(),
197            Box::new(|_| Ok(Box::new(MockDecoder {}))),
198        );
199
200        register_format_detection_hook(
201            MOCK_HOOK_EXTENSION.into(),
202            &[b'H', b'E', b'A', b'D', 0, 0, 0, 0, b'M', b'O', b'C', b'K'],
203            Some(&[0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff]),
204        );
205
206        const TEST_INPUT_IMAGE: [u8; 16] = [
207            b'H', b'E', b'A', b'D', b'J', b'U', b'N', b'K', b'M', b'O', b'C', b'K', b'm', b'o',
208            b'r', b'e',
209        ];
210        let image = ImageReader::new(Cursor::new(TEST_INPUT_IMAGE))
211            .with_guessed_format()
212            .unwrap()
213            .decode()
214            .unwrap();
215
216        assert!(is_mock_decoder_output(image));
217
218        let image_via_free_function = load_from_memory(&TEST_INPUT_IMAGE).unwrap();
219        assert!(is_mock_decoder_output(image_via_free_function));
220    }
221}