include_fs/
lib.rs

1use std::collections::HashMap;
2use std::env;
3use std::fs::File;
4use std::io::{self, Write};
5use std::path::{Path, PathBuf};
6use std::sync::LazyLock;
7use thiserror::Error;
8use walkdir::WalkDir;
9
10const MAGIC: &[u8; 4] = b"INFS";
11
12#[derive(Error, Debug)]
13pub enum ArchiveError {
14  #[error("Path too long: {path} ({len} bytes, max {max} bytes)")]
15  PathTooLong {
16    path: String,
17    len: usize,
18    max: usize,
19  },
20
21  #[error("Too many files: {count} (max {max})")]
22  TooManyFiles { count: usize, max: usize },
23
24  #[error("IO error: {0}")]
25  Io(#[from] std::io::Error),
26
27  #[error("Source directory must be a subdirectory of the manifest directory")]
28  InvalidSourceDirectory,
29
30  #[error("Failed to collect files: {0}")]
31  WalkDir(#[from] walkdir::Error),
32}
33
34#[derive(Error, Debug)]
35pub enum FsError {
36  #[error("File not found")]
37  NotFound,
38
39  #[error("Invalid archive")]
40  InvalidArchive,
41}
42
43#[derive(Debug)]
44struct FileEntry {
45  pub path: PathBuf,
46  pub size: u64,
47}
48
49impl FileEntry {
50  pub fn new(path: impl Into<PathBuf>, size: u64) -> Self {
51    Self {
52      path: path.into(),
53      size,
54    }
55  }
56}
57
58fn compute_header(files: &[FileEntry]) -> Result<Vec<u8>, ArchiveError> {
59  // Validate file count fits in u32
60  if files.len() > u32::MAX as usize {
61    return Err(ArchiveError::TooManyFiles {
62      count: files.len(),
63      max: u32::MAX as usize,
64    });
65  }
66
67  let mut header_size = 4 + 4; // magic + file count
68  for file in files {
69    let path_str = file.path.to_string_lossy();
70    let path_len = path_str.len();
71
72    if path_len > u16::MAX as usize {
73      return Err(ArchiveError::PathTooLong {
74        path: path_str.to_string(),
75        len: path_str.len(),
76        max: u16::MAX as usize,
77      });
78    }
79
80    // path_len + path + size + offset
81    header_size += 2 + path_len + 8 + 8;
82  }
83
84  let mut header = Vec::with_capacity(header_size);
85
86  header.extend_from_slice(MAGIC);
87  header.extend_from_slice(&(files.len() as u32).to_le_bytes());
88
89  let mut data_offset = header_size as u64;
90  for file in files {
91    let path_str = file.path.to_string_lossy();
92    let path_bytes = path_str.as_bytes();
93
94    header.extend_from_slice(&(path_bytes.len() as u16).to_le_bytes());
95    header.extend_from_slice(path_bytes);
96    header.extend_from_slice(&file.size.to_le_bytes());
97    header.extend_from_slice(&data_offset.to_le_bytes());
98
99    data_offset += file.size;
100  }
101
102  Ok(header)
103}
104
105fn write_archive(files: &[FileEntry], output_path: &Path) -> Result<(), ArchiveError> {
106  let mut file = File::create(output_path)?;
107
108  // Write header
109  let header = compute_header(files)?;
110  file.write_all(&header)?;
111
112  // Write file data
113  for file_entry in files {
114    let mut f = File::open(&file_entry.path)?;
115    io::copy(&mut f, &mut file)?;
116  }
117
118  Ok(())
119}
120
121pub fn embed_fs(source_dir: &str, name: &str) -> Result<(), ArchiveError> {
122  let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("no CARGO_MANIFEST_DIR");
123  let source_dir = Path::new(&manifest_dir).join(source_dir).canonicalize()?;
124
125  // Ensure the source directory is a subdirectory of the manifest directory
126  if !source_dir.starts_with(&manifest_dir) {
127    return Err(ArchiveError::InvalidSourceDirectory);
128  }
129
130  let relative_source_dir = source_dir.strip_prefix(&manifest_dir).unwrap();
131  println!("cargo:rerun-if-changed={}", relative_source_dir.display());
132
133  let mut files = Vec::new();
134  let walk = WalkDir::new(&source_dir).follow_links(false);
135  for entry in walk {
136    let entry = entry?;
137    let meta = entry.metadata()?;
138    if !meta.is_file() {
139      continue;
140    }
141
142    let path = entry.path().strip_prefix(&manifest_dir).unwrap();
143    files.push(FileEntry::new(path, meta.len()));
144  }
145
146  let out_dir = env::var("OUT_DIR").expect("no OUT_DIR");
147  let output_file = format!("{}.embed_fs", name);
148  let output_path = Path::new(&out_dir).join(output_file);
149
150  write_archive(&files, &output_path)
151}
152
153pub struct FsEntry {
154  pub path: String,
155  pub size: u64,
156  data_offset: u64,
157}
158
159impl FsEntry {
160  pub fn new(path: String, size: u64, data_offset: u64) -> Self {
161    Self {
162      path,
163      size,
164      data_offset,
165    }
166  }
167}
168
169pub type IncludeFs = LazyLock<IncludeFsInner>;
170
171pub struct IncludeFsInner {
172  pub file_index: HashMap<String, FsEntry>,
173  pub archive_bytes: Vec<u8>,
174}
175
176impl IncludeFsInner {
177  pub fn new(archive_bytes: &[u8]) -> Result<Self, FsError> {
178    if &archive_bytes[0..4] != MAGIC {
179      return Err(FsError::InvalidArchive);
180    }
181
182    let file_count = u32::from_le_bytes([
183      archive_bytes[4],
184      archive_bytes[5],
185      archive_bytes[6],
186      archive_bytes[7],
187    ]) as usize;
188
189    let mut offset = 8;
190    let mut file_index = HashMap::with_capacity(file_count);
191
192    for _ in 0..file_count {
193      let path_len =
194        u16::from_le_bytes([archive_bytes[offset], archive_bytes[offset + 1]]) as usize;
195      offset += 2;
196
197      let path = String::from_utf8_lossy(&archive_bytes[offset..offset + path_len]).to_string();
198      offset += path_len;
199
200      let size = u64::from_le_bytes([
201        archive_bytes[offset],
202        archive_bytes[offset + 1],
203        archive_bytes[offset + 2],
204        archive_bytes[offset + 3],
205        archive_bytes[offset + 4],
206        archive_bytes[offset + 5],
207        archive_bytes[offset + 6],
208        archive_bytes[offset + 7],
209      ]);
210      offset += 8;
211
212      let data_offset = u64::from_le_bytes([
213        archive_bytes[offset],
214        archive_bytes[offset + 1],
215        archive_bytes[offset + 2],
216        archive_bytes[offset + 3],
217        archive_bytes[offset + 4],
218        archive_bytes[offset + 5],
219        archive_bytes[offset + 6],
220        archive_bytes[offset + 7],
221      ]);
222      offset += 8;
223
224      file_index.insert(path.clone(), FsEntry::new(path, size, data_offset));
225    }
226
227    Ok(Self {
228      file_index,
229      archive_bytes: archive_bytes.to_vec(),
230    })
231  }
232
233  pub fn exists(&self, path: &str) -> bool {
234    self.file_index.contains_key(path)
235  }
236
237  pub fn get(&self, path: &str) -> Result<&[u8], FsError> {
238    let Some(entry) = self.file_index.get(path) else {
239      return Err(FsError::NotFound);
240    };
241
242    let start = entry.data_offset as usize;
243    let end = start + entry.size as usize;
244    Ok(&self.archive_bytes[start..end])
245  }
246
247  pub fn list_paths(&self) -> Vec<&str> {
248    self.file_index.keys().map(|s| s.as_str()).collect()
249  }
250}
251
252#[macro_export]
253macro_rules! include_fs {
254  ($name:expr) => {
255    ::std::sync::LazyLock::new(|| {
256      let archive_bytes = include_bytes!(concat!(env!("OUT_DIR"), "/", $name, ".embed_fs"));
257      ::include_fs::IncludeFsInner::new(archive_bytes).expect("Failed to initialize IncludeFs")
258    })
259  };
260}
261
262#[cfg(test)]
263mod tests {
264  use super::*;
265
266  #[test]
267  fn test_compute_header() {
268    let files = vec![
269      FileEntry::new("src/main.rs", 1024),
270      FileEntry::new("assets/image.png", 2048),
271    ];
272
273    let header = compute_header(&files).unwrap();
274
275    // Verify magic
276    assert_eq!(&header[0..4], b"INFS");
277
278    // Verify file count
279    let file_count = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
280    assert_eq!(file_count, 2);
281
282    // Basic size check (exact calculation depends on path lengths)
283    let expected_min_size = 4 + 4 + // magic + count
284      2 + "src/main.rs".len() + 8 + 8 + // first file
285      2 + "assets/image.png".len() + 8 + 8; // second file
286
287    assert_eq!(header.len(), expected_min_size);
288  }
289
290  #[test]
291  fn test_path_too_long() {
292    let long_path = "a".repeat(u16::MAX as usize + 1);
293    let files = vec![FileEntry::new(long_path.clone(), 100)];
294
295    let result = compute_header(&files);
296    assert!(matches!(result, Err(ArchiveError::PathTooLong { .. })));
297
298    if let Err(ArchiveError::PathTooLong { path, len, max }) = result {
299      assert_eq!(path, long_path);
300      assert_eq!(len, u16::MAX as usize + 1);
301      assert_eq!(max, u16::MAX as usize);
302    }
303  }
304}