git_internal/protocol/
pack.rs

1use bytes::Bytes;
2use std::collections::{HashSet, VecDeque};
3use std::io::Cursor;
4use tokio;
5use tokio::sync::mpsc;
6use tokio_stream::wrappers::ReceiverStream;
7
8use super::core::RepositoryAccess;
9use super::types::ProtocolError;
10use crate::internal::object::types::ObjectType;
11use crate::internal::object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree};
12use crate::internal::pack::{Pack, encode::PackEncoder, entry::Entry};
13
14/// Pack generation service for Git protocol operations
15///
16/// This handles the core Git pack generation logic internally within git-internal,
17/// using the RepositoryAccess trait only for data access.
18pub struct PackGenerator<'a, R>
19where
20    R: RepositoryAccess,
21{
22    repo_access: &'a R,
23}
24
25impl<'a, R> PackGenerator<'a, R>
26where
27    R: RepositoryAccess,
28{
29    pub fn new(repo_access: &'a R) -> Self {
30        Self { repo_access }
31    }
32
33    /// Generate a full pack containing all requested objects
34    pub async fn generate_full_pack(
35        &self,
36        want: Vec<String>,
37    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
38        let (tx, rx) = mpsc::channel(1024);
39
40        // Collect all objects needed for the wanted commits
41        let all_objects = self.collect_all_objects(want).await?;
42
43        // Generate pack data
44        tokio::spawn(async move {
45            if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
46                tracing::error!("Failed to generate pack stream: {}", e);
47            }
48        });
49
50        Ok(ReceiverStream::new(rx))
51    }
52
53    /// Generate an incremental pack containing only objects not in 'have'
54    pub async fn generate_incremental_pack(
55        &self,
56        want: Vec<String>,
57        have: Vec<String>,
58    ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
59        let (tx, rx) = mpsc::channel(1024);
60
61        // Collect objects for wanted commits
62        let wanted_objects = self.collect_all_objects(want).await?;
63
64        // Collect objects for have commits (to exclude)
65        let have_objects = self.collect_all_objects(have).await?;
66
67        // Filter out objects that are already in 'have'
68        let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
69
70        // Generate pack data
71        tokio::spawn(async move {
72            if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
73                tracing::error!("Failed to generate incremental pack stream: {}", e);
74            }
75        });
76
77        Ok(ReceiverStream::new(rx))
78    }
79
80    /// Unpack incoming pack stream and extract objects
81    pub async fn unpack_stream(
82        &self,
83        pack_data: Bytes,
84    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
85        use std::sync::{Arc, Mutex};
86
87        let commits = Arc::new(Mutex::new(Vec::new()));
88        let trees = Arc::new(Mutex::new(Vec::new()));
89        let blobs = Arc::new(Mutex::new(Vec::new()));
90
91        let commits_clone = commits.clone();
92        let trees_clone = trees.clone();
93        let blobs_clone = blobs.clone();
94
95        // Create a Pack instance for decoding
96        let mut pack = Pack::new(None, None, None, true);
97        let mut cursor = Cursor::new(pack_data.to_vec());
98
99        // Decode the pack and collect entries
100        pack.decode(
101            &mut cursor,
102            move |entry: Entry, _offset: usize| match entry.obj_type {
103                ObjectType::Commit => {
104                    if let Ok(commit) = Commit::from_bytes(&entry.data, entry.hash) {
105                        commits_clone.lock().unwrap().push(commit);
106                    } else {
107                        tracing::warn!("Failed to parse commit from pack entry");
108                    }
109                }
110                ObjectType::Tree => {
111                    if let Ok(tree) = Tree::from_bytes(&entry.data, entry.hash) {
112                        trees_clone.lock().unwrap().push(tree);
113                    } else {
114                        tracing::warn!("Failed to parse tree from pack entry");
115                    }
116                }
117                ObjectType::Blob => {
118                    if let Ok(blob) = Blob::from_bytes(&entry.data, entry.hash) {
119                        blobs_clone.lock().unwrap().push(blob);
120                    } else {
121                        tracing::warn!("Failed to parse blob from pack entry");
122                    }
123                }
124                _ => {
125                    tracing::warn!("Unknown object type in pack: {:?}", entry.obj_type);
126                }
127            },
128        )
129        .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {}", e)))?;
130
131        // Extract the results
132        let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
133        let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
134        let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
135
136        Ok((commits_result, trees_result, blobs_result))
137    }
138
139    /// Collect all objects reachable from the given commit hashes
140    async fn collect_all_objects(
141        &self,
142        commit_hashes: Vec<String>,
143    ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
144        let mut commits = Vec::new();
145        let mut trees = Vec::new();
146        let mut blobs = Vec::new();
147
148        let mut visited_commits = HashSet::new();
149        let mut visited_trees = HashSet::new();
150        let mut visited_blobs = HashSet::new();
151
152        let mut commit_queue = VecDeque::from(commit_hashes);
153
154        // BFS traversal of commit graph
155        while let Some(commit_hash) = commit_queue.pop_front() {
156            if visited_commits.contains(&commit_hash) {
157                continue;
158            }
159            visited_commits.insert(commit_hash.clone());
160
161            // Get commit object
162            let commit = self
163                .repo_access
164                .get_commit(&commit_hash)
165                .await
166                .map_err(|e| {
167                    ProtocolError::repository_error(format!(
168                        "Failed to get commit {}: {}",
169                        commit_hash, e
170                    ))
171                })?;
172
173            // Add parent commits to queue
174            for parent in &commit.parent_commit_ids {
175                let parent_str = parent.to_string();
176                if !visited_commits.contains(&parent_str) {
177                    commit_queue.push_back(parent_str);
178                }
179            }
180
181            // Collect tree objects
182            Box::pin(self.collect_tree_objects(
183                &commit.tree_id.to_string(),
184                &mut trees,
185                &mut blobs,
186                &mut visited_trees,
187                &mut visited_blobs,
188            ))
189            .await?;
190
191            commits.push(commit);
192        }
193
194        Ok((commits, trees, blobs))
195    }
196
197    /// Recursively collect tree and blob objects
198    async fn collect_tree_objects(
199        &self,
200        tree_hash: &str,
201        trees: &mut Vec<Tree>,
202        blobs: &mut Vec<Blob>,
203        visited_trees: &mut HashSet<String>,
204        visited_blobs: &mut HashSet<String>,
205    ) -> Result<(), ProtocolError> {
206        if visited_trees.contains(tree_hash) {
207            return Ok(());
208        }
209        visited_trees.insert(tree_hash.to_string());
210
211        let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
212            ProtocolError::repository_error(format!("Failed to get tree {}: {}", tree_hash, e))
213        })?;
214
215        for entry in &tree.tree_items {
216            let entry_hash = entry.id.to_string();
217            match entry.mode {
218                crate::internal::object::tree::TreeItemMode::Tree => {
219                    Box::pin(self.collect_tree_objects(
220                        &entry_hash,
221                        trees,
222                        blobs,
223                        visited_trees,
224                        visited_blobs,
225                    ))
226                    .await?;
227                }
228                crate::internal::object::tree::TreeItemMode::Blob
229                | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
230                    if !visited_blobs.contains(&entry_hash) {
231                        visited_blobs.insert(entry_hash.clone());
232                        let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
233                            ProtocolError::repository_error(format!(
234                                "Failed to get blob {}: {}",
235                                entry_hash, e
236                            ))
237                        })?;
238                        blobs.push(blob);
239                    }
240                }
241                _ => {}
242            }
243        }
244
245        trees.push(tree);
246        Ok(())
247    }
248
249    /// Filter objects to exclude those already in 'have'
250    fn filter_objects(
251        wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
252        have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
253    ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
254        let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
255        let (have_commits, have_trees, have_blobs) = have;
256
257        // Create hash sets for efficient lookup
258        let have_commit_hashes: HashSet<String> =
259            have_commits.iter().map(|c| c.id.to_string()).collect();
260        let have_tree_hashes: HashSet<String> =
261            have_trees.iter().map(|t| t.id.to_string()).collect();
262        let have_blob_hashes: HashSet<String> =
263            have_blobs.iter().map(|b| b.id.to_string()).collect();
264
265        // Filter out objects that are in 'have'
266        let filtered_commits: Vec<Commit> = wanted_commits
267            .into_iter()
268            .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
269            .collect();
270
271        let filtered_trees: Vec<Tree> = wanted_trees
272            .into_iter()
273            .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
274            .collect();
275
276        let filtered_blobs: Vec<Blob> = wanted_blobs
277            .into_iter()
278            .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
279            .collect();
280
281        (filtered_commits, filtered_trees, filtered_blobs)
282    }
283
284    /// Generate pack stream from objects
285    async fn generate_pack_stream(
286        objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
287        tx: mpsc::Sender<Vec<u8>>,
288    ) -> Result<(), ProtocolError> {
289        let (commits, trees, blobs) = objects;
290
291        // Convert objects to entries
292        let mut entries = Vec::new();
293
294        for commit in commits {
295            entries.push(Entry::from(commit));
296        }
297
298        for tree in trees {
299            entries.push(Entry::from(tree));
300        }
301
302        for blob in blobs {
303            entries.push(Entry::from(blob));
304        }
305
306        // Create PackEncoder and encode entries
307        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
308        let (entry_tx, entry_rx) = mpsc::channel(1024);
309        let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); // window_size = 10
310
311        // Spawn encoding task
312        tokio::spawn(async move {
313            if let Err(e) = encoder.encode(entry_rx).await {
314                tracing::error!("Failed to encode pack: {}", e);
315            }
316        });
317
318        // Send entries to encoder
319        tokio::spawn(async move {
320            for entry in entries {
321                if entry_tx.send(entry).await.is_err() {
322                    break; // Receiver dropped
323                }
324            }
325            // Drop sender to signal end of entries
326        });
327
328        // Forward pack data to output channel
329        while let Some(chunk) = pack_rx.recv().await {
330            if tx.send(chunk).await.is_err() {
331                break; // Receiver dropped
332            }
333        }
334
335        Ok(())
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use crate::internal::object::blob::Blob;
343    use crate::internal::object::commit::Commit;
344    use crate::internal::object::signature::{Signature, SignatureType};
345    use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
346    use async_trait::async_trait;
347    use bytes::Bytes;
348
349    #[derive(Clone)]
350    struct DummyRepoAccess;
351
352    #[async_trait]
353    impl RepositoryAccess for DummyRepoAccess {
354        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
355            Ok(vec![])
356        }
357        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
358            Ok(false)
359        }
360        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
361            Err(ProtocolError::repository_error(
362                "not implemented".to_string(),
363            ))
364        }
365        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
366            Ok(())
367        }
368        async fn update_reference(
369            &self,
370            _ref_name: &str,
371            _old_hash: Option<&str>,
372            _new_hash: &str,
373        ) -> Result<(), ProtocolError> {
374            Ok(())
375        }
376        async fn get_objects_for_pack(
377            &self,
378            _wants: &[String],
379            _haves: &[String],
380        ) -> Result<Vec<String>, ProtocolError> {
381            Ok(vec![])
382        }
383        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
384            Ok(false)
385        }
386        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
387            Ok(())
388        }
389    }
390
391    #[tokio::test]
392    async fn test_pack_roundtrip_encode_decode() {
393        // Create two Blob objects
394        let blob1 = Blob::from_content("hello");
395        let blob2 = Blob::from_content("world");
396
397        // Create a Tree containing two file items
398        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
399        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
400        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
401
402        // Create a Commit pointing to the Tree
403        let author = Signature::new(
404            SignatureType::Author,
405            "tester".to_string(),
406            "tester@example.com".to_string(),
407        );
408        let committer = Signature::new(
409            SignatureType::Committer,
410            "tester".to_string(),
411            "tester@example.com".to_string(),
412        );
413        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
414
415        // Generate pack stream
416        let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
417        PackGenerator::<DummyRepoAccess>::generate_pack_stream(
418            (
419                vec![commit.clone()],
420                vec![tree.clone()],
421                vec![blob1.clone(), blob2.clone()],
422            ),
423            tx,
424        )
425        .await
426        .unwrap();
427
428        let mut pack_bytes: Vec<u8> = Vec::new();
429        while let Some(chunk) = rx.recv().await {
430            pack_bytes.extend_from_slice(&chunk);
431        }
432        println!("Encoded pack size: {} bytes", pack_bytes.len());
433
434        // Unpack the pack stream
435        let dummy = DummyRepoAccess;
436        let generator = PackGenerator::new(&dummy);
437        let (decoded_commits, decoded_trees, decoded_blobs) = generator
438            .unpack_stream(Bytes::from(pack_bytes))
439            .await
440            .unwrap();
441
442        println!(
443            "Decoded commits: {:?}",
444            decoded_commits
445                .iter()
446                .map(|c| c.id.to_string())
447                .collect::<Vec<_>>()
448        );
449        println!(
450            "Decoded trees:   {:?}",
451            decoded_trees
452                .iter()
453                .map(|t| t.id.to_string())
454                .collect::<Vec<_>>()
455        );
456        println!(
457            "Decoded blobs:   {:?}",
458            decoded_blobs
459                .iter()
460                .map(|b| b.id.to_string())
461                .collect::<Vec<_>>()
462        );
463
464        // Verify object ID roundtrip consistency
465        assert_eq!(decoded_commits.len(), 1);
466        assert_eq!(decoded_trees.len(), 1);
467        assert_eq!(decoded_blobs.len(), 2);
468
469        assert_eq!(decoded_commits[0].id, commit.id);
470        assert_eq!(decoded_trees[0].id, tree.id);
471
472        let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
473        orig_blob_ids.sort();
474        let mut decoded_blob_ids = decoded_blobs
475            .iter()
476            .map(|b| b.id.to_string())
477            .collect::<Vec<_>>();
478        decoded_blob_ids.sort();
479        assert_eq!(orig_blob_ids, decoded_blob_ids);
480    }
481}