git_internal/protocol/
smart.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use std::collections::HashMap;
3use tokio_stream::wrappers::ReceiverStream;
4
5use super::core::{AuthenticationService, RepositoryAccess};
6use super::pack::PackGenerator;
7use super::types::ProtocolError;
8use super::types::{
9    COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolStream, RECEIVE_CAP_LIST,
10    RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol, UPLOAD_CAP_LIST,
11};
12use super::utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space};
13use crate::hash::{HashKind, ObjectHash, get_hash_kind};
14/// Smart Git Protocol implementation
15///
16/// This struct handles the Git smart protocol operations for both HTTP and SSH transports.
17/// It uses trait abstractions to decouple from specific business logic implementations.
18pub struct SmartProtocol<R, A>
19where
20    R: RepositoryAccess,
21    A: AuthenticationService,
22{
23    pub transport_protocol: TransportProtocol,
24    pub capabilities: Vec<Capability>,
25    pub side_band: Option<SideBand>,
26    pub command_list: Vec<RefCommand>,
27    pub wire_hash_kind: HashKind,
28    pub local_hash_kind: HashKind,
29    pub zero_id: String,
30    // Trait-based dependencies
31    repo_storage: R,
32    auth_service: A,
33}
34
35impl<R, A> SmartProtocol<R, A>
36where
37    R: RepositoryAccess,
38    A: AuthenticationService,
39{
40    pub fn set_wire_hash_kind(&mut self, kind: HashKind) {
41        self.wire_hash_kind = kind;
42        self.zero_id = ObjectHash::zero_str(kind);
43    }
44
45    /// Create a new SmartProtocol instance
46    pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
47        Self {
48            transport_protocol,
49            capabilities: Vec::new(),
50            side_band: None,
51            command_list: Vec::new(),
52            repo_storage,
53            auth_service,
54            wire_hash_kind: HashKind::default(), // Default to SHA-1
55            local_hash_kind: get_hash_kind(),
56            zero_id: ObjectHash::zero_str(HashKind::default()),
57        }
58    }
59
60    /// Authenticate an HTTP request using the injected auth service
61    pub async fn authenticate_http(
62        &self,
63        headers: &HashMap<String, String>,
64    ) -> Result<(), ProtocolError> {
65        self.auth_service.authenticate_http(headers).await
66    }
67
68    /// Authenticate an SSH session using username and public key
69    pub async fn authenticate_ssh(
70        &self,
71        username: &str,
72        public_key: &[u8],
73    ) -> Result<(), ProtocolError> {
74        self.auth_service
75            .authenticate_ssh(username, public_key)
76            .await
77    }
78
79    /// Set transport protocol (Http, Ssh, etc.)
80    pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
81        self.transport_protocol = protocol;
82    }
83
84    /// Get git info refs for the repository, with explicit service type
85    pub async fn git_info_refs(
86        &self,
87        service_type: ServiceType,
88    ) -> Result<BytesMut, ProtocolError> {
89        let refs =
90            self.repo_storage.get_repository_refs().await.map_err(|e| {
91                ProtocolError::repository_error(format!("Failed to get refs: {}", e))
92            })?;
93        let hex_len = self.wire_hash_kind.hex_len();
94        for (name, h) in &refs {
95            if h.len() != hex_len {
96                return Err(ProtocolError::invalid_request(&format!(
97                    "Hash length mismatch for ref {}: expected {}, got {}",
98                    name,
99                    hex_len,
100                    h.len()
101                )));
102            }
103        } // Ensure refs match the expected wire hash kind
104        // Convert to the expected format (head_hash, git_refs)
105        let head_hash = refs
106            .iter()
107            .find(|(name, _)| {
108                name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
109            })
110            .map(|(_, hash)| hash.clone())
111            .unwrap_or_else(|| self.zero_id.clone());
112
113        let git_refs: Vec<super::types::GitRef> = refs
114            .into_iter()
115            .map(|(name, hash)| super::types::GitRef { name, hash })
116            .collect();
117        // capability add object-format,declare the wire hash kind
118        let format_cap = match self.wire_hash_kind {
119            HashKind::Sha1 => " object-format=sha1",
120            HashKind::Sha256 => " object-format=sha256",
121        };
122        // Determine capabilities based on service type
123        let cap_list = match service_type {
124            ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
125            ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
126        };
127
128        // The stream MUST include capability declarations behind a NUL on the first ref.
129        let name = if head_hash == self.zero_id {
130            "capabilities^{}"
131        } else {
132            "HEAD"
133        };
134        let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
135        let mut ref_list = vec![pkt_line];
136
137        for git_ref in git_refs {
138            let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
139            ref_list.push(pkt_line);
140        }
141
142        let pkt_line_stream =
143            build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
144        tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
145        Ok(pkt_line_stream)
146    }
147
148    /// Handle git-upload-pack request
149    pub async fn git_upload_pack(
150        &mut self,
151        upload_request: Bytes,
152    ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
153        let mut upload_request = upload_request;
154        let mut want: Vec<String> = Vec::new();
155        let mut have: Vec<String> = Vec::new();
156        let mut last_common_commit = String::new();
157
158        let mut read_first_line = false;
159        loop {
160            let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
161
162            if bytes_take == 0 {
163                break;
164            }
165
166            if pkt_line.is_empty() {
167                break;
168            }
169
170            let mut pkt_line = pkt_line;
171            let command = read_until_white_space(&mut pkt_line);
172
173            match command.as_str() {
174                "want" => {
175                    let hash = read_until_white_space(&mut pkt_line);
176                    want.push(hash);
177                    if !read_first_line {
178                        let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
179                        self.parse_capabilities(&cap_str);
180                        read_first_line = true;
181                    }
182                }
183                "have" => {
184                    let hash = read_until_white_space(&mut pkt_line);
185                    have.push(hash);
186                }
187                "done" => {
188                    break;
189                }
190                _ => {
191                    tracing::warn!("Unknown upload-pack command: {}", command);
192                }
193            }
194        }
195
196        let mut protocol_buf = BytesMut::new();
197
198        // Create pack generator for this operation
199        let pack_generator = PackGenerator::new(&self.repo_storage);
200
201        if have.is_empty() {
202            // Full pack
203            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
204            let pack_stream = pack_generator.generate_full_pack(want).await?;
205            return Ok((pack_stream, protocol_buf));
206        }
207
208        // Check for common commits
209        for hash in &have {
210            let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
211                ProtocolError::repository_error(format!("Failed to check commit existence: {}", e))
212            })?;
213            if exists {
214                add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
215                if last_common_commit.is_empty() {
216                    last_common_commit = hash.clone();
217                }
218            }
219        }
220
221        if last_common_commit.is_empty() {
222            // No common commits found
223            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
224            let pack_stream = pack_generator.generate_full_pack(want).await?;
225            return Ok((pack_stream, protocol_buf));
226        }
227
228        // Generate incremental pack
229        add_pkt_line_string(
230            &mut protocol_buf,
231            format!("ACK {last_common_commit} ready\n"),
232        );
233        protocol_buf.put(&PKT_LINE_END_MARKER[..]);
234
235        add_pkt_line_string(&mut protocol_buf, format!("ACK {last_common_commit} \n"));
236
237        let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
238
239        Ok((pack_stream, protocol_buf))
240    }
241
242    /// Parse receive pack commands from protocol bytes
243    pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
244        loop {
245            let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
246
247            if bytes_take == 0 {
248                break;
249            }
250
251            if pkt_line.is_empty() {
252                break;
253            }
254
255            let ref_command = self.parse_ref_command(&mut pkt_line.clone());
256            self.command_list.push(ref_command);
257        }
258    }
259
260    /// Handle git receive-pack operation (push)
261    pub async fn git_receive_pack_stream(
262        &mut self,
263        data_stream: ProtocolStream,
264    ) -> Result<Bytes, ProtocolError> {
265        // Collect all pack data from stream
266        let mut pack_data = BytesMut::new();
267        let mut stream = data_stream;
268
269        while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
270            let chunk = chunk_result
271                .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {}", e)))?;
272            pack_data.extend_from_slice(&chunk);
273        }
274
275        // Create pack generator for unpacking
276        let pack_generator = PackGenerator::new(&self.repo_storage);
277
278        // Unpack the received data
279        let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data.freeze()).await?;
280
281        // Store the unpacked objects via the repository access trait
282        self.repo_storage
283            .handle_pack_objects(commits, trees, blobs)
284            .await
285            .map_err(|e| {
286                ProtocolError::repository_error(format!("Failed to store pack objects: {}", e))
287            })?;
288
289        // Build status report
290        let mut report_status = BytesMut::new();
291        add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
292
293        let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
294            ProtocolError::repository_error(format!("Failed to check default branch: {}", e))
295        })?;
296
297        // Update refs with proper error handling
298        for command in &mut self.command_list {
299            if command.ref_type == RefTypeEnum::Tag {
300                // Just update if refs type is tag
301                // Convert zero_id to None for old hash
302                let old_hash = if command.old_hash == self.zero_id {
303                    None
304                } else {
305                    Some(command.old_hash.as_str())
306                };
307                if let Err(e) = self
308                    .repo_storage
309                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
310                    .await
311                {
312                    command.failed(e.to_string());
313                }
314            } else {
315                // Handle default branch setting for the first branch
316                if !default_exist {
317                    command.default_branch = true;
318                    default_exist = true;
319                }
320                // Convert zero_id to None for old hash
321                let old_hash = if command.old_hash == self.zero_id {
322                    None
323                } else {
324                    Some(command.old_hash.as_str())
325                };
326                if let Err(e) = self
327                    .repo_storage
328                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
329                    .await
330                {
331                    command.failed(e.to_string());
332                }
333            }
334            add_pkt_line_string(&mut report_status, command.get_status());
335        }
336
337        // Post-receive hook
338        self.repo_storage.post_receive_hook().await.map_err(|e| {
339            ProtocolError::repository_error(format!("Post-receive hook failed: {}", e))
340        })?;
341
342        report_status.put(&PKT_LINE_END_MARKER[..]);
343        Ok(report_status.freeze())
344    }
345
346    /// Builds the packet data in the sideband format if the SideBand/64k capability is enabled.
347    pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
348        let mut to_bytes = BytesMut::new();
349        if self.capabilities.contains(&Capability::SideBand)
350            || self.capabilities.contains(&Capability::SideBand64k)
351        {
352            let length = length + 5;
353            to_bytes.put(Bytes::from(format!("{length:04x}")));
354            to_bytes.put_u8(SideBand::PackfileData.value());
355            to_bytes.put(from_bytes);
356        } else {
357            to_bytes.put(from_bytes);
358        }
359        to_bytes
360    }
361
362    /// Parse capabilities from capability string
363    pub fn parse_capabilities(&mut self, cap_str: &str) {
364        for cap in cap_str.split_whitespace() {
365            if let Some(fmt) = cap.strip_prefix("object-format=") {
366                match fmt {
367                    "sha1" => self.set_wire_hash_kind(HashKind::Sha1),
368                    "sha256" => self.set_wire_hash_kind(HashKind::Sha256),
369                    _ => {
370                        tracing::warn!("Unknown object-format capability: {}", fmt);
371                    }
372                }
373                continue;
374            }
375            if let Ok(capability) = cap.parse::<Capability>() {
376                self.capabilities.push(capability);
377            }
378        }
379    }
380
381    /// Parse a reference command from packet line
382    pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
383        let old_id = read_until_white_space(pkt_line);
384        let new_id = read_until_white_space(pkt_line);
385        let ref_name = read_until_white_space(pkt_line);
386        let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
387
388        RefCommand::new(old_id, new_id, ref_name)
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::hash::{HashKind, set_hash_kind_for_test};
396    use crate::internal::metadata::{EntryMeta, MetaAttached};
397    use crate::internal::object::blob::Blob;
398    use crate::internal::object::commit::Commit;
399    use crate::internal::object::signature::{Signature, SignatureType};
400    use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
401    use crate::internal::pack::{encode::PackEncoder, entry::Entry};
402    use crate::protocol::types::RefCommand; // import sibling types
403    use crate::protocol::utils; // import sibling module
404    use async_trait::async_trait;
405    use bytes::Bytes;
406    use futures;
407    use std::sync::{
408        Arc, Mutex,
409        atomic::{AtomicBool, Ordering},
410    };
411    use tokio::sync::mpsc;
412
413    // Simplify complex type via aliases to satisfy clippy::type_complexity
414    type UpdateRecord = (String, Option<String>, String);
415    type UpdateList = Vec<UpdateRecord>;
416    type SharedUpdates = Arc<Mutex<UpdateList>>;
417
418    #[derive(Clone)]
419    struct TestRepoAccess {
420        updates: SharedUpdates,
421        stored_count: Arc<Mutex<usize>>,
422        default_branch_exists: Arc<Mutex<bool>>,
423        post_called: Arc<AtomicBool>,
424    }
425
426    impl TestRepoAccess {
427        fn new() -> Self {
428            Self {
429                updates: Arc::new(Mutex::new(vec![])),
430                stored_count: Arc::new(Mutex::new(0)),
431                default_branch_exists: Arc::new(Mutex::new(false)),
432                post_called: Arc::new(AtomicBool::new(false)),
433            }
434        }
435
436        fn updates_len(&self) -> usize {
437            self.updates.lock().unwrap().len()
438        }
439
440        fn post_hook_called(&self) -> bool {
441            self.post_called.load(Ordering::SeqCst)
442        }
443    }
444
445    #[async_trait]
446    impl RepositoryAccess for TestRepoAccess {
447        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
448            Ok(vec![
449                (
450                    "HEAD".to_string(),
451                    "0000000000000000000000000000000000000000".to_string(),
452                ),
453                (
454                    "refs/heads/main".to_string(),
455                    "1111111111111111111111111111111111111111".to_string(),
456                ),
457            ])
458        }
459
460        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
461            Ok(true)
462        }
463
464        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
465            Ok(vec![])
466        }
467
468        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
469            *self.stored_count.lock().unwrap() += 1;
470            Ok(())
471        }
472
473        async fn update_reference(
474            &self,
475            ref_name: &str,
476            old_hash: Option<&str>,
477            new_hash: &str,
478        ) -> Result<(), ProtocolError> {
479            self.updates.lock().unwrap().push((
480                ref_name.to_string(),
481                old_hash.map(|s| s.to_string()),
482                new_hash.to_string(),
483            ));
484            Ok(())
485        }
486
487        async fn get_objects_for_pack(
488            &self,
489            _wants: &[String],
490            _haves: &[String],
491        ) -> Result<Vec<String>, ProtocolError> {
492            Ok(vec![])
493        }
494
495        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
496            let mut exists = self.default_branch_exists.lock().unwrap();
497            let current = *exists;
498            *exists = true; // flip to true after first check
499            Ok(current)
500        }
501
502        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
503            self.post_called.store(true, Ordering::SeqCst);
504            Ok(())
505        }
506    }
507
508    struct TestAuth;
509
510    #[async_trait]
511    impl AuthenticationService for TestAuth {
512        async fn authenticate_http(
513            &self,
514            _headers: &std::collections::HashMap<String, String>,
515        ) -> Result<(), ProtocolError> {
516            Ok(())
517        }
518
519        async fn authenticate_ssh(
520            &self,
521            _username: &str,
522            _public_key: &[u8],
523        ) -> Result<(), ProtocolError> {
524            Ok(())
525        }
526    }
527
528    #[tokio::test]
529    async fn test_receive_pack_stream_status_report() {
530        let _guard = set_hash_kind_for_test(HashKind::Sha1);
531        // Build simple objects
532        let blob1 = Blob::from_content("hello");
533        let blob2 = Blob::from_content("world");
534
535        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
536        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
537        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
538
539        let author = Signature::new(
540            SignatureType::Author,
541            "tester".to_string(),
542            "tester@example.com".to_string(),
543        );
544        let committer = Signature::new(
545            SignatureType::Committer,
546            "tester".to_string(),
547            "tester@example.com".to_string(),
548        );
549        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
550
551        // Encode pack bytes via PackEncoder
552        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
553        let (entry_tx, entry_rx) = mpsc::channel(1024);
554        let mut encoder = PackEncoder::new(4, 10, pack_tx);
555
556        tokio::spawn(async move {
557            if let Err(e) = encoder.encode(entry_rx).await {
558                panic!("Failed to encode pack: {}", e);
559            }
560        });
561
562        let commit_clone = commit.clone();
563        let tree_clone = tree.clone();
564        let blob1_clone = blob1.clone();
565        let blob2_clone = blob2.clone();
566        tokio::spawn(async move {
567            let _ = entry_tx
568                .send(MetaAttached {
569                    inner: Entry::from(commit_clone),
570                    meta: EntryMeta::new(),
571                })
572                .await;
573            let _ = entry_tx
574                .send(MetaAttached {
575                    inner: Entry::from(tree_clone),
576                    meta: EntryMeta::new(),
577                })
578                .await;
579            let _ = entry_tx
580                .send(MetaAttached {
581                    inner: Entry::from(blob1_clone),
582                    meta: EntryMeta::new(),
583                })
584                .await;
585            let _ = entry_tx
586                .send(MetaAttached {
587                    inner: Entry::from(blob2_clone),
588                    meta: EntryMeta::new(),
589                })
590                .await;
591            // sender drop indicates end
592        });
593
594        let mut pack_bytes: Vec<u8> = Vec::new();
595        while let Some(chunk) = pack_rx.recv().await {
596            pack_bytes.extend_from_slice(&chunk);
597        }
598
599        // Prepare protocol and command
600        let repo_access = TestRepoAccess::new();
601        let auth = TestAuth;
602        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
603        smart.set_wire_hash_kind(HashKind::Sha1);
604        smart.command_list.push(RefCommand::new(
605            smart.zero_id.to_string(),
606            commit.id.to_string(),
607            "refs/heads/main".to_string(),
608        ));
609
610        // Create request stream
611        let request_stream = Box::pin(futures::stream::once(async { Ok(Bytes::from(pack_bytes)) }));
612
613        // Execute receive-pack
614        let result_bytes = smart
615            .git_receive_pack_stream(request_stream)
616            .await
617            .expect("receive-pack should succeed");
618
619        // Verify pkt-lines
620        let mut out = result_bytes.clone();
621        let (_c1, l1) = utils::read_pkt_line(&mut out);
622        assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
623
624        let (_c2, l2) = utils::read_pkt_line(&mut out);
625        assert_eq!(
626            String::from_utf8(l2.to_vec()).unwrap(),
627            "ok refs/heads/main"
628        );
629
630        let (c3, l3) = utils::read_pkt_line(&mut out);
631        assert_eq!(c3, 4);
632        assert!(l3.is_empty());
633
634        // Verify side effects
635        assert_eq!(repo_access.updates_len(), 1);
636        assert!(repo_access.post_hook_called());
637    }
638
639    #[tokio::test]
640    async fn info_refs_rejects_sha256_with_sha1_refs() {
641        let _guard = set_hash_kind_for_test(HashKind::Sha1); // avoid thread-local contamination
642        let repo_access = TestRepoAccess::new(); // still returns 40-char strings
643        let auth = TestAuth;
644        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
645        smart.set_wire_hash_kind(HashKind::Sha256); // claims wire uses SHA-256
646        // expect failure because refs are SHA-1
647        let res = smart.git_info_refs(ServiceType::UploadPack).await;
648        assert!(res.is_err(), "expected failure when hash lengths mismatch");
649
650        smart.set_wire_hash_kind(HashKind::Sha1);
651
652        let res = smart.git_info_refs(ServiceType::UploadPack).await;
653        assert!(
654            res.is_ok(),
655            "expected SHA1 refs to be accepted when wire is SHA1"
656        );
657    }
658}