1use crate::error::{ChkpttError, Result};
2use bitcode::{Decode, Encode};
3use memmap2::Mmap;
4use std::io::{BufWriter, Seek, SeekFrom, Write};
5use std::path::PathBuf;
6use tempfile::NamedTempFile;
7
8const TREE_PACK_MAGIC: &[u8; 4] = b"CKTR";
9const TREE_PACK_VERSION: u32 = 1;
10const TREE_IDX_ENTRY_SIZE: usize = 32 + 8 + 8; const TREE_HEADER_SIZE: u64 = 12;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Encode, Decode)]
14pub enum EntryType {
15 File,
16 Dir,
17 Symlink,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
21pub struct TreeEntry {
22 pub name: String,
23 pub entry_type: EntryType,
24 pub hash: [u8; 32],
25 pub size: u64,
26 pub mode: u32,
27}
28
29#[derive(Debug, Clone)]
31struct TreeIdxEntry {
32 hash: [u8; 32],
33 offset: u64,
34 size: u64,
35}
36
37pub struct TreeStore {
38 base_dir: PathBuf,
39 pack_dat: Option<Mmap>,
41 pack_idx: Option<Mmap>,
43 pack_entry_count: usize,
44}
45
46impl TreeStore {
47 pub fn new(base_dir: PathBuf) -> Self {
48 let dat_path = base_dir.join("trees.dat");
49 let idx_path = base_dir.join("trees.idx");
50
51 let (pack_dat, pack_idx, pack_entry_count) = match (
52 std::fs::File::open(&dat_path),
53 std::fs::File::open(&idx_path),
54 ) {
55 (Ok(dat_file), Ok(idx_file)) => match (unsafe { Mmap::map(&dat_file) }, unsafe {
56 Mmap::map(&idx_file)
57 }) {
58 (Ok(dat), Ok(idx)) => {
59 let count = idx.len() / TREE_IDX_ENTRY_SIZE;
60 (Some(dat), Some(idx), count)
61 }
62 _ => (None, None, 0),
63 },
64 (Err(dat_error), _) if dat_error.kind() == std::io::ErrorKind::NotFound => {
65 (None, None, 0)
66 }
67 (_, Err(idx_error)) if idx_error.kind() == std::io::ErrorKind::NotFound => {
68 (None, None, 0)
69 }
70 _ => (None, None, 0),
71 };
72
73 Self {
74 base_dir,
75 pack_dat,
76 pack_idx,
77 pack_entry_count,
78 }
79 }
80
81 fn tree_path(&self, hash_hex: &str) -> PathBuf {
82 let (prefix, rest) = hash_hex.split_at(2);
83 self.base_dir.join(prefix).join(rest)
84 }
85
86 pub fn write(&self, entries: &[TreeEntry]) -> Result<String> {
89 let mut sorted = entries.to_vec();
90 sorted.sort_unstable_by(|a, b| a.name.cmp(&b.name));
91 let encoded = bitcode::encode(&sorted);
92 let hash_hex = blake3::hash(&encoded).to_hex().to_string();
93 let path = self.tree_path(&hash_hex);
94 let parent = path
95 .parent()
96 .ok_or_else(|| ChkpttError::Other("Tree path missing parent directory".into()))?;
97 std::fs::create_dir_all(parent)?;
98
99 let mut tmp = NamedTempFile::new_in(parent)?;
100 tmp.write_all(&encoded)?;
101 tmp.flush()?;
102
103 match tmp.persist_noclobber(&path) {
104 Ok(_) => Ok(hash_hex),
105 Err(error) if error.error.kind() == std::io::ErrorKind::AlreadyExists => Ok(hash_hex),
106 Err(error) => Err(error.error.into()),
107 }
108 }
109
110 pub fn write_pack(&self, entries: &[(String, Vec<u8>)]) -> Result<()> {
113 if entries.is_empty() {
114 return Ok(());
115 }
116
117 std::fs::create_dir_all(&self.base_dir)?;
118
119 let dat_path = self.base_dir.join("trees.dat");
120 let idx_path = self.base_dir.join("trees.idx");
121
122 let mut all_idx_entries: Vec<TreeIdxEntry> = Vec::new();
124 let mut existing_hashes: std::collections::HashSet<[u8; 32]> =
125 std::collections::HashSet::new();
126
127 let existing_dat_len = if let (Some(dat), Some(idx)) = (&self.pack_dat, &self.pack_idx) {
128 for i in 0..self.pack_entry_count {
129 let pos = i * TREE_IDX_ENTRY_SIZE;
130 let mut hash = [0u8; 32];
131 hash.copy_from_slice(&idx[pos..pos + 32]);
132 let offset = u64::from_le_bytes(idx[pos + 32..pos + 40].try_into().unwrap());
133 let size = u64::from_le_bytes(idx[pos + 40..pos + 48].try_into().unwrap());
134 existing_hashes.insert(hash);
135 all_idx_entries.push(TreeIdxEntry { hash, offset, size });
136 }
137 dat.len() as u64
138 } else {
139 TREE_HEADER_SIZE
140 };
141
142 let mut dat_tmp = NamedTempFile::new_in(&self.base_dir)?;
144 {
145 let mut writer = BufWriter::with_capacity(256 * 1024, &mut dat_tmp);
146
147 if let Some(dat) = &self.pack_dat {
148 writer.write_all(dat)?;
149 } else {
150 writer.write_all(&[0u8; TREE_HEADER_SIZE as usize])?;
151 }
152
153 let mut offset = existing_dat_len;
154
155 for (hash_hex, encoded) in entries {
156 let hash = hex_to_bytes(hash_hex)?;
157 if existing_hashes.contains(&hash) {
158 continue;
159 }
160 let data_len = encoded.len() as u64;
161 writer.write_all(&hash)?;
163 writer.write_all(&data_len.to_le_bytes())?;
164 writer.write_all(encoded)?;
165
166 all_idx_entries.push(TreeIdxEntry {
167 hash,
168 offset,
169 size: data_len,
170 });
171 offset += 32 + 8 + data_len;
172 }
173
174 writer.flush()?;
175 }
176
177 let total_count = all_idx_entries.len() as u32;
179 dat_tmp.seek(SeekFrom::Start(0))?;
180 dat_tmp.write_all(TREE_PACK_MAGIC)?;
181 dat_tmp.write_all(&TREE_PACK_VERSION.to_le_bytes())?;
182 dat_tmp.write_all(&total_count.to_le_bytes())?;
183 dat_tmp.flush()?;
184
185 dat_tmp
187 .persist(&dat_path)
188 .map_err(|e| ChkpttError::Other(e.error.to_string()))?;
189
190 all_idx_entries.sort_unstable_by(|a, b| a.hash.cmp(&b.hash));
192 let mut idx_buf: Vec<u8> = Vec::with_capacity(all_idx_entries.len() * TREE_IDX_ENTRY_SIZE);
193 for entry in &all_idx_entries {
194 idx_buf.extend_from_slice(&entry.hash);
195 idx_buf.extend_from_slice(&entry.offset.to_le_bytes());
196 idx_buf.extend_from_slice(&entry.size.to_le_bytes());
197 }
198 let idx_tmp_path = idx_path.with_extension("idx.tmp");
199 std::fs::write(&idx_tmp_path, &idx_buf)?;
200 std::fs::rename(&idx_tmp_path, &idx_path)?;
201
202 Ok(())
203 }
204
205 pub fn read(&self, hash_hex: &str) -> Result<Vec<TreeEntry>> {
207 if let Some(data) = self.read_from_pack(hash_hex) {
209 let entries: Vec<TreeEntry> = bitcode::decode(&data)?;
210 return Ok(entries);
211 }
212
213 let path = self.tree_path(hash_hex);
215 let data = match std::fs::read(&path) {
216 Ok(data) => data,
217 Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
218 return Err(ChkpttError::ObjectNotFound(hash_hex.to_string()));
219 }
220 Err(error) => return Err(error.into()),
221 };
222 let entries: Vec<TreeEntry> = bitcode::decode(&data)?;
223 Ok(entries)
224 }
225
226 fn read_from_pack(&self, hash_hex: &str) -> Option<Vec<u8>> {
228 let idx = self.pack_idx.as_ref()?;
229 let dat = self.pack_dat.as_ref()?;
230 let hash_bytes = hex_to_bytes(hash_hex).ok()?;
231
232 let mut lo = 0usize;
234 let mut hi = self.pack_entry_count;
235 while lo < hi {
236 let mid = lo + (hi - lo) / 2;
237 let pos = mid * TREE_IDX_ENTRY_SIZE;
238 let mid_hash = &idx[pos..pos + 32];
239 match mid_hash.cmp(&hash_bytes) {
240 std::cmp::Ordering::Equal => {
241 let offset = u64::from_le_bytes(idx[pos + 32..pos + 40].try_into().unwrap());
242 let size = u64::from_le_bytes(idx[pos + 40..pos + 48].try_into().unwrap());
243 let data_start = offset as usize + 32 + 8;
244 let data_end = data_start + size as usize;
245 if data_end > dat.len() {
246 return None;
247 }
248 return Some(dat[data_start..data_end].to_vec());
249 }
250 std::cmp::Ordering::Less => lo = mid + 1,
251 std::cmp::Ordering::Greater => hi = mid,
252 }
253 }
254 None
255 }
256}
257
258fn hex_to_bytes(hex: &str) -> Result<[u8; 32]> {
259 let mut bytes = [0u8; 32];
260 if hex.len() != 64 {
261 return Err(ChkpttError::Other(format!(
262 "Invalid hash length: {}",
263 hex.len()
264 )));
265 }
266 for i in 0..32 {
267 bytes[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16)
268 .map_err(|_| ChkpttError::Other("Invalid hex".into()))?;
269 }
270 Ok(bytes)
271}