Skip to main content

chkpt_core/store/
tree.rs

1use crate::error::{ChkpttError, Result};
2use bitcode::{Decode, Encode};
3use memmap2::Mmap;
4use std::io::{BufWriter, Seek, SeekFrom, Write};
5use std::path::PathBuf;
6use tempfile::NamedTempFile;
7
8const TREE_PACK_MAGIC: &[u8; 4] = b"CKTR";
9const TREE_PACK_VERSION: u32 = 1;
10const TREE_IDX_ENTRY_SIZE: usize = 32 + 8 + 8; // hash(32) + offset(8) + size(8)
11const TREE_HEADER_SIZE: u64 = 12;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Encode, Decode)]
14pub enum EntryType {
15    File,
16    Dir,
17    Symlink,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
21pub struct TreeEntry {
22    pub name: String,
23    pub entry_type: EntryType,
24    pub hash: [u8; 32],
25    pub size: u64,
26    pub mode: u32,
27}
28
29/// Index entry for a tree in the pack.
30#[derive(Debug, Clone)]
31struct TreeIdxEntry {
32    hash: [u8; 32],
33    offset: u64,
34    size: u64,
35}
36
37pub struct TreeStore {
38    base_dir: PathBuf,
39    /// mmap'd tree pack data file (if exists)
40    pack_dat: Option<Mmap>,
41    /// mmap'd tree pack index file (if exists)
42    pack_idx: Option<Mmap>,
43    pack_entry_count: usize,
44}
45
46impl TreeStore {
47    pub fn new(base_dir: PathBuf) -> Self {
48        let dat_path = base_dir.join("trees.dat");
49        let idx_path = base_dir.join("trees.idx");
50
51        let (pack_dat, pack_idx, pack_entry_count) = match (
52            std::fs::File::open(&dat_path),
53            std::fs::File::open(&idx_path),
54        ) {
55            (Ok(dat_file), Ok(idx_file)) => match (unsafe { Mmap::map(&dat_file) }, unsafe {
56                Mmap::map(&idx_file)
57            }) {
58                (Ok(dat), Ok(idx)) => {
59                    let count = idx.len() / TREE_IDX_ENTRY_SIZE;
60                    (Some(dat), Some(idx), count)
61                }
62                _ => (None, None, 0),
63            },
64            (Err(dat_error), _) if dat_error.kind() == std::io::ErrorKind::NotFound => {
65                (None, None, 0)
66            }
67            (_, Err(idx_error)) if idx_error.kind() == std::io::ErrorKind::NotFound => {
68                (None, None, 0)
69            }
70            _ => (None, None, 0),
71        };
72
73        Self {
74            base_dir,
75            pack_dat,
76            pack_idx,
77            pack_entry_count,
78        }
79    }
80
81    fn tree_path(&self, hash_hex: &str) -> PathBuf {
82        let (prefix, rest) = hash_hex.split_at(2);
83        self.base_dir.join(prefix).join(rest)
84    }
85
86    /// Write tree entries (sorted by name). Returns hash hex.
87    /// Used for single-tree writes (tests, small operations).
88    pub fn write(&self, entries: &[TreeEntry]) -> Result<String> {
89        let mut sorted = entries.to_vec();
90        sorted.sort_unstable_by(|a, b| a.name.cmp(&b.name));
91        let encoded = bitcode::encode(&sorted);
92        let hash_hex = blake3::hash(&encoded).to_hex().to_string();
93        let path = self.tree_path(&hash_hex);
94        let parent = path
95            .parent()
96            .ok_or_else(|| ChkpttError::Other("Tree path missing parent directory".into()))?;
97        std::fs::create_dir_all(parent)?;
98
99        let mut tmp = NamedTempFile::new_in(parent)?;
100        tmp.write_all(&encoded)?;
101        tmp.flush()?;
102
103        match tmp.persist_noclobber(&path) {
104            Ok(_) => Ok(hash_hex),
105            Err(error) if error.error.kind() == std::io::ErrorKind::AlreadyExists => Ok(hash_hex),
106            Err(error) => Err(error.error.into()),
107        }
108    }
109
110    /// Write a batch of pre-computed trees to a pack file.
111    /// Each entry is (hash_hex, encoded_data).
112    pub fn write_pack(&self, entries: &[(String, Vec<u8>)]) -> Result<()> {
113        if entries.is_empty() {
114            return Ok(());
115        }
116
117        std::fs::create_dir_all(&self.base_dir)?;
118
119        let dat_path = self.base_dir.join("trees.dat");
120        let idx_path = self.base_dir.join("trees.idx");
121
122        // Collect existing idx entries
123        let mut all_idx_entries: Vec<TreeIdxEntry> = Vec::new();
124        let mut existing_hashes: std::collections::HashSet<[u8; 32]> =
125            std::collections::HashSet::new();
126
127        let existing_dat_len = if let (Some(dat), Some(idx)) = (&self.pack_dat, &self.pack_idx) {
128            for i in 0..self.pack_entry_count {
129                let pos = i * TREE_IDX_ENTRY_SIZE;
130                let mut hash = [0u8; 32];
131                hash.copy_from_slice(&idx[pos..pos + 32]);
132                let offset = u64::from_le_bytes(idx[pos + 32..pos + 40].try_into().unwrap());
133                let size = u64::from_le_bytes(idx[pos + 40..pos + 48].try_into().unwrap());
134                existing_hashes.insert(hash);
135                all_idx_entries.push(TreeIdxEntry { hash, offset, size });
136            }
137            dat.len() as u64
138        } else {
139            TREE_HEADER_SIZE
140        };
141
142        // Write new .dat
143        let mut dat_tmp = NamedTempFile::new_in(&self.base_dir)?;
144        {
145            let mut writer = BufWriter::with_capacity(256 * 1024, &mut dat_tmp);
146
147            if let Some(dat) = &self.pack_dat {
148                writer.write_all(dat)?;
149            } else {
150                writer.write_all(&[0u8; TREE_HEADER_SIZE as usize])?;
151            }
152
153            let mut offset = existing_dat_len;
154
155            for (hash_hex, encoded) in entries {
156                let hash = hex_to_bytes(hash_hex)?;
157                if existing_hashes.contains(&hash) {
158                    continue;
159                }
160                let data_len = encoded.len() as u64;
161                // Write: hash(32) + size(8) + data(N)
162                writer.write_all(&hash)?;
163                writer.write_all(&data_len.to_le_bytes())?;
164                writer.write_all(encoded)?;
165
166                all_idx_entries.push(TreeIdxEntry {
167                    hash,
168                    offset,
169                    size: data_len,
170                });
171                offset += 32 + 8 + data_len;
172            }
173
174            writer.flush()?;
175        }
176
177        // Write header
178        let total_count = all_idx_entries.len() as u32;
179        dat_tmp.seek(SeekFrom::Start(0))?;
180        dat_tmp.write_all(TREE_PACK_MAGIC)?;
181        dat_tmp.write_all(&TREE_PACK_VERSION.to_le_bytes())?;
182        dat_tmp.write_all(&total_count.to_le_bytes())?;
183        dat_tmp.flush()?;
184
185        // Persist .dat
186        dat_tmp
187            .persist(&dat_path)
188            .map_err(|e| ChkpttError::Other(e.error.to_string()))?;
189
190        // Sort idx and write
191        all_idx_entries.sort_unstable_by(|a, b| a.hash.cmp(&b.hash));
192        let mut idx_buf: Vec<u8> = Vec::with_capacity(all_idx_entries.len() * TREE_IDX_ENTRY_SIZE);
193        for entry in &all_idx_entries {
194            idx_buf.extend_from_slice(&entry.hash);
195            idx_buf.extend_from_slice(&entry.offset.to_le_bytes());
196            idx_buf.extend_from_slice(&entry.size.to_le_bytes());
197        }
198        let idx_tmp_path = idx_path.with_extension("idx.tmp");
199        std::fs::write(&idx_tmp_path, &idx_buf)?;
200        std::fs::rename(&idx_tmp_path, &idx_path)?;
201
202        Ok(())
203    }
204
205    /// Read tree entries by hash. Checks pack first, then loose files.
206    pub fn read(&self, hash_hex: &str) -> Result<Vec<TreeEntry>> {
207        // Check pack first
208        if let Some(data) = self.read_from_pack(hash_hex) {
209            let entries: Vec<TreeEntry> = bitcode::decode(&data)?;
210            return Ok(entries);
211        }
212
213        // Fall back to loose file
214        let path = self.tree_path(hash_hex);
215        let data = match std::fs::read(&path) {
216            Ok(data) => data,
217            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
218                return Err(ChkpttError::ObjectNotFound(hash_hex.to_string()));
219            }
220            Err(error) => return Err(error.into()),
221        };
222        let entries: Vec<TreeEntry> = bitcode::decode(&data)?;
223        Ok(entries)
224    }
225
226    /// Read raw data from the tree pack by hash.
227    fn read_from_pack(&self, hash_hex: &str) -> Option<Vec<u8>> {
228        let idx = self.pack_idx.as_ref()?;
229        let dat = self.pack_dat.as_ref()?;
230        let hash_bytes = hex_to_bytes(hash_hex).ok()?;
231
232        // Binary search in idx
233        let mut lo = 0usize;
234        let mut hi = self.pack_entry_count;
235        while lo < hi {
236            let mid = lo + (hi - lo) / 2;
237            let pos = mid * TREE_IDX_ENTRY_SIZE;
238            let mid_hash = &idx[pos..pos + 32];
239            match mid_hash.cmp(&hash_bytes) {
240                std::cmp::Ordering::Equal => {
241                    let offset = u64::from_le_bytes(idx[pos + 32..pos + 40].try_into().unwrap());
242                    let size = u64::from_le_bytes(idx[pos + 40..pos + 48].try_into().unwrap());
243                    let data_start = offset as usize + 32 + 8;
244                    let data_end = data_start + size as usize;
245                    if data_end > dat.len() {
246                        return None;
247                    }
248                    return Some(dat[data_start..data_end].to_vec());
249                }
250                std::cmp::Ordering::Less => lo = mid + 1,
251                std::cmp::Ordering::Greater => hi = mid,
252            }
253        }
254        None
255    }
256}
257
258fn hex_to_bytes(hex: &str) -> Result<[u8; 32]> {
259    let mut bytes = [0u8; 32];
260    if hex.len() != 64 {
261        return Err(ChkpttError::Other(format!(
262            "Invalid hash length: {}",
263            hex.len()
264        )));
265    }
266    for i in 0..32 {
267        bytes[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16)
268            .map_err(|_| ChkpttError::Other("Invalid hex".into()))?;
269    }
270    Ok(bytes)
271}