1use std::collections::HashMap;
5
6use bytes::{BufMut, Bytes, BytesMut};
7use tokio_stream::wrappers::ReceiverStream;
8
9use super::{
10 core::{AuthenticationService, RepositoryAccess},
11 pack::PackGenerator,
12 types::{
13 COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolError, ProtocolStream,
14 RECEIVE_CAP_LIST, RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol,
15 UPLOAD_CAP_LIST,
16 },
17 utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space},
18};
19use crate::hash::{HashKind, ObjectHash, get_hash_kind};
20pub struct SmartProtocol<R, A>
25where
26 R: RepositoryAccess,
27 A: AuthenticationService,
28{
29 pub transport_protocol: TransportProtocol,
30 pub capabilities: Vec<Capability>,
31 pub side_band: Option<SideBand>,
32 pub command_list: Vec<RefCommand>,
33 pub wire_hash_kind: HashKind,
34 pub local_hash_kind: HashKind,
35 pub zero_id: String,
36 repo_storage: R,
38 auth_service: A,
39}
40
41impl<R, A> SmartProtocol<R, A>
42where
43 R: RepositoryAccess,
44 A: AuthenticationService,
45{
46 pub fn set_wire_hash_kind(&mut self, kind: HashKind) {
48 self.wire_hash_kind = kind;
49 self.zero_id = ObjectHash::zero_str(kind);
50 }
51
52 pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
54 Self {
55 transport_protocol,
56 capabilities: Vec::new(),
57 side_band: None,
58 command_list: Vec::new(),
59 repo_storage,
60 auth_service,
61 wire_hash_kind: HashKind::default(), local_hash_kind: get_hash_kind(),
63 zero_id: ObjectHash::zero_str(HashKind::default()),
64 }
65 }
66
67 pub async fn authenticate_http(
69 &self,
70 headers: &HashMap<String, String>,
71 ) -> Result<(), ProtocolError> {
72 self.auth_service.authenticate_http(headers).await
73 }
74
75 pub async fn authenticate_ssh(
77 &self,
78 username: &str,
79 public_key: &[u8],
80 ) -> Result<(), ProtocolError> {
81 self.auth_service
82 .authenticate_ssh(username, public_key)
83 .await
84 }
85
86 pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
88 self.transport_protocol = protocol;
89 }
90
91 pub async fn git_info_refs(
93 &self,
94 service_type: ServiceType,
95 ) -> Result<BytesMut, ProtocolError> {
96 let refs = self
97 .repo_storage
98 .get_repository_refs()
99 .await
100 .map_err(|e| ProtocolError::repository_error(format!("Failed to get refs: {e}")))?;
101 let hex_len = self.wire_hash_kind.hex_len();
102 for (name, h) in &refs {
103 if h.len() != hex_len {
104 return Err(ProtocolError::invalid_request(&format!(
105 "Hash length mismatch for ref {}: expected {}, got {}",
106 name,
107 hex_len,
108 h.len()
109 )));
110 }
111 } let head_hash = refs
114 .iter()
115 .find(|(name, _)| {
116 name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
117 })
118 .map(|(_, hash)| hash.clone())
119 .unwrap_or_else(|| self.zero_id.clone());
120
121 let git_refs: Vec<super::types::GitRef> = refs
122 .into_iter()
123 .map(|(name, hash)| super::types::GitRef { name, hash })
124 .collect();
125 let format_cap = match self.wire_hash_kind {
127 HashKind::Sha1 => " object-format=sha1",
128 HashKind::Sha256 => " object-format=sha256",
129 };
130 let cap_list = match service_type {
132 ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
133 ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
134 };
135
136 let name = if head_hash == self.zero_id {
138 "capabilities^{}"
139 } else {
140 "HEAD"
141 };
142 let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
143 let mut ref_list = vec![pkt_line];
144
145 for git_ref in git_refs {
146 let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
147 ref_list.push(pkt_line);
148 }
149
150 let pkt_line_stream =
151 build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
152 tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
153 Ok(pkt_line_stream)
154 }
155
156 pub async fn git_upload_pack(
158 &mut self,
159 upload_request: Bytes,
160 ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
161 self.capabilities.clear();
162 self.set_wire_hash_kind(self.local_hash_kind);
163 let mut upload_request = upload_request;
164 let mut want: Vec<String> = Vec::new();
165 let mut have: Vec<String> = Vec::new();
166 let mut last_common_commit = String::new();
167
168 let mut read_first_line = false;
169 loop {
170 let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
171
172 if bytes_take == 0 {
173 break;
174 }
175
176 if pkt_line.is_empty() {
177 break;
178 }
179
180 let mut pkt_line = pkt_line;
181 let command = read_until_white_space(&mut pkt_line);
182
183 match command.as_str() {
184 "want" => {
185 let hash = read_until_white_space(&mut pkt_line);
186 want.push(hash);
187 if !read_first_line {
188 let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
189 self.parse_capabilities(&cap_str);
190 read_first_line = true;
191 }
192 }
193 "have" => {
194 let hash = read_until_white_space(&mut pkt_line);
195 have.push(hash);
196 }
197 "done" => {
198 break;
199 }
200 _ => {
201 tracing::warn!("Unknown upload-pack command: {}", command);
202 }
203 }
204 }
205
206 let mut protocol_buf = BytesMut::new();
207
208 let pack_generator = PackGenerator::new(&self.repo_storage);
210
211 if have.is_empty() {
212 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
214 let pack_stream = pack_generator.generate_full_pack(want).await?;
215 return Ok((pack_stream, protocol_buf));
216 }
217
218 for hash in &have {
220 let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
221 ProtocolError::repository_error(format!("Failed to check commit existence: {e}"))
222 })?;
223 if exists {
224 add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
225 if last_common_commit.is_empty() {
226 last_common_commit = hash.clone();
227 }
228 }
229 }
230
231 if last_common_commit.is_empty() {
232 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
234 let pack_stream = pack_generator.generate_full_pack(want).await?;
235 return Ok((pack_stream, protocol_buf));
236 }
237
238 add_pkt_line_string(
240 &mut protocol_buf,
241 format!("ACK {last_common_commit} ready\n"),
242 );
243
244 let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
245
246 Ok((pack_stream, protocol_buf))
247 }
248
249 pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
251 self.command_list.clear();
252 self.capabilities.clear();
253 self.set_wire_hash_kind(self.local_hash_kind);
254 let mut first_line = true;
255 loop {
256 let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
257
258 if bytes_take == 0 {
259 break;
260 }
261
262 if pkt_line.is_empty() {
263 break;
264 }
265
266 if first_line {
267 if let Some(pos) = pkt_line.iter().position(|b| *b == b'\0') {
268 let caps = String::from_utf8_lossy(&pkt_line[(pos + 1)..]).to_string();
269 self.parse_capabilities(&caps);
270 }
271 first_line = false;
272 }
273
274 let ref_command = self.parse_ref_command(&mut pkt_line.clone());
275 self.command_list.push(ref_command);
276 }
277 }
278
279 pub async fn git_receive_pack_stream(
281 &mut self,
282 data_stream: ProtocolStream,
283 ) -> Result<Bytes, ProtocolError> {
284 let mut request_data = BytesMut::new();
286 let mut stream = data_stream;
287
288 while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
289 let chunk = chunk_result
290 .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {e}")))?;
291 request_data.extend_from_slice(&chunk);
292 }
293
294 let mut protocol_bytes = request_data.freeze();
295 self.command_list.clear();
296 self.capabilities.clear();
297 self.set_wire_hash_kind(self.local_hash_kind);
298 let mut first_line = true;
299 let mut saw_flush = false;
300 loop {
301 let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
302
303 if bytes_take == 0 {
304 if protocol_bytes.is_empty() {
305 break;
306 }
307 return Err(ProtocolError::invalid_request(
308 "Invalid pkt-line in receive-pack request",
309 ));
310 }
311
312 if pkt_line.is_empty() {
313 saw_flush = true;
314 break;
315 }
316
317 if first_line {
318 if let Some(pos) = pkt_line.iter().position(|b| *b == b'\0') {
319 let caps = String::from_utf8_lossy(&pkt_line[(pos + 1)..]).to_string();
320 self.parse_capabilities(&caps);
321 }
322 first_line = false;
323 }
324
325 let ref_command = self.parse_ref_command(&mut pkt_line.clone());
326 self.command_list.push(ref_command);
327 }
328
329 if !saw_flush {
330 return Err(ProtocolError::invalid_request(
331 "Missing flush before pack data",
332 ));
333 }
334
335 let pack_data = if protocol_bytes.is_empty() {
337 None
338 } else {
339 Some(protocol_bytes)
340 };
341
342 if let Some(pack_data) = pack_data {
343 let pack_generator = PackGenerator::new(&self.repo_storage);
345 let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data).await?;
347
348 self.repo_storage
350 .handle_pack_objects(commits, trees, blobs)
351 .await
352 .map_err(|e| {
353 ProtocolError::repository_error(format!("Failed to store pack objects: {e}"))
354 })?;
355 }
356
357 let mut report_status = BytesMut::new();
359 add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
360
361 let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
362 ProtocolError::repository_error(format!("Failed to check default branch: {e}"))
363 })?;
364
365 for command in &mut self.command_list {
367 if command.ref_type == RefTypeEnum::Tag {
368 let old_hash = if command.old_hash == self.zero_id {
371 None
372 } else {
373 Some(command.old_hash.as_str())
374 };
375 if let Err(e) = self
376 .repo_storage
377 .update_reference(&command.ref_name, old_hash, &command.new_hash)
378 .await
379 {
380 command.failed(e.to_string());
381 }
382 } else {
383 if !default_exist {
385 command.default_branch = true;
386 default_exist = true;
387 }
388 let old_hash = if command.old_hash == self.zero_id {
390 None
391 } else {
392 Some(command.old_hash.as_str())
393 };
394 if let Err(e) = self
395 .repo_storage
396 .update_reference(&command.ref_name, old_hash, &command.new_hash)
397 .await
398 {
399 command.failed(e.to_string());
400 }
401 }
402 add_pkt_line_string(&mut report_status, command.get_status());
403 }
404
405 self.repo_storage.post_receive_hook().await.map_err(|e| {
407 ProtocolError::repository_error(format!("Post-receive hook failed: {e}"))
408 })?;
409
410 report_status.put(&PKT_LINE_END_MARKER[..]);
411 Ok(report_status.freeze())
412 }
413
414 pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
416 let mut to_bytes = BytesMut::new();
417 if self.capabilities.contains(&Capability::SideBand)
418 || self.capabilities.contains(&Capability::SideBand64k)
419 {
420 let length = length + 5;
421 to_bytes.put(Bytes::from(format!("{length:04x}")));
422 to_bytes.put_u8(SideBand::PackfileData.value());
423 to_bytes.put(from_bytes);
424 } else {
425 to_bytes.put(from_bytes);
426 }
427 to_bytes
428 }
429
430 pub fn parse_capabilities(&mut self, cap_str: &str) {
432 for cap in cap_str.split_whitespace() {
433 if let Some(fmt) = cap.strip_prefix("object-format=") {
434 match fmt {
435 "sha1" => self.set_wire_hash_kind(HashKind::Sha1),
436 "sha256" => self.set_wire_hash_kind(HashKind::Sha256),
437 _ => {
438 tracing::warn!("Unknown object-format capability: {}", fmt);
439 }
440 }
441 continue;
442 }
443 if let Ok(capability) = cap.parse::<Capability>() {
444 self.capabilities.push(capability);
445 }
446 }
447 }
448
449 pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
451 let old_id = read_until_white_space(pkt_line);
452 let new_id = read_until_white_space(pkt_line);
453 let ref_name = read_until_white_space(pkt_line);
454 let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
455
456 RefCommand::new(old_id, new_id, ref_name)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use std::sync::{
463 Arc, Mutex,
464 atomic::{AtomicBool, Ordering},
465 };
466
467 use async_trait::async_trait;
468 use bytes::BytesMut;
469 use futures;
470 use tokio::sync::mpsc;
471
472 use super::*;
473 use crate::protocol::utils; use crate::{
475 hash::{HashKind, ObjectHash, set_hash_kind_for_test},
476 internal::{
477 metadata::{EntryMeta, MetaAttached},
478 object::{
479 blob::Blob,
480 commit::Commit,
481 signature::{Signature, SignatureType},
482 tree::{Tree, TreeItem, TreeItemMode},
483 },
484 pack::{encode::PackEncoder, entry::Entry},
485 },
486 };
487
488 type UpdateRecord = (String, Option<String>, String);
490 type UpdateList = Vec<UpdateRecord>;
491 type SharedUpdates = Arc<Mutex<UpdateList>>;
492
493 #[derive(Clone)]
495 struct TestRepoAccess {
496 updates: SharedUpdates,
497 stored_count: Arc<Mutex<usize>>,
498 default_branch_exists: Arc<Mutex<bool>>,
499 post_called: Arc<AtomicBool>,
500 }
501
502 impl TestRepoAccess {
503 fn new() -> Self {
504 Self {
505 updates: Arc::new(Mutex::new(vec![])),
506 stored_count: Arc::new(Mutex::new(0)),
507 default_branch_exists: Arc::new(Mutex::new(false)),
508 post_called: Arc::new(AtomicBool::new(false)),
509 }
510 }
511
512 fn updates_len(&self) -> usize {
513 self.updates.lock().unwrap().len()
514 }
515
516 fn post_hook_called(&self) -> bool {
517 self.post_called.load(Ordering::SeqCst)
518 }
519 }
520
521 #[async_trait]
522 impl RepositoryAccess for TestRepoAccess {
523 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
524 Ok(vec![
525 (
526 "HEAD".to_string(),
527 "0000000000000000000000000000000000000000".to_string(),
528 ),
529 (
530 "refs/heads/main".to_string(),
531 "1111111111111111111111111111111111111111".to_string(),
532 ),
533 ])
534 }
535
536 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
537 Ok(true)
538 }
539
540 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
541 Ok(vec![])
542 }
543
544 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
545 *self.stored_count.lock().unwrap() += 1;
546 Ok(())
547 }
548
549 async fn update_reference(
550 &self,
551 ref_name: &str,
552 old_hash: Option<&str>,
553 new_hash: &str,
554 ) -> Result<(), ProtocolError> {
555 self.updates.lock().unwrap().push((
556 ref_name.to_string(),
557 old_hash.map(|s| s.to_string()),
558 new_hash.to_string(),
559 ));
560 Ok(())
561 }
562
563 async fn get_objects_for_pack(
564 &self,
565 _wants: &[String],
566 _haves: &[String],
567 ) -> Result<Vec<String>, ProtocolError> {
568 Ok(vec![])
569 }
570
571 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
572 let mut exists = self.default_branch_exists.lock().unwrap();
573 let current = *exists;
574 *exists = true; Ok(current)
576 }
577
578 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
579 self.post_called.store(true, Ordering::SeqCst);
580 Ok(())
581 }
582 }
583
584 struct TestAuth;
586
587 #[async_trait]
588 impl AuthenticationService for TestAuth {
589 async fn authenticate_http(
590 &self,
591 _headers: &std::collections::HashMap<String, String>,
592 ) -> Result<(), ProtocolError> {
593 Ok(())
594 }
595
596 async fn authenticate_ssh(
597 &self,
598 _username: &str,
599 _public_key: &[u8],
600 ) -> Result<(), ProtocolError> {
601 Ok(())
602 }
603 }
604
605 #[tokio::test]
607 async fn test_receive_pack_stream_status_report() {
608 let _guard = set_hash_kind_for_test(HashKind::Sha1);
609 let blob1 = Blob::from_content("hello");
611 let blob2 = Blob::from_content("world");
612
613 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
614 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
615 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
616
617 let author = Signature::new(
618 SignatureType::Author,
619 "tester".to_string(),
620 "tester@example.com".to_string(),
621 );
622 let committer = Signature::new(
623 SignatureType::Committer,
624 "tester".to_string(),
625 "tester@example.com".to_string(),
626 );
627 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
628
629 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
631 let (entry_tx, entry_rx) = mpsc::channel(1024);
632 let mut encoder = PackEncoder::new(4, 10, pack_tx);
633
634 tokio::spawn(async move {
635 if let Err(e) = encoder.encode(entry_rx).await {
636 panic!("Failed to encode pack: {}", e);
637 }
638 });
639
640 let commit_clone = commit.clone();
641 let tree_clone = tree.clone();
642 let blob1_clone = blob1.clone();
643 let blob2_clone = blob2.clone();
644 tokio::spawn(async move {
645 let _ = entry_tx
646 .send(MetaAttached {
647 inner: Entry::from(commit_clone),
648 meta: EntryMeta::new(),
649 })
650 .await;
651 let _ = entry_tx
652 .send(MetaAttached {
653 inner: Entry::from(tree_clone),
654 meta: EntryMeta::new(),
655 })
656 .await;
657 let _ = entry_tx
658 .send(MetaAttached {
659 inner: Entry::from(blob1_clone),
660 meta: EntryMeta::new(),
661 })
662 .await;
663 let _ = entry_tx
664 .send(MetaAttached {
665 inner: Entry::from(blob2_clone),
666 meta: EntryMeta::new(),
667 })
668 .await;
669 });
671
672 let mut pack_bytes: Vec<u8> = Vec::new();
673 while let Some(chunk) = pack_rx.recv().await {
674 pack_bytes.extend_from_slice(&chunk);
675 }
676
677 let repo_access = TestRepoAccess::new();
679 let auth = TestAuth;
680 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
681 smart.set_wire_hash_kind(HashKind::Sha1);
682
683 let mut request = BytesMut::new();
684 add_pkt_line_string(
685 &mut request,
686 format!(
687 "{} {} refs/heads/main\0report-status\n",
688 smart.zero_id, commit.id
689 ),
690 );
691 request.put(&PKT_LINE_END_MARKER[..]);
692 request.extend_from_slice(&pack_bytes);
693
694 let request_stream = Box::pin(futures::stream::once(async { Ok(request.freeze()) }));
696
697 let result_bytes = smart
699 .git_receive_pack_stream(request_stream)
700 .await
701 .expect("receive-pack should succeed");
702
703 let mut out = result_bytes.clone();
705 let (_c1, l1) = utils::read_pkt_line(&mut out);
706 assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
707
708 let (_c2, l2) = utils::read_pkt_line(&mut out);
709 assert_eq!(
710 String::from_utf8(l2.to_vec()).unwrap(),
711 "ok refs/heads/main"
712 );
713
714 let (c3, l3) = utils::read_pkt_line(&mut out);
715 assert_eq!(c3, 4);
716 assert!(l3.is_empty());
717
718 assert_eq!(repo_access.updates_len(), 1);
720 assert!(repo_access.post_hook_called());
721 }
722
723 #[tokio::test]
725 async fn info_refs_rejects_sha256_with_sha1_refs() {
726 let _guard = set_hash_kind_for_test(HashKind::Sha1); let repo_access = TestRepoAccess::new(); let auth = TestAuth;
729 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
730 smart.set_wire_hash_kind(HashKind::Sha256); let res = smart.git_info_refs(ServiceType::UploadPack).await;
733 assert!(res.is_err(), "expected failure when hash lengths mismatch");
734
735 smart.set_wire_hash_kind(HashKind::Sha1);
736
737 let res = smart.git_info_refs(ServiceType::UploadPack).await;
738 assert!(
739 res.is_ok(),
740 "expected SHA1 refs to be accepted when wire is SHA1"
741 );
742 }
743
744 #[tokio::test]
746 async fn parse_capabilities_updates_hash_and_caps() {
747 let _guard = set_hash_kind_for_test(HashKind::Sha1);
748 let repo_access = TestRepoAccess::new();
749 let auth = TestAuth;
750 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
751
752 smart.parse_capabilities("object-format=sha256 side-band-64k multi_ack");
753
754 assert_eq!(smart.wire_hash_kind, HashKind::Sha256);
755 assert_eq!(smart.zero_id.len(), HashKind::Sha256.hex_len());
756 assert!(
757 smart.capabilities.contains(&Capability::SideBand64k),
758 "side-band-64k should be recorded"
759 );
760 }
761
762 #[tokio::test]
764 async fn info_refs_accepts_sha256_refs_and_emits_capability() {
765 #[derive(Clone)]
767 struct Sha256Repo;
768
769 #[async_trait]
770 impl RepositoryAccess for Sha256Repo {
771 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
772 Ok(vec![
773 (
774 "HEAD".to_string(),
775 "0000000000000000000000000000000000000000000000000000000000000000"
776 .to_string(),
777 ),
778 (
779 "refs/heads/main".to_string(),
780 "1111111111111111111111111111111111111111111111111111111111111111"
781 .to_string(),
782 ),
783 ])
784 }
785 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
786 Ok(true)
787 }
788 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
789 Ok(vec![])
790 }
791 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
792 Ok(())
793 }
794 async fn update_reference(
795 &self,
796 _ref_name: &str,
797 _old_hash: Option<&str>,
798 _new_hash: &str,
799 ) -> Result<(), ProtocolError> {
800 Ok(())
801 }
802 async fn get_objects_for_pack(
803 &self,
804 _wants: &[String],
805 _haves: &[String],
806 ) -> Result<Vec<String>, ProtocolError> {
807 Ok(vec![])
808 }
809 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
810 Ok(false)
811 }
812 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
813 Ok(())
814 }
815 }
816
817 let _guard = set_hash_kind_for_test(HashKind::Sha1);
818 let repo_access = Sha256Repo;
819 let auth = TestAuth;
820 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
821 smart.set_wire_hash_kind(HashKind::Sha256);
822
823 let resp = smart
824 .git_info_refs(ServiceType::UploadPack)
825 .await
826 .expect("sha256 refs should be accepted");
827 let resp_str = String::from_utf8(resp.to_vec()).expect("pkt-line should be valid UTF-8");
828 assert!(
829 resp_str.contains("object-format=sha256"),
830 "capability line should advertise sha256"
831 );
832 }
833
834 #[tokio::test]
836 async fn parse_receive_pack_commands_decodes_commands() {
837 let _guard = set_hash_kind_for_test(HashKind::Sha1);
838 let repo_access = TestRepoAccess::new();
839 let auth = TestAuth;
840 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
841
842 let zero = ObjectHash::zero_str(HashKind::Sha1);
843 let mut pkt = BytesMut::new();
844 add_pkt_line_string(&mut pkt, format!("{zero} {zero} refs/heads/main\n"));
845 add_pkt_line_string(&mut pkt, format!("{zero} {zero} refs/tags/v1.0\n"));
846 pkt.put(&PKT_LINE_END_MARKER[..]);
847
848 smart.parse_receive_pack_commands(pkt.freeze());
849
850 assert_eq!(smart.command_list.len(), 2);
851 assert_eq!(smart.command_list[0].ref_name, "refs/heads/main");
852 assert_eq!(smart.command_list[1].ref_name, "refs/tags/v1.0");
853 }
854
855 #[tokio::test]
857 async fn receive_pack_missing_flush_errors() {
858 let _guard = set_hash_kind_for_test(HashKind::Sha1);
859 let repo_access = TestRepoAccess::new();
860 let auth = TestAuth;
861 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
862
863 let zero = ObjectHash::zero_str(HashKind::Sha1);
864 let mut pkt = BytesMut::new();
865 add_pkt_line_string(&mut pkt, format!("{zero} {zero} refs/heads/main\n"));
866
867 let request_stream = Box::pin(futures::stream::once(async { Ok(pkt.freeze()) }));
868 let err = smart
869 .git_receive_pack_stream(request_stream)
870 .await
871 .unwrap_err();
872 assert!(matches!(err, ProtocolError::InvalidRequest(_)));
873 }
874}