git_internal/protocol/
pack.rs

1use super::core::RepositoryAccess;
2use super::types::ProtocolError;
3use crate::hash::ObjectHash;
4use crate::internal::metadata::{EntryMeta, MetaAttached};
5use crate::internal::object::types::ObjectType;
6use crate::internal::object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree};
7use crate::internal::pack::{Pack, encode::PackEncoder, entry::Entry};
8use bytes::Bytes;
9use std::collections::{HashSet, VecDeque};
10use std::io::Cursor;
11use tokio;
12use tokio::sync::mpsc;
13use tokio_stream::wrappers::ReceiverStream;
14
15/// Pack generation service for Git protocol operations
16///
17/// This handles the core Git pack generation logic internally within git-internal,
18/// using the RepositoryAccess trait only for data access.
19pub struct PackGenerator<'a, R>
20where
21    R: RepositoryAccess,
22{
23    repo_access: &'a R,
24}
25
26impl<'a, R> PackGenerator<'a, R>
27where
28    R: RepositoryAccess,
29{
30    pub fn new(repo_access: &'a R) -> Self {
31        Self { repo_access }
32    }
33
34    /// Generate a full pack containing all requested objects
35    pub async fn generate_full_pack(
36        &self,
37        want: Vec<String>,
38    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
39        let (tx, rx) = mpsc::channel(1024);
40
41        // Collect all objects needed for the wanted commits
42        let all_objects = self.collect_all_objects(want).await?;
43
44        // Generate pack data
45        tokio::spawn(async move {
46            if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
47                tracing::error!("Failed to generate pack stream: {}", e);
48            }
49        });
50
51        Ok(ReceiverStream::new(rx))
52    }
53
54    /// Generate an incremental pack containing only objects not in 'have'
55    pub async fn generate_incremental_pack(
56        &self,
57        want: Vec<String>,
58        have: Vec<String>,
59    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
60        let (tx, rx) = mpsc::channel(1024);
61
62        // Collect objects for wanted commits
63        let wanted_objects = self.collect_all_objects(want).await?;
64
65        // Collect objects for have commits (to exclude)
66        let have_objects = self.collect_all_objects(have).await?;
67
68        // Filter out objects that are already in 'have'
69        let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
70
71        // Generate pack data
72        tokio::spawn(async move {
73            if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
74                tracing::error!("Failed to generate incremental pack stream: {}", e);
75            }
76        });
77
78        Ok(ReceiverStream::new(rx))
79    }
80
81    /// Unpack incoming pack stream and extract objects
82    pub async fn unpack_stream(
83        &self,
84        pack_data: Bytes,
85    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
86        use std::sync::{Arc, Mutex};
87
88        let commits = Arc::new(Mutex::new(Vec::new()));
89        let trees = Arc::new(Mutex::new(Vec::new()));
90        let blobs = Arc::new(Mutex::new(Vec::new()));
91
92        let commits_clone = commits.clone();
93        let trees_clone = trees.clone();
94        let blobs_clone = blobs.clone();
95
96        // Create a Pack instance for decoding
97        let mut pack = Pack::new(None, None, None, true);
98        let mut cursor = Cursor::new(pack_data.to_vec());
99
100        // Decode the pack and collect entries
101        pack.decode(
102            &mut cursor,
103            move |entry: MetaAttached<Entry, EntryMeta>| match entry.inner.obj_type {
104                ObjectType::Commit => {
105                    if let Ok(commit) = Commit::from_bytes(&entry.inner.data, entry.inner.hash) {
106                        commits_clone.lock().unwrap().push(commit);
107                    } else {
108                        tracing::warn!("Failed to parse commit from pack entry");
109                    }
110                }
111                ObjectType::Tree => {
112                    if let Ok(tree) = Tree::from_bytes(&entry.inner.data, entry.inner.hash) {
113                        trees_clone.lock().unwrap().push(tree);
114                    } else {
115                        tracing::warn!("Failed to parse tree from pack entry");
116                    }
117                }
118                ObjectType::Blob => {
119                    if let Ok(blob) = Blob::from_bytes(&entry.inner.data, entry.inner.hash) {
120                        blobs_clone.lock().unwrap().push(blob);
121                    } else {
122                        tracing::warn!("Failed to parse blob from pack entry");
123                    }
124                }
125                _ => {
126                    tracing::warn!("Unknown object type in pack: {:?}", entry.inner.obj_type);
127                }
128            },
129            None::<fn(ObjectHash)>,
130        )
131        .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {}", e)))?;
132
133        // Extract the results
134        let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
135        let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
136        let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
137
138        Ok((commits_result, trees_result, blobs_result))
139    }
140
141    /// Collect all objects reachable from the given commit hashes
142    async fn collect_all_objects(
143        &self,
144        commit_hashes: Vec<String>,
145    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
146        let mut commits = Vec::new();
147        let mut trees = Vec::new();
148        let mut blobs = Vec::new();
149
150        let mut visited_commits = HashSet::new();
151        let mut visited_trees = HashSet::new();
152        let mut visited_blobs = HashSet::new();
153
154        let mut commit_queue = VecDeque::from(commit_hashes);
155
156        // BFS traversal of commit graph
157        while let Some(commit_hash) = commit_queue.pop_front() {
158            if visited_commits.contains(&commit_hash) {
159                continue;
160            }
161            visited_commits.insert(commit_hash.clone());
162
163            // Get commit object
164            let commit = self
165                .repo_access
166                .get_commit(&commit_hash)
167                .await
168                .map_err(|e| {
169                    ProtocolError::repository_error(format!(
170                        "Failed to get commit {}: {}",
171                        commit_hash, e
172                    ))
173                })?;
174
175            // Add parent commits to queue
176            for parent in &commit.parent_commit_ids {
177                let parent_str = parent.to_string();
178                if !visited_commits.contains(&parent_str) {
179                    commit_queue.push_back(parent_str);
180                }
181            }
182
183            // Collect tree objects
184            Box::pin(self.collect_tree_objects(
185                &commit.tree_id.to_string(),
186                &mut trees,
187                &mut blobs,
188                &mut visited_trees,
189                &mut visited_blobs,
190            ))
191            .await?;
192
193            commits.push(commit);
194        }
195
196        Ok((commits, trees, blobs))
197    }
198
199    /// Recursively collect tree and blob objects
200    async fn collect_tree_objects(
201        &self,
202        tree_hash: &str,
203        trees: &mut Vec<Tree>,
204        blobs: &mut Vec<Blob>,
205        visited_trees: &mut HashSet<String>,
206        visited_blobs: &mut HashSet<String>,
207    ) -> Result<(), ProtocolError> {
208        if visited_trees.contains(tree_hash) {
209            return Ok(());
210        }
211        visited_trees.insert(tree_hash.to_string());
212
213        let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
214            ProtocolError::repository_error(format!("Failed to get tree {}: {}", tree_hash, e))
215        })?;
216
217        for entry in &tree.tree_items {
218            let entry_hash = entry.id.to_string();
219            match entry.mode {
220                crate::internal::object::tree::TreeItemMode::Tree => {
221                    Box::pin(self.collect_tree_objects(
222                        &entry_hash,
223                        trees,
224                        blobs,
225                        visited_trees,
226                        visited_blobs,
227                    ))
228                    .await?;
229                }
230                crate::internal::object::tree::TreeItemMode::Blob
231                | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
232                    if !visited_blobs.contains(&entry_hash) {
233                        visited_blobs.insert(entry_hash.clone());
234                        let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
235                            ProtocolError::repository_error(format!(
236                                "Failed to get blob {}: {}",
237                                entry_hash, e
238                            ))
239                        })?;
240                        blobs.push(blob);
241                    }
242                }
243                _ => {}
244            }
245        }
246
247        trees.push(tree);
248        Ok(())
249    }
250
251    /// Filter objects to exclude those already in 'have'
252    fn filter_objects(
253        wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
254        have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
255    ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
256        let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
257        let (have_commits, have_trees, have_blobs) = have;
258
259        // Create hash sets for efficient lookup
260        let have_commit_hashes: HashSet<String> =
261            have_commits.iter().map(|c| c.id.to_string()).collect();
262        let have_tree_hashes: HashSet<String> =
263            have_trees.iter().map(|t| t.id.to_string()).collect();
264        let have_blob_hashes: HashSet<String> =
265            have_blobs.iter().map(|b| b.id.to_string()).collect();
266
267        // Filter out objects that are in 'have'
268        let filtered_commits: Vec<Commit> = wanted_commits
269            .into_iter()
270            .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
271            .collect();
272
273        let filtered_trees: Vec<Tree> = wanted_trees
274            .into_iter()
275            .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
276            .collect();
277
278        let filtered_blobs: Vec<Blob> = wanted_blobs
279            .into_iter()
280            .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
281            .collect();
282
283        (filtered_commits, filtered_trees, filtered_blobs)
284    }
285
286    /// Generate pack stream from objects
287    async fn generate_pack_stream(
288        objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
289        tx: mpsc::Sender<Vec<u8>>,
290    ) -> Result<(), ProtocolError> {
291        let (commits, trees, blobs) = objects;
292
293        // Convert objects to entries
294        let mut entries = Vec::new();
295
296        for commit in commits {
297            entries.push(Entry::from(commit));
298        }
299
300        for tree in trees {
301            entries.push(Entry::from(tree));
302        }
303
304        for blob in blobs {
305            entries.push(Entry::from(blob));
306        }
307
308        // Create PackEncoder and encode entries
309        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
310        let (entry_tx, entry_rx) = mpsc::channel(1024);
311        let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); // window_size = 10
312
313        // Spawn encoding task
314        tokio::spawn(async move {
315            if let Err(e) = encoder.encode(entry_rx).await {
316                tracing::error!("Failed to encode pack: {}", e);
317            }
318        });
319
320        // Send entries to encoder
321        tokio::spawn(async move {
322            for entry in entries {
323                if entry_tx
324                    .send(MetaAttached {
325                        inner: entry,
326                        meta: EntryMeta::new(),
327                    })
328                    .await
329                    .is_err()
330                {
331                    break; // Receiver dropped
332                }
333            }
334            // Drop sender to signal end of entries
335        });
336
337        // Forward pack data to output channel
338        while let Some(chunk) = pack_rx.recv().await {
339            if tx.send(chunk).await.is_err() {
340                break; // Receiver dropped
341            }
342        }
343
344        Ok(())
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::hash::{HashKind, set_hash_kind_for_test};
352    use crate::internal::object::blob::Blob;
353    use crate::internal::object::commit::Commit;
354    use crate::internal::object::signature::{Signature, SignatureType};
355    use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
356    use async_trait::async_trait;
357    use bytes::Bytes;
358
359    #[derive(Clone)]
360    struct DummyRepoAccess;
361
362    #[async_trait]
363    impl RepositoryAccess for DummyRepoAccess {
364        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
365            Ok(vec![])
366        }
367        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
368            Ok(false)
369        }
370        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
371            Err(ProtocolError::repository_error(
372                "not implemented".to_string(),
373            ))
374        }
375        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
376            Ok(())
377        }
378        async fn update_reference(
379            &self,
380            _ref_name: &str,
381            _old_hash: Option<&str>,
382            _new_hash: &str,
383        ) -> Result<(), ProtocolError> {
384            Ok(())
385        }
386        async fn get_objects_for_pack(
387            &self,
388            _wants: &[String],
389            _haves: &[String],
390        ) -> Result<Vec<String>, ProtocolError> {
391            Ok(vec![])
392        }
393        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
394            Ok(false)
395        }
396        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
397            Ok(())
398        }
399    }
400
401    #[tokio::test]
402    async fn test_pack_roundtrip_encode_decode() {
403        let _guard = set_hash_kind_for_test(HashKind::Sha1);
404        // Create two Blob objects
405        let blob1 = Blob::from_content("hello");
406        let blob2 = Blob::from_content("world");
407
408        // Create a Tree containing two file items
409        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
410        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
411        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
412
413        // Create a Commit pointing to the Tree
414        let author = Signature::new(
415            SignatureType::Author,
416            "tester".to_string(),
417            "tester@example.com".to_string(),
418        );
419        let committer = Signature::new(
420            SignatureType::Committer,
421            "tester".to_string(),
422            "tester@example.com".to_string(),
423        );
424        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
425
426        // Generate pack stream
427        let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
428        PackGenerator::<DummyRepoAccess>::generate_pack_stream(
429            (
430                vec![commit.clone()],
431                vec![tree.clone()],
432                vec![blob1.clone(), blob2.clone()],
433            ),
434            tx,
435        )
436        .await
437        .unwrap();
438
439        let mut pack_bytes: Vec<u8> = Vec::new();
440        while let Some(chunk) = rx.recv().await {
441            pack_bytes.extend_from_slice(&chunk);
442        }
443        println!("Encoded pack size: {} bytes", pack_bytes.len());
444
445        // Unpack the pack stream
446        let dummy = DummyRepoAccess;
447        let generator = PackGenerator::new(&dummy);
448        let (decoded_commits, decoded_trees, decoded_blobs) = generator
449            .unpack_stream(Bytes::from(pack_bytes))
450            .await
451            .unwrap();
452
453        println!(
454            "Decoded commits: {:?}",
455            decoded_commits
456                .iter()
457                .map(|c| c.id.to_string())
458                .collect::<Vec<_>>()
459        );
460        println!(
461            "Decoded trees:   {:?}",
462            decoded_trees
463                .iter()
464                .map(|t| t.id.to_string())
465                .collect::<Vec<_>>()
466        );
467        println!(
468            "Decoded blobs:   {:?}",
469            decoded_blobs
470                .iter()
471                .map(|b| b.id.to_string())
472                .collect::<Vec<_>>()
473        );
474
475        // Verify object ID roundtrip consistency
476        assert_eq!(decoded_commits.len(), 1);
477        assert_eq!(decoded_trees.len(), 1);
478        assert_eq!(decoded_blobs.len(), 2);
479
480        assert_eq!(decoded_commits[0].id, commit.id);
481        assert_eq!(decoded_trees[0].id, tree.id);
482
483        let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
484        orig_blob_ids.sort();
485        let mut decoded_blob_ids = decoded_blobs
486            .iter()
487            .map(|b| b.id.to_string())
488            .collect::<Vec<_>>();
489        decoded_blob_ids.sort();
490        assert_eq!(orig_blob_ids, decoded_blob_ids);
491    }
492    #[tokio::test]
493    async fn test_pack_roundtrip_encode_decode_sha256() {
494        let _guard = set_hash_kind_for_test(HashKind::Sha256);
495        // Create two Blob objects
496        let blob1 = Blob::from_content("hello");
497        let blob2 = Blob::from_content("world");
498
499        // Create a Tree containing two file items
500        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
501        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
502        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
503
504        // Create a Commit pointing to the Tree
505        let author = Signature::new(
506            SignatureType::Author,
507            "tester".to_string(),
508            "tester@example.com".to_string(),
509        );
510        let committer = Signature::new(
511            SignatureType::Committer,
512            "tester".to_string(),
513            "tester@example.com".to_string(),
514        );
515        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
516
517        // Generate pack stream
518        let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
519        PackGenerator::<DummyRepoAccess>::generate_pack_stream(
520            (
521                vec![commit.clone()],
522                vec![tree.clone()],
523                vec![blob1.clone(), blob2.clone()],
524            ),
525            tx,
526        )
527        .await
528        .unwrap();
529
530        let mut pack_bytes: Vec<u8> = Vec::new();
531        while let Some(chunk) = rx.recv().await {
532            pack_bytes.extend_from_slice(&chunk);
533        }
534        println!("Encoded pack size: {} bytes", pack_bytes.len());
535
536        // Unpack the pack stream
537        let dummy = DummyRepoAccess;
538        let generator = PackGenerator::new(&dummy);
539        let (decoded_commits, decoded_trees, decoded_blobs) = generator
540            .unpack_stream(Bytes::from(pack_bytes))
541            .await
542            .unwrap();
543
544        println!(
545            "Decoded commits: {:?}",
546            decoded_commits
547                .iter()
548                .map(|c| c.id.to_string())
549                .collect::<Vec<_>>()
550        );
551        println!(
552            "Decoded trees:   {:?}",
553            decoded_trees
554                .iter()
555                .map(|t| t.id.to_string())
556                .collect::<Vec<_>>()
557        );
558        println!(
559            "Decoded blobs:   {:?}",
560            decoded_blobs
561                .iter()
562                .map(|b| b.id.to_string())
563                .collect::<Vec<_>>()
564        );
565
566        // Verify object ID roundtrip consistency
567        assert_eq!(decoded_commits.len(), 1);
568        assert_eq!(decoded_trees.len(), 1);
569        assert_eq!(decoded_blobs.len(), 2);
570
571        assert_eq!(decoded_commits[0].id, commit.id);
572        assert_eq!(decoded_trees[0].id, tree.id);
573
574        let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
575        orig_blob_ids.sort();
576        let mut decoded_blob_ids = decoded_blobs
577            .iter()
578            .map(|b| b.id.to_string())
579            .collect::<Vec<_>>();
580        decoded_blob_ids.sort();
581        assert_eq!(orig_blob_ids, decoded_blob_ids);
582    }
583}