Skip to main content

git_internal/protocol/
core.rs

1//! Core Git protocol implementation
2//!
3//! This module provides the main `GitProtocol` struct and `RepositoryAccess` trait
4//! that form the core interface of the git-internal library.
5use std::{collections::HashMap, str::FromStr};
6
7use async_trait::async_trait;
8use bytes::{BufMut, Bytes, BytesMut};
9use futures::stream::StreamExt;
10
11use crate::{
12    hash::ObjectHash,
13    internal::object::ObjectTrait,
14    protocol::{
15        smart::SmartProtocol,
16        types::{Capability, ProtocolError, ProtocolStream, ServiceType, SideBand},
17    },
18};
19
20/// Repository access trait for storage operations
21///
22/// This trait only handles storage-level operations, not Git protocol details.
23/// The git-internal library handles all Git protocol formatting and parsing.
24#[async_trait]
25pub trait RepositoryAccess: Send + Sync + Clone {
26    /// Get repository references as raw (name, hash) pairs
27    async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError>;
28
29    /// Check if an object exists in the repository
30    async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError>;
31
32    /// Get raw object data by hash
33    async fn get_object(&self, object_hash: &str) -> Result<Vec<u8>, ProtocolError>;
34
35    /// Store pack data in the repository
36    async fn store_pack_data(&self, pack_data: &[u8]) -> Result<(), ProtocolError>;
37
38    /// Update a single reference
39    async fn update_reference(
40        &self,
41        ref_name: &str,
42        old_hash: Option<&str>,
43        new_hash: &str,
44    ) -> Result<(), ProtocolError>;
45
46    /// Get objects needed for pack generation
47    async fn get_objects_for_pack(
48        &self,
49        wants: &[String],
50        haves: &[String],
51    ) -> Result<Vec<String>, ProtocolError>;
52
53    /// Check if repository has a default branch
54    async fn has_default_branch(&self) -> Result<bool, ProtocolError>;
55
56    /// Post-receive hook after successful push
57    async fn post_receive_hook(&self) -> Result<(), ProtocolError>;
58
59    /// Get blob data by hash
60    ///
61    /// Default implementation parses the object data using the internal object module.
62    /// Override this method if you need custom blob handling logic.
63    async fn get_blob(
64        &self,
65        object_hash: &str,
66    ) -> Result<crate::internal::object::blob::Blob, ProtocolError> {
67        let data = self.get_object(object_hash).await?;
68        let hash = ObjectHash::from_str(object_hash)
69            .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
70
71        crate::internal::object::blob::Blob::from_bytes(&data, hash)
72            .map_err(|e| ProtocolError::repository_error(format!("Failed to parse blob: {e}")))
73    }
74
75    /// Get commit data by hash
76    ///
77    /// Default implementation parses the object data using the internal object module.
78    /// Override this method if you need custom commit handling logic.
79    async fn get_commit(
80        &self,
81        commit_hash: &str,
82    ) -> Result<crate::internal::object::commit::Commit, ProtocolError> {
83        let data = self.get_object(commit_hash).await?;
84        let hash = ObjectHash::from_str(commit_hash)
85            .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
86
87        crate::internal::object::commit::Commit::from_bytes(&data, hash)
88            .map_err(|e| ProtocolError::repository_error(format!("Failed to parse commit: {e}")))
89    }
90
91    /// Get tree data by hash
92    ///
93    /// Default implementation parses the object data using the internal object module.
94    /// Override this method if you need custom tree handling logic.
95    async fn get_tree(
96        &self,
97        tree_hash: &str,
98    ) -> Result<crate::internal::object::tree::Tree, ProtocolError> {
99        let data = self.get_object(tree_hash).await?;
100        let hash = ObjectHash::from_str(tree_hash)
101            .map_err(|e| ProtocolError::repository_error(format!("Invalid hash format: {e}")))?;
102
103        crate::internal::object::tree::Tree::from_bytes(&data, hash)
104            .map_err(|e| ProtocolError::repository_error(format!("Failed to parse tree: {e}")))
105    }
106
107    /// Check if a commit exists
108    ///
109    /// Default implementation checks object existence and validates it's a commit.
110    /// Override this method if you have more efficient commit existence checking.
111    async fn commit_exists(&self, commit_hash: &str) -> Result<bool, ProtocolError> {
112        match self.has_object(commit_hash).await {
113            Ok(exists) => {
114                if !exists {
115                    return Ok(false);
116                }
117
118                // Verify it's actually a commit by trying to parse it
119                match self.get_commit(commit_hash).await {
120                    Ok(_) => Ok(true),
121                    Err(_) => Ok(false), // Object exists but is not a valid commit
122                }
123            }
124            Err(e) => Err(e),
125        }
126    }
127
128    /// Handle pack objects after unpacking
129    ///
130    /// Default implementation stores each object individually using store_pack_data.
131    /// Override this method if you need batch processing or custom storage logic.
132    async fn handle_pack_objects(
133        &self,
134        commits: Vec<crate::internal::object::commit::Commit>,
135        trees: Vec<crate::internal::object::tree::Tree>,
136        blobs: Vec<crate::internal::object::blob::Blob>,
137    ) -> Result<(), ProtocolError> {
138        // Store blobs
139        for blob in blobs {
140            let data = blob.to_data().map_err(|e| {
141                ProtocolError::repository_error(format!("Failed to serialize blob: {e}"))
142            })?;
143            self.store_pack_data(&data).await.map_err(|e| {
144                ProtocolError::repository_error(format!("Failed to store blob {}: {}", blob.id, e))
145            })?;
146        }
147
148        // Store trees
149        for tree in trees {
150            let data = tree.to_data().map_err(|e| {
151                ProtocolError::repository_error(format!("Failed to serialize tree: {e}"))
152            })?;
153            self.store_pack_data(&data).await.map_err(|e| {
154                ProtocolError::repository_error(format!("Failed to store tree {}: {}", tree.id, e))
155            })?;
156        }
157
158        // Store commits
159        for commit in commits {
160            let data = commit.to_data().map_err(|e| {
161                ProtocolError::repository_error(format!("Failed to serialize commit: {e}"))
162            })?;
163            self.store_pack_data(&data).await.map_err(|e| {
164                ProtocolError::repository_error(format!(
165                    "Failed to store commit {}: {}",
166                    commit.id, e
167                ))
168            })?;
169        }
170
171        Ok(())
172    }
173}
174
175/// Authentication service trait
176#[async_trait]
177pub trait AuthenticationService: Send + Sync {
178    /// Authenticate HTTP request
179    async fn authenticate_http(
180        &self,
181        headers: &std::collections::HashMap<String, String>,
182    ) -> Result<(), ProtocolError>;
183
184    /// Authenticate SSH public key
185    async fn authenticate_ssh(
186        &self,
187        username: &str,
188        public_key: &[u8],
189    ) -> Result<(), ProtocolError>;
190}
191
192/// Transport-agnostic Git smart protocol handler
193/// Main Git protocol handler
194///
195/// This struct provides the core Git protocol implementation that works
196/// across HTTP, SSH, and other transports. It uses SmartProtocol internally
197/// to handle all Git protocol details.
198pub struct GitProtocol<R: RepositoryAccess, A: AuthenticationService> {
199    smart_protocol: SmartProtocol<R, A>,
200}
201
202impl<R: RepositoryAccess, A: AuthenticationService> GitProtocol<R, A> {
203    /// Create a new GitProtocol instance
204    pub fn new(repo_access: R, auth_service: A) -> Self {
205        Self {
206            smart_protocol: SmartProtocol::new(
207                super::types::TransportProtocol::Http,
208                repo_access,
209                auth_service,
210            ),
211        }
212    }
213
214    /// Authenticate HTTP request before serving Git operations
215    pub async fn authenticate_http(
216        &self,
217        headers: &HashMap<String, String>,
218    ) -> Result<(), ProtocolError> {
219        self.smart_protocol.authenticate_http(headers).await
220    }
221
222    /// Authenticate SSH session before serving Git operations
223    pub async fn authenticate_ssh(
224        &self,
225        username: &str,
226        public_key: &[u8],
227    ) -> Result<(), ProtocolError> {
228        self.smart_protocol
229            .authenticate_ssh(username, public_key)
230            .await
231    }
232
233    /// Set transport protocol (Http, Ssh, etc.)
234    pub fn set_transport(&mut self, protocol: super::types::TransportProtocol) {
235        self.smart_protocol.set_transport_protocol(protocol);
236    }
237
238    /// Handle git info-refs request
239    pub async fn info_refs(&self, service: &str) -> Result<Vec<u8>, ProtocolError> {
240        let service_type = match service {
241            "git-upload-pack" => ServiceType::UploadPack,
242            "git-receive-pack" => ServiceType::ReceivePack,
243            _ => return Err(ProtocolError::invalid_service(service)),
244        };
245
246        let bytes = self.smart_protocol.git_info_refs(service_type).await?;
247        Ok(bytes.to_vec())
248    }
249
250    /// Handle git-upload-pack request (for clone/fetch)
251    pub async fn upload_pack(
252        &mut self,
253        request_data: &[u8],
254    ) -> Result<ProtocolStream, ProtocolError> {
255        const SIDE_BAND_PACKET_LEN: usize = 1000;
256        const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
257        const SIDE_BAND_HEADER_LEN: usize = 5; // 4-byte length + 1-byte band
258
259        let request_bytes = bytes::Bytes::from(request_data.to_vec());
260        let (pack_stream, protocol_buf) =
261            self.smart_protocol.git_upload_pack(request_bytes).await?;
262        let ack_bytes = protocol_buf.freeze();
263
264        let ack_stream: ProtocolStream = if ack_bytes.is_empty() {
265            Box::pin(futures::stream::empty::<Result<Bytes, ProtocolError>>())
266        } else {
267            Box::pin(futures::stream::once(async move { Ok(ack_bytes) }))
268        };
269
270        let sideband_max = if self
271            .smart_protocol
272            .capabilities
273            .contains(&Capability::SideBand64k)
274        {
275            Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
276        } else if self
277            .smart_protocol
278            .capabilities
279            .contains(&Capability::SideBand)
280        {
281            Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
282        } else {
283            None
284        };
285
286        let data_stream: ProtocolStream = if let Some(max_payload) = sideband_max {
287            let stream = pack_stream.flat_map(move |chunk| {
288                let packets = build_side_band_packets(&chunk, max_payload);
289                futures::stream::iter(packets.into_iter().map(Ok))
290            });
291            let stream = stream.chain(futures::stream::once(async {
292                Ok(Bytes::from_static(b"0000"))
293            }));
294            Box::pin(stream)
295        } else {
296            Box::pin(pack_stream.map(|data| Ok(Bytes::from(data))))
297        };
298
299        Ok(Box::pin(ack_stream.chain(data_stream)))
300    }
301
302    /// Handle git-receive-pack request (for push)
303    pub async fn receive_pack(
304        &mut self,
305        request_stream: ProtocolStream,
306    ) -> Result<ProtocolStream, ProtocolError> {
307        const SIDE_BAND_PACKET_LEN: usize = 1000;
308        const SIDE_BAND_64K_PACKET_LEN: usize = 65520;
309        const SIDE_BAND_HEADER_LEN: usize = 5; // 4-byte length + 1-byte band
310
311        let result_bytes = self
312            .smart_protocol
313            .git_receive_pack_stream(request_stream)
314            .await?;
315
316        let sideband_max = if self
317            .smart_protocol
318            .capabilities
319            .contains(&Capability::SideBand64k)
320        {
321            Some(SIDE_BAND_64K_PACKET_LEN - SIDE_BAND_HEADER_LEN)
322        } else if self
323            .smart_protocol
324            .capabilities
325            .contains(&Capability::SideBand)
326        {
327            Some(SIDE_BAND_PACKET_LEN - SIDE_BAND_HEADER_LEN)
328        } else {
329            None
330        };
331
332        // Wrap report-status in side-band if negotiated by the client.
333        if let Some(max_payload) = sideband_max {
334            let packets = build_side_band_packets(result_bytes.as_ref(), max_payload);
335            let stream = futures::stream::iter(packets.into_iter().map(Ok)).chain(
336                futures::stream::once(async { Ok(Bytes::from_static(b"0000")) }),
337            );
338            Ok(Box::pin(stream))
339        } else {
340            // Return the report status as a single-chunk stream
341            Ok(Box::pin(futures::stream::once(async { Ok(result_bytes) })))
342        }
343    }
344}
345
346fn build_side_band_packets(chunk: &[u8], max_payload: usize) -> Vec<Bytes> {
347    if chunk.is_empty() {
348        return Vec::new();
349    }
350
351    let mut out = Vec::new();
352    let mut offset = 0;
353
354    while offset < chunk.len() {
355        let end = (offset + max_payload).min(chunk.len());
356        let payload = &chunk[offset..end];
357        let length = payload.len() + 5; // 4-byte length + 1-byte band
358        let mut pkt = BytesMut::with_capacity(length);
359        pkt.put(Bytes::from(format!("{length:04x}")));
360        pkt.put_u8(SideBand::PackfileData.value());
361        pkt.put(payload);
362        out.push(pkt.freeze());
363        offset = end;
364    }
365
366    out
367}
368
369#[cfg(test)]
370mod tests {
371    use async_trait::async_trait;
372    use bytes::{Bytes, BytesMut};
373    use futures::StreamExt;
374
375    use super::*;
376    use crate::{
377        hash::{HashKind, set_hash_kind_for_test},
378        internal::object::{
379            blob::Blob,
380            commit::Commit,
381            signature::{Signature, SignatureType},
382            tree::{Tree, TreeItem, TreeItemMode},
383        },
384        protocol::{types::TransportProtocol, utils},
385    };
386
387    /// Simple mock repository that serves fixed refs and echoes wants.
388    #[derive(Clone)]
389    struct MockRepo {
390        refs: Vec<(String, String)>,
391    }
392
393    #[async_trait]
394    impl RepositoryAccess for MockRepo {
395        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
396            Ok(self.refs.clone())
397        }
398        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
399            Ok(false)
400        }
401        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
402            Ok(Vec::new())
403        }
404        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
405            Ok(())
406        }
407        async fn update_reference(
408            &self,
409            _ref_name: &str,
410            _old_hash: Option<&str>,
411            _new_hash: &str,
412        ) -> Result<(), ProtocolError> {
413            Ok(())
414        }
415        async fn get_objects_for_pack(
416            &self,
417            wants: &[String],
418            _haves: &[String],
419        ) -> Result<Vec<String>, ProtocolError> {
420            Ok(wants.to_vec())
421        }
422        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
423            Ok(false)
424        }
425        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
426            Ok(())
427        }
428    }
429
430    /// No-op auth service for tests.
431    struct MockAuth;
432    #[async_trait]
433    impl AuthenticationService for MockAuth {
434        async fn authenticate_http(
435            &self,
436            _headers: &std::collections::HashMap<String, String>,
437        ) -> Result<(), ProtocolError> {
438            Ok(())
439        }
440        async fn authenticate_ssh(
441            &self,
442            _username: &str,
443            _public_key: &[u8],
444        ) -> Result<(), ProtocolError> {
445            Ok(())
446        }
447    }
448
449    /// Convenience builder for GitProtocol with mock repo/auth.
450    fn make_protocol() -> GitProtocol<MockRepo, MockAuth> {
451        GitProtocol::new(
452            MockRepo {
453                refs: vec![
454                    (
455                        "refs/heads/main".to_string(),
456                        ObjectHash::default().to_string(),
457                    ),
458                    ("HEAD".to_string(), ObjectHash::default().to_string()),
459                ],
460            },
461            MockAuth,
462        )
463    }
464
465    /// Mock repo that serves a single commit, tree, and blobs.
466    #[derive(Clone)]
467    struct SideBandRepo {
468        commit: Commit,
469        tree: Tree,
470        blobs: Vec<Blob>,
471    }
472    #[async_trait]
473    impl RepositoryAccess for SideBandRepo {
474        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
475            Ok(vec![(
476                "refs/heads/main".to_string(),
477                self.commit.id.to_string(),
478            )])
479        }
480
481        async fn has_object(&self, object_hash: &str) -> Result<bool, ProtocolError> {
482            let known = object_hash == self.commit.id.to_string()
483                || object_hash == self.tree.id.to_string()
484                || self.blobs.iter().any(|b| b.id.to_string() == object_hash);
485            Ok(known)
486        }
487
488        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
489            Ok(Vec::new())
490        }
491
492        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
493            Ok(())
494        }
495
496        async fn update_reference(
497            &self,
498            _ref_name: &str,
499            _old_hash: Option<&str>,
500            _new_hash: &str,
501        ) -> Result<(), ProtocolError> {
502            Ok(())
503        }
504
505        async fn get_objects_for_pack(
506            &self,
507            _wants: &[String],
508            _haves: &[String],
509        ) -> Result<Vec<String>, ProtocolError> {
510            Ok(Vec::new())
511        }
512
513        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
514            Ok(true)
515        }
516
517        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
518            Ok(())
519        }
520
521        async fn get_commit(&self, commit_hash: &str) -> Result<Commit, ProtocolError> {
522            if commit_hash == self.commit.id.to_string() {
523                Ok(self.commit.clone())
524            } else {
525                Err(ProtocolError::ObjectNotFound(commit_hash.to_string()))
526            }
527        }
528
529        async fn get_tree(&self, tree_hash: &str) -> Result<Tree, ProtocolError> {
530            if tree_hash == self.tree.id.to_string() {
531                Ok(self.tree.clone())
532            } else {
533                Err(ProtocolError::ObjectNotFound(tree_hash.to_string()))
534            }
535        }
536
537        async fn get_blob(&self, blob_hash: &str) -> Result<Blob, ProtocolError> {
538            self.blobs
539                .iter()
540                .find(|b| b.id.to_string() == blob_hash)
541                .cloned()
542                .ok_or_else(|| ProtocolError::ObjectNotFound(blob_hash.to_string()))
543        }
544    }
545
546    fn build_repo_with_objects() -> (SideBandRepo, Commit) {
547        let blob = Blob::from_content("hello");
548        let item = TreeItem::new(TreeItemMode::Blob, blob.id, "hello.txt".to_string());
549        let tree = Tree::from_tree_items(vec![item]).unwrap();
550        let author = Signature::new(
551            SignatureType::Author,
552            "tester".to_string(),
553            "tester@example.com".to_string(),
554        );
555        let committer = Signature::new(
556            SignatureType::Committer,
557            "tester".to_string(),
558            "tester@example.com".to_string(),
559        );
560        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
561
562        let repo = SideBandRepo {
563            commit: commit.clone(),
564            tree,
565            blobs: vec![blob],
566        };
567
568        (repo, commit)
569    }
570
571    /// upload-pack should emit NAK before sending pack data.
572    #[tokio::test]
573    async fn upload_pack_emits_ack_before_pack() {
574        let _guard = set_hash_kind_for_test(HashKind::Sha1);
575        let (repo, commit) = build_repo_with_objects();
576        let mut proto = GitProtocol::new(repo, MockAuth);
577        let mut request = BytesMut::new();
578        utils::add_pkt_line_string(&mut request, format!("want {}\n", commit.id));
579        utils::add_pkt_line_string(&mut request, "done\n".to_string());
580
581        let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
582        let mut out = BytesMut::new();
583        while let Some(chunk) = stream.next().await {
584            out.extend_from_slice(&chunk.expect("stream chunk"));
585        }
586
587        let mut out_bytes = out.freeze();
588        let (_len, line) = utils::read_pkt_line(&mut out_bytes);
589        assert_eq!(line, Bytes::from_static(b"NAK\n"));
590        assert!(
591            out_bytes.as_ref().starts_with(b"PACK"),
592            "pack should follow ack"
593        );
594    }
595
596    /// upload-pack with side-band should wrap pack data in side-band packets.
597    #[tokio::test]
598    async fn upload_pack_sideband_frames_pack() {
599        let _guard = set_hash_kind_for_test(HashKind::Sha1);
600        let (repo, commit) = build_repo_with_objects();
601
602        let mut proto = GitProtocol::new(repo, MockAuth);
603        let mut request = BytesMut::new();
604        utils::add_pkt_line_string(&mut request, format!("want {} side-band-64k\n", commit.id));
605        utils::add_pkt_line_string(&mut request, "done\n".to_string());
606
607        let mut stream = proto.upload_pack(&request).await.expect("upload-pack");
608        let mut out = BytesMut::new();
609        while let Some(chunk) = stream.next().await {
610            out.extend_from_slice(&chunk.expect("stream chunk"));
611        }
612
613        let mut out_bytes = out.freeze();
614        let (_len, line) = utils::read_pkt_line(&mut out_bytes);
615        assert_eq!(line, Bytes::from_static(b"NAK\n"));
616
617        let raw = out_bytes.as_ref();
618        assert!(raw.len() > 9, "side-band packet should include PACK header");
619        let len_hex = std::str::from_utf8(&raw[..4]).expect("hex length");
620        let pkt_len = usize::from_str_radix(len_hex, 16).expect("parse length");
621        assert!(pkt_len > 5, "side-band packet should contain data");
622        assert_eq!(raw[4], SideBand::PackfileData.value());
623        assert_eq!(&raw[5..9], b"PACK");
624        assert!(raw.ends_with(b"0000"), "side-band stream should flush");
625    }
626
627    /// info_refs should include refs, capabilities, and object-format.
628    #[tokio::test]
629    async fn info_refs_includes_refs_and_caps() {
630        let proto = make_protocol();
631        let bytes = proto.info_refs("git-upload-pack").await.expect("info_refs");
632        let text = String::from_utf8(bytes).expect("utf8");
633        assert!(text.contains("refs/heads/main"));
634        assert!(text.contains("capabilities"));
635        assert!(text.contains("object-format"));
636    }
637
638    /// Invalid service name should return InvalidService.
639    #[tokio::test]
640    async fn info_refs_invalid_service_errors() {
641        let proto = make_protocol();
642        let err = proto.info_refs("git-invalid").await.unwrap_err();
643        assert!(matches!(err, ProtocolError::InvalidService(_)));
644    }
645
646    /// Ensure set_transport can switch protocols without panic.
647    #[tokio::test]
648    async fn can_switch_transport() {
649        let mut proto = make_protocol();
650        proto.set_transport(TransportProtocol::Ssh);
651        // if set_transport did not panic, we consider this path covered
652    }
653
654    /// Wire hash kind expects SHA1 length; providing SHA256 refs should error.
655    #[tokio::test]
656    async fn info_refs_hash_length_mismatch_errors() {
657        let proto = GitProtocol::new(
658            MockRepo {
659                refs: vec![(
660                    "refs/heads/main".to_string(),
661                    "f".repeat(HashKind::Sha256.hex_len()),
662                )],
663            },
664            MockAuth,
665        );
666        let err = proto.info_refs("git-upload-pack").await.unwrap_err();
667        assert!(matches!(err, ProtocolError::InvalidRequest(_)));
668    }
669}