1use super::core::RepositoryAccess;
2use super::types::ProtocolError;
3use crate::hash::SHA1;
4use crate::internal::metadata::{EntryMeta, MetaAttached};
5use crate::internal::object::types::ObjectType;
6use crate::internal::object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree};
7use crate::internal::pack::{Pack, encode::PackEncoder, entry::Entry};
8use bytes::Bytes;
9use std::collections::{HashSet, VecDeque};
10use std::io::Cursor;
11use tokio;
12use tokio::sync::mpsc;
13use tokio_stream::wrappers::ReceiverStream;
14
15pub struct PackGenerator<'a, R>
20where
21 R: RepositoryAccess,
22{
23 repo_access: &'a R,
24}
25
26impl<'a, R> PackGenerator<'a, R>
27where
28 R: RepositoryAccess,
29{
30 pub fn new(repo_access: &'a R) -> Self {
31 Self { repo_access }
32 }
33
34 pub async fn generate_full_pack(
36 &self,
37 want: Vec<String>,
38 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
39 let (tx, rx) = mpsc::channel(1024);
40
41 let all_objects = self.collect_all_objects(want).await?;
43
44 tokio::spawn(async move {
46 if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
47 tracing::error!("Failed to generate pack stream: {}", e);
48 }
49 });
50
51 Ok(ReceiverStream::new(rx))
52 }
53
54 pub async fn generate_incremental_pack(
56 &self,
57 want: Vec<String>,
58 have: Vec<String>,
59 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
60 let (tx, rx) = mpsc::channel(1024);
61
62 let wanted_objects = self.collect_all_objects(want).await?;
64
65 let have_objects = self.collect_all_objects(have).await?;
67
68 let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
70
71 tokio::spawn(async move {
73 if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
74 tracing::error!("Failed to generate incremental pack stream: {}", e);
75 }
76 });
77
78 Ok(ReceiverStream::new(rx))
79 }
80
81 pub async fn unpack_stream(
83 &self,
84 pack_data: Bytes,
85 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
86 use std::sync::{Arc, Mutex};
87
88 let commits = Arc::new(Mutex::new(Vec::new()));
89 let trees = Arc::new(Mutex::new(Vec::new()));
90 let blobs = Arc::new(Mutex::new(Vec::new()));
91
92 let commits_clone = commits.clone();
93 let trees_clone = trees.clone();
94 let blobs_clone = blobs.clone();
95
96 let mut pack = Pack::new(None, None, None, true);
98 let mut cursor = Cursor::new(pack_data.to_vec());
99
100 pack.decode(
102 &mut cursor,
103 move |entry: MetaAttached<Entry, EntryMeta>| match entry.inner.obj_type {
104 ObjectType::Commit => {
105 if let Ok(commit) = Commit::from_bytes(&entry.inner.data, entry.inner.hash) {
106 commits_clone.lock().unwrap().push(commit);
107 } else {
108 tracing::warn!("Failed to parse commit from pack entry");
109 }
110 }
111 ObjectType::Tree => {
112 if let Ok(tree) = Tree::from_bytes(&entry.inner.data, entry.inner.hash) {
113 trees_clone.lock().unwrap().push(tree);
114 } else {
115 tracing::warn!("Failed to parse tree from pack entry");
116 }
117 }
118 ObjectType::Blob => {
119 if let Ok(blob) = Blob::from_bytes(&entry.inner.data, entry.inner.hash) {
120 blobs_clone.lock().unwrap().push(blob);
121 } else {
122 tracing::warn!("Failed to parse blob from pack entry");
123 }
124 }
125 _ => {
126 tracing::warn!("Unknown object type in pack: {:?}", entry.inner.obj_type);
127 }
128 },
129 None::<fn(SHA1)>,
130 )
131 .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {}", e)))?;
132
133 let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
135 let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
136 let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
137
138 Ok((commits_result, trees_result, blobs_result))
139 }
140
141 async fn collect_all_objects(
143 &self,
144 commit_hashes: Vec<String>,
145 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
146 let mut commits = Vec::new();
147 let mut trees = Vec::new();
148 let mut blobs = Vec::new();
149
150 let mut visited_commits = HashSet::new();
151 let mut visited_trees = HashSet::new();
152 let mut visited_blobs = HashSet::new();
153
154 let mut commit_queue = VecDeque::from(commit_hashes);
155
156 while let Some(commit_hash) = commit_queue.pop_front() {
158 if visited_commits.contains(&commit_hash) {
159 continue;
160 }
161 visited_commits.insert(commit_hash.clone());
162
163 let commit = self
165 .repo_access
166 .get_commit(&commit_hash)
167 .await
168 .map_err(|e| {
169 ProtocolError::repository_error(format!(
170 "Failed to get commit {}: {}",
171 commit_hash, e
172 ))
173 })?;
174
175 for parent in &commit.parent_commit_ids {
177 let parent_str = parent.to_string();
178 if !visited_commits.contains(&parent_str) {
179 commit_queue.push_back(parent_str);
180 }
181 }
182
183 Box::pin(self.collect_tree_objects(
185 &commit.tree_id.to_string(),
186 &mut trees,
187 &mut blobs,
188 &mut visited_trees,
189 &mut visited_blobs,
190 ))
191 .await?;
192
193 commits.push(commit);
194 }
195
196 Ok((commits, trees, blobs))
197 }
198
199 async fn collect_tree_objects(
201 &self,
202 tree_hash: &str,
203 trees: &mut Vec<Tree>,
204 blobs: &mut Vec<Blob>,
205 visited_trees: &mut HashSet<String>,
206 visited_blobs: &mut HashSet<String>,
207 ) -> Result<(), ProtocolError> {
208 if visited_trees.contains(tree_hash) {
209 return Ok(());
210 }
211 visited_trees.insert(tree_hash.to_string());
212
213 let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
214 ProtocolError::repository_error(format!("Failed to get tree {}: {}", tree_hash, e))
215 })?;
216
217 for entry in &tree.tree_items {
218 let entry_hash = entry.id.to_string();
219 match entry.mode {
220 crate::internal::object::tree::TreeItemMode::Tree => {
221 Box::pin(self.collect_tree_objects(
222 &entry_hash,
223 trees,
224 blobs,
225 visited_trees,
226 visited_blobs,
227 ))
228 .await?;
229 }
230 crate::internal::object::tree::TreeItemMode::Blob
231 | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
232 if !visited_blobs.contains(&entry_hash) {
233 visited_blobs.insert(entry_hash.clone());
234 let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
235 ProtocolError::repository_error(format!(
236 "Failed to get blob {}: {}",
237 entry_hash, e
238 ))
239 })?;
240 blobs.push(blob);
241 }
242 }
243 _ => {}
244 }
245 }
246
247 trees.push(tree);
248 Ok(())
249 }
250
251 fn filter_objects(
253 wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
254 have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
255 ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
256 let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
257 let (have_commits, have_trees, have_blobs) = have;
258
259 let have_commit_hashes: HashSet<String> =
261 have_commits.iter().map(|c| c.id.to_string()).collect();
262 let have_tree_hashes: HashSet<String> =
263 have_trees.iter().map(|t| t.id.to_string()).collect();
264 let have_blob_hashes: HashSet<String> =
265 have_blobs.iter().map(|b| b.id.to_string()).collect();
266
267 let filtered_commits: Vec<Commit> = wanted_commits
269 .into_iter()
270 .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
271 .collect();
272
273 let filtered_trees: Vec<Tree> = wanted_trees
274 .into_iter()
275 .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
276 .collect();
277
278 let filtered_blobs: Vec<Blob> = wanted_blobs
279 .into_iter()
280 .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
281 .collect();
282
283 (filtered_commits, filtered_trees, filtered_blobs)
284 }
285
286 async fn generate_pack_stream(
288 objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
289 tx: mpsc::Sender<Vec<u8>>,
290 ) -> Result<(), ProtocolError> {
291 let (commits, trees, blobs) = objects;
292
293 let mut entries = Vec::new();
295
296 for commit in commits {
297 entries.push(Entry::from(commit));
298 }
299
300 for tree in trees {
301 entries.push(Entry::from(tree));
302 }
303
304 for blob in blobs {
305 entries.push(Entry::from(blob));
306 }
307
308 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
310 let (entry_tx, entry_rx) = mpsc::channel(1024);
311 let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); tokio::spawn(async move {
315 if let Err(e) = encoder.encode(entry_rx).await {
316 tracing::error!("Failed to encode pack: {}", e);
317 }
318 });
319
320 tokio::spawn(async move {
322 for entry in entries {
323 if entry_tx
324 .send(MetaAttached {
325 inner: entry,
326 meta: EntryMeta::new(),
327 })
328 .await
329 .is_err()
330 {
331 break; }
333 }
334 });
336
337 while let Some(chunk) = pack_rx.recv().await {
339 if tx.send(chunk).await.is_err() {
340 break; }
342 }
343
344 Ok(())
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::internal::object::blob::Blob;
352 use crate::internal::object::commit::Commit;
353 use crate::internal::object::signature::{Signature, SignatureType};
354 use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
355 use async_trait::async_trait;
356 use bytes::Bytes;
357
358 #[derive(Clone)]
359 struct DummyRepoAccess;
360
361 #[async_trait]
362 impl RepositoryAccess for DummyRepoAccess {
363 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
364 Ok(vec![])
365 }
366 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
367 Ok(false)
368 }
369 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
370 Err(ProtocolError::repository_error(
371 "not implemented".to_string(),
372 ))
373 }
374 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
375 Ok(())
376 }
377 async fn update_reference(
378 &self,
379 _ref_name: &str,
380 _old_hash: Option<&str>,
381 _new_hash: &str,
382 ) -> Result<(), ProtocolError> {
383 Ok(())
384 }
385 async fn get_objects_for_pack(
386 &self,
387 _wants: &[String],
388 _haves: &[String],
389 ) -> Result<Vec<String>, ProtocolError> {
390 Ok(vec![])
391 }
392 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
393 Ok(false)
394 }
395 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
396 Ok(())
397 }
398 }
399
400 #[tokio::test]
401 async fn test_pack_roundtrip_encode_decode() {
402 let blob1 = Blob::from_content("hello");
404 let blob2 = Blob::from_content("world");
405
406 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
408 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
409 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
410
411 let author = Signature::new(
413 SignatureType::Author,
414 "tester".to_string(),
415 "tester@example.com".to_string(),
416 );
417 let committer = Signature::new(
418 SignatureType::Committer,
419 "tester".to_string(),
420 "tester@example.com".to_string(),
421 );
422 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
423
424 let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
426 PackGenerator::<DummyRepoAccess>::generate_pack_stream(
427 (
428 vec![commit.clone()],
429 vec![tree.clone()],
430 vec![blob1.clone(), blob2.clone()],
431 ),
432 tx,
433 )
434 .await
435 .unwrap();
436
437 let mut pack_bytes: Vec<u8> = Vec::new();
438 while let Some(chunk) = rx.recv().await {
439 pack_bytes.extend_from_slice(&chunk);
440 }
441 println!("Encoded pack size: {} bytes", pack_bytes.len());
442
443 let dummy = DummyRepoAccess;
445 let generator = PackGenerator::new(&dummy);
446 let (decoded_commits, decoded_trees, decoded_blobs) = generator
447 .unpack_stream(Bytes::from(pack_bytes))
448 .await
449 .unwrap();
450
451 println!(
452 "Decoded commits: {:?}",
453 decoded_commits
454 .iter()
455 .map(|c| c.id.to_string())
456 .collect::<Vec<_>>()
457 );
458 println!(
459 "Decoded trees: {:?}",
460 decoded_trees
461 .iter()
462 .map(|t| t.id.to_string())
463 .collect::<Vec<_>>()
464 );
465 println!(
466 "Decoded blobs: {:?}",
467 decoded_blobs
468 .iter()
469 .map(|b| b.id.to_string())
470 .collect::<Vec<_>>()
471 );
472
473 assert_eq!(decoded_commits.len(), 1);
475 assert_eq!(decoded_trees.len(), 1);
476 assert_eq!(decoded_blobs.len(), 2);
477
478 assert_eq!(decoded_commits[0].id, commit.id);
479 assert_eq!(decoded_trees[0].id, tree.id);
480
481 let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
482 orig_blob_ids.sort();
483 let mut decoded_blob_ids = decoded_blobs
484 .iter()
485 .map(|b| b.id.to_string())
486 .collect::<Vec<_>>();
487 decoded_blob_ids.sort();
488 assert_eq!(orig_blob_ids, decoded_blob_ids);
489 }
490}