1use super::core::RepositoryAccess;
2use super::types::ProtocolError;
3use crate::hash::ObjectHash;
4use crate::internal::metadata::{EntryMeta, MetaAttached};
5use crate::internal::object::types::ObjectType;
6use crate::internal::object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree};
7use crate::internal::pack::{Pack, encode::PackEncoder, entry::Entry};
8use bytes::Bytes;
9use std::collections::{HashSet, VecDeque};
10use std::io::Cursor;
11use tokio;
12use tokio::sync::mpsc;
13use tokio_stream::wrappers::ReceiverStream;
14
15pub struct PackGenerator<'a, R>
20where
21 R: RepositoryAccess,
22{
23 repo_access: &'a R,
24}
25
26impl<'a, R> PackGenerator<'a, R>
27where
28 R: RepositoryAccess,
29{
30 pub fn new(repo_access: &'a R) -> Self {
31 Self { repo_access }
32 }
33
34 pub async fn generate_full_pack(
36 &self,
37 want: Vec<String>,
38 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
39 let (tx, rx) = mpsc::channel(1024);
40
41 let all_objects = self.collect_all_objects(want).await?;
43
44 tokio::spawn(async move {
46 if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
47 tracing::error!("Failed to generate pack stream: {}", e);
48 }
49 });
50
51 Ok(ReceiverStream::new(rx))
52 }
53
54 pub async fn generate_incremental_pack(
56 &self,
57 want: Vec<String>,
58 have: Vec<String>,
59 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
60 let (tx, rx) = mpsc::channel(1024);
61
62 let wanted_objects = self.collect_all_objects(want).await?;
64
65 let have_objects = self.collect_all_objects(have).await?;
67
68 let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
70
71 tokio::spawn(async move {
73 if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
74 tracing::error!("Failed to generate incremental pack stream: {}", e);
75 }
76 });
77
78 Ok(ReceiverStream::new(rx))
79 }
80
81 pub async fn unpack_stream(
83 &self,
84 pack_data: Bytes,
85 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
86 use std::sync::{Arc, Mutex};
87
88 let commits = Arc::new(Mutex::new(Vec::new()));
89 let trees = Arc::new(Mutex::new(Vec::new()));
90 let blobs = Arc::new(Mutex::new(Vec::new()));
91
92 let commits_clone = commits.clone();
93 let trees_clone = trees.clone();
94 let blobs_clone = blobs.clone();
95
96 let mut pack = Pack::new(None, None, None, true);
98 let mut cursor = Cursor::new(pack_data.to_vec());
99
100 pack.decode(
102 &mut cursor,
103 move |entry: MetaAttached<Entry, EntryMeta>| match entry.inner.obj_type {
104 ObjectType::Commit => {
105 if let Ok(commit) = Commit::from_bytes(&entry.inner.data, entry.inner.hash) {
106 commits_clone.lock().unwrap().push(commit);
107 } else {
108 tracing::warn!("Failed to parse commit from pack entry");
109 }
110 }
111 ObjectType::Tree => {
112 if let Ok(tree) = Tree::from_bytes(&entry.inner.data, entry.inner.hash) {
113 trees_clone.lock().unwrap().push(tree);
114 } else {
115 tracing::warn!("Failed to parse tree from pack entry");
116 }
117 }
118 ObjectType::Blob => {
119 if let Ok(blob) = Blob::from_bytes(&entry.inner.data, entry.inner.hash) {
120 blobs_clone.lock().unwrap().push(blob);
121 } else {
122 tracing::warn!("Failed to parse blob from pack entry");
123 }
124 }
125 _ => {
126 tracing::warn!("Unknown object type in pack: {:?}", entry.inner.obj_type);
127 }
128 },
129 None::<fn(ObjectHash)>,
130 )
131 .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {}", e)))?;
132
133 let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
135 let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
136 let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
137
138 Ok((commits_result, trees_result, blobs_result))
139 }
140
141 async fn collect_all_objects(
143 &self,
144 commit_hashes: Vec<String>,
145 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
146 let mut commits = Vec::new();
147 let mut trees = Vec::new();
148 let mut blobs = Vec::new();
149
150 let mut visited_commits = HashSet::new();
151 let mut visited_trees = HashSet::new();
152 let mut visited_blobs = HashSet::new();
153
154 let mut commit_queue = VecDeque::from(commit_hashes);
155
156 while let Some(commit_hash) = commit_queue.pop_front() {
158 if visited_commits.contains(&commit_hash) {
159 continue;
160 }
161 visited_commits.insert(commit_hash.clone());
162
163 let commit = self
165 .repo_access
166 .get_commit(&commit_hash)
167 .await
168 .map_err(|e| {
169 ProtocolError::repository_error(format!(
170 "Failed to get commit {}: {}",
171 commit_hash, e
172 ))
173 })?;
174
175 for parent in &commit.parent_commit_ids {
177 let parent_str = parent.to_string();
178 if !visited_commits.contains(&parent_str) {
179 commit_queue.push_back(parent_str);
180 }
181 }
182
183 Box::pin(self.collect_tree_objects(
185 &commit.tree_id.to_string(),
186 &mut trees,
187 &mut blobs,
188 &mut visited_trees,
189 &mut visited_blobs,
190 ))
191 .await?;
192
193 commits.push(commit);
194 }
195
196 Ok((commits, trees, blobs))
197 }
198
199 async fn collect_tree_objects(
201 &self,
202 tree_hash: &str,
203 trees: &mut Vec<Tree>,
204 blobs: &mut Vec<Blob>,
205 visited_trees: &mut HashSet<String>,
206 visited_blobs: &mut HashSet<String>,
207 ) -> Result<(), ProtocolError> {
208 if visited_trees.contains(tree_hash) {
209 return Ok(());
210 }
211 visited_trees.insert(tree_hash.to_string());
212
213 let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
214 ProtocolError::repository_error(format!("Failed to get tree {}: {}", tree_hash, e))
215 })?;
216
217 for entry in &tree.tree_items {
218 let entry_hash = entry.id.to_string();
219 match entry.mode {
220 crate::internal::object::tree::TreeItemMode::Tree => {
221 Box::pin(self.collect_tree_objects(
222 &entry_hash,
223 trees,
224 blobs,
225 visited_trees,
226 visited_blobs,
227 ))
228 .await?;
229 }
230 crate::internal::object::tree::TreeItemMode::Blob
231 | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
232 if !visited_blobs.contains(&entry_hash) {
233 visited_blobs.insert(entry_hash.clone());
234 let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
235 ProtocolError::repository_error(format!(
236 "Failed to get blob {}: {}",
237 entry_hash, e
238 ))
239 })?;
240 blobs.push(blob);
241 }
242 }
243 _ => {}
244 }
245 }
246
247 trees.push(tree);
248 Ok(())
249 }
250
251 fn filter_objects(
253 wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
254 have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
255 ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
256 let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
257 let (have_commits, have_trees, have_blobs) = have;
258
259 let have_commit_hashes: HashSet<String> =
261 have_commits.iter().map(|c| c.id.to_string()).collect();
262 let have_tree_hashes: HashSet<String> =
263 have_trees.iter().map(|t| t.id.to_string()).collect();
264 let have_blob_hashes: HashSet<String> =
265 have_blobs.iter().map(|b| b.id.to_string()).collect();
266
267 let filtered_commits: Vec<Commit> = wanted_commits
269 .into_iter()
270 .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
271 .collect();
272
273 let filtered_trees: Vec<Tree> = wanted_trees
274 .into_iter()
275 .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
276 .collect();
277
278 let filtered_blobs: Vec<Blob> = wanted_blobs
279 .into_iter()
280 .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
281 .collect();
282
283 (filtered_commits, filtered_trees, filtered_blobs)
284 }
285
286 async fn generate_pack_stream(
288 objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
289 tx: mpsc::Sender<Vec<u8>>,
290 ) -> Result<(), ProtocolError> {
291 let (commits, trees, blobs) = objects;
292
293 let mut entries = Vec::new();
295
296 for commit in commits {
297 entries.push(Entry::from(commit));
298 }
299
300 for tree in trees {
301 entries.push(Entry::from(tree));
302 }
303
304 for blob in blobs {
305 entries.push(Entry::from(blob));
306 }
307
308 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
310 let (entry_tx, entry_rx) = mpsc::channel(1024);
311 let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); tokio::spawn(async move {
315 if let Err(e) = encoder.encode(entry_rx).await {
316 tracing::error!("Failed to encode pack: {}", e);
317 }
318 });
319
320 tokio::spawn(async move {
322 for entry in entries {
323 if entry_tx
324 .send(MetaAttached {
325 inner: entry,
326 meta: EntryMeta::new(),
327 })
328 .await
329 .is_err()
330 {
331 break; }
333 }
334 });
336
337 while let Some(chunk) = pack_rx.recv().await {
339 if tx.send(chunk).await.is_err() {
340 break; }
342 }
343
344 Ok(())
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::hash::{HashKind, set_hash_kind_for_test};
352 use crate::internal::object::blob::Blob;
353 use crate::internal::object::commit::Commit;
354 use crate::internal::object::signature::{Signature, SignatureType};
355 use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
356 use async_trait::async_trait;
357 use bytes::Bytes;
358
359 #[derive(Clone)]
360 struct DummyRepoAccess;
361
362 #[async_trait]
363 impl RepositoryAccess for DummyRepoAccess {
364 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
365 Ok(vec![])
366 }
367 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
368 Ok(false)
369 }
370 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
371 Err(ProtocolError::repository_error(
372 "not implemented".to_string(),
373 ))
374 }
375 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
376 Ok(())
377 }
378 async fn update_reference(
379 &self,
380 _ref_name: &str,
381 _old_hash: Option<&str>,
382 _new_hash: &str,
383 ) -> Result<(), ProtocolError> {
384 Ok(())
385 }
386 async fn get_objects_for_pack(
387 &self,
388 _wants: &[String],
389 _haves: &[String],
390 ) -> Result<Vec<String>, ProtocolError> {
391 Ok(vec![])
392 }
393 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
394 Ok(false)
395 }
396 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
397 Ok(())
398 }
399 }
400
401 #[tokio::test]
402 async fn test_pack_roundtrip_encode_decode() {
403 let _guard = set_hash_kind_for_test(HashKind::Sha1);
404 let blob1 = Blob::from_content("hello");
406 let blob2 = Blob::from_content("world");
407
408 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
410 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
411 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
412
413 let author = Signature::new(
415 SignatureType::Author,
416 "tester".to_string(),
417 "tester@example.com".to_string(),
418 );
419 let committer = Signature::new(
420 SignatureType::Committer,
421 "tester".to_string(),
422 "tester@example.com".to_string(),
423 );
424 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
425
426 let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
428 PackGenerator::<DummyRepoAccess>::generate_pack_stream(
429 (
430 vec![commit.clone()],
431 vec![tree.clone()],
432 vec![blob1.clone(), blob2.clone()],
433 ),
434 tx,
435 )
436 .await
437 .unwrap();
438
439 let mut pack_bytes: Vec<u8> = Vec::new();
440 while let Some(chunk) = rx.recv().await {
441 pack_bytes.extend_from_slice(&chunk);
442 }
443 println!("Encoded pack size: {} bytes", pack_bytes.len());
444
445 let dummy = DummyRepoAccess;
447 let generator = PackGenerator::new(&dummy);
448 let (decoded_commits, decoded_trees, decoded_blobs) = generator
449 .unpack_stream(Bytes::from(pack_bytes))
450 .await
451 .unwrap();
452
453 println!(
454 "Decoded commits: {:?}",
455 decoded_commits
456 .iter()
457 .map(|c| c.id.to_string())
458 .collect::<Vec<_>>()
459 );
460 println!(
461 "Decoded trees: {:?}",
462 decoded_trees
463 .iter()
464 .map(|t| t.id.to_string())
465 .collect::<Vec<_>>()
466 );
467 println!(
468 "Decoded blobs: {:?}",
469 decoded_blobs
470 .iter()
471 .map(|b| b.id.to_string())
472 .collect::<Vec<_>>()
473 );
474
475 assert_eq!(decoded_commits.len(), 1);
477 assert_eq!(decoded_trees.len(), 1);
478 assert_eq!(decoded_blobs.len(), 2);
479
480 assert_eq!(decoded_commits[0].id, commit.id);
481 assert_eq!(decoded_trees[0].id, tree.id);
482
483 let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
484 orig_blob_ids.sort();
485 let mut decoded_blob_ids = decoded_blobs
486 .iter()
487 .map(|b| b.id.to_string())
488 .collect::<Vec<_>>();
489 decoded_blob_ids.sort();
490 assert_eq!(orig_blob_ids, decoded_blob_ids);
491 }
492 #[tokio::test]
493 async fn test_pack_roundtrip_encode_decode_sha256() {
494 let _guard = set_hash_kind_for_test(HashKind::Sha256);
495 let blob1 = Blob::from_content("hello");
497 let blob2 = Blob::from_content("world");
498
499 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
501 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
502 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
503
504 let author = Signature::new(
506 SignatureType::Author,
507 "tester".to_string(),
508 "tester@example.com".to_string(),
509 );
510 let committer = Signature::new(
511 SignatureType::Committer,
512 "tester".to_string(),
513 "tester@example.com".to_string(),
514 );
515 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
516
517 let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
519 PackGenerator::<DummyRepoAccess>::generate_pack_stream(
520 (
521 vec![commit.clone()],
522 vec![tree.clone()],
523 vec![blob1.clone(), blob2.clone()],
524 ),
525 tx,
526 )
527 .await
528 .unwrap();
529
530 let mut pack_bytes: Vec<u8> = Vec::new();
531 while let Some(chunk) = rx.recv().await {
532 pack_bytes.extend_from_slice(&chunk);
533 }
534 println!("Encoded pack size: {} bytes", pack_bytes.len());
535
536 let dummy = DummyRepoAccess;
538 let generator = PackGenerator::new(&dummy);
539 let (decoded_commits, decoded_trees, decoded_blobs) = generator
540 .unpack_stream(Bytes::from(pack_bytes))
541 .await
542 .unwrap();
543
544 println!(
545 "Decoded commits: {:?}",
546 decoded_commits
547 .iter()
548 .map(|c| c.id.to_string())
549 .collect::<Vec<_>>()
550 );
551 println!(
552 "Decoded trees: {:?}",
553 decoded_trees
554 .iter()
555 .map(|t| t.id.to_string())
556 .collect::<Vec<_>>()
557 );
558 println!(
559 "Decoded blobs: {:?}",
560 decoded_blobs
561 .iter()
562 .map(|b| b.id.to_string())
563 .collect::<Vec<_>>()
564 );
565
566 assert_eq!(decoded_commits.len(), 1);
568 assert_eq!(decoded_trees.len(), 1);
569 assert_eq!(decoded_blobs.len(), 2);
570
571 assert_eq!(decoded_commits[0].id, commit.id);
572 assert_eq!(decoded_trees[0].id, tree.id);
573
574 let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
575 orig_blob_ids.sort();
576 let mut decoded_blob_ids = decoded_blobs
577 .iter()
578 .map(|b| b.id.to_string())
579 .collect::<Vec<_>>();
580 decoded_blob_ids.sort();
581 assert_eq!(orig_blob_ids, decoded_blob_ids);
582 }
583}