1use std::{collections::HashMap, str::FromStr};
6
7use async_trait::async_trait;
8use bytes::{BufMut, Bytes, BytesMut};
9use futures::stream::StreamExt;
10
11use crate::{
12 hash::ObjectHash,
13 internal::object::ObjectTrait,
14 protocol::{
15 smart::SmartProtocol,
16 types::{Capability, ProtocolError, ProtocolStream, ServiceType, SideBand},
17 },
18};
19
20#[async_trait]
25pub trait RepositoryAccess: Send + Sync + Clone {
26 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError>;
28
29 async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError>;
31
32 async fn get_object(&self, object_hash: &str) -> Result<Vec<u8>, ProtocolError>;
34
35 async fn store_pack_data(&self, pack_data: &[u8]) -> Result<(), ProtocolError>;
37
38 async fn update_reference(
40 &self,
41 ref_name: &str,
42 old_hash: Option<&str>,
43 new_hash: &str,
44 ) -> Result<(), ProtocolError>;
45
46 async fn get_objects_for_pack(
48 &self,
49 wants: &[String],
50 haves: &[String],
51 ) -> Result<Vec<String>, ProtocolError>;
52
53 async fn has_default_branch(&self) -> Result<bool, ProtocolError>;
55
56 async fn post_receive_hook(&self) -> Result<(), ProtocolError>;
58
59 async fn get_blob(
64 &self,
65 object_hash: &str,
66 ) -> Result<crate::internal::object::blob::Blob, ProtocolError> {
67 let data = self.get_object(object_hash).await?;
68 let hash = ObjectHash::from_str(object_hash)
69 .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
70
71 crate::internal::object::blob::Blob::from_bytes(&data, hash)
72 .map_err(|e| ProtocolError::repository_error(format!("Failed to parse blob: {e}")))
73 }
74
75 async fn get_commit(
80 &self,
81 commit_hash: &str,
82 ) -> Result<crate::internal::object::commit::Commit, ProtocolError> {
83 let data = self.get_object(commit_hash).await?;
84 let hash = ObjectHash::from_str(commit_hash)
85 .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
86
87 crate::internal::object::commit::Commit::from_bytes(&data, hash)
88 .map_err(|e| ProtocolError::repository_error(format!("Failed to parse commit: {e}")))
89 }
90
91 async fn get_tree(
96 &self,
97 tree_hash: &str,
98 ) -> Result<crate::internal::object::tree::Tree, ProtocolError> {
99 let data = self.get_object(tree_hash).await?;
100 let hash = ObjectHash::from_str(tree_hash)
101 .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
102
103 crate::internal::object::tree::Tree::from_bytes(&data, hash)
104 .map_err(|e| ProtocolError::repository_error(format!("Failed to parse tree: {e}")))
105 }
106
107 async fn commit_exists(&self, commit_hash: &str) -> Result<bool, ProtocolError> {
112 match self.has_object(commit_hash).await {
113 Ok(exists) => {
114 if !exists {
115 return Ok(false);
116 }
117
118 match self.get_commit(commit_hash).await {
120 Ok(_) => Ok(true),
121 Err(_) => Ok(false), }
123 }
124 Err(e) => Err(e),
125 }
126 }
127
128 async fn handle_pack_objects(
133 &self,
134 commits: Vec<crate::internal::object::commit::Commit>,
135 trees: Vec<crate::internal::object::tree::Tree>,
136 blobs: Vec<crate::internal::object::blob::Blob>,
137 ) -> Result<(), ProtocolError> {
138 for blob in blobs {
140 let data = blob.to_data().map_err(|e| {
141 ProtocolError::repository_error(format!("Failed to serialize blob: {e}"))
142 })?;
143 self.store_pack_data(&data).await.map_err(|e| {
144 ProtocolError::repository_error(format!("Failed to store blob {}: {}", blob.id, e))
145 })?;
146 }
147
148 for tree in trees {
150 let data = tree.to_data().map_err(|e| {
151 ProtocolError::repository_error(format!("Failed to serialize tree: {e}"))
152 })?;
153 self.store_pack_data(&data).await.map_err(|e| {
154 ProtocolError::repository_error(format!("Failed to store tree {}: {}", tree.id, e))
155 })?;
156 }
157
158 for commit in commits {
160 let data = commit.to_data().map_err(|e| {
161 ProtocolError::repository_error(format!("Failed to serialize commit: {e}"))
162 })?;
163 self.store_pack_data(&data).await.map_err(|e| {
164 ProtocolError::repository_error(format!(
165 "Failed to store commit {}: {}",
166 commit.id, e
167 ))
168 })?;
169 }
170
171 Ok(())
172 }
173}
174
175#[async_trait]
177pub trait AuthenticationService: Send + Sync {
178 async fn authenticate_http(
180 &self,
181 headers: &std::collections::HashMap<String, String>,
182 ) -> Result<(), ProtocolError>;
183
184 async fn authenticate_ssh(
186 &self,
187 username: &str,
188 public_key: &[u8],
189 ) -> Result<(), ProtocolError>;
190}
191
192pub struct GitProtocol<R: RepositoryAccess, A: AuthenticationService> {
199 smart_protocol: SmartProtocol<R, A>,
200}
201
202impl<R: RepositoryAccess, A: AuthenticationService> GitProtocol<R, A> {
203 pub fn new(repo_access: R, auth_service: A) -> Self {
205 Self {
206 smart_protocol: SmartProtocol::new(
207 super::types::TransportProtocol::Http,
208 repo_access,
209 auth_service,
210 ),
211 }
212 }
213
214 pub async fn authenticate_http(
216 &self,
217 headers: &HashMap<String, String>,
218 ) -> Result<(), ProtocolError> {
219 self.smart_protocol.authenticate_http(headers).await
220 }
221
222 pub async fn authenticate_ssh(
224 &self,
225 username: &str,
226 public_key: &[u8],
227 ) -> Result<(), ProtocolError> {
228 self.smart_protocol
229 .authenticate_ssh(username, public_key)
230 .await
231 }
232
233 pub fn set_transport(&mut self, protocol: super::types::TransportProtocol) {
235 self.smart_protocol.set_transport_protocol(protocol);
236 }
237
238 pub async fn info_refs(&self, service: &str) -> Result<Vec<u8>, ProtocolError> {
240 let service_type = match service {
241 "git-upload-pack" => ServiceType::UploadPack,
242 "git-receive-pack" => ServiceType::ReceivePack,
243 _ => return Err(ProtocolError::invalid_service(service)),
244 };
245
246 let bytes = self.smart_protocol.git_info_refs(service_type).await?;
247 Ok(bytes.to_vec())
248 }
249
250 pub async fn upload_pack(
252 &mut self,
253 request_data: &[u8],
254 ) -> Result<ProtocolStream, ProtocolError> {
255 const SIDE_BAND_PACKET_LEN: usize = 1000;
256 const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
257 const SIDE_BAND_HEADER_LEN: usize = 5; let request_bytes = bytes::Bytes::from(request_data.to_vec());
260 let (pack_stream, protocol_buf) =
261 self.smart_protocol.git_upload_pack(request_bytes).await?;
262 let ack_bytes = protocol_buf.freeze();
263
264 let ack_stream: ProtocolStream = if ack_bytes.is_empty() {
265 Box::pin(futures::stream::empty::<Result<Bytes, ProtocolError>>())
266 } else {
267 Box::pin(futures::stream::once(async move { Ok(ack_bytes) }))
268 };
269
270 let sideband_max = if self
271 .smart_protocol
272 .capabilities
273 .contains(&Capability::SideBand64k)
274 {
275 Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
276 } else if self
277 .smart_protocol
278 .capabilities
279 .contains(&Capability::SideBand)
280 {
281 Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
282 } else {
283 None
284 };
285
286 let data_stream: ProtocolStream = if let Some(max_payload) = sideband_max {
287 let stream = pack_stream.flat_map(move |chunk| {
288 let packets = build_side_band_packets(&chunk, max_payload);
289 futures::stream::iter(packets.into_iter().map(Ok))
290 });
291 let stream = stream.chain(futures::stream::once(async {
292 Ok(Bytes::from_static(b"0000"))
293 }));
294 Box::pin(stream)
295 } else {
296 Box::pin(pack_stream.map(|data| Ok(Bytes::from(data))))
297 };
298
299 Ok(Box::pin(ack_stream.chain(data_stream)))
300 }
301
302 pub async fn receive_pack(
304 &mut self,
305 request_stream: ProtocolStream,
306 ) -> Result<ProtocolStream, ProtocolError> {
307 const SIDE_BAND_PACKET_LEN: usize = 1000;
308 const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
309 const SIDE_BAND_HEADER_LEN: usize = 5; let result_bytes = self
312 .smart_protocol
313 .git_receive_pack_stream(request_stream)
314 .await?;
315
316 let sideband_max = if self
317 .smart_protocol
318 .capabilities
319 .contains(&Capability::SideBand64k)
320 {
321 Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
322 } else if self
323 .smart_protocol
324 .capabilities
325 .contains(&Capability::SideBand)
326 {
327 Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
328 } else {
329 None
330 };
331
332 if let Some(max_payload) = sideband_max {
334 let packets = build_side_band_packets(result_bytes.as_ref(), max_payload);
335 let stream = futures::stream::iter(packets.into_iter().map(Ok)).chain(
336 futures::stream::once(async { Ok(Bytes::from_static(b"0000")) }),
337 );
338 Ok(Box::pin(stream))
339 } else {
340 Ok(Box::pin(futures::stream::once(async { Ok(result_bytes) })))
342 }
343 }
344}
345
346fn build_side_band_packets(chunk: &[u8], max_payload: usize) -> Vec<Bytes> {
347 if chunk.is_empty() {
348 return Vec::new();
349 }
350
351 let mut out = Vec::new();
352 let mut offset = 0;
353
354 while offset < chunk.len() {
355 let end = (offset + max_payload).min(chunk.len());
356 let payload = &chunk[offset..end];
357 let length = payload.len() + 5; let mut pkt = BytesMut::with_capacity(length);
359 pkt.put(Bytes::from(format!("{length:04x}")));
360 pkt.put_u8(SideBand::PackfileData.value());
361 pkt.put(payload);
362 out.push(pkt.freeze());
363 offset = end;
364 }
365
366 out
367}
368
369#[cfg(test)]
370mod tests {
371 use async_trait::async_trait;
372 use bytes::{Bytes, BytesMut};
373 use futures::StreamExt;
374
375 use super::*;
376 use crate::{
377 hash::{HashKind, set_hash_kind_for_test},
378 internal::object::{
379 blob::Blob,
380 commit::Commit,
381 signature::{Signature, SignatureType},
382 tree::{Tree, TreeItem, TreeItemMode},
383 },
384 protocol::{types::TransportProtocol, utils},
385 };
386
387 #[derive(Clone)]
389 struct MockRepo {
390 refs: Vec<(String, String)>,
391 }
392
393 #[async_trait]
394 impl RepositoryAccess for MockRepo {
395 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
396 Ok(self.refs.clone())
397 }
398 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
399 Ok(false)
400 }
401 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
402 Ok(Vec::new())
403 }
404 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
405 Ok(())
406 }
407 async fn update_reference(
408 &self,
409 _ref_name: &str,
410 _old_hash: Option<&str>,
411 _new_hash: &str,
412 ) -> Result<(), ProtocolError> {
413 Ok(())
414 }
415 async fn get_objects_for_pack(
416 &self,
417 wants: &[String],
418 _haves: &[String],
419 ) -> Result<Vec<String>, ProtocolError> {
420 Ok(wants.to_vec())
421 }
422 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
423 Ok(false)
424 }
425 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
426 Ok(())
427 }
428 }
429
430 struct MockAuth;
432 #[async_trait]
433 impl AuthenticationService for MockAuth {
434 async fn authenticate_http(
435 &self,
436 _headers: &std::collections::HashMap<String, String>,
437 ) -> Result<(), ProtocolError> {
438 Ok(())
439 }
440 async fn authenticate_ssh(
441 &self,
442 _username: &str,
443 _public_key: &[u8],
444 ) -> Result<(), ProtocolError> {
445 Ok(())
446 }
447 }
448
449 fn make_protocol() -> GitProtocol<MockRepo, MockAuth> {
451 GitProtocol::new(
452 MockRepo {
453 refs: vec![
454 (
455 "refs/heads/main".to_string(),
456 ObjectHash::default().to_string(),
457 ),
458 ("HEAD".to_string(), ObjectHash::default().to_string()),
459 ],
460 },
461 MockAuth,
462 )
463 }
464
465 #[derive(Clone)]
467 struct SideBandRepo {
468 commit: Commit,
469 tree: Tree,
470 blobs: Vec<Blob>,
471 }
472 #[async_trait]
473 impl RepositoryAccess for SideBandRepo {
474 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
475 Ok(vec![(
476 "refs/heads/main".to_string(),
477 self.commit.id.to_string(),
478 )])
479 }
480
481 async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError> {
482 let known = object_hash == self.commit.id.to_string()
483 || object_hash == self.tree.id.to_string()
484 || self.blobs.iter().any(|b| b.id.to_string() == object_hash);
485 Ok(known)
486 }
487
488 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
489 Ok(Vec::new())
490 }
491
492 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
493 Ok(())
494 }
495
496 async fn update_reference(
497 &self,
498 _ref_name: &str,
499 _old_hash: Option<&str>,
500 _new_hash: &str,
501 ) -> Result<(), ProtocolError> {
502 Ok(())
503 }
504
505 async fn get_objects_for_pack(
506 &self,
507 _wants: &[String],
508 _haves: &[String],
509 ) -> Result<Vec<String>, ProtocolError> {
510 Ok(Vec::new())
511 }
512
513 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
514 Ok(true)
515 }
516
517 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
518 Ok(())
519 }
520
521 async fn get_commit(&self, commit_hash: &str) -> Result<Commit, ProtocolError> {
522 if commit_hash == self.commit.id.to_string() {
523 Ok(self.commit.clone())
524 } else {
525 Err(ProtocolError::ObjectNotFound(commit_hash.to_string()))
526 }
527 }
528
529 async fn get_tree(&self, tree_hash: &str) -> Result<Tree, ProtocolError> {
530 if tree_hash == self.tree.id.to_string() {
531 Ok(self.tree.clone())
532 } else {
533 Err(ProtocolError::ObjectNotFound(tree_hash.to_string()))
534 }
535 }
536
537 async fn get_blob(&self, blob_hash: &str) -> Result<Blob, ProtocolError> {
538 self.blobs
539 .iter()
540 .find(|b| b.id.to_string() == blob_hash)
541 .cloned()
542 .ok_or_else(|| ProtocolError::ObjectNotFound(blob_hash.to_string()))
543 }
544 }
545
546 fn build_repo_with_objects() -> (SideBandRepo, Commit) {
547 let blob = Blob::from_content("hello");
548 let item = TreeItem::new(TreeItemMode::Blob, blob.id, "hello.txt".to_string());
549 let tree = Tree::from_tree_items(vec![item]).unwrap();
550 let author = Signature::new(
551 SignatureType::Author,
552 "tester".to_string(),
553 "tester@example.com".to_string(),
554 );
555 let committer = Signature::new(
556 SignatureType::Committer,
557 "tester".to_string(),
558 "tester@example.com".to_string(),
559 );
560 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
561
562 let repo = SideBandRepo {
563 commit: commit.clone(),
564 tree,
565 blobs: vec![blob],
566 };
567
568 (repo, commit)
569 }
570
571 #[tokio::test]
573 async fn upload_pack_emits_ack_before_pack() {
574 let _guard = set_hash_kind_for_test(HashKind::Sha1);
575 let (repo, commit) = build_repo_with_objects();
576 let mut proto = GitProtocol::new(repo, MockAuth);
577 let mut request = BytesMut::new();
578 utils::add_pkt_line_string(&mut request, format!("want {}\n", commit.id));
579 utils::add_pkt_line_string(&mut request, "done\n".to_string());
580
581 let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
582 let mut out = BytesMut::new();
583 while let Some(chunk) = stream.next().await {
584 out.extend_from_slice(&chunk.expect("stream chunk"));
585 }
586
587 let mut out_bytes = out.freeze();
588 let (_len, line) = utils::read_pkt_line(&mut out_bytes);
589 assert_eq!(line, Bytes::from_static(b"NAK\n"));
590 assert!(
591 out_bytes.as_ref().starts_with(b"PACK"),
592 "pack should follow ack"
593 );
594 }
595
596 #[tokio::test]
598 async fn upload_pack_sideband_frames_pack() {
599 let _guard = set_hash_kind_for_test(HashKind::Sha1);
600 let (repo, commit) = build_repo_with_objects();
601
602 let mut proto = GitProtocol::new(repo, MockAuth);
603 let mut request = BytesMut::new();
604 utils::add_pkt_line_string(&mut request, format!("want {} side-band-64k\n", commit.id));
605 utils::add_pkt_line_string(&mut request, "done\n".to_string());
606
607 let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
608 let mut out = BytesMut::new();
609 while let Some(chunk) = stream.next().await {
610 out.extend_from_slice(&chunk.expect("stream chunk"));
611 }
612
613 let mut out_bytes = out.freeze();
614 let (_len, line) = utils::read_pkt_line(&mut out_bytes);
615 assert_eq!(line, Bytes::from_static(b"NAK\n"));
616
617 let raw = out_bytes.as_ref();
618 assert!(raw.len() > 9, "side-band packet should include PACK header");
619 let len_hex = std::str::from_utf8(&raw[..4]).expect("hex length");
620 let pkt_len = usize::from_str_radix(len_hex, 16).expect("parse length");
621 assert!(pkt_len > 5, "side-band packet should contain data");
622 assert_eq!(raw[4], SideBand::PackfileData.value());
623 assert_eq!(&raw[5..9], b"PACK");
624 assert!(raw.ends_with(b"0000"), "side-band stream should flush");
625 }
626
627 #[tokio::test]
629 async fn info_refs_includes_refs_and_caps() {
630 let proto = make_protocol();
631 let bytes = proto.info_refs("git-upload-pack").await.expect("info_refs");
632 let text = String::from_utf8(bytes).expect("utf8");
633 assert!(text.contains("refs/heads/main"));
634 assert!(text.contains("capabilities"));
635 assert!(text.contains("object-format"));
636 }
637
638 #[tokio::test]
640 async fn info_refs_invalid_service_errors() {
641 let proto = make_protocol();
642 let err = proto.info_refs("git-invalid").await.unwrap_err();
643 assert!(matches!(err, ProtocolError::InvalidService(_)));
644 }
645
646 #[tokio::test]
648 async fn can_switch_transport() {
649 let mut proto = make_protocol();
650 proto.set_transport(TransportProtocol::Ssh);
651 }
653
654 #[tokio::test]
656 async fn info_refs_hash_length_mismatch_errors() {
657 let proto = GitProtocol::new(
658 MockRepo {
659 refs: vec![(
660 "refs/heads/main".to_string(),
661 "f".repeat(HashKind::Sha256.hex_len()),
662 )],
663 },
664 MockAuth,
665 );
666 let err = proto.info_refs("git-upload-pack").await.unwrap_err();
667 assert!(matches!(err, ProtocolError::InvalidRequest(_)));
668 }
669}