git_internal/protocol/
core.rs

1//! Core Git protocol implementation
2//!
3//! This module provides the main `GitProtocol` struct and `RepositoryAccess` trait
4//! that form the core interface of the git-internal library.
5use std::{collections::HashMap, str::FromStr};
6
7use async_trait::async_trait;
8use bytes::{BufMut, Bytes, BytesMut};
9use futures::stream::StreamExt;
10
11use crate::{
12    hash::ObjectHash,
13    internal::object::ObjectTrait,
14    protocol::{
15        smart::SmartProtocol,
16        types::{Capability, ProtocolError, ProtocolStream, ServiceType, SideBand},
17    },
18};
19
20/// Repository access trait for storage operations
21///
22/// This trait only handles storage-level operations, not Git protocol details.
23/// The git-internal library handles all Git protocol formatting and parsing.
24#[async_trait]
25pub trait RepositoryAccess: Send + Sync + Clone {
26    /// Get repository references as raw (name, hash) pairs
27    async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError>;
28
29    /// Check if an object exists in the repository
30    async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError>;
31
32    /// Get raw object data by hash
33    async fn get_object(&self, object_hash: &str) -> Result<Vec<u8>, ProtocolError>;
34
35    /// Store pack data in the repository
36    async fn store_pack_data(&self, pack_data: &[u8]) -> Result<(), ProtocolError>;
37
38    /// Update a single reference
39    async fn update_reference(
40        &self,
41        ref_name: &str,
42        old_hash: Option<&str>,
43        new_hash: &str,
44    ) -> Result<(), ProtocolError>;
45
46    /// Get objects needed for pack generation
47    async fn get_objects_for_pack(
48        &self,
49        wants: &[String],
50        haves: &[String],
51    ) -> Result<Vec<String>, ProtocolError>;
52
53    /// Check if repository has a default branch
54    async fn has_default_branch(&self) -> Result<bool, ProtocolError>;
55
56    /// Post-receive hook after successful push
57    async fn post_receive_hook(&self) -> Result<(), ProtocolError>;
58
59    /// Get blob data by hash
60    ///
61    /// Default implementation parses the object data using the internal object module.
62    /// Override this method if you need custom blob handling logic.
63    async fn get_blob(
64        &self,
65        object_hash: &str,
66    ) -> Result<crate::internal::object::blob::Blob, ProtocolError> {
67        let data = self.get_object(object_hash).await?;
68        let hash = ObjectHash::from_str(object_hash)
69            .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
70
71        crate::internal::object::blob::Blob::from_bytes(&data, hash)
72            .map_err(|e| ProtocolError::repository_error(format!("Failed to parse blob: {e}")))
73    }
74
75    /// Get commit data by hash
76    ///
77    /// Default implementation parses the object data using the internal object module.
78    /// Override this method if you need custom commit handling logic.
79    async fn get_commit(
80        &self,
81        commit_hash: &str,
82    ) -> Result<crate::internal::object::commit::Commit, ProtocolError> {
83        let data = self.get_object(commit_hash).await?;
84        let hash = ObjectHash::from_str(commit_hash)
85            .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
86
87        crate::internal::object::commit::Commit::from_bytes(&data, hash)
88            .map_err(|e| ProtocolError::repository_error(format!("Failed to parse commit: {e}")))
89    }
90
91    /// Get tree data by hash
92    ///
93    /// Default implementation parses the object data using the internal object module.
94    /// Override this method if you need custom tree handling logic.
95    async fn get_tree(
96        &self,
97        tree_hash: &str,
98    ) -> Result<crate::internal::object::tree::Tree, ProtocolError> {
99        let data = self.get_object(tree_hash).await?;
100        let hash = ObjectHash::from_str(tree_hash)
101            .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
102
103        crate::internal::object::tree::Tree::from_bytes(&data, hash)
104            .map_err(|e| ProtocolError::repository_error(format!("Failed to parse tree: {e}")))
105    }
106
107    /// Check if a commit exists
108    ///
109    /// Default implementation checks object existence and validates it's a commit.
110    /// Override this method if you have more efficient commit existence checking.
111    async fn commit_exists(&self, commit_hash: &str) -> Result<bool, ProtocolError> {
112        match self.has_object(commit_hash).await {
113            Ok(exists) => {
114                if !exists {
115                    return Ok(false);
116                }
117
118                // Verify it's actually a commit by trying to parse it
119                match self.get_commit(commit_hash).await {
120                    Ok(_) => Ok(true),
121                    Err(_) => Ok(false), // Object exists but is not a valid commit
122                }
123            }
124            Err(e) => Err(e),
125        }
126    }
127
128    /// Handle pack objects after unpacking
129    ///
130    /// Default implementation stores each object individually using store_pack_data.
131    /// Override this method if you need batch processing or custom storage logic.
132    async fn handle_pack_objects(
133        &self,
134        commits: Vec<crate::internal::object::commit::Commit>,
135        trees: Vec<crate::internal::object::tree::Tree>,
136        blobs: Vec<crate::internal::object::blob::Blob>,
137    ) -> Result<(), ProtocolError> {
138        // Store blobs
139        for blob in blobs {
140            let data = blob.to_data().map_err(|e| {
141                ProtocolError::repository_error(format!("Failed to serialize blob: {e}"))
142            })?;
143            self.store_pack_data(&data).await.map_err(|e| {
144                ProtocolError::repository_error(format!("Failed to store blob {}: {}", blob.id, e))
145            })?;
146        }
147
148        // Store trees
149        for tree in trees {
150            let data = tree.to_data().map_err(|e| {
151                ProtocolError::repository_error(format!("Failed to serialize tree: {e}"))
152            })?;
153            self.store_pack_data(&data).await.map_err(|e| {
154                ProtocolError::repository_error(format!("Failed to store tree {}: {}", tree.id, e))
155            })?;
156        }
157
158        // Store commits
159        for commit in commits {
160            let data = commit.to_data().map_err(|e| {
161                ProtocolError::repository_error(format!("Failed to serialize commit: {e}"))
162            })?;
163            self.store_pack_data(&data).await.map_err(|e| {
164                ProtocolError::repository_error(format!(
165                    "Failed to store commit {}: {}",
166                    commit.id, e
167                ))
168            })?;
169        }
170
171        Ok(())
172    }
173}
174
175/// Authentication service trait
176#[async_trait]
177pub trait AuthenticationService: Send + Sync {
178    /// Authenticate HTTP request
179    async fn authenticate_http(
180        &self,
181        headers: &std::collections::HashMap<String, String>,
182    ) -> Result<(), ProtocolError>;
183
184    /// Authenticate SSH public key
185    async fn authenticate_ssh(
186        &self,
187        username: &str,
188        public_key: &[u8],
189    ) -> Result<(), ProtocolError>;
190}
191
192/// Transport-agnostic Git smart protocol handler
193/// Main Git protocol handler
194///
195/// This struct provides the core Git protocol implementation that works
196/// across HTTP, SSH, and other transports. It uses SmartProtocol internally
197/// to handle all Git protocol details.
198pub struct GitProtocol<R: RepositoryAccess, A: AuthenticationService> {
199    smart_protocol: SmartProtocol<R, A>,
200}
201
202impl<R: RepositoryAccess, A: AuthenticationService> GitProtocol<R, A> {
203    /// Create a new GitProtocol instance
204    pub fn new(repo_access: R, auth_service: A) -> Self {
205        Self {
206            smart_protocol: SmartProtocol::new(
207                super::types::TransportProtocol::Http,
208                repo_access,
209                auth_service,
210            ),
211        }
212    }
213
214    /// Authenticate HTTP request before serving Git operations
215    pub async fn authenticate_http(
216        &self,
217        headers: &HashMap<String, String>,
218    ) -> Result<(), ProtocolError> {
219        self.smart_protocol.authenticate_http(headers).await
220    }
221
222    /// Authenticate SSH session before serving Git operations
223    pub async fn authenticate_ssh(
224        &self,
225        username: &str,
226        public_key: &[u8],
227    ) -> Result<(), ProtocolError> {
228        self.smart_protocol
229            .authenticate_ssh(username, public_key)
230            .await
231    }
232
233    /// Set transport protocol (Http, Ssh, etc.)
234    pub fn set_transport(&mut self, protocol: super::types::TransportProtocol) {
235        self.smart_protocol.set_transport_protocol(protocol);
236    }
237
238    /// Handle git info-refs request
239    pub async fn info_refs(&self, service: &str) -> Result<Vec<u8>, ProtocolError> {
240        let service_type = match service {
241            "git-upload-pack" => ServiceType::UploadPack,
242            "git-receive-pack" => ServiceType::ReceivePack,
243            _ => return Err(ProtocolError::invalid_service(service)),
244        };
245
246        let bytes = self.smart_protocol.git_info_refs(service_type).await?;
247        Ok(bytes.to_vec())
248    }
249
250    /// Handle git-upload-pack request (for clone/fetch)
251    pub async fn upload_pack(
252        &mut self,
253        request_data: &[u8],
254    ) -> Result<ProtocolStream, ProtocolError> {
255        const SIDE_BAND_PACKET_LEN: usize = 1000;
256        const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
257        const SIDE_BAND_HEADER_LEN: usize = 5; // 4-byte length + 1-byte band
258
259        let request_bytes = bytes::Bytes::from(request_data.to_vec());
260        let (pack_stream, protocol_buf) =
261            self.smart_protocol.git_upload_pack(request_bytes).await?;
262        let ack_bytes = protocol_buf.freeze();
263
264        let ack_stream: ProtocolStream = if ack_bytes.is_empty() {
265            Box::pin(futures::stream::empty::<Result<Bytes, ProtocolError>>())
266        } else {
267            Box::pin(futures::stream::once(async move { Ok(ack_bytes) }))
268        };
269
270        let sideband_max = if self
271            .smart_protocol
272            .capabilities
273            .contains(&Capability::SideBand64k)
274        {
275            Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
276        } else if self
277            .smart_protocol
278            .capabilities
279            .contains(&Capability::SideBand)
280        {
281            Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
282        } else {
283            None
284        };
285
286        let data_stream: ProtocolStream = if let Some(max_payload) = sideband_max {
287            let stream = pack_stream.flat_map(move |chunk| {
288                let packets = build_side_band_packets(&chunk, max_payload);
289                futures::stream::iter(packets.into_iter().map(Ok))
290            });
291            let stream = stream.chain(futures::stream::once(async {
292                Ok(Bytes::from_static(b"0000"))
293            }));
294            Box::pin(stream)
295        } else {
296            Box::pin(pack_stream.map(|data| Ok(Bytes::from(data))))
297        };
298
299        Ok(Box::pin(ack_stream.chain(data_stream)))
300    }
301
302    /// Handle git-receive-pack request (for push)
303    pub async fn receive_pack(
304        &mut self,
305        request_stream: ProtocolStream,
306    ) -> Result<ProtocolStream, ProtocolError> {
307        const SIDE_BAND_PACKET_LEN: usize = 1000;
308        const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
309        const SIDE_BAND_HEADER_LEN: usize = 5; // 4-byte length + 1-byte band
310
311        let result_bytes = self
312            .smart_protocol
313            .git_receive_pack_stream(request_stream)
314            .await?;
315
316        let sideband_max = if self
317            .smart_protocol
318            .capabilities
319            .contains(&Capability::SideBand64k)
320        {
321            Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
322        } else if self
323            .smart_protocol
324            .capabilities
325            .contains(&Capability::SideBand)
326        {
327            Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
328        } else {
329            None
330        };
331
332        // Wrap report-status in side-band if negotiated by the client.
333        if let Some(max_payload) = sideband_max {
334            let packets = build_side_band_packets(result_bytes.as_ref(), max_payload);
335            let stream = futures::stream::iter(packets.into_iter().map(Ok)).chain(
336                futures::stream::once(async { Ok(Bytes::from_static(b"0000")) }),
337            );
338            Ok(Box::pin(stream))
339        } else {
340            // Return the report status as a single-chunk stream
341            Ok(Box::pin(futures::stream::once(async { Ok(result_bytes) })))
342        }
343    }
344}
345
346fn build_side_band_packets(chunk: &[u8], max_payload: usize) -> Vec<Bytes> {
347    if chunk.is_empty() {
348        return Vec::new();
349    }
350
351    let mut out = Vec::new();
352    let mut offset = 0;
353
354    while offset < chunk.len() {
355        let end = (offset + max_payload).min(chunk.len());
356        let payload = &chunk[offset..end];
357        let length = payload.len() + 5; // 4-byte length + 1-byte band
358        let mut pkt = BytesMut::with_capacity(length);
359        pkt.put(Bytes::from(format!("{length:04x}")));
360        pkt.put_u8(SideBand::PackfileData.value());
361        pkt.put(payload);
362        out.push(pkt.freeze());
363        offset = end;
364    }
365
366    out
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::hash::{HashKind, set_hash_kind_for_test};
373    use crate::internal::object::{
374        blob::Blob,
375        commit::Commit,
376        signature::{Signature, SignatureType},
377        tree::{Tree, TreeItem, TreeItemMode},
378    };
379    use crate::protocol::types::TransportProtocol;
380    use crate::protocol::utils;
381    use async_trait::async_trait;
382    use bytes::{Bytes, BytesMut};
383    use futures::StreamExt;
384
385    /// Simple mock repository that serves fixed refs and echoes wants.
386    #[derive(Clone)]
387    struct MockRepo {
388        refs: Vec<(String, String)>,
389    }
390
391    #[async_trait]
392    impl RepositoryAccess for MockRepo {
393        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
394            Ok(self.refs.clone())
395        }
396        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
397            Ok(false)
398        }
399        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
400            Ok(Vec::new())
401        }
402        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
403            Ok(())
404        }
405        async fn update_reference(
406            &self,
407            _ref_name: &str,
408            _old_hash: Option<&str>,
409            _new_hash: &str,
410        ) -> Result<(), ProtocolError> {
411            Ok(())
412        }
413        async fn get_objects_for_pack(
414            &self,
415            wants: &[String],
416            _haves: &[String],
417        ) -> Result<Vec<String>, ProtocolError> {
418            Ok(wants.to_vec())
419        }
420        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
421            Ok(false)
422        }
423        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
424            Ok(())
425        }
426    }
427
428    /// No-op auth service for tests.
429    struct MockAuth;
430    #[async_trait]
431    impl AuthenticationService for MockAuth {
432        async fn authenticate_http(
433            &self,
434            _headers: &std::collections::HashMap<String, String>,
435        ) -> Result<(), ProtocolError> {
436            Ok(())
437        }
438        async fn authenticate_ssh(
439            &self,
440            _username: &str,
441            _public_key: &[u8],
442        ) -> Result<(), ProtocolError> {
443            Ok(())
444        }
445    }
446
447    /// Convenience builder for GitProtocol with mock repo/auth.
448    fn make_protocol() -> GitProtocol<MockRepo, MockAuth> {
449        GitProtocol::new(
450            MockRepo {
451                refs: vec![
452                    (
453                        "refs/heads/main".to_string(),
454                        ObjectHash::default().to_string(),
455                    ),
456                    ("HEAD".to_string(), ObjectHash::default().to_string()),
457                ],
458            },
459            MockAuth,
460        )
461    }
462
463    /// Mock repo that serves a single commit, tree, and blobs.
464    #[derive(Clone)]
465    struct SideBandRepo {
466        commit: Commit,
467        tree: Tree,
468        blobs: Vec<Blob>,
469    }
470    #[async_trait]
471    impl RepositoryAccess for SideBandRepo {
472        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
473            Ok(vec![(
474                "refs/heads/main".to_string(),
475                self.commit.id.to_string(),
476            )])
477        }
478
479        async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError> {
480            let known = object_hash == self.commit.id.to_string()
481                || object_hash == self.tree.id.to_string()
482                || self.blobs.iter().any(|b| b.id.to_string() == object_hash);
483            Ok(known)
484        }
485
486        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
487            Ok(Vec::new())
488        }
489
490        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
491            Ok(())
492        }
493
494        async fn update_reference(
495            &self,
496            _ref_name: &str,
497            _old_hash: Option<&str>,
498            _new_hash: &str,
499        ) -> Result<(), ProtocolError> {
500            Ok(())
501        }
502
503        async fn get_objects_for_pack(
504            &self,
505            _wants: &[String],
506            _haves: &[String],
507        ) -> Result<Vec<String>, ProtocolError> {
508            Ok(Vec::new())
509        }
510
511        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
512            Ok(true)
513        }
514
515        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
516            Ok(())
517        }
518
519        async fn get_commit(&self, commit_hash: &str) -> Result<Commit, ProtocolError> {
520            if commit_hash == self.commit.id.to_string() {
521                Ok(self.commit.clone())
522            } else {
523                Err(ProtocolError::ObjectNotFound(commit_hash.to_string()))
524            }
525        }
526
527        async fn get_tree(&self, tree_hash: &str) -> Result<Tree, ProtocolError> {
528            if tree_hash == self.tree.id.to_string() {
529                Ok(self.tree.clone())
530            } else {
531                Err(ProtocolError::ObjectNotFound(tree_hash.to_string()))
532            }
533        }
534
535        async fn get_blob(&self, blob_hash: &str) -> Result<Blob, ProtocolError> {
536            self.blobs
537                .iter()
538                .find(|b| b.id.to_string() == blob_hash)
539                .cloned()
540                .ok_or_else(|| ProtocolError::ObjectNotFound(blob_hash.to_string()))
541        }
542    }
543
544    fn build_repo_with_objects() -> (SideBandRepo, Commit) {
545        let blob = Blob::from_content("hello");
546        let item = TreeItem::new(TreeItemMode::Blob, blob.id, "hello.txt".to_string());
547        let tree = Tree::from_tree_items(vec![item]).unwrap();
548        let author = Signature::new(
549            SignatureType::Author,
550            "tester".to_string(),
551            "tester@example.com".to_string(),
552        );
553        let committer = Signature::new(
554            SignatureType::Committer,
555            "tester".to_string(),
556            "tester@example.com".to_string(),
557        );
558        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
559
560        let repo = SideBandRepo {
561            commit: commit.clone(),
562            tree,
563            blobs: vec![blob],
564        };
565
566        (repo, commit)
567    }
568
569    /// upload-pack should emit NAK before sending pack data.
570    #[tokio::test]
571    async fn upload_pack_emits_ack_before_pack() {
572        let _guard = set_hash_kind_for_test(HashKind::Sha1);
573        let (repo, commit) = build_repo_with_objects();
574        let mut proto = GitProtocol::new(repo, MockAuth);
575        let mut request = BytesMut::new();
576        utils::add_pkt_line_string(&mut request, format!("want {}\n", commit.id));
577        utils::add_pkt_line_string(&mut request, "done\n".to_string());
578
579        let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
580        let mut out = BytesMut::new();
581        while let Some(chunk) = stream.next().await {
582            out.extend_from_slice(&chunk.expect("stream chunk"));
583        }
584
585        let mut out_bytes = out.freeze();
586        let (_len, line) = utils::read_pkt_line(&mut out_bytes);
587        assert_eq!(line, Bytes::from_static(b"NAK\n"));
588        assert!(
589            out_bytes.as_ref().starts_with(b"PACK"),
590            "pack should follow ack"
591        );
592    }
593
594    /// upload-pack with side-band should wrap pack data in side-band packets.
595    #[tokio::test]
596    async fn upload_pack_sideband_frames_pack() {
597        let _guard = set_hash_kind_for_test(HashKind::Sha1);
598        let (repo, commit) = build_repo_with_objects();
599
600        let mut proto = GitProtocol::new(repo, MockAuth);
601        let mut request = BytesMut::new();
602        utils::add_pkt_line_string(&mut request, format!("want {} side-band-64k\n", commit.id));
603        utils::add_pkt_line_string(&mut request, "done\n".to_string());
604
605        let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
606        let mut out = BytesMut::new();
607        while let Some(chunk) = stream.next().await {
608            out.extend_from_slice(&chunk.expect("stream chunk"));
609        }
610
611        let mut out_bytes = out.freeze();
612        let (_len, line) = utils::read_pkt_line(&mut out_bytes);
613        assert_eq!(line, Bytes::from_static(b"NAK\n"));
614
615        let raw = out_bytes.as_ref();
616        assert!(raw.len() > 9, "side-band packet should include PACK header");
617        let len_hex = std::str::from_utf8(&raw[..4]).expect("hex length");
618        let pkt_len = usize::from_str_radix(len_hex, 16).expect("parse length");
619        assert!(pkt_len > 5, "side-band packet should contain data");
620        assert_eq!(raw[4], SideBand::PackfileData.value());
621        assert_eq!(&raw[5..9], b"PACK");
622        assert!(raw.ends_with(b"0000"), "side-band stream should flush");
623    }
624
625    /// info_refs should include refs, capabilities, and object-format.
626    #[tokio::test]
627    async fn info_refs_includes_refs_and_caps() {
628        let proto = make_protocol();
629        let bytes = proto.info_refs("git-upload-pack").await.expect("info_refs");
630        let text = String::from_utf8(bytes).expect("utf8");
631        assert!(text.contains("refs/heads/main"));
632        assert!(text.contains("capabilities"));
633        assert!(text.contains("object-format"));
634    }
635
636    /// Invalid service name should return InvalidService.
637    #[tokio::test]
638    async fn info_refs_invalid_service_errors() {
639        let proto = make_protocol();
640        let err = proto.info_refs("git-invalid").await.unwrap_err();
641        assert!(matches!(err, ProtocolError::InvalidService(_)));
642    }
643
644    /// Ensure set_transport can switch protocols without panic.
645    #[tokio::test]
646    async fn can_switch_transport() {
647        let mut proto = make_protocol();
648        proto.set_transport(TransportProtocol::Ssh);
649        // if set_transport did not panic, we consider this path covered
650    }
651
652    /// Wire hash kind expects SHA1 length; providing SHA256 refs should error.
653    #[tokio::test]
654    async fn info_refs_hash_length_mismatch_errors() {
655        let proto = GitProtocol::new(
656            MockRepo {
657                refs: vec![(
658                    "refs/heads/main".to_string(),
659                    "f".repeat(HashKind::Sha256.hex_len()),
660                )],
661            },
662            MockAuth,
663        );
664        let err = proto.info_refs("git-upload-pack").await.unwrap_err();
665        assert!(matches!(err, ProtocolError::InvalidRequest(_)));
666    }
667}