git_internal/protocol/
pack.rs

1//! Transport-agnostic pack generator that reuses repository storage traits to walk commits, expand
2//! trees/blobs, and either stream packs to clients or unpack uploads for server-side ingestion.
3
4use std::{
5    collections::{HashSet, VecDeque},
6    io::Cursor,
7};
8
9use bytes::Bytes;
10use tokio::{self, sync::mpsc};
11use tokio_stream::wrappers::ReceiverStream;
12
13use super::{core::RepositoryAccess, types::ProtocolError};
14use crate::{
15    hash::ObjectHash,
16    internal::{
17        metadata::{EntryMeta, MetaAttached},
18        object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree, types::ObjectType},
19        pack::{Pack, encode::PackEncoder, entry::Entry},
20    },
21};
22
23/// Pack generation service for Git protocol operations
24///
25/// This handles the core Git pack generation logic internally within git-internal,
26/// using the RepositoryAccess trait only for data access.
27pub struct PackGenerator<'a, R>
28where
29    R: RepositoryAccess,
30{
31    repo_access: &'a R,
32}
33
34impl<'a, R> PackGenerator<'a, R>
35where
36    R: RepositoryAccess,
37{
38    pub fn new(repo_access: &'a R) -> Self {
39        Self { repo_access }
40    }
41
42    /// Generate a full pack containing all requested objects
43    pub async fn generate_full_pack(
44        &self,
45        want: Vec<String>,
46    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
47        let (tx, rx) = mpsc::channel(1024);
48
49        // Collect all objects needed for the wanted commits
50        let all_objects = self.collect_all_objects(want).await?;
51
52        // Generate pack data
53        tokio::spawn(async move {
54            if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
55                tracing::error!("Failed to generate pack stream: {}", e);
56            }
57        });
58
59        Ok(ReceiverStream::new(rx))
60    }
61
62    /// Generate an incremental pack containing only objects not in 'have'
63    pub async fn generate_incremental_pack(
64        &self,
65        want: Vec<String>,
66        have: Vec<String>,
67    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
68        let (tx, rx) = mpsc::channel(1024);
69
70        // Collect objects for wanted commits
71        let wanted_objects = self.collect_all_objects(want).await?;
72
73        // Collect objects for have commits (to exclude)
74        let have_objects = self.collect_all_objects(have).await?;
75
76        // Filter out objects that are already in 'have'
77        let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
78
79        // Generate pack data
80        tokio::spawn(async move {
81            if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
82                tracing::error!("Failed to generate incremental pack stream: {}", e);
83            }
84        });
85
86        Ok(ReceiverStream::new(rx))
87    }
88
89    /// Unpack incoming pack stream and extract objects
90    pub async fn unpack_stream(
91        &self,
92        pack_data: Bytes,
93    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
94        use std::sync::{Arc, Mutex};
95
96        let commits = Arc::new(Mutex::new(Vec::new()));
97        let trees = Arc::new(Mutex::new(Vec::new()));
98        let blobs = Arc::new(Mutex::new(Vec::new()));
99
100        let commits_clone = commits.clone();
101        let trees_clone = trees.clone();
102        let blobs_clone = blobs.clone();
103
104        // Create a Pack instance for decoding
105        let mut pack = Pack::new(None, None, None, true);
106        let mut cursor = Cursor::new(pack_data.to_vec());
107
108        // Decode the pack and collect entries
109        pack.decode(
110            &mut cursor,
111            move |entry: MetaAttached<Entry, EntryMeta>| match entry.inner.obj_type {
112                ObjectType::Commit => {
113                    if let Ok(commit) = Commit::from_bytes(&entry.inner.data, entry.inner.hash) {
114                        commits_clone.lock().unwrap().push(commit);
115                    } else {
116                        tracing::warn!("Failed to parse commit from pack entry");
117                    }
118                }
119                ObjectType::Tree => {
120                    if let Ok(tree) = Tree::from_bytes(&entry.inner.data, entry.inner.hash) {
121                        trees_clone.lock().unwrap().push(tree);
122                    } else {
123                        tracing::warn!("Failed to parse tree from pack entry");
124                    }
125                }
126                ObjectType::Blob => {
127                    if let Ok(blob) = Blob::from_bytes(&entry.inner.data, entry.inner.hash) {
128                        blobs_clone.lock().unwrap().push(blob);
129                    } else {
130                        tracing::warn!("Failed to parse blob from pack entry");
131                    }
132                }
133                _ => {
134                    tracing::warn!("Unknown object type in pack: {:?}", entry.inner.obj_type);
135                }
136            },
137            None::<fn(ObjectHash)>,
138        )
139        .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {e}")))?;
140
141        // Extract the results
142        let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
143        let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
144        let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
145
146        Ok((commits_result, trees_result, blobs_result))
147    }
148
149    /// Collect all objects reachable from the given commit hashes
150    async fn collect_all_objects(
151        &self,
152        commit_hashes: Vec<String>,
153    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
154        let mut commits = Vec::new();
155        let mut trees = Vec::new();
156        let mut blobs = Vec::new();
157
158        let mut visited_commits = HashSet::new();
159        let mut visited_trees = HashSet::new();
160        let mut visited_blobs = HashSet::new();
161
162        let mut commit_queue = VecDeque::from(commit_hashes);
163
164        // BFS traversal of commit graph
165        while let Some(commit_hash) = commit_queue.pop_front() {
166            if visited_commits.contains(&commit_hash) {
167                continue;
168            }
169            visited_commits.insert(commit_hash.clone());
170
171            // Get commit object
172            let commit = self
173                .repo_access
174                .get_commit(&commit_hash)
175                .await
176                .map_err(|e| {
177                    ProtocolError::repository_error(format!(
178                        "Failed to get commit {commit_hash}: {e}"
179                    ))
180                })?;
181
182            // Add parent commits to queue
183            for parent in &commit.parent_commit_ids {
184                let parent_str = parent.to_string();
185                if !visited_commits.contains(&parent_str) {
186                    commit_queue.push_back(parent_str);
187                }
188            }
189
190            // Collect tree objects
191            Box::pin(self.collect_tree_objects(
192                &commit.tree_id.to_string(),
193                &mut trees,
194                &mut blobs,
195                &mut visited_trees,
196                &mut visited_blobs,
197            ))
198            .await?;
199
200            commits.push(commit);
201        }
202
203        Ok((commits, trees, blobs))
204    }
205
206    /// Recursively collect tree and blob objects
207    async fn collect_tree_objects(
208        &self,
209        tree_hash: &str,
210        trees: &mut Vec<Tree>,
211        blobs: &mut Vec<Blob>,
212        visited_trees: &mut HashSet<String>,
213        visited_blobs: &mut HashSet<String>,
214    ) -> Result<(), ProtocolError> {
215        if visited_trees.contains(tree_hash) {
216            return Ok(());
217        }
218        visited_trees.insert(tree_hash.to_string());
219
220        let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
221            ProtocolError::repository_error(format!("Failed to get tree {tree_hash}: {e}"))
222        })?;
223
224        for entry in &tree.tree_items {
225            let entry_hash = entry.id.to_string();
226            match entry.mode {
227                crate::internal::object::tree::TreeItemMode::Tree => {
228                    Box::pin(self.collect_tree_objects(
229                        &entry_hash,
230                        trees,
231                        blobs,
232                        visited_trees,
233                        visited_blobs,
234                    ))
235                    .await?;
236                }
237                crate::internal::object::tree::TreeItemMode::Blob
238                | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
239                    if !visited_blobs.contains(&entry_hash) {
240                        visited_blobs.insert(entry_hash.clone());
241                        let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
242                            ProtocolError::repository_error(format!(
243                                "Failed to get blob {entry_hash}: {e}"
244                            ))
245                        })?;
246                        blobs.push(blob);
247                    }
248                }
249                _ => {}
250            }
251        }
252
253        trees.push(tree);
254        Ok(())
255    }
256
257    /// Filter objects to exclude those already in 'have'
258    fn filter_objects(
259        wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
260        have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
261    ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
262        let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
263        let (have_commits, have_trees, have_blobs) = have;
264
265        // Create hash sets for efficient lookup
266        let have_commit_hashes: HashSet<String> =
267            have_commits.iter().map(|c| c.id.to_string()).collect();
268        let have_tree_hashes: HashSet<String> =
269            have_trees.iter().map(|t| t.id.to_string()).collect();
270        let have_blob_hashes: HashSet<String> =
271            have_blobs.iter().map(|b| b.id.to_string()).collect();
272
273        // Filter out objects that are in 'have'
274        let filtered_commits: Vec<Commit> = wanted_commits
275            .into_iter()
276            .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
277            .collect();
278
279        let filtered_trees: Vec<Tree> = wanted_trees
280            .into_iter()
281            .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
282            .collect();
283
284        let filtered_blobs: Vec<Blob> = wanted_blobs
285            .into_iter()
286            .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
287            .collect();
288
289        (filtered_commits, filtered_trees, filtered_blobs)
290    }
291
292    /// Generate pack stream from objects
293    async fn generate_pack_stream(
294        objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
295        tx: mpsc::Sender<Vec<u8>>,
296    ) -> Result<(), ProtocolError> {
297        let (commits, trees, blobs) = objects;
298
299        // Convert objects to entries
300        let mut entries = Vec::new();
301
302        for commit in commits {
303            entries.push(Entry::from(commit));
304        }
305
306        for tree in trees {
307            entries.push(Entry::from(tree));
308        }
309
310        for blob in blobs {
311            entries.push(Entry::from(blob));
312        }
313
314        // Create PackEncoder and encode entries
315        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
316        let (entry_tx, entry_rx) = mpsc::channel(1024);
317        let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); // window_size = 10
318
319        // Spawn encoding task
320        tokio::spawn(async move {
321            if let Err(e) = encoder.encode(entry_rx).await {
322                tracing::error!("Failed to encode pack: {}", e);
323            }
324        });
325
326        // Send entries to encoder
327        tokio::spawn(async move {
328            for entry in entries {
329                if entry_tx
330                    .send(MetaAttached {
331                        inner: entry,
332                        meta: EntryMeta::new(),
333                    })
334                    .await
335                    .is_err()
336                {
337                    break; // Receiver dropped
338                }
339            }
340            // Drop sender to signal end of entries
341        });
342
343        // Forward pack data to output channel
344        while let Some(chunk) = pack_rx.recv().await {
345            if tx.send(chunk).await.is_err() {
346                break; // Receiver dropped
347            }
348        }
349
350        Ok(())
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use async_trait::async_trait;
357    use bytes::Bytes;
358
359    use super::*;
360    use crate::{
361        hash::{HashKind, set_hash_kind_for_test},
362        internal::object::{
363            blob::Blob,
364            commit::Commit,
365            signature::{Signature, SignatureType},
366            tree::{Tree, TreeItem, TreeItemMode},
367        },
368    };
369    /// Dummy repository access for testing
370    #[derive(Clone)]
371    struct DummyRepoAccess;
372
373    #[async_trait]
374    impl RepositoryAccess for DummyRepoAccess {
375        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
376            Ok(vec![])
377        }
378        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
379            Ok(false)
380        }
381        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
382            Err(ProtocolError::repository_error(
383                "not implemented".to_string(),
384            ))
385        }
386        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
387            Ok(())
388        }
389        async fn update_reference(
390            &self,
391            _ref_name: &str,
392            _old_hash: Option<&str>,
393            _new_hash: &str,
394        ) -> Result<(), ProtocolError> {
395            Ok(())
396        }
397        async fn get_objects_for_pack(
398            &self,
399            _wants: &[String],
400            _haves: &[String],
401        ) -> Result<Vec<String>, ProtocolError> {
402            Ok(vec![])
403        }
404        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
405            Ok(false)
406        }
407        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
408            Ok(())
409        }
410    }
411
412    /// Encode and decode a pack, asserting that all object IDs survive the roundtrip.
413    async fn run_pack_roundtrip(kind: HashKind) {
414        let _guard = set_hash_kind_for_test(kind);
415        let blob1 = Blob::from_content("hello");
416        let blob2 = Blob::from_content("world");
417
418        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
419        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
420        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
421
422        let author = Signature::new(
423            SignatureType::Author,
424            "tester".to_string(),
425            "tester@example.com".to_string(),
426        );
427        let committer = Signature::new(
428            SignatureType::Committer,
429            "tester".to_string(),
430            "tester@example.com".to_string(),
431        );
432        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
433
434        let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
435        PackGenerator::<DummyRepoAccess>::generate_pack_stream(
436            (
437                vec![commit.clone()],
438                vec![tree.clone()],
439                vec![blob1.clone(), blob2.clone()],
440            ),
441            tx,
442        )
443        .await
444        .unwrap();
445
446        let mut pack_bytes: Vec<u8> = Vec::new();
447        while let Some(chunk) = rx.recv().await {
448            pack_bytes.extend_from_slice(&chunk);
449        }
450
451        let dummy = DummyRepoAccess;
452        let generator = PackGenerator::new(&dummy);
453        let (decoded_commits, decoded_trees, decoded_blobs) = generator
454            .unpack_stream(Bytes::from(pack_bytes))
455            .await
456            .unwrap();
457
458        assert_eq!(decoded_commits.len(), 1);
459        assert_eq!(decoded_trees.len(), 1);
460        assert_eq!(decoded_blobs.len(), 2);
461
462        assert_eq!(decoded_commits[0].id, commit.id);
463        assert_eq!(decoded_trees[0].id, tree.id);
464
465        let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
466        orig_blob_ids.sort_unstable();
467        let mut decoded_blob_ids = decoded_blobs
468            .iter()
469            .map(|b| b.id.to_string())
470            .collect::<Vec<_>>();
471        decoded_blob_ids.sort_unstable();
472        assert_eq!(orig_blob_ids, decoded_blob_ids);
473    }
474
475    /// Pack encode/decode roundtrip using SHA-1 and SHA-256
476    #[tokio::test]
477    async fn test_pack_roundtrip_encode_decode() {
478        run_pack_roundtrip(HashKind::Sha1).await;
479        run_pack_roundtrip(HashKind::Sha256).await;
480    }
481}