git_internal/protocol/
pack.rs

1use super::core::RepositoryAccess;
2use super::types::ProtocolError;
3use crate::hash::SHA1;
4use crate::internal::metadata::{EntryMeta, MetaAttached};
5use crate::internal::object::types::ObjectType;
6use crate::internal::object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree};
7use crate::internal::pack::{Pack, encode::PackEncoder, entry::Entry};
8use bytes::Bytes;
9use std::collections::{HashSet, VecDeque};
10use std::io::Cursor;
11use tokio;
12use tokio::sync::mpsc;
13use tokio_stream::wrappers::ReceiverStream;
14
15/// Pack generation service for Git protocol operations
16///
17/// This handles the core Git pack generation logic internally within git-internal,
18/// using the RepositoryAccess trait only for data access.
19pub struct PackGenerator<'a, R>
20where
21    R: RepositoryAccess,
22{
23    repo_access: &'a R,
24}
25
26impl<'a, R> PackGenerator<'a, R>
27where
28    R: RepositoryAccess,
29{
30    pub fn new(repo_access: &'a R) -> Self {
31        Self { repo_access }
32    }
33
34    /// Generate a full pack containing all requested objects
35    pub async fn generate_full_pack(
36        &self,
37        want: Vec<String>,
38    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
39        let (tx, rx) = mpsc::channel(1024);
40
41        // Collect all objects needed for the wanted commits
42        let all_objects = self.collect_all_objects(want).await?;
43
44        // Generate pack data
45        tokio::spawn(async move {
46            if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
47                tracing::error!("Failed to generate pack stream: {}", e);
48            }
49        });
50
51        Ok(ReceiverStream::new(rx))
52    }
53
54    /// Generate an incremental pack containing only objects not in 'have'
55    pub async fn generate_incremental_pack(
56        &self,
57        want: Vec<String>,
58        have: Vec<String>,
59    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
60        let (tx, rx) = mpsc::channel(1024);
61
62        // Collect objects for wanted commits
63        let wanted_objects = self.collect_all_objects(want).await?;
64
65        // Collect objects for have commits (to exclude)
66        let have_objects = self.collect_all_objects(have).await?;
67
68        // Filter out objects that are already in 'have'
69        let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
70
71        // Generate pack data
72        tokio::spawn(async move {
73            if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
74                tracing::error!("Failed to generate incremental pack stream: {}", e);
75            }
76        });
77
78        Ok(ReceiverStream::new(rx))
79    }
80
81    /// Unpack incoming pack stream and extract objects
82    pub async fn unpack_stream(
83        &self,
84        pack_data: Bytes,
85    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
86        use std::sync::{Arc, Mutex};
87
88        let commits = Arc::new(Mutex::new(Vec::new()));
89        let trees = Arc::new(Mutex::new(Vec::new()));
90        let blobs = Arc::new(Mutex::new(Vec::new()));
91
92        let commits_clone = commits.clone();
93        let trees_clone = trees.clone();
94        let blobs_clone = blobs.clone();
95
96        // Create a Pack instance for decoding
97        let mut pack = Pack::new(None, None, None, true);
98        let mut cursor = Cursor::new(pack_data.to_vec());
99
100        // Decode the pack and collect entries
101        pack.decode(
102            &mut cursor,
103            move |entry: MetaAttached<Entry, EntryMeta>| match entry.inner.obj_type {
104                ObjectType::Commit => {
105                    if let Ok(commit) = Commit::from_bytes(&entry.inner.data, entry.inner.hash) {
106                        commits_clone.lock().unwrap().push(commit);
107                    } else {
108                        tracing::warn!("Failed to parse commit from pack entry");
109                    }
110                }
111                ObjectType::Tree => {
112                    if let Ok(tree) = Tree::from_bytes(&entry.inner.data, entry.inner.hash) {
113                        trees_clone.lock().unwrap().push(tree);
114                    } else {
115                        tracing::warn!("Failed to parse tree from pack entry");
116                    }
117                }
118                ObjectType::Blob => {
119                    if let Ok(blob) = Blob::from_bytes(&entry.inner.data, entry.inner.hash) {
120                        blobs_clone.lock().unwrap().push(blob);
121                    } else {
122                        tracing::warn!("Failed to parse blob from pack entry");
123                    }
124                }
125                _ => {
126                    tracing::warn!("Unknown object type in pack: {:?}", entry.inner.obj_type);
127                }
128            },
129            None::<fn(SHA1)>,
130        )
131        .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {}", e)))?;
132
133        // Extract the results
134        let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
135        let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
136        let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
137
138        Ok((commits_result, trees_result, blobs_result))
139    }
140
141    /// Collect all objects reachable from the given commit hashes
142    async fn collect_all_objects(
143        &self,
144        commit_hashes: Vec<String>,
145    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
146        let mut commits = Vec::new();
147        let mut trees = Vec::new();
148        let mut blobs = Vec::new();
149
150        let mut visited_commits = HashSet::new();
151        let mut visited_trees = HashSet::new();
152        let mut visited_blobs = HashSet::new();
153
154        let mut commit_queue = VecDeque::from(commit_hashes);
155
156        // BFS traversal of commit graph
157        while let Some(commit_hash) = commit_queue.pop_front() {
158            if visited_commits.contains(&commit_hash) {
159                continue;
160            }
161            visited_commits.insert(commit_hash.clone());
162
163            // Get commit object
164            let commit = self
165                .repo_access
166                .get_commit(&commit_hash)
167                .await
168                .map_err(|e| {
169                    ProtocolError::repository_error(format!(
170                        "Failed to get commit {}: {}",
171                        commit_hash, e
172                    ))
173                })?;
174
175            // Add parent commits to queue
176            for parent in &commit.parent_commit_ids {
177                let parent_str = parent.to_string();
178                if !visited_commits.contains(&parent_str) {
179                    commit_queue.push_back(parent_str);
180                }
181            }
182
183            // Collect tree objects
184            Box::pin(self.collect_tree_objects(
185                &commit.tree_id.to_string(),
186                &mut trees,
187                &mut blobs,
188                &mut visited_trees,
189                &mut visited_blobs,
190            ))
191            .await?;
192
193            commits.push(commit);
194        }
195
196        Ok((commits, trees, blobs))
197    }
198
199    /// Recursively collect tree and blob objects
200    async fn collect_tree_objects(
201        &self,
202        tree_hash: &str,
203        trees: &mut Vec<Tree>,
204        blobs: &mut Vec<Blob>,
205        visited_trees: &mut HashSet<String>,
206        visited_blobs: &mut HashSet<String>,
207    ) -> Result<(), ProtocolError> {
208        if visited_trees.contains(tree_hash) {
209            return Ok(());
210        }
211        visited_trees.insert(tree_hash.to_string());
212
213        let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
214            ProtocolError::repository_error(format!("Failed to get tree {}: {}", tree_hash, e))
215        })?;
216
217        for entry in &tree.tree_items {
218            let entry_hash = entry.id.to_string();
219            match entry.mode {
220                crate::internal::object::tree::TreeItemMode::Tree => {
221                    Box::pin(self.collect_tree_objects(
222                        &entry_hash,
223                        trees,
224                        blobs,
225                        visited_trees,
226                        visited_blobs,
227                    ))
228                    .await?;
229                }
230                crate::internal::object::tree::TreeItemMode::Blob
231                | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
232                    if !visited_blobs.contains(&entry_hash) {
233                        visited_blobs.insert(entry_hash.clone());
234                        let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
235                            ProtocolError::repository_error(format!(
236                                "Failed to get blob {}: {}",
237                                entry_hash, e
238                            ))
239                        })?;
240                        blobs.push(blob);
241                    }
242                }
243                _ => {}
244            }
245        }
246
247        trees.push(tree);
248        Ok(())
249    }
250
251    /// Filter objects to exclude those already in 'have'
252    fn filter_objects(
253        wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
254        have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
255    ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
256        let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
257        let (have_commits, have_trees, have_blobs) = have;
258
259        // Create hash sets for efficient lookup
260        let have_commit_hashes: HashSet<String> =
261            have_commits.iter().map(|c| c.id.to_string()).collect();
262        let have_tree_hashes: HashSet<String> =
263            have_trees.iter().map(|t| t.id.to_string()).collect();
264        let have_blob_hashes: HashSet<String> =
265            have_blobs.iter().map(|b| b.id.to_string()).collect();
266
267        // Filter out objects that are in 'have'
268        let filtered_commits: Vec<Commit> = wanted_commits
269            .into_iter()
270            .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
271            .collect();
272
273        let filtered_trees: Vec<Tree> = wanted_trees
274            .into_iter()
275            .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
276            .collect();
277
278        let filtered_blobs: Vec<Blob> = wanted_blobs
279            .into_iter()
280            .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
281            .collect();
282
283        (filtered_commits, filtered_trees, filtered_blobs)
284    }
285
286    /// Generate pack stream from objects
287    async fn generate_pack_stream(
288        objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
289        tx: mpsc::Sender<Vec<u8>>,
290    ) -> Result<(), ProtocolError> {
291        let (commits, trees, blobs) = objects;
292
293        // Convert objects to entries
294        let mut entries = Vec::new();
295
296        for commit in commits {
297            entries.push(Entry::from(commit));
298        }
299
300        for tree in trees {
301            entries.push(Entry::from(tree));
302        }
303
304        for blob in blobs {
305            entries.push(Entry::from(blob));
306        }
307
308        // Create PackEncoder and encode entries
309        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
310        let (entry_tx, entry_rx) = mpsc::channel(1024);
311        let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); // window_size = 10
312
313        // Spawn encoding task
314        tokio::spawn(async move {
315            if let Err(e) = encoder.encode(entry_rx).await {
316                tracing::error!("Failed to encode pack: {}", e);
317            }
318        });
319
320        // Send entries to encoder
321        tokio::spawn(async move {
322            for entry in entries {
323                if entry_tx
324                    .send(MetaAttached {
325                        inner: entry,
326                        meta: EntryMeta::new(),
327                    })
328                    .await
329                    .is_err()
330                {
331                    break; // Receiver dropped
332                }
333            }
334            // Drop sender to signal end of entries
335        });
336
337        // Forward pack data to output channel
338        while let Some(chunk) = pack_rx.recv().await {
339            if tx.send(chunk).await.is_err() {
340                break; // Receiver dropped
341            }
342        }
343
344        Ok(())
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::internal::object::blob::Blob;
352    use crate::internal::object::commit::Commit;
353    use crate::internal::object::signature::{Signature, SignatureType};
354    use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
355    use async_trait::async_trait;
356    use bytes::Bytes;
357
358    #[derive(Clone)]
359    struct DummyRepoAccess;
360
361    #[async_trait]
362    impl RepositoryAccess for DummyRepoAccess {
363        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
364            Ok(vec![])
365        }
366        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
367            Ok(false)
368        }
369        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
370            Err(ProtocolError::repository_error(
371                "not implemented".to_string(),
372            ))
373        }
374        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
375            Ok(())
376        }
377        async fn update_reference(
378            &self,
379            _ref_name: &str,
380            _old_hash: Option<&str>,
381            _new_hash: &str,
382        ) -> Result<(), ProtocolError> {
383            Ok(())
384        }
385        async fn get_objects_for_pack(
386            &self,
387            _wants: &[String],
388            _haves: &[String],
389        ) -> Result<Vec<String>, ProtocolError> {
390            Ok(vec![])
391        }
392        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
393            Ok(false)
394        }
395        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
396            Ok(())
397        }
398    }
399
400    #[tokio::test]
401    async fn test_pack_roundtrip_encode_decode() {
402        // Create two Blob objects
403        let blob1 = Blob::from_content("hello");
404        let blob2 = Blob::from_content("world");
405
406        // Create a Tree containing two file items
407        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
408        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
409        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
410
411        // Create a Commit pointing to the Tree
412        let author = Signature::new(
413            SignatureType::Author,
414            "tester".to_string(),
415            "tester@example.com".to_string(),
416        );
417        let committer = Signature::new(
418            SignatureType::Committer,
419            "tester".to_string(),
420            "tester@example.com".to_string(),
421        );
422        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
423
424        // Generate pack stream
425        let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
426        PackGenerator::<DummyRepoAccess>::generate_pack_stream(
427            (
428                vec![commit.clone()],
429                vec![tree.clone()],
430                vec![blob1.clone(), blob2.clone()],
431            ),
432            tx,
433        )
434        .await
435        .unwrap();
436
437        let mut pack_bytes: Vec<u8> = Vec::new();
438        while let Some(chunk) = rx.recv().await {
439            pack_bytes.extend_from_slice(&chunk);
440        }
441        println!("Encoded pack size: {} bytes", pack_bytes.len());
442
443        // Unpack the pack stream
444        let dummy = DummyRepoAccess;
445        let generator = PackGenerator::new(&dummy);
446        let (decoded_commits, decoded_trees, decoded_blobs) = generator
447            .unpack_stream(Bytes::from(pack_bytes))
448            .await
449            .unwrap();
450
451        println!(
452            "Decoded commits: {:?}",
453            decoded_commits
454                .iter()
455                .map(|c| c.id.to_string())
456                .collect::<Vec<_>>()
457        );
458        println!(
459            "Decoded trees:   {:?}",
460            decoded_trees
461                .iter()
462                .map(|t| t.id.to_string())
463                .collect::<Vec<_>>()
464        );
465        println!(
466            "Decoded blobs:   {:?}",
467            decoded_blobs
468                .iter()
469                .map(|b| b.id.to_string())
470                .collect::<Vec<_>>()
471        );
472
473        // Verify object ID roundtrip consistency
474        assert_eq!(decoded_commits.len(), 1);
475        assert_eq!(decoded_trees.len(), 1);
476        assert_eq!(decoded_blobs.len(), 2);
477
478        assert_eq!(decoded_commits[0].id, commit.id);
479        assert_eq!(decoded_trees[0].id, tree.id);
480
481        let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
482        orig_blob_ids.sort();
483        let mut decoded_blob_ids = decoded_blobs
484            .iter()
485            .map(|b| b.id.to_string())
486            .collect::<Vec<_>>();
487        decoded_blob_ids.sort();
488        assert_eq!(orig_blob_ids, decoded_blob_ids);
489    }
490}