1use std::{
5 collections::{HashSet, VecDeque},
6 io::Cursor,
7};
8
9use bytes::Bytes;
10use tokio::{self, sync::mpsc};
11use tokio_stream::wrappers::ReceiverStream;
12
13use super::{core::RepositoryAccess, types::ProtocolError};
14use crate::{
15 hash::ObjectHash,
16 internal::{
17 metadata::{EntryMeta, MetaAttached},
18 object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree, types::ObjectType},
19 pack::{Pack, encode::PackEncoder, entry::Entry},
20 },
21};
22
23pub struct PackGenerator<'a, R>
28where
29 R: RepositoryAccess,
30{
31 repo_access: &'a R,
32}
33
34impl<'a, R> PackGenerator<'a, R>
35where
36 R: RepositoryAccess,
37{
38 pub fn new(repo_access: &'a R) -> Self {
39 Self { repo_access }
40 }
41
42 pub async fn generate_full_pack(
44 &self,
45 want: Vec<String>,
46 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
47 let (tx, rx) = mpsc::channel(1024);
48
49 let all_objects = self.collect_all_objects(want).await?;
51
52 tokio::spawn(async move {
54 if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
55 tracing::error!("Failed to generate pack stream: {}", e);
56 }
57 });
58
59 Ok(ReceiverStream::new(rx))
60 }
61
62 pub async fn generate_incremental_pack(
64 &self,
65 want: Vec<String>,
66 have: Vec<String>,
67 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
68 let (tx, rx) = mpsc::channel(1024);
69
70 let wanted_objects = self.collect_all_objects(want).await?;
72
73 let have_objects = self.collect_all_objects(have).await?;
75
76 let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
78
79 tokio::spawn(async move {
81 if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
82 tracing::error!("Failed to generate incremental pack stream: {}", e);
83 }
84 });
85
86 Ok(ReceiverStream::new(rx))
87 }
88
89 pub async fn unpack_stream(
91 &self,
92 pack_data: Bytes,
93 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
94 use std::sync::{Arc, Mutex};
95
96 let commits = Arc::new(Mutex::new(Vec::new()));
97 let trees = Arc::new(Mutex::new(Vec::new()));
98 let blobs = Arc::new(Mutex::new(Vec::new()));
99
100 let commits_clone = commits.clone();
101 let trees_clone = trees.clone();
102 let blobs_clone = blobs.clone();
103
104 let mut pack = Pack::new(None, None, None, true);
106 let mut cursor = Cursor::new(pack_data.to_vec());
107
108 pack.decode(
110 &mut cursor,
111 move |entry: MetaAttached<Entry, EntryMeta>| match entry.inner.obj_type {
112 ObjectType::Commit => {
113 if let Ok(commit) = Commit::from_bytes(&entry.inner.data, entry.inner.hash) {
114 commits_clone.lock().unwrap().push(commit);
115 } else {
116 tracing::warn!("Failed to parse commit from pack entry");
117 }
118 }
119 ObjectType::Tree => {
120 if let Ok(tree) = Tree::from_bytes(&entry.inner.data, entry.inner.hash) {
121 trees_clone.lock().unwrap().push(tree);
122 } else {
123 tracing::warn!("Failed to parse tree from pack entry");
124 }
125 }
126 ObjectType::Blob => {
127 if let Ok(blob) = Blob::from_bytes(&entry.inner.data, entry.inner.hash) {
128 blobs_clone.lock().unwrap().push(blob);
129 } else {
130 tracing::warn!("Failed to parse blob from pack entry");
131 }
132 }
133 _ => {
134 tracing::warn!("Unknown object type in pack: {:?}", entry.inner.obj_type);
135 }
136 },
137 None::<fn(ObjectHash)>,
138 )
139 .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {e}")))?;
140
141 let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
143 let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
144 let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
145
146 Ok((commits_result, trees_result, blobs_result))
147 }
148
149 async fn collect_all_objects(
151 &self,
152 commit_hashes: Vec<String>,
153 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
154 let mut commits = Vec::new();
155 let mut trees = Vec::new();
156 let mut blobs = Vec::new();
157
158 let mut visited_commits = HashSet::new();
159 let mut visited_trees = HashSet::new();
160 let mut visited_blobs = HashSet::new();
161
162 let mut commit_queue = VecDeque::from(commit_hashes);
163
164 while let Some(commit_hash) = commit_queue.pop_front() {
166 if visited_commits.contains(&commit_hash) {
167 continue;
168 }
169 visited_commits.insert(commit_hash.clone());
170
171 let commit = self
173 .repo_access
174 .get_commit(&commit_hash)
175 .await
176 .map_err(|e| {
177 ProtocolError::repository_error(format!(
178 "Failed to get commit {commit_hash}: {e}"
179 ))
180 })?;
181
182 for parent in &commit.parent_commit_ids {
184 let parent_str = parent.to_string();
185 if !visited_commits.contains(&parent_str) {
186 commit_queue.push_back(parent_str);
187 }
188 }
189
190 Box::pin(self.collect_tree_objects(
192 &commit.tree_id.to_string(),
193 &mut trees,
194 &mut blobs,
195 &mut visited_trees,
196 &mut visited_blobs,
197 ))
198 .await?;
199
200 commits.push(commit);
201 }
202
203 Ok((commits, trees, blobs))
204 }
205
206 async fn collect_tree_objects(
208 &self,
209 tree_hash: &str,
210 trees: &mut Vec<Tree>,
211 blobs: &mut Vec<Blob>,
212 visited_trees: &mut HashSet<String>,
213 visited_blobs: &mut HashSet<String>,
214 ) -> Result<(), ProtocolError> {
215 if visited_trees.contains(tree_hash) {
216 return Ok(());
217 }
218 visited_trees.insert(tree_hash.to_string());
219
220 let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
221 ProtocolError::repository_error(format!("Failed to get tree {tree_hash}: {e}"))
222 })?;
223
224 for entry in &tree.tree_items {
225 let entry_hash = entry.id.to_string();
226 match entry.mode {
227 crate::internal::object::tree::TreeItemMode::Tree => {
228 Box::pin(self.collect_tree_objects(
229 &entry_hash,
230 trees,
231 blobs,
232 visited_trees,
233 visited_blobs,
234 ))
235 .await?;
236 }
237 crate::internal::object::tree::TreeItemMode::Blob
238 | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
239 if !visited_blobs.contains(&entry_hash) {
240 visited_blobs.insert(entry_hash.clone());
241 let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
242 ProtocolError::repository_error(format!(
243 "Failed to get blob {entry_hash}: {e}"
244 ))
245 })?;
246 blobs.push(blob);
247 }
248 }
249 _ => {}
250 }
251 }
252
253 trees.push(tree);
254 Ok(())
255 }
256
257 fn filter_objects(
259 wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
260 have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
261 ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
262 let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
263 let (have_commits, have_trees, have_blobs) = have;
264
265 let have_commit_hashes: HashSet<String> =
267 have_commits.iter().map(|c| c.id.to_string()).collect();
268 let have_tree_hashes: HashSet<String> =
269 have_trees.iter().map(|t| t.id.to_string()).collect();
270 let have_blob_hashes: HashSet<String> =
271 have_blobs.iter().map(|b| b.id.to_string()).collect();
272
273 let filtered_commits: Vec<Commit> = wanted_commits
275 .into_iter()
276 .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
277 .collect();
278
279 let filtered_trees: Vec<Tree> = wanted_trees
280 .into_iter()
281 .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
282 .collect();
283
284 let filtered_blobs: Vec<Blob> = wanted_blobs
285 .into_iter()
286 .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
287 .collect();
288
289 (filtered_commits, filtered_trees, filtered_blobs)
290 }
291
292 async fn generate_pack_stream(
294 objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
295 tx: mpsc::Sender<Vec<u8>>,
296 ) -> Result<(), ProtocolError> {
297 let (commits, trees, blobs) = objects;
298
299 let mut entries = Vec::new();
301
302 for commit in commits {
303 entries.push(Entry::from(commit));
304 }
305
306 for tree in trees {
307 entries.push(Entry::from(tree));
308 }
309
310 for blob in blobs {
311 entries.push(Entry::from(blob));
312 }
313
314 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
316 let (entry_tx, entry_rx) = mpsc::channel(1024);
317 let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); tokio::spawn(async move {
321 if let Err(e) = encoder.encode(entry_rx).await {
322 tracing::error!("Failed to encode pack: {}", e);
323 }
324 });
325
326 tokio::spawn(async move {
328 for entry in entries {
329 if entry_tx
330 .send(MetaAttached {
331 inner: entry,
332 meta: EntryMeta::new(),
333 })
334 .await
335 .is_err()
336 {
337 break; }
339 }
340 });
342
343 while let Some(chunk) = pack_rx.recv().await {
345 if tx.send(chunk).await.is_err() {
346 break; }
348 }
349
350 Ok(())
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use async_trait::async_trait;
357 use bytes::Bytes;
358
359 use super::*;
360 use crate::{
361 hash::{HashKind, set_hash_kind_for_test},
362 internal::object::{
363 blob::Blob,
364 commit::Commit,
365 signature::{Signature, SignatureType},
366 tree::{Tree, TreeItem, TreeItemMode},
367 },
368 };
369 #[derive(Clone)]
371 struct DummyRepoAccess;
372
373 #[async_trait]
374 impl RepositoryAccess for DummyRepoAccess {
375 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
376 Ok(vec![])
377 }
378 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
379 Ok(false)
380 }
381 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
382 Err(ProtocolError::repository_error(
383 "not implemented".to_string(),
384 ))
385 }
386 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
387 Ok(())
388 }
389 async fn update_reference(
390 &self,
391 _ref_name: &str,
392 _old_hash: Option<&str>,
393 _new_hash: &str,
394 ) -> Result<(), ProtocolError> {
395 Ok(())
396 }
397 async fn get_objects_for_pack(
398 &self,
399 _wants: &[String],
400 _haves: &[String],
401 ) -> Result<Vec<String>, ProtocolError> {
402 Ok(vec![])
403 }
404 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
405 Ok(false)
406 }
407 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
408 Ok(())
409 }
410 }
411
412 async fn run_pack_roundtrip(kind: HashKind) {
414 let _guard = set_hash_kind_for_test(kind);
415 let blob1 = Blob::from_content("hello");
416 let blob2 = Blob::from_content("world");
417
418 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
419 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
420 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
421
422 let author = Signature::new(
423 SignatureType::Author,
424 "tester".to_string(),
425 "tester@example.com".to_string(),
426 );
427 let committer = Signature::new(
428 SignatureType::Committer,
429 "tester".to_string(),
430 "tester@example.com".to_string(),
431 );
432 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
433
434 let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
435 PackGenerator::<DummyRepoAccess>::generate_pack_stream(
436 (
437 vec![commit.clone()],
438 vec![tree.clone()],
439 vec![blob1.clone(), blob2.clone()],
440 ),
441 tx,
442 )
443 .await
444 .unwrap();
445
446 let mut pack_bytes: Vec<u8> = Vec::new();
447 while let Some(chunk) = rx.recv().await {
448 pack_bytes.extend_from_slice(&chunk);
449 }
450
451 let dummy = DummyRepoAccess;
452 let generator = PackGenerator::new(&dummy);
453 let (decoded_commits, decoded_trees, decoded_blobs) = generator
454 .unpack_stream(Bytes::from(pack_bytes))
455 .await
456 .unwrap();
457
458 assert_eq!(decoded_commits.len(), 1);
459 assert_eq!(decoded_trees.len(), 1);
460 assert_eq!(decoded_blobs.len(), 2);
461
462 assert_eq!(decoded_commits[0].id, commit.id);
463 assert_eq!(decoded_trees[0].id, tree.id);
464
465 let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
466 orig_blob_ids.sort_unstable();
467 let mut decoded_blob_ids = decoded_blobs
468 .iter()
469 .map(|b| b.id.to_string())
470 .collect::<Vec<_>>();
471 decoded_blob_ids.sort_unstable();
472 assert_eq!(orig_blob_ids, decoded_blob_ids);
473 }
474
475 #[tokio::test]
477 async fn test_pack_roundtrip_encode_decode() {
478 run_pack_roundtrip(HashKind::Sha1).await;
479 run_pack_roundtrip(HashKind::Sha256).await;
480 }
481}