Skip to main content

git_internal/protocol/
smart.rs

1//! Implementation of the Git smart protocol state machine, handling capability negotiation, pkt
2//! exchanges, authentication delegation, and bridging repository storage to transport streams.
3
4use std::collections::HashMap;
5
6use bytes::{BufMut, Bytes, BytesMut};
7use tokio_stream::wrappers::ReceiverStream;
8
9use super::{
10    core::{AuthenticationService, RepositoryAccess},
11    pack::PackGenerator,
12    types::{
13        COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolError, ProtocolStream,
14        RECEIVE_CAP_LIST, RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol,
15        UPLOAD_CAP_LIST,
16    },
17    utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space},
18};
19use crate::hash::{HashKind, ObjectHash, get_hash_kind};
20/// Smart Git Protocol implementation
21///
22/// This struct handles the Git smart protocol operations for both HTTP and SSH transports.
23/// It uses trait abstractions to decouple from specific business logic implementations.
24pub struct SmartProtocol<R, A>
25where
26    R: RepositoryAccess,
27    A: AuthenticationService,
28{
29    pub transport_protocol: TransportProtocol,
30    pub capabilities: Vec<Capability>,
31    pub side_band: Option<SideBand>,
32    pub command_list: Vec<RefCommand>,
33    pub wire_hash_kind: HashKind,
34    pub local_hash_kind: HashKind,
35    pub zero_id: String,
36    // Trait-based dependencies
37    repo_storage: R,
38    auth_service: A,
39}
40
41impl<R, A> SmartProtocol<R, A>
42where
43    R: RepositoryAccess,
44    A: AuthenticationService,
45{
46    /// Set the wire hash kind (sha1 or sha256)
47    pub fn set_wire_hash_kind(&mut self, kind: HashKind) {
48        self.wire_hash_kind = kind;
49        self.zero_id = ObjectHash::zero_str(kind);
50    }
51
52    /// Create a new SmartProtocol instance
53    pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
54        Self {
55            transport_protocol,
56            capabilities: Vec::new(),
57            side_band: None,
58            command_list: Vec::new(),
59            repo_storage,
60            auth_service,
61            wire_hash_kind: HashKind::default(), // Default to SHA-1
62            local_hash_kind: get_hash_kind(),
63            zero_id: ObjectHash::zero_str(HashKind::default()),
64        }
65    }
66
67    /// Authenticate an HTTP request using the injected auth service
68    pub async fn authenticate_http(
69        &self,
70        headers: &HashMap<String, String>,
71    ) -> Result<(), ProtocolError> {
72        self.auth_service.authenticate_http(headers).await
73    }
74
75    /// Authenticate an SSH session using username and public key
76    pub async fn authenticate_ssh(
77        &self,
78        username: &str,
79        public_key: &[u8],
80    ) -> Result<(), ProtocolError> {
81        self.auth_service
82            .authenticate_ssh(username, public_key)
83            .await
84    }
85
86    /// Set transport protocol (Http, Ssh, etc.)
87    pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
88        self.transport_protocol = protocol;
89    }
90
91    /// Get git info refs for the repository, with explicit service type
92    pub async fn git_info_refs(
93        &self,
94        service_type: ServiceType,
95    ) -> Result<BytesMut, ProtocolError> {
96        let refs = self
97            .repo_storage
98            .get_repository_refs()
99            .await
100            .map_err(|e| ProtocolError::repository_error(format!("Failed to get refs: {e}")))?;
101        let hex_len = self.wire_hash_kind.hex_len();
102        for (name, h) in &refs {
103            if h.len() != hex_len {
104                return Err(ProtocolError::invalid_request(&format!(
105                    "Hash length mismatch for ref {}: expected {}, got {}",
106                    name,
107                    hex_len,
108                    h.len()
109                )));
110            }
111        } // Ensure refs match the expected wire hash kind
112        // Convert to the expected format (head_hash, git_refs)
113        let head_hash = refs
114            .iter()
115            .find(|(name, _)| {
116                name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
117            })
118            .map(|(_, hash)| hash.clone())
119            .unwrap_or_else(|| self.zero_id.clone());
120
121        let git_refs: Vec<super::types::GitRef> = refs
122            .into_iter()
123            .map(|(name, hash)| super::types::GitRef { name, hash })
124            .collect();
125        // capability add object-format,declare the wire hash kind
126        let format_cap = match self.wire_hash_kind {
127            HashKind::Sha1 => " object-format=sha1",
128            HashKind::Sha256 => " object-format=sha256",
129        };
130        // Determine capabilities based on service type
131        let cap_list = match service_type {
132            ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
133            ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
134        };
135
136        // The stream MUST include capability declarations behind a NUL on the first ref.
137        let name = if head_hash == self.zero_id {
138            "capabilities^{}"
139        } else {
140            "HEAD"
141        };
142        let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
143        let mut ref_list = vec![pkt_line];
144
145        for git_ref in git_refs {
146            let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
147            ref_list.push(pkt_line);
148        }
149
150        let pkt_line_stream =
151            build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
152        tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
153        Ok(pkt_line_stream)
154    }
155
156    /// Handle git-upload-pack request
157    pub async fn git_upload_pack(
158        &mut self,
159        upload_request: Bytes,
160    ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
161        self.capabilities.clear();
162        self.set_wire_hash_kind(self.local_hash_kind);
163        let mut upload_request = upload_request;
164        let mut want: Vec<String> = Vec::new();
165        let mut have: Vec<String> = Vec::new();
166        let mut last_common_commit = String::new();
167
168        let mut read_first_line = false;
169        loop {
170            let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
171
172            if bytes_take == 0 {
173                break;
174            }
175
176            if pkt_line.is_empty() {
177                break;
178            }
179
180            let mut pkt_line = pkt_line;
181            let command = read_until_white_space(&mut pkt_line);
182
183            match command.as_str() {
184                "want" => {
185                    let hash = read_until_white_space(&mut pkt_line);
186                    want.push(hash);
187                    if !read_first_line {
188                        let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
189                        self.parse_capabilities(&cap_str);
190                        read_first_line = true;
191                    }
192                }
193                "have" => {
194                    let hash = read_until_white_space(&mut pkt_line);
195                    have.push(hash);
196                }
197                "done" => {
198                    break;
199                }
200                _ => {
201                    tracing::warn!("Unknown upload-pack command: {}", command);
202                }
203            }
204        }
205
206        let mut protocol_buf = BytesMut::new();
207
208        // Create pack generator for this operation
209        let pack_generator = PackGenerator::new(&self.repo_storage);
210
211        if have.is_empty() {
212            // Full pack
213            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
214            let pack_stream = pack_generator.generate_full_pack(want).await?;
215            return Ok((pack_stream, protocol_buf));
216        }
217
218        // Check for common commits
219        for hash in &have {
220            let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
221                ProtocolError::repository_error(format!("Failed to check commit existence: {e}"))
222            })?;
223            if exists {
224                add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
225                if last_common_commit.is_empty() {
226                    last_common_commit = hash.clone();
227                }
228            }
229        }
230
231        if last_common_commit.is_empty() {
232            // No common commits found
233            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
234            let pack_stream = pack_generator.generate_full_pack(want).await?;
235            return Ok((pack_stream, protocol_buf));
236        }
237
238        // Generate incremental pack
239        add_pkt_line_string(
240            &mut protocol_buf,
241            format!("ACK {last_common_commit} ready\n"),
242        );
243
244        let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
245
246        Ok((pack_stream, protocol_buf))
247    }
248
249    /// Parse receive pack commands from protocol bytes
250    pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
251        self.command_list.clear();
252        self.capabilities.clear();
253        self.set_wire_hash_kind(self.local_hash_kind);
254        let mut first_line = true;
255        loop {
256            let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
257
258            if bytes_take == 0 {
259                break;
260            }
261
262            if pkt_line.is_empty() {
263                break;
264            }
265
266            if first_line {
267                if let Some(pos) = pkt_line.iter().position(|b| *b == b'\0') {
268                    let caps = String::from_utf8_lossy(&pkt_line[(pos + 1)..]).to_string();
269                    self.parse_capabilities(&caps);
270                }
271                first_line = false;
272            }
273
274            let ref_command = self.parse_ref_command(&mut pkt_line.clone());
275            self.command_list.push(ref_command);
276        }
277    }
278
279    /// Handle git receive-pack operation (push)
280    pub async fn git_receive_pack_stream(
281        &mut self,
282        data_stream: ProtocolStream,
283    ) -> Result<Bytes, ProtocolError> {
284        // Collect all request data from stream
285        let mut request_data = BytesMut::new();
286        let mut stream = data_stream;
287
288        while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
289            let chunk = chunk_result
290                .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {e}")))?;
291            request_data.extend_from_slice(&chunk);
292        }
293
294        let mut protocol_bytes = request_data.freeze();
295        self.command_list.clear();
296        self.capabilities.clear();
297        self.set_wire_hash_kind(self.local_hash_kind);
298        let mut first_line = true;
299        let mut saw_flush = false;
300        loop {
301            let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
302
303            if bytes_take == 0 {
304                if protocol_bytes.is_empty() {
305                    break;
306                }
307                return Err(ProtocolError::invalid_request(
308                    "Invalid pkt-line in receive-pack request",
309                ));
310            }
311
312            if pkt_line.is_empty() {
313                saw_flush = true;
314                break;
315            }
316
317            if first_line {
318                if let Some(pos) = pkt_line.iter().position(|b| *b == b'\0') {
319                    let caps = String::from_utf8_lossy(&pkt_line[(pos + 1)..]).to_string();
320                    self.parse_capabilities(&caps);
321                }
322                first_line = false;
323            }
324
325            let ref_command = self.parse_ref_command(&mut pkt_line.clone());
326            self.command_list.push(ref_command);
327        }
328
329        if !saw_flush {
330            return Err(ProtocolError::invalid_request(
331                "Missing flush before pack data",
332            ));
333        }
334
335        // Remaining bytes (if any) are pack data.
336        let pack_data = if protocol_bytes.is_empty() {
337            None
338        } else {
339            Some(protocol_bytes)
340        };
341
342        if let Some(pack_data) = pack_data {
343            // Create pack generator for unpacking
344            let pack_generator = PackGenerator::new(&self.repo_storage);
345            // Unpack the received data
346            let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data).await?;
347
348            // Store the unpacked objects via the repository access trait
349            self.repo_storage
350                .handle_pack_objects(commits, trees, blobs)
351                .await
352                .map_err(|e| {
353                    ProtocolError::repository_error(format!("Failed to store pack objects: {e}"))
354                })?;
355        }
356
357        // Build status report
358        let mut report_status = BytesMut::new();
359        add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
360
361        let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
362            ProtocolError::repository_error(format!("Failed to check default branch: {e}"))
363        })?;
364
365        // Update refs with proper error handling
366        for command in &mut self.command_list {
367            if command.ref_type == RefTypeEnum::Tag {
368                // Just update if refs type is tag
369                // Convert zero_id to None for old hash
370                let old_hash = if command.old_hash == self.zero_id {
371                    None
372                } else {
373                    Some(command.old_hash.as_str())
374                };
375                if let Err(e) = self
376                    .repo_storage
377                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
378                    .await
379                {
380                    command.failed(e.to_string());
381                }
382            } else {
383                // Handle default branch setting for the first branch
384                if !default_exist {
385                    command.default_branch = true;
386                    default_exist = true;
387                }
388                // Convert zero_id to None for old hash
389                let old_hash = if command.old_hash == self.zero_id {
390                    None
391                } else {
392                    Some(command.old_hash.as_str())
393                };
394                if let Err(e) = self
395                    .repo_storage
396                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
397                    .await
398                {
399                    command.failed(e.to_string());
400                }
401            }
402            add_pkt_line_string(&mut report_status, command.get_status());
403        }
404
405        // Post-receive hook
406        self.repo_storage.post_receive_hook().await.map_err(|e| {
407            ProtocolError::repository_error(format!("Post-receive hook failed: {e}"))
408        })?;
409
410        report_status.put(&PKT_LINE_END_MARKER[..]);
411        Ok(report_status.freeze())
412    }
413
414    /// Builds the packet data in the sideband format if the SideBand/64k capability is enabled.
415    pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
416        let mut to_bytes = BytesMut::new();
417        if self.capabilities.contains(&Capability::SideBand)
418            || self.capabilities.contains(&Capability::SideBand64k)
419        {
420            let length = length + 5;
421            to_bytes.put(Bytes::from(format!("{length:04x}")));
422            to_bytes.put_u8(SideBand::PackfileData.value());
423            to_bytes.put(from_bytes);
424        } else {
425            to_bytes.put(from_bytes);
426        }
427        to_bytes
428    }
429
430    /// Parse capabilities from capability string
431    pub fn parse_capabilities(&mut self, cap_str: &str) {
432        for cap in cap_str.split_whitespace() {
433            if let Some(fmt) = cap.strip_prefix("object-format=") {
434                match fmt {
435                    "sha1" => self.set_wire_hash_kind(HashKind::Sha1),
436                    "sha256" => self.set_wire_hash_kind(HashKind::Sha256),
437                    _ => {
438                        tracing::warn!("Unknown object-format capability: {}", fmt);
439                    }
440                }
441                continue;
442            }
443            if let Ok(capability) = cap.parse::<Capability>() {
444                self.capabilities.push(capability);
445            }
446        }
447    }
448
449    /// Parse a reference command from packet line
450    pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
451        let old_id = read_until_white_space(pkt_line);
452        let new_id = read_until_white_space(pkt_line);
453        let ref_name = read_until_white_space(pkt_line);
454        let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
455
456        RefCommand::new(old_id, new_id, ref_name)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use std::sync::{
463        Arc, Mutex,
464        atomic::{AtomicBool, Ordering},
465    };
466
467    use async_trait::async_trait;
468    use bytes::BytesMut;
469    use futures;
470    use tokio::sync::mpsc;
471
472    use super::*;
473    use crate::protocol::utils; // import sibling module
474    use crate::{
475        hash::{HashKind, ObjectHash, set_hash_kind_for_test},
476        internal::{
477            metadata::{EntryMeta, MetaAttached},
478            object::{
479                blob::Blob,
480                commit::Commit,
481                signature::{Signature, SignatureType},
482                tree::{Tree, TreeItem, TreeItemMode},
483            },
484            pack::{encode::PackEncoder, entry::Entry},
485        },
486    };
487
488    // Simplify complex type via aliases to satisfy clippy::type_complexity
489    type UpdateRecord = (String, Option<String>, String);
490    type UpdateList = Vec<UpdateRecord>;
491    type SharedUpdates = Arc<Mutex<UpdateList>>;
492
493    /// Test repository access implementation for testing
494    #[derive(Clone)]
495    struct TestRepoAccess {
496        updates: SharedUpdates,
497        stored_count: Arc<Mutex<usize>>,
498        default_branch_exists: Arc<Mutex<bool>>,
499        post_called: Arc<AtomicBool>,
500    }
501
502    impl TestRepoAccess {
503        fn new() -> Self {
504            Self {
505                updates: Arc::new(Mutex::new(vec![])),
506                stored_count: Arc::new(Mutex::new(0)),
507                default_branch_exists: Arc::new(Mutex::new(false)),
508                post_called: Arc::new(AtomicBool::new(false)),
509            }
510        }
511
512        fn updates_len(&self) -> usize {
513            self.updates.lock().unwrap().len()
514        }
515
516        fn post_hook_called(&self) -> bool {
517            self.post_called.load(Ordering::SeqCst)
518        }
519    }
520
521    #[async_trait]
522    impl RepositoryAccess for TestRepoAccess {
523        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
524            Ok(vec![
525                (
526                    "HEAD".to_string(),
527                    "0000000000000000000000000000000000000000".to_string(),
528                ),
529                (
530                    "refs/heads/main".to_string(),
531                    "1111111111111111111111111111111111111111".to_string(),
532                ),
533            ])
534        }
535
536        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
537            Ok(true)
538        }
539
540        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
541            Ok(vec![])
542        }
543
544        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
545            *self.stored_count.lock().unwrap() += 1;
546            Ok(())
547        }
548
549        async fn update_reference(
550            &self,
551            ref_name: &str,
552            old_hash: Option<&str>,
553            new_hash: &str,
554        ) -> Result<(), ProtocolError> {
555            self.updates.lock().unwrap().push((
556                ref_name.to_string(),
557                old_hash.map(|s| s.to_string()),
558                new_hash.to_string(),
559            ));
560            Ok(())
561        }
562
563        async fn get_objects_for_pack(
564            &self,
565            _wants: &[String],
566            _haves: &[String],
567        ) -> Result<Vec<String>, ProtocolError> {
568            Ok(vec![])
569        }
570
571        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
572            let mut exists = self.default_branch_exists.lock().unwrap();
573            let current = *exists;
574            *exists = true; // flip to true after first check
575            Ok(current)
576        }
577
578        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
579            self.post_called.store(true, Ordering::SeqCst);
580            Ok(())
581        }
582    }
583
584    /// Test authentication service implementation for testing
585    struct TestAuth;
586
587    #[async_trait]
588    impl AuthenticationService for TestAuth {
589        async fn authenticate_http(
590            &self,
591            _headers: &std::collections::HashMap<String, String>,
592        ) -> Result<(), ProtocolError> {
593            Ok(())
594        }
595
596        async fn authenticate_ssh(
597            &self,
598            _username: &str,
599            _public_key: &[u8],
600        ) -> Result<(), ProtocolError> {
601            Ok(())
602        }
603    }
604
605    /// Receive-pack stream decodes the pack, updates refs, and reports status (SHA-1).
606    #[tokio::test]
607    async fn test_receive_pack_stream_status_report() {
608        let _guard = set_hash_kind_for_test(HashKind::Sha1);
609        // Build simple objects
610        let blob1 = Blob::from_content("hello");
611        let blob2 = Blob::from_content("world");
612
613        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
614        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
615        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
616
617        let author = Signature::new(
618            SignatureType::Author,
619            "tester".to_string(),
620            "tester@example.com".to_string(),
621        );
622        let committer = Signature::new(
623            SignatureType::Committer,
624            "tester".to_string(),
625            "tester@example.com".to_string(),
626        );
627        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
628
629        // Encode pack bytes via PackEncoder
630        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
631        let (entry_tx, entry_rx) = mpsc::channel(1024);
632        let mut encoder = PackEncoder::new(4, 10, pack_tx);
633
634        tokio::spawn(async move {
635            if let Err(e) = encoder.encode(entry_rx).await {
636                panic!("Failed to encode pack: {}", e);
637            }
638        });
639
640        let commit_clone = commit.clone();
641        let tree_clone = tree.clone();
642        let blob1_clone = blob1.clone();
643        let blob2_clone = blob2.clone();
644        tokio::spawn(async move {
645            let _ = entry_tx
646                .send(MetaAttached {
647                    inner: Entry::from(commit_clone),
648                    meta: EntryMeta::new(),
649                })
650                .await;
651            let _ = entry_tx
652                .send(MetaAttached {
653                    inner: Entry::from(tree_clone),
654                    meta: EntryMeta::new(),
655                })
656                .await;
657            let _ = entry_tx
658                .send(MetaAttached {
659                    inner: Entry::from(blob1_clone),
660                    meta: EntryMeta::new(),
661                })
662                .await;
663            let _ = entry_tx
664                .send(MetaAttached {
665                    inner: Entry::from(blob2_clone),
666                    meta: EntryMeta::new(),
667                })
668                .await;
669            // sender drop indicates end
670        });
671
672        let mut pack_bytes: Vec<u8> = Vec::new();
673        while let Some(chunk) = pack_rx.recv().await {
674            pack_bytes.extend_from_slice(&chunk);
675        }
676
677        // Prepare protocol and request
678        let repo_access = TestRepoAccess::new();
679        let auth = TestAuth;
680        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
681        smart.set_wire_hash_kind(HashKind::Sha1);
682
683        let mut request = BytesMut::new();
684        add_pkt_line_string(
685            &mut request,
686            format!(
687                "{} {} refs/heads/main\0report-status\n",
688                smart.zero_id, commit.id
689            ),
690        );
691        request.put(&PKT_LINE_END_MARKER[..]);
692        request.extend_from_slice(&pack_bytes);
693
694        // Create request stream
695        let request_stream = Box::pin(futures::stream::once(async { Ok(request.freeze()) }));
696
697        // Execute receive-pack
698        let result_bytes = smart
699            .git_receive_pack_stream(request_stream)
700            .await
701            .expect("receive-pack should succeed");
702
703        // Verify pkt-lines
704        let mut out = result_bytes.clone();
705        let (_c1, l1) = utils::read_pkt_line(&mut out);
706        assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
707
708        let (_c2, l2) = utils::read_pkt_line(&mut out);
709        assert_eq!(
710            String::from_utf8(l2.to_vec()).unwrap(),
711            "ok refs/heads/main"
712        );
713
714        let (c3, l3) = utils::read_pkt_line(&mut out);
715        assert_eq!(c3, 4);
716        assert!(l3.is_empty());
717
718        // Verify side effects
719        assert_eq!(repo_access.updates_len(), 1);
720        assert!(repo_access.post_hook_called());
721    }
722
723    /// info-refs rejects SHA-256 wire format when repository refs are still SHA-1.
724    #[tokio::test]
725    async fn info_refs_rejects_sha256_with_sha1_refs() {
726        let _guard = set_hash_kind_for_test(HashKind::Sha1); // avoid thread-local contamination
727        let repo_access = TestRepoAccess::new(); // still returns 40-char strings
728        let auth = TestAuth;
729        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
730        smart.set_wire_hash_kind(HashKind::Sha256); // claims wire uses SHA-256
731        // expect failure because refs are SHA-1
732        let res = smart.git_info_refs(ServiceType::UploadPack).await;
733        assert!(res.is_err(), "expected failure when hash lengths mismatch");
734
735        smart.set_wire_hash_kind(HashKind::Sha1);
736
737        let res = smart.git_info_refs(ServiceType::UploadPack).await;
738        assert!(
739            res.is_ok(),
740            "expected SHA1 refs to be accepted when wire is SHA1"
741        );
742    }
743
744    /// parse_capabilities should switch wire hash kind and record declared capabilities.
745    #[tokio::test]
746    async fn parse_capabilities_updates_hash_and_caps() {
747        let _guard = set_hash_kind_for_test(HashKind::Sha1);
748        let repo_access = TestRepoAccess::new();
749        let auth = TestAuth;
750        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
751
752        smart.parse_capabilities("object-format=sha256 side-band-64k multi_ack");
753
754        assert_eq!(smart.wire_hash_kind, HashKind::Sha256);
755        assert_eq!(smart.zero_id.len(), HashKind::Sha256.hex_len());
756        assert!(
757            smart.capabilities.contains(&Capability::SideBand64k),
758            "side-band-64k should be recorded"
759        );
760    }
761
762    /// info-refs should accept SHA-256 refs and emit the matching object-format capability.
763    #[tokio::test]
764    async fn info_refs_accepts_sha256_refs_and_emits_capability() {
765        // Define a repo access that returns SHA-256 refs
766        #[derive(Clone)]
767        struct Sha256Repo;
768
769        #[async_trait]
770        impl RepositoryAccess for Sha256Repo {
771            async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
772                Ok(vec![
773                    (
774                        "HEAD".to_string(),
775                        "0000000000000000000000000000000000000000000000000000000000000000"
776                            .to_string(),
777                    ),
778                    (
779                        "refs/heads/main".to_string(),
780                        "1111111111111111111111111111111111111111111111111111111111111111"
781                            .to_string(),
782                    ),
783                ])
784            }
785            async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
786                Ok(true)
787            }
788            async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
789                Ok(vec![])
790            }
791            async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
792                Ok(())
793            }
794            async fn update_reference(
795                &self,
796                _ref_name: &str,
797                _old_hash: Option<&str>,
798                _new_hash: &str,
799            ) -> Result<(), ProtocolError> {
800                Ok(())
801            }
802            async fn get_objects_for_pack(
803                &self,
804                _wants: &[String],
805                _haves: &[String],
806            ) -> Result<Vec<String>, ProtocolError> {
807                Ok(vec![])
808            }
809            async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
810                Ok(false)
811            }
812            async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
813                Ok(())
814            }
815        }
816
817        let _guard = set_hash_kind_for_test(HashKind::Sha1);
818        let repo_access = Sha256Repo;
819        let auth = TestAuth;
820        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
821        smart.set_wire_hash_kind(HashKind::Sha256);
822
823        let resp = smart
824            .git_info_refs(ServiceType::UploadPack)
825            .await
826            .expect("sha256 refs should be accepted");
827        let resp_str = String::from_utf8(resp.to_vec()).expect("pkt-line should be valid UTF-8");
828        assert!(
829            resp_str.contains("object-format=sha256"),
830            "capability line should advertise sha256"
831        );
832    }
833
834    /// parse_receive_pack_commands should decode multiple pkt-lines into RefCommand list.
835    #[tokio::test]
836    async fn parse_receive_pack_commands_decodes_commands() {
837        let _guard = set_hash_kind_for_test(HashKind::Sha1);
838        let repo_access = TestRepoAccess::new();
839        let auth = TestAuth;
840        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
841
842        let zero = ObjectHash::zero_str(HashKind::Sha1);
843        let mut pkt = BytesMut::new();
844        add_pkt_line_string(&mut pkt, format!("{zero} {zero} refs/heads/main\n"));
845        add_pkt_line_string(&mut pkt, format!("{zero} {zero} refs/tags/v1.0\n"));
846        pkt.put(&PKT_LINE_END_MARKER[..]);
847
848        smart.parse_receive_pack_commands(pkt.freeze());
849
850        assert_eq!(smart.command_list.len(), 2);
851        assert_eq!(smart.command_list[0].ref_name, "refs/heads/main");
852        assert_eq!(smart.command_list[1].ref_name, "refs/tags/v1.0");
853    }
854
855    /// receive-pack should error if ref commands are not terminated by a flush.
856    #[tokio::test]
857    async fn receive_pack_missing_flush_errors() {
858        let _guard = set_hash_kind_for_test(HashKind::Sha1);
859        let repo_access = TestRepoAccess::new();
860        let auth = TestAuth;
861        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
862
863        let zero = ObjectHash::zero_str(HashKind::Sha1);
864        let mut pkt = BytesMut::new();
865        add_pkt_line_string(&mut pkt, format!("{zero} {zero} refs/heads/main\n"));
866
867        let request_stream = Box::pin(futures::stream::once(async { Ok(pkt.freeze()) }));
868        let err = smart
869            .git_receive_pack_stream(request_stream)
870            .await
871            .unwrap_err();
872        assert!(matches!(err, ProtocolError::InvalidRequest(_)));
873    }
874}