Skip to main content

include_flate_codegen/
lib.rs

1// include-flate
2// Copyright (C) SOFe
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16extern crate proc_macro;
17
18use std::fs::{self, File};
19use std::io::{Read, Seek};
20use std::path::PathBuf;
21use std::str::{FromStr, from_utf8};
22
23use include_flate_compress::{CompressionMethod, apply_compression};
24use proc_macro::TokenStream;
25use proc_macro_error2::{emit_warning, proc_macro_error};
26use proc_macro2::Span;
27use quote::quote;
28use syn::{Error, LitByteStr};
29
30/// `deflate_file!("file")` is equivalent to `include_bytes!("file.gz")`.
31///
32/// # Parameters
33/// This macro accepts exactly one literal parameter that refers to a path relative to
34/// `CARGO_MANIFEST_DIR`. Absolute paths are not supported.
35///
36/// Note that **this is distinct from the behaviour of the builtin `include_bytes!`/`include_str!` macros** —
37/// `includle_bytes!`/`include_str!` paths are relative to the current source file, while `deflate_file!` paths are relative to
38/// `CARGO_MANIFEST_DIR`.
39///
40/// # Returns
41/// This macro expands to a `b"byte string"` literal that contains the deflated form of the file.
42///
43/// # Compile errors
44/// - If the argument is not a single literal
45/// - If the referenced file does not exist or is not readable
46#[proc_macro]
47#[proc_macro_error]
48pub fn deflate_file(ts: TokenStream) -> TokenStream {
49    match inner(ts, false) {
50        Ok(ts) => ts.into(),
51        Err(err) => err.to_compile_error().into(),
52    }
53}
54
55/// This macro is identical to `deflate_file!()`, except it additionally performs UTF-8 validation.
56///
57/// # Compile errors
58/// - The compile errors in `deflate_file!`
59/// - If the file contents are not all valid UTF-8
60#[proc_macro]
61#[proc_macro_error]
62pub fn deflate_utf8_file(ts: TokenStream) -> TokenStream {
63    match inner(ts, true) {
64        Ok(ts) => ts.into(),
65        Err(err) => err.to_compile_error().into(),
66    }
67}
68
69/// An arguments expected provided by the proc-macro.
70///
71/// ```ignore
72/// flate!(pub static DATA: [u8] from "assets/009f.dat"); // default, DEFLATE
73/// flate!(pub static DATA: [u8] from "assets/009f.dat" with zstd); // Use Zstd for this file spcifically
74/// flate!(pub static DATA: [u8] from "assets/009f.dat" with deflate); // Explicitly use DEFLATE.
75/// ```
76struct FlateArgs {
77    path: syn::LitStr,
78    algorithm: Option<CompressionMethodTy>,
79}
80
81impl syn::parse::Parse for FlateArgs {
82    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
83        let path = input.parse()?;
84
85        let algorithm = if input.is_empty() {
86            None
87        } else {
88            let lookahead = input.lookahead1();
89            if lookahead.peek(kw::deflate) {
90                #[cfg(feature = "deflate")]
91                {
92                    input.parse::<kw::deflate>()?;
93                    Some(CompressionMethodTy(CompressionMethod::Deflate))
94                }
95                #[cfg(not(feature = "deflate"))]
96                return Err(Error::new(
97                    input.span(),
98                    "Please enable the `deflate` feature",
99                ));
100            } else if lookahead.peek(kw::zstd) {
101                #[cfg(feature = "zstd")]
102                {
103                    input.parse::<kw::zstd>()?;
104                    Some(CompressionMethodTy(CompressionMethod::Zstd))
105                }
106                #[cfg(not(feature = "zstd"))]
107                return Err(Error::new(input.span(), "Please enable the `zstd` feature"));
108            } else {
109                return Err(lookahead.error());
110            }
111        };
112
113        Ok(Self { path, algorithm })
114    }
115}
116
117mod kw {
118    syn::custom_keyword!(deflate);
119    syn::custom_keyword!(zstd);
120}
121
122#[derive(Debug, Default)]
123struct CompressionMethodTy(CompressionMethod);
124
125fn compression_ratio(original_size: u64, compressed_size: u64) -> f64 {
126    (compressed_size as f64 / original_size as f64) * 100.0
127}
128
129fn inner(ts: TokenStream, utf8: bool) -> syn::Result<impl Into<TokenStream>> {
130    fn emap<E: std::fmt::Display>(error: E) -> Error {
131        Error::new(Span::call_site(), error)
132    }
133
134    let dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").map_err(emap)?);
135
136    let args: FlateArgs = syn::parse2::<FlateArgs>(ts.to_owned().into())?;
137    let path = PathBuf::from_str(&args.path.value()).map_err(emap)?;
138    let algo = args.algorithm.unwrap_or_default();
139
140    if path.is_absolute() {
141        Err(emap("absolute paths are not supported"))?;
142    }
143
144    let target = dir.join(&path);
145
146    let mut file = File::open(&target).map_err(emap)?;
147
148    let mut vec = Vec::<u8>::new();
149    if utf8 {
150        std::io::copy(&mut file, &mut vec).map_err(emap)?;
151        from_utf8(&vec).map_err(emap)?;
152    }
153
154    let mut compressed_buffer = Vec::<u8>::new();
155
156    {
157        let mut compressed_cursor = std::io::Cursor::new(&mut compressed_buffer);
158        let mut source: Box<dyn Read> = if utf8 {
159            Box::new(std::io::Cursor::new(vec))
160        } else {
161            file.seek(std::io::SeekFrom::Start(0)).map_err(emap)?;
162            Box::new(&file)
163        };
164
165        apply_compression(&mut source, &mut compressed_cursor, algo.0).map_err(emap)?;
166    }
167
168    let bytes = LitByteStr::new(&compressed_buffer, Span::call_site());
169    let result = quote!(#bytes);
170
171    #[cfg(not(feature = "no-compression-warnings"))]
172    {
173        let compression_ratio = compression_ratio(
174            fs::metadata(&target).map_err(emap)?.len(),
175            compressed_buffer.len() as u64,
176        );
177
178        if compression_ratio < 10.0f64 {
179            emit_warning!(
180                &args.path,
181                "Detected low compression ratio ({:.2}%) for file {:?} with `{:?}`. Consider using other compression methods.",
182                compression_ratio,
183                path.display(),
184                algo.0,
185            );
186        }
187    }
188
189    Ok(result)
190}