1use std::{collections::HashMap, str::FromStr};
6
7use async_trait::async_trait;
8use bytes::{BufMut, Bytes, BytesMut};
9use futures::stream::StreamExt;
10
11use crate::{
12 hash::ObjectHash,
13 internal::object::ObjectTrait,
14 protocol::{
15 smart::SmartProtocol,
16 types::{Capability, ProtocolError, ProtocolStream, ServiceType, SideBand},
17 },
18};
19
20#[async_trait]
25pub trait RepositoryAccess: Send + Sync + Clone {
26 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError>;
28
29 async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError>;
31
32 async fn get_object(&self, object_hash: &str) -> Result<Vec<u8>, ProtocolError>;
34
35 async fn store_pack_data(&self, pack_data: &[u8]) -> Result<(), ProtocolError>;
37
38 async fn update_reference(
40 &self,
41 ref_name: &str,
42 old_hash: Option<&str>,
43 new_hash: &str,
44 ) -> Result<(), ProtocolError>;
45
46 async fn get_objects_for_pack(
48 &self,
49 wants: &[String],
50 haves: &[String],
51 ) -> Result<Vec<String>, ProtocolError>;
52
53 async fn has_default_branch(&self) -> Result<bool, ProtocolError>;
55
56 async fn post_receive_hook(&self) -> Result<(), ProtocolError>;
58
59 async fn get_blob(
64 &self,
65 object_hash: &str,
66 ) -> Result<crate::internal::object::blob::Blob, ProtocolError> {
67 let data = self.get_object(object_hash).await?;
68 let hash = ObjectHash::from_str(object_hash)
69 .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
70
71 crate::internal::object::blob::Blob::from_bytes(&data, hash)
72 .map_err(|e| ProtocolError::repository_error(format!("Failed to parse blob: {e}")))
73 }
74
75 async fn get_commit(
80 &self,
81 commit_hash: &str,
82 ) -> Result<crate::internal::object::commit::Commit, ProtocolError> {
83 let data = self.get_object(commit_hash).await?;
84 let hash = ObjectHash::from_str(commit_hash)
85 .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
86
87 crate::internal::object::commit::Commit::from_bytes(&data, hash)
88 .map_err(|e| ProtocolError::repository_error(format!("Failed to parse commit: {e}")))
89 }
90
91 async fn get_tree(
96 &self,
97 tree_hash: &str,
98 ) -> Result<crate::internal::object::tree::Tree, ProtocolError> {
99 let data = self.get_object(tree_hash).await?;
100 let hash = ObjectHash::from_str(tree_hash)
101 .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
102
103 crate::internal::object::tree::Tree::from_bytes(&data, hash)
104 .map_err(|e| ProtocolError::repository_error(format!("Failed to parse tree: {e}")))
105 }
106
107 async fn commit_exists(&self, commit_hash: &str) -> Result<bool, ProtocolError> {
112 match self.has_object(commit_hash).await {
113 Ok(exists) => {
114 if !exists {
115 return Ok(false);
116 }
117
118 match self.get_commit(commit_hash).await {
120 Ok(_) => Ok(true),
121 Err(_) => Ok(false), }
123 }
124 Err(e) => Err(e),
125 }
126 }
127
128 async fn handle_pack_objects(
133 &self,
134 commits: Vec<crate::internal::object::commit::Commit>,
135 trees: Vec<crate::internal::object::tree::Tree>,
136 blobs: Vec<crate::internal::object::blob::Blob>,
137 ) -> Result<(), ProtocolError> {
138 for blob in blobs {
140 let data = blob.to_data().map_err(|e| {
141 ProtocolError::repository_error(format!("Failed to serialize blob: {e}"))
142 })?;
143 self.store_pack_data(&data).await.map_err(|e| {
144 ProtocolError::repository_error(format!("Failed to store blob {}: {}", blob.id, e))
145 })?;
146 }
147
148 for tree in trees {
150 let data = tree.to_data().map_err(|e| {
151 ProtocolError::repository_error(format!("Failed to serialize tree: {e}"))
152 })?;
153 self.store_pack_data(&data).await.map_err(|e| {
154 ProtocolError::repository_error(format!("Failed to store tree {}: {}", tree.id, e))
155 })?;
156 }
157
158 for commit in commits {
160 let data = commit.to_data().map_err(|e| {
161 ProtocolError::repository_error(format!("Failed to serialize commit: {e}"))
162 })?;
163 self.store_pack_data(&data).await.map_err(|e| {
164 ProtocolError::repository_error(format!(
165 "Failed to store commit {}: {}",
166 commit.id, e
167 ))
168 })?;
169 }
170
171 Ok(())
172 }
173}
174
175#[async_trait]
177pub trait AuthenticationService: Send + Sync {
178 async fn authenticate_http(
180 &self,
181 headers: &std::collections::HashMap<String, String>,
182 ) -> Result<(), ProtocolError>;
183
184 async fn authenticate_ssh(
186 &self,
187 username: &str,
188 public_key: &[u8],
189 ) -> Result<(), ProtocolError>;
190}
191
192pub struct GitProtocol<R: RepositoryAccess, A: AuthenticationService> {
199 smart_protocol: SmartProtocol<R, A>,
200}
201
202impl<R: RepositoryAccess, A: AuthenticationService> GitProtocol<R, A> {
203 pub fn new(repo_access: R, auth_service: A) -> Self {
205 Self {
206 smart_protocol: SmartProtocol::new(
207 super::types::TransportProtocol::Http,
208 repo_access,
209 auth_service,
210 ),
211 }
212 }
213
214 pub async fn authenticate_http(
216 &self,
217 headers: &HashMap<String, String>,
218 ) -> Result<(), ProtocolError> {
219 self.smart_protocol.authenticate_http(headers).await
220 }
221
222 pub async fn authenticate_ssh(
224 &self,
225 username: &str,
226 public_key: &[u8],
227 ) -> Result<(), ProtocolError> {
228 self.smart_protocol
229 .authenticate_ssh(username, public_key)
230 .await
231 }
232
233 pub fn set_transport(&mut self, protocol: super::types::TransportProtocol) {
235 self.smart_protocol.set_transport_protocol(protocol);
236 }
237
238 pub async fn info_refs(&self, service: &str) -> Result<Vec<u8>, ProtocolError> {
240 let service_type = match service {
241 "git-upload-pack" => ServiceType::UploadPack,
242 "git-receive-pack" => ServiceType::ReceivePack,
243 _ => return Err(ProtocolError::invalid_service(service)),
244 };
245
246 let bytes = self.smart_protocol.git_info_refs(service_type).await?;
247 Ok(bytes.to_vec())
248 }
249
250 pub async fn upload_pack(
252 &mut self,
253 request_data: &[u8],
254 ) -> Result<ProtocolStream, ProtocolError> {
255 const SIDE_BAND_PACKET_LEN: usize = 1000;
256 const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
257 const SIDE_BAND_HEADER_LEN: usize = 5; let request_bytes = bytes::Bytes::from(request_data.to_vec());
260 let (pack_stream, protocol_buf) =
261 self.smart_protocol.git_upload_pack(request_bytes).await?;
262 let ack_bytes = protocol_buf.freeze();
263
264 let ack_stream: ProtocolStream = if ack_bytes.is_empty() {
265 Box::pin(futures::stream::empty::<Result<Bytes, ProtocolError>>())
266 } else {
267 Box::pin(futures::stream::once(async move { Ok(ack_bytes) }))
268 };
269
270 let sideband_max = if self
271 .smart_protocol
272 .capabilities
273 .contains(&Capability::SideBand64k)
274 {
275 Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
276 } else if self
277 .smart_protocol
278 .capabilities
279 .contains(&Capability::SideBand)
280 {
281 Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
282 } else {
283 None
284 };
285
286 let data_stream: ProtocolStream = if let Some(max_payload) = sideband_max {
287 let stream = pack_stream.flat_map(move |chunk| {
288 let packets = build_side_band_packets(&chunk, max_payload);
289 futures::stream::iter(packets.into_iter().map(Ok))
290 });
291 let stream = stream.chain(futures::stream::once(async {
292 Ok(Bytes::from_static(b"0000"))
293 }));
294 Box::pin(stream)
295 } else {
296 Box::pin(pack_stream.map(|data| Ok(Bytes::from(data))))
297 };
298
299 Ok(Box::pin(ack_stream.chain(data_stream)))
300 }
301
302 pub async fn receive_pack(
304 &mut self,
305 request_stream: ProtocolStream,
306 ) -> Result<ProtocolStream, ProtocolError> {
307 const SIDE_BAND_PACKET_LEN: usize = 1000;
308 const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
309 const SIDE_BAND_HEADER_LEN: usize = 5; let result_bytes = self
312 .smart_protocol
313 .git_receive_pack_stream(request_stream)
314 .await?;
315
316 let sideband_max = if self
317 .smart_protocol
318 .capabilities
319 .contains(&Capability::SideBand64k)
320 {
321 Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
322 } else if self
323 .smart_protocol
324 .capabilities
325 .contains(&Capability::SideBand)
326 {
327 Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
328 } else {
329 None
330 };
331
332 if let Some(max_payload) = sideband_max {
334 let packets = build_side_band_packets(result_bytes.as_ref(), max_payload);
335 let stream = futures::stream::iter(packets.into_iter().map(Ok)).chain(
336 futures::stream::once(async { Ok(Bytes::from_static(b"0000")) }),
337 );
338 Ok(Box::pin(stream))
339 } else {
340 Ok(Box::pin(futures::stream::once(async { Ok(result_bytes) })))
342 }
343 }
344}
345
346fn build_side_band_packets(chunk: &[u8], max_payload: usize) -> Vec<Bytes> {
347 if chunk.is_empty() {
348 return Vec::new();
349 }
350
351 let mut out = Vec::new();
352 let mut offset = 0;
353
354 while offset < chunk.len() {
355 let end = (offset + max_payload).min(chunk.len());
356 let payload = &chunk[offset..end];
357 let length = payload.len() + 5; let mut pkt = BytesMut::with_capacity(length);
359 pkt.put(Bytes::from(format!("{length:04x}")));
360 pkt.put_u8(SideBand::PackfileData.value());
361 pkt.put(payload);
362 out.push(pkt.freeze());
363 offset = end;
364 }
365
366 out
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::hash::{HashKind, set_hash_kind_for_test};
373 use crate::internal::object::{
374 blob::Blob,
375 commit::Commit,
376 signature::{Signature, SignatureType},
377 tree::{Tree, TreeItem, TreeItemMode},
378 };
379 use crate::protocol::types::TransportProtocol;
380 use crate::protocol::utils;
381 use async_trait::async_trait;
382 use bytes::{Bytes, BytesMut};
383 use futures::StreamExt;
384
385 #[derive(Clone)]
387 struct MockRepo {
388 refs: Vec<(String, String)>,
389 }
390
391 #[async_trait]
392 impl RepositoryAccess for MockRepo {
393 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
394 Ok(self.refs.clone())
395 }
396 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
397 Ok(false)
398 }
399 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
400 Ok(Vec::new())
401 }
402 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
403 Ok(())
404 }
405 async fn update_reference(
406 &self,
407 _ref_name: &str,
408 _old_hash: Option<&str>,
409 _new_hash: &str,
410 ) -> Result<(), ProtocolError> {
411 Ok(())
412 }
413 async fn get_objects_for_pack(
414 &self,
415 wants: &[String],
416 _haves: &[String],
417 ) -> Result<Vec<String>, ProtocolError> {
418 Ok(wants.to_vec())
419 }
420 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
421 Ok(false)
422 }
423 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
424 Ok(())
425 }
426 }
427
428 struct MockAuth;
430 #[async_trait]
431 impl AuthenticationService for MockAuth {
432 async fn authenticate_http(
433 &self,
434 _headers: &std::collections::HashMap<String, String>,
435 ) -> Result<(), ProtocolError> {
436 Ok(())
437 }
438 async fn authenticate_ssh(
439 &self,
440 _username: &str,
441 _public_key: &[u8],
442 ) -> Result<(), ProtocolError> {
443 Ok(())
444 }
445 }
446
447 fn make_protocol() -> GitProtocol<MockRepo, MockAuth> {
449 GitProtocol::new(
450 MockRepo {
451 refs: vec![
452 (
453 "refs/heads/main".to_string(),
454 ObjectHash::default().to_string(),
455 ),
456 ("HEAD".to_string(), ObjectHash::default().to_string()),
457 ],
458 },
459 MockAuth,
460 )
461 }
462
463 #[derive(Clone)]
465 struct SideBandRepo {
466 commit: Commit,
467 tree: Tree,
468 blobs: Vec<Blob>,
469 }
470 #[async_trait]
471 impl RepositoryAccess for SideBandRepo {
472 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
473 Ok(vec![(
474 "refs/heads/main".to_string(),
475 self.commit.id.to_string(),
476 )])
477 }
478
479 async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError> {
480 let known = object_hash == self.commit.id.to_string()
481 || object_hash == self.tree.id.to_string()
482 || self.blobs.iter().any(|b| b.id.to_string() == object_hash);
483 Ok(known)
484 }
485
486 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
487 Ok(Vec::new())
488 }
489
490 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
491 Ok(())
492 }
493
494 async fn update_reference(
495 &self,
496 _ref_name: &str,
497 _old_hash: Option<&str>,
498 _new_hash: &str,
499 ) -> Result<(), ProtocolError> {
500 Ok(())
501 }
502
503 async fn get_objects_for_pack(
504 &self,
505 _wants: &[String],
506 _haves: &[String],
507 ) -> Result<Vec<String>, ProtocolError> {
508 Ok(Vec::new())
509 }
510
511 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
512 Ok(true)
513 }
514
515 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
516 Ok(())
517 }
518
519 async fn get_commit(&self, commit_hash: &str) -> Result<Commit, ProtocolError> {
520 if commit_hash == self.commit.id.to_string() {
521 Ok(self.commit.clone())
522 } else {
523 Err(ProtocolError::ObjectNotFound(commit_hash.to_string()))
524 }
525 }
526
527 async fn get_tree(&self, tree_hash: &str) -> Result<Tree, ProtocolError> {
528 if tree_hash == self.tree.id.to_string() {
529 Ok(self.tree.clone())
530 } else {
531 Err(ProtocolError::ObjectNotFound(tree_hash.to_string()))
532 }
533 }
534
535 async fn get_blob(&self, blob_hash: &str) -> Result<Blob, ProtocolError> {
536 self.blobs
537 .iter()
538 .find(|b| b.id.to_string() == blob_hash)
539 .cloned()
540 .ok_or_else(|| ProtocolError::ObjectNotFound(blob_hash.to_string()))
541 }
542 }
543
544 fn build_repo_with_objects() -> (SideBandRepo, Commit) {
545 let blob = Blob::from_content("hello");
546 let item = TreeItem::new(TreeItemMode::Blob, blob.id, "hello.txt".to_string());
547 let tree = Tree::from_tree_items(vec![item]).unwrap();
548 let author = Signature::new(
549 SignatureType::Author,
550 "tester".to_string(),
551 "tester@example.com".to_string(),
552 );
553 let committer = Signature::new(
554 SignatureType::Committer,
555 "tester".to_string(),
556 "tester@example.com".to_string(),
557 );
558 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
559
560 let repo = SideBandRepo {
561 commit: commit.clone(),
562 tree,
563 blobs: vec![blob],
564 };
565
566 (repo, commit)
567 }
568
569 #[tokio::test]
571 async fn upload_pack_emits_ack_before_pack() {
572 let _guard = set_hash_kind_for_test(HashKind::Sha1);
573 let (repo, commit) = build_repo_with_objects();
574 let mut proto = GitProtocol::new(repo, MockAuth);
575 let mut request = BytesMut::new();
576 utils::add_pkt_line_string(&mut request, format!("want {}\n", commit.id));
577 utils::add_pkt_line_string(&mut request, "done\n".to_string());
578
579 let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
580 let mut out = BytesMut::new();
581 while let Some(chunk) = stream.next().await {
582 out.extend_from_slice(&chunk.expect("stream chunk"));
583 }
584
585 let mut out_bytes = out.freeze();
586 let (_len, line) = utils::read_pkt_line(&mut out_bytes);
587 assert_eq!(line, Bytes::from_static(b"NAK\n"));
588 assert!(
589 out_bytes.as_ref().starts_with(b"PACK"),
590 "pack should follow ack"
591 );
592 }
593
594 #[tokio::test]
596 async fn upload_pack_sideband_frames_pack() {
597 let _guard = set_hash_kind_for_test(HashKind::Sha1);
598 let (repo, commit) = build_repo_with_objects();
599
600 let mut proto = GitProtocol::new(repo, MockAuth);
601 let mut request = BytesMut::new();
602 utils::add_pkt_line_string(&mut request, format!("want {} side-band-64k\n", commit.id));
603 utils::add_pkt_line_string(&mut request, "done\n".to_string());
604
605 let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
606 let mut out = BytesMut::new();
607 while let Some(chunk) = stream.next().await {
608 out.extend_from_slice(&chunk.expect("stream chunk"));
609 }
610
611 let mut out_bytes = out.freeze();
612 let (_len, line) = utils::read_pkt_line(&mut out_bytes);
613 assert_eq!(line, Bytes::from_static(b"NAK\n"));
614
615 let raw = out_bytes.as_ref();
616 assert!(raw.len() > 9, "side-band packet should include PACK header");
617 let len_hex = std::str::from_utf8(&raw[..4]).expect("hex length");
618 let pkt_len = usize::from_str_radix(len_hex, 16).expect("parse length");
619 assert!(pkt_len > 5, "side-band packet should contain data");
620 assert_eq!(raw[4], SideBand::PackfileData.value());
621 assert_eq!(&raw[5..9], b"PACK");
622 assert!(raw.ends_with(b"0000"), "side-band stream should flush");
623 }
624
625 #[tokio::test]
627 async fn info_refs_includes_refs_and_caps() {
628 let proto = make_protocol();
629 let bytes = proto.info_refs("git-upload-pack").await.expect("info_refs");
630 let text = String::from_utf8(bytes).expect("utf8");
631 assert!(text.contains("refs/heads/main"));
632 assert!(text.contains("capabilities"));
633 assert!(text.contains("object-format"));
634 }
635
636 #[tokio::test]
638 async fn info_refs_invalid_service_errors() {
639 let proto = make_protocol();
640 let err = proto.info_refs("git-invalid").await.unwrap_err();
641 assert!(matches!(err, ProtocolError::InvalidService(_)));
642 }
643
644 #[tokio::test]
646 async fn can_switch_transport() {
647 let mut proto = make_protocol();
648 proto.set_transport(TransportProtocol::Ssh);
649 }
651
652 #[tokio::test]
654 async fn info_refs_hash_length_mismatch_errors() {
655 let proto = GitProtocol::new(
656 MockRepo {
657 refs: vec![(
658 "refs/heads/main".to_string(),
659 "f".repeat(HashKind::Sha256.hex_len()),
660 )],
661 },
662 MockAuth,
663 );
664 let err = proto.info_refs("git-upload-pack").await.unwrap_err();
665 assert!(matches!(err, ProtocolError::InvalidRequest(_)));
666 }
667}