1use bytes::Bytes;
2use std::collections::{HashSet, VecDeque};
3use std::io::Cursor;
4use tokio;
5use tokio::sync::mpsc;
6use tokio_stream::wrappers::ReceiverStream;
7
8use super::core::RepositoryAccess;
9use super::types::ProtocolError;
10use crate::internal::object::types::ObjectType;
11use crate::internal::object::{ObjectTrait, blob::Blob, commit::Commit, tree::Tree};
12use crate::internal::pack::{Pack, encode::PackEncoder, entry::Entry};
13
14pub struct PackGenerator<'a, R>
19where
20 R: RepositoryAccess,
21{
22 repo_access: &'a R,
23}
24
25impl<'a, R> PackGenerator<'a, R>
26where
27 R: RepositoryAccess,
28{
29 pub fn new(repo_access: &'a R) -> Self {
30 Self { repo_access }
31 }
32
33 pub async fn generate_full_pack(
35 &self,
36 want: Vec<String>,
37 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
38 let (tx, rx) = mpsc::channel(1024);
39
40 let all_objects = self.collect_all_objects(want).await?;
42
43 tokio::spawn(async move {
45 if let Err(e) = Self::generate_pack_stream(all_objects, tx).await {
46 tracing::error!("Failed to generate pack stream: {}", e);
47 }
48 });
49
50 Ok(ReceiverStream::new(rx))
51 }
52
53 pub async fn generate_incremental_pack(
55 &self,
56 want: Vec<String>,
57 have: Vec<String>,
58 ) -> Result<ReceiverStream<Vec<u8>>, ProtocolError> {
59 let (tx, rx) = mpsc::channel(1024);
60
61 let wanted_objects = self.collect_all_objects(want).await?;
63
64 let have_objects = self.collect_all_objects(have).await?;
66
67 let incremental_objects = Self::filter_objects(wanted_objects, have_objects);
69
70 tokio::spawn(async move {
72 if let Err(e) = Self::generate_pack_stream(incremental_objects, tx).await {
73 tracing::error!("Failed to generate incremental pack stream: {}", e);
74 }
75 });
76
77 Ok(ReceiverStream::new(rx))
78 }
79
80 pub async fn unpack_stream(
82 &self,
83 pack_data: Bytes,
84 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
85 use std::sync::{Arc, Mutex};
86
87 let commits = Arc::new(Mutex::new(Vec::new()));
88 let trees = Arc::new(Mutex::new(Vec::new()));
89 let blobs = Arc::new(Mutex::new(Vec::new()));
90
91 let commits_clone = commits.clone();
92 let trees_clone = trees.clone();
93 let blobs_clone = blobs.clone();
94
95 let mut pack = Pack::new(None, None, None, true);
97 let mut cursor = Cursor::new(pack_data.to_vec());
98
99 pack.decode(
101 &mut cursor,
102 move |entry: Entry, _offset: usize| match entry.obj_type {
103 ObjectType::Commit => {
104 if let Ok(commit) = Commit::from_bytes(&entry.data, entry.hash) {
105 commits_clone.lock().unwrap().push(commit);
106 } else {
107 tracing::warn!("Failed to parse commit from pack entry");
108 }
109 }
110 ObjectType::Tree => {
111 if let Ok(tree) = Tree::from_bytes(&entry.data, entry.hash) {
112 trees_clone.lock().unwrap().push(tree);
113 } else {
114 tracing::warn!("Failed to parse tree from pack entry");
115 }
116 }
117 ObjectType::Blob => {
118 if let Ok(blob) = Blob::from_bytes(&entry.data, entry.hash) {
119 blobs_clone.lock().unwrap().push(blob);
120 } else {
121 tracing::warn!("Failed to parse blob from pack entry");
122 }
123 }
124 _ => {
125 tracing::warn!("Unknown object type in pack: {:?}", entry.obj_type);
126 }
127 },
128 )
129 .map_err(|e| ProtocolError::invalid_request(&format!("Failed to decode pack: {}", e)))?;
130
131 let commits_result = Arc::try_unwrap(commits).unwrap().into_inner().unwrap();
133 let trees_result = Arc::try_unwrap(trees).unwrap().into_inner().unwrap();
134 let blobs_result = Arc::try_unwrap(blobs).unwrap().into_inner().unwrap();
135
136 Ok((commits_result, trees_result, blobs_result))
137 }
138
139 async fn collect_all_objects(
141 &self,
142 commit_hashes: Vec<String>,
143 ) -> Result<(Vec<Commit>, Vec<Tree>, Vec<Blob>), ProtocolError> {
144 let mut commits = Vec::new();
145 let mut trees = Vec::new();
146 let mut blobs = Vec::new();
147
148 let mut visited_commits = HashSet::new();
149 let mut visited_trees = HashSet::new();
150 let mut visited_blobs = HashSet::new();
151
152 let mut commit_queue = VecDeque::from(commit_hashes);
153
154 while let Some(commit_hash) = commit_queue.pop_front() {
156 if visited_commits.contains(&commit_hash) {
157 continue;
158 }
159 visited_commits.insert(commit_hash.clone());
160
161 let commit = self
163 .repo_access
164 .get_commit(&commit_hash)
165 .await
166 .map_err(|e| {
167 ProtocolError::repository_error(format!(
168 "Failed to get commit {}: {}",
169 commit_hash, e
170 ))
171 })?;
172
173 for parent in &commit.parent_commit_ids {
175 let parent_str = parent.to_string();
176 if !visited_commits.contains(&parent_str) {
177 commit_queue.push_back(parent_str);
178 }
179 }
180
181 Box::pin(self.collect_tree_objects(
183 &commit.tree_id.to_string(),
184 &mut trees,
185 &mut blobs,
186 &mut visited_trees,
187 &mut visited_blobs,
188 ))
189 .await?;
190
191 commits.push(commit);
192 }
193
194 Ok((commits, trees, blobs))
195 }
196
197 async fn collect_tree_objects(
199 &self,
200 tree_hash: &str,
201 trees: &mut Vec<Tree>,
202 blobs: &mut Vec<Blob>,
203 visited_trees: &mut HashSet<String>,
204 visited_blobs: &mut HashSet<String>,
205 ) -> Result<(), ProtocolError> {
206 if visited_trees.contains(tree_hash) {
207 return Ok(());
208 }
209 visited_trees.insert(tree_hash.to_string());
210
211 let tree = self.repo_access.get_tree(tree_hash).await.map_err(|e| {
212 ProtocolError::repository_error(format!("Failed to get tree {}: {}", tree_hash, e))
213 })?;
214
215 for entry in &tree.tree_items {
216 let entry_hash = entry.id.to_string();
217 match entry.mode {
218 crate::internal::object::tree::TreeItemMode::Tree => {
219 Box::pin(self.collect_tree_objects(
220 &entry_hash,
221 trees,
222 blobs,
223 visited_trees,
224 visited_blobs,
225 ))
226 .await?;
227 }
228 crate::internal::object::tree::TreeItemMode::Blob
229 | crate::internal::object::tree::TreeItemMode::BlobExecutable => {
230 if !visited_blobs.contains(&entry_hash) {
231 visited_blobs.insert(entry_hash.clone());
232 let blob = self.repo_access.get_blob(&entry_hash).await.map_err(|e| {
233 ProtocolError::repository_error(format!(
234 "Failed to get blob {}: {}",
235 entry_hash, e
236 ))
237 })?;
238 blobs.push(blob);
239 }
240 }
241 _ => {}
242 }
243 }
244
245 trees.push(tree);
246 Ok(())
247 }
248
249 fn filter_objects(
251 wanted: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
252 have: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
253 ) -> (Vec<Commit>, Vec<Tree>, Vec<Blob>) {
254 let (wanted_commits, wanted_trees, wanted_blobs) = wanted;
255 let (have_commits, have_trees, have_blobs) = have;
256
257 let have_commit_hashes: HashSet<String> =
259 have_commits.iter().map(|c| c.id.to_string()).collect();
260 let have_tree_hashes: HashSet<String> =
261 have_trees.iter().map(|t| t.id.to_string()).collect();
262 let have_blob_hashes: HashSet<String> =
263 have_blobs.iter().map(|b| b.id.to_string()).collect();
264
265 let filtered_commits: Vec<Commit> = wanted_commits
267 .into_iter()
268 .filter(|c| !have_commit_hashes.contains(&c.id.to_string()))
269 .collect();
270
271 let filtered_trees: Vec<Tree> = wanted_trees
272 .into_iter()
273 .filter(|t| !have_tree_hashes.contains(&t.id.to_string()))
274 .collect();
275
276 let filtered_blobs: Vec<Blob> = wanted_blobs
277 .into_iter()
278 .filter(|b| !have_blob_hashes.contains(&b.id.to_string()))
279 .collect();
280
281 (filtered_commits, filtered_trees, filtered_blobs)
282 }
283
284 async fn generate_pack_stream(
286 objects: (Vec<Commit>, Vec<Tree>, Vec<Blob>),
287 tx: mpsc::Sender<Vec<u8>>,
288 ) -> Result<(), ProtocolError> {
289 let (commits, trees, blobs) = objects;
290
291 let mut entries = Vec::new();
293
294 for commit in commits {
295 entries.push(Entry::from(commit));
296 }
297
298 for tree in trees {
299 entries.push(Entry::from(tree));
300 }
301
302 for blob in blobs {
303 entries.push(Entry::from(blob));
304 }
305
306 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
308 let (entry_tx, entry_rx) = mpsc::channel(1024);
309 let mut encoder = PackEncoder::new(entries.len(), 10, pack_tx); tokio::spawn(async move {
313 if let Err(e) = encoder.encode(entry_rx).await {
314 tracing::error!("Failed to encode pack: {}", e);
315 }
316 });
317
318 tokio::spawn(async move {
320 for entry in entries {
321 if entry_tx.send(entry).await.is_err() {
322 break; }
324 }
325 });
327
328 while let Some(chunk) = pack_rx.recv().await {
330 if tx.send(chunk).await.is_err() {
331 break; }
333 }
334
335 Ok(())
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use crate::internal::object::blob::Blob;
343 use crate::internal::object::commit::Commit;
344 use crate::internal::object::signature::{Signature, SignatureType};
345 use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
346 use async_trait::async_trait;
347 use bytes::Bytes;
348
349 #[derive(Clone)]
350 struct DummyRepoAccess;
351
352 #[async_trait]
353 impl RepositoryAccess for DummyRepoAccess {
354 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
355 Ok(vec![])
356 }
357 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
358 Ok(false)
359 }
360 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
361 Err(ProtocolError::repository_error(
362 "not implemented".to_string(),
363 ))
364 }
365 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
366 Ok(())
367 }
368 async fn update_reference(
369 &self,
370 _ref_name: &str,
371 _old_hash: Option<&str>,
372 _new_hash: &str,
373 ) -> Result<(), ProtocolError> {
374 Ok(())
375 }
376 async fn get_objects_for_pack(
377 &self,
378 _wants: &[String],
379 _haves: &[String],
380 ) -> Result<Vec<String>, ProtocolError> {
381 Ok(vec![])
382 }
383 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
384 Ok(false)
385 }
386 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
387 Ok(())
388 }
389 }
390
391 #[tokio::test]
392 async fn test_pack_roundtrip_encode_decode() {
393 let blob1 = Blob::from_content("hello");
395 let blob2 = Blob::from_content("world");
396
397 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
399 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
400 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
401
402 let author = Signature::new(
404 SignatureType::Author,
405 "tester".to_string(),
406 "tester@example.com".to_string(),
407 );
408 let committer = Signature::new(
409 SignatureType::Committer,
410 "tester".to_string(),
411 "tester@example.com".to_string(),
412 );
413 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
414
415 let (tx, mut rx) = mpsc::channel::<Vec<u8>>(64);
417 PackGenerator::<DummyRepoAccess>::generate_pack_stream(
418 (
419 vec![commit.clone()],
420 vec![tree.clone()],
421 vec![blob1.clone(), blob2.clone()],
422 ),
423 tx,
424 )
425 .await
426 .unwrap();
427
428 let mut pack_bytes: Vec<u8> = Vec::new();
429 while let Some(chunk) = rx.recv().await {
430 pack_bytes.extend_from_slice(&chunk);
431 }
432 println!("Encoded pack size: {} bytes", pack_bytes.len());
433
434 let dummy = DummyRepoAccess;
436 let generator = PackGenerator::new(&dummy);
437 let (decoded_commits, decoded_trees, decoded_blobs) = generator
438 .unpack_stream(Bytes::from(pack_bytes))
439 .await
440 .unwrap();
441
442 println!(
443 "Decoded commits: {:?}",
444 decoded_commits
445 .iter()
446 .map(|c| c.id.to_string())
447 .collect::<Vec<_>>()
448 );
449 println!(
450 "Decoded trees: {:?}",
451 decoded_trees
452 .iter()
453 .map(|t| t.id.to_string())
454 .collect::<Vec<_>>()
455 );
456 println!(
457 "Decoded blobs: {:?}",
458 decoded_blobs
459 .iter()
460 .map(|b| b.id.to_string())
461 .collect::<Vec<_>>()
462 );
463
464 assert_eq!(decoded_commits.len(), 1);
466 assert_eq!(decoded_trees.len(), 1);
467 assert_eq!(decoded_blobs.len(), 2);
468
469 assert_eq!(decoded_commits[0].id, commit.id);
470 assert_eq!(decoded_trees[0].id, tree.id);
471
472 let mut orig_blob_ids = vec![blob1.id.to_string(), blob2.id.to_string()];
473 orig_blob_ids.sort();
474 let mut decoded_blob_ids = decoded_blobs
475 .iter()
476 .map(|b| b.id.to_string())
477 .collect::<Vec<_>>();
478 decoded_blob_ids.sort();
479 assert_eq!(orig_blob_ids, decoded_blob_ids);
480 }
481}