git_internal/protocol/
smart.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use std::collections::HashMap;
3use tokio_stream::wrappers::ReceiverStream;
4
5use super::core::{AuthenticationService, RepositoryAccess};
6use super::pack::PackGenerator;
7use super::types::ProtocolError;
8use super::types::{
9    COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolStream, RECEIVE_CAP_LIST,
10    RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol, UPLOAD_CAP_LIST,
11    ZERO_ID,
12};
13use super::utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space};
14
15/// Smart Git Protocol implementation
16///
17/// This struct handles the Git smart protocol operations for both HTTP and SSH transports.
18/// It uses trait abstractions to decouple from specific business logic implementations.
19pub struct SmartProtocol<R, A>
20where
21    R: RepositoryAccess,
22    A: AuthenticationService,
23{
24    pub transport_protocol: TransportProtocol,
25    pub capabilities: Vec<Capability>,
26    pub side_band: Option<SideBand>,
27    pub command_list: Vec<RefCommand>,
28
29    // Trait-based dependencies
30    repo_storage: R,
31    auth_service: A,
32}
33
34impl<R, A> SmartProtocol<R, A>
35where
36    R: RepositoryAccess,
37    A: AuthenticationService,
38{
39    /// Create a new SmartProtocol instance
40    pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
41        Self {
42            transport_protocol,
43            capabilities: Vec::new(),
44            side_band: None,
45            command_list: Vec::new(),
46            repo_storage,
47            auth_service,
48        }
49    }
50
51    /// Authenticate an HTTP request using the injected auth service
52    pub async fn authenticate_http(
53        &self,
54        headers: &HashMap<String, String>,
55    ) -> Result<(), ProtocolError> {
56        self.auth_service.authenticate_http(headers).await
57    }
58
59    /// Authenticate an SSH session using username and public key
60    pub async fn authenticate_ssh(
61        &self,
62        username: &str,
63        public_key: &[u8],
64    ) -> Result<(), ProtocolError> {
65        self.auth_service
66            .authenticate_ssh(username, public_key)
67            .await
68    }
69
70    /// Set transport protocol (Http, Ssh, etc.)
71    pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
72        self.transport_protocol = protocol;
73    }
74
75    /// Get git info refs for the repository, with explicit service type
76    pub async fn git_info_refs(
77        &self,
78        service_type: ServiceType,
79    ) -> Result<BytesMut, ProtocolError> {
80        let refs =
81            self.repo_storage.get_repository_refs().await.map_err(|e| {
82                ProtocolError::repository_error(format!("Failed to get refs: {}", e))
83            })?;
84
85        // Convert to the expected format (head_hash, git_refs)
86        let head_hash = refs
87            .iter()
88            .find(|(name, _)| {
89                name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
90            })
91            .map(|(_, hash)| hash.clone())
92            .unwrap_or_else(|| "0000000000000000000000000000000000000000".to_string());
93
94        let git_refs: Vec<super::types::GitRef> = refs
95            .into_iter()
96            .map(|(name, hash)| super::types::GitRef { name, hash })
97            .collect();
98
99        // Determine capabilities based on service type
100        let cap_list = match service_type {
101            ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}"),
102            ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}"),
103        };
104
105        // The stream MUST include capability declarations behind a NUL on the first ref.
106        let name = if head_hash == ZERO_ID {
107            "capabilities^{}"
108        } else {
109            "HEAD"
110        };
111        let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
112        let mut ref_list = vec![pkt_line];
113
114        for git_ref in git_refs {
115            let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
116            ref_list.push(pkt_line);
117        }
118
119        let pkt_line_stream =
120            build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
121        tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
122        Ok(pkt_line_stream)
123    }
124
125    /// Handle git-upload-pack request
126    pub async fn git_upload_pack(
127        &mut self,
128        upload_request: Bytes,
129    ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
130        let mut upload_request = upload_request;
131        let mut want: Vec<String> = Vec::new();
132        let mut have: Vec<String> = Vec::new();
133        let mut last_common_commit = String::new();
134
135        let mut read_first_line = false;
136        loop {
137            let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
138
139            if bytes_take == 0 {
140                break;
141            }
142
143            if pkt_line.is_empty() {
144                break;
145            }
146
147            let mut pkt_line = pkt_line;
148            let command = read_until_white_space(&mut pkt_line);
149
150            match command.as_str() {
151                "want" => {
152                    let hash = read_until_white_space(&mut pkt_line);
153                    want.push(hash);
154                    if !read_first_line {
155                        let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
156                        self.parse_capabilities(&cap_str);
157                        read_first_line = true;
158                    }
159                }
160                "have" => {
161                    let hash = read_until_white_space(&mut pkt_line);
162                    have.push(hash);
163                }
164                "done" => {
165                    break;
166                }
167                _ => {
168                    tracing::warn!("Unknown upload-pack command: {}", command);
169                }
170            }
171        }
172
173        let mut protocol_buf = BytesMut::new();
174
175        // Create pack generator for this operation
176        let pack_generator = PackGenerator::new(&self.repo_storage);
177
178        if have.is_empty() {
179            // Full pack
180            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
181            let pack_stream = pack_generator.generate_full_pack(want).await?;
182            return Ok((pack_stream, protocol_buf));
183        }
184
185        // Check for common commits
186        for hash in &have {
187            let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
188                ProtocolError::repository_error(format!("Failed to check commit existence: {}", e))
189            })?;
190            if exists {
191                add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
192                if last_common_commit.is_empty() {
193                    last_common_commit = hash.clone();
194                }
195            }
196        }
197
198        if last_common_commit.is_empty() {
199            // No common commits found
200            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
201            let pack_stream = pack_generator.generate_full_pack(want).await?;
202            return Ok((pack_stream, protocol_buf));
203        }
204
205        // Generate incremental pack
206        add_pkt_line_string(
207            &mut protocol_buf,
208            format!("ACK {last_common_commit} ready\n"),
209        );
210        protocol_buf.put(&PKT_LINE_END_MARKER[..]);
211
212        add_pkt_line_string(&mut protocol_buf, format!("ACK {last_common_commit} \n"));
213
214        let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
215
216        Ok((pack_stream, protocol_buf))
217    }
218
219    /// Parse receive pack commands from protocol bytes
220    pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
221        loop {
222            let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
223
224            if bytes_take == 0 {
225                break;
226            }
227
228            if pkt_line.is_empty() {
229                break;
230            }
231
232            let ref_command = self.parse_ref_command(&mut pkt_line.clone());
233            self.command_list.push(ref_command);
234        }
235    }
236
237    /// Handle git receive-pack operation (push)
238    pub async fn git_receive_pack_stream(
239        &mut self,
240        data_stream: ProtocolStream,
241    ) -> Result<Bytes, ProtocolError> {
242        // Collect all pack data from stream
243        let mut pack_data = BytesMut::new();
244        let mut stream = data_stream;
245
246        while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
247            let chunk = chunk_result
248                .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {}", e)))?;
249            pack_data.extend_from_slice(&chunk);
250        }
251
252        // Create pack generator for unpacking
253        let pack_generator = PackGenerator::new(&self.repo_storage);
254
255        // Unpack the received data
256        let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data.freeze()).await?;
257
258        // Store the unpacked objects via the repository access trait
259        self.repo_storage
260            .handle_pack_objects(commits, trees, blobs)
261            .await
262            .map_err(|e| {
263                ProtocolError::repository_error(format!("Failed to store pack objects: {}", e))
264            })?;
265
266        // Build status report
267        let mut report_status = BytesMut::new();
268        add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
269
270        let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
271            ProtocolError::repository_error(format!("Failed to check default branch: {}", e))
272        })?;
273
274        // Update refs with proper error handling
275        for command in &mut self.command_list {
276            if command.ref_type == RefTypeEnum::Tag {
277                // Just update if refs type is tag
278                // Convert ZERO_ID to None for old hash
279                let old_hash = if command.old_hash == ZERO_ID {
280                    None
281                } else {
282                    Some(command.old_hash.as_str())
283                };
284                if let Err(e) = self
285                    .repo_storage
286                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
287                    .await
288                {
289                    command.failed(e.to_string());
290                }
291            } else {
292                // Handle default branch setting for the first branch
293                if !default_exist {
294                    command.default_branch = true;
295                    default_exist = true;
296                }
297                // Convert ZERO_ID to None for old hash
298                let old_hash = if command.old_hash == ZERO_ID {
299                    None
300                } else {
301                    Some(command.old_hash.as_str())
302                };
303                if let Err(e) = self
304                    .repo_storage
305                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
306                    .await
307                {
308                    command.failed(e.to_string());
309                }
310            }
311            add_pkt_line_string(&mut report_status, command.get_status());
312        }
313
314        // Post-receive hook
315        self.repo_storage.post_receive_hook().await.map_err(|e| {
316            ProtocolError::repository_error(format!("Post-receive hook failed: {}", e))
317        })?;
318
319        report_status.put(&PKT_LINE_END_MARKER[..]);
320        Ok(report_status.freeze())
321    }
322
323    /// Builds the packet data in the sideband format if the SideBand/64k capability is enabled.
324    pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
325        let mut to_bytes = BytesMut::new();
326        if self.capabilities.contains(&Capability::SideBand)
327            || self.capabilities.contains(&Capability::SideBand64k)
328        {
329            let length = length + 5;
330            to_bytes.put(Bytes::from(format!("{length:04x}")));
331            to_bytes.put_u8(SideBand::PackfileData.value());
332            to_bytes.put(from_bytes);
333        } else {
334            to_bytes.put(from_bytes);
335        }
336        to_bytes
337    }
338
339    /// Parse capabilities from capability string
340    pub fn parse_capabilities(&mut self, cap_str: &str) {
341        for cap in cap_str.split_whitespace() {
342            if let Ok(capability) = cap.parse::<Capability>() {
343                self.capabilities.push(capability);
344            }
345        }
346    }
347
348    /// Parse a reference command from packet line
349    pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
350        let old_id = read_until_white_space(pkt_line);
351        let new_id = read_until_white_space(pkt_line);
352        let ref_name = read_until_white_space(pkt_line);
353        let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
354
355        RefCommand::new(old_id, new_id, ref_name)
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use crate::internal::metadata::{EntryMeta, MetaAttached};
363    use crate::internal::object::blob::Blob;
364    use crate::internal::object::commit::Commit;
365    use crate::internal::object::signature::{Signature, SignatureType};
366    use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
367    use crate::internal::pack::{encode::PackEncoder, entry::Entry};
368    use crate::protocol::types::{RefCommand, ZERO_ID}; // import sibling types
369    use crate::protocol::utils; // import sibling module
370    use async_trait::async_trait;
371    use bytes::Bytes;
372    use futures;
373    use std::sync::{
374        Arc, Mutex,
375        atomic::{AtomicBool, Ordering},
376    };
377    use tokio::sync::mpsc;
378
379    // Simplify complex type via aliases to satisfy clippy::type_complexity
380    type UpdateRecord = (String, Option<String>, String);
381    type UpdateList = Vec<UpdateRecord>;
382    type SharedUpdates = Arc<Mutex<UpdateList>>;
383
384    #[derive(Clone)]
385    struct TestRepoAccess {
386        updates: SharedUpdates,
387        stored_count: Arc<Mutex<usize>>,
388        default_branch_exists: Arc<Mutex<bool>>,
389        post_called: Arc<AtomicBool>,
390    }
391
392    impl TestRepoAccess {
393        fn new() -> Self {
394            Self {
395                updates: Arc::new(Mutex::new(vec![])),
396                stored_count: Arc::new(Mutex::new(0)),
397                default_branch_exists: Arc::new(Mutex::new(false)),
398                post_called: Arc::new(AtomicBool::new(false)),
399            }
400        }
401
402        fn updates_len(&self) -> usize {
403            self.updates.lock().unwrap().len()
404        }
405
406        fn post_hook_called(&self) -> bool {
407            self.post_called.load(Ordering::SeqCst)
408        }
409    }
410
411    #[async_trait]
412    impl RepositoryAccess for TestRepoAccess {
413        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
414            Ok(vec![
415                (
416                    "HEAD".to_string(),
417                    "0000000000000000000000000000000000000000".to_string(),
418                ),
419                (
420                    "refs/heads/main".to_string(),
421                    "1111111111111111111111111111111111111111".to_string(),
422                ),
423            ])
424        }
425
426        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
427            Ok(true)
428        }
429
430        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
431            Ok(vec![])
432        }
433
434        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
435            *self.stored_count.lock().unwrap() += 1;
436            Ok(())
437        }
438
439        async fn update_reference(
440            &self,
441            ref_name: &str,
442            old_hash: Option<&str>,
443            new_hash: &str,
444        ) -> Result<(), ProtocolError> {
445            self.updates.lock().unwrap().push((
446                ref_name.to_string(),
447                old_hash.map(|s| s.to_string()),
448                new_hash.to_string(),
449            ));
450            Ok(())
451        }
452
453        async fn get_objects_for_pack(
454            &self,
455            _wants: &[String],
456            _haves: &[String],
457        ) -> Result<Vec<String>, ProtocolError> {
458            Ok(vec![])
459        }
460
461        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
462            let mut exists = self.default_branch_exists.lock().unwrap();
463            let current = *exists;
464            *exists = true; // flip to true after first check
465            Ok(current)
466        }
467
468        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
469            self.post_called.store(true, Ordering::SeqCst);
470            Ok(())
471        }
472    }
473
474    struct TestAuth;
475
476    #[async_trait]
477    impl AuthenticationService for TestAuth {
478        async fn authenticate_http(
479            &self,
480            _headers: &std::collections::HashMap<String, String>,
481        ) -> Result<(), ProtocolError> {
482            Ok(())
483        }
484
485        async fn authenticate_ssh(
486            &self,
487            _username: &str,
488            _public_key: &[u8],
489        ) -> Result<(), ProtocolError> {
490            Ok(())
491        }
492    }
493
494    #[tokio::test]
495    async fn test_receive_pack_stream_status_report() {
496        // Build simple objects
497        let blob1 = Blob::from_content("hello");
498        let blob2 = Blob::from_content("world");
499
500        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
501        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
502        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
503
504        let author = Signature::new(
505            SignatureType::Author,
506            "tester".to_string(),
507            "tester@example.com".to_string(),
508        );
509        let committer = Signature::new(
510            SignatureType::Committer,
511            "tester".to_string(),
512            "tester@example.com".to_string(),
513        );
514        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
515
516        // Encode pack bytes via PackEncoder
517        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
518        let (entry_tx, entry_rx) = mpsc::channel(1024);
519        let mut encoder = PackEncoder::new(4, 10, pack_tx);
520
521        tokio::spawn(async move {
522            if let Err(e) = encoder.encode(entry_rx).await {
523                panic!("Failed to encode pack: {}", e);
524            }
525        });
526
527        let commit_clone = commit.clone();
528        let tree_clone = tree.clone();
529        let blob1_clone = blob1.clone();
530        let blob2_clone = blob2.clone();
531        tokio::spawn(async move {
532            let _ = entry_tx
533                .send(MetaAttached {
534                    inner: Entry::from(commit_clone),
535                    meta: EntryMeta::new(),
536                })
537                .await;
538            let _ = entry_tx
539                .send(MetaAttached {
540                    inner: Entry::from(tree_clone),
541                    meta: EntryMeta::new(),
542                })
543                .await;
544            let _ = entry_tx
545                .send(MetaAttached {
546                    inner: Entry::from(blob1_clone),
547                    meta: EntryMeta::new(),
548                })
549                .await;
550            let _ = entry_tx
551                .send(MetaAttached {
552                    inner: Entry::from(blob2_clone),
553                    meta: EntryMeta::new(),
554                })
555                .await;
556            // sender drop indicates end
557        });
558
559        let mut pack_bytes: Vec<u8> = Vec::new();
560        while let Some(chunk) = pack_rx.recv().await {
561            pack_bytes.extend_from_slice(&chunk);
562        }
563
564        // Prepare protocol and command
565        let repo_access = TestRepoAccess::new();
566        let auth = TestAuth;
567        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
568        smart.command_list.push(RefCommand::new(
569            ZERO_ID.to_string(),
570            commit.id.to_string(),
571            "refs/heads/main".to_string(),
572        ));
573
574        // Create request stream
575        let request_stream = Box::pin(futures::stream::once(async { Ok(Bytes::from(pack_bytes)) }));
576
577        // Execute receive-pack
578        let result_bytes = smart
579            .git_receive_pack_stream(request_stream)
580            .await
581            .expect("receive-pack should succeed");
582
583        // Verify pkt-lines
584        let mut out = result_bytes.clone();
585        let (_c1, l1) = utils::read_pkt_line(&mut out);
586        assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
587
588        let (_c2, l2) = utils::read_pkt_line(&mut out);
589        assert_eq!(
590            String::from_utf8(l2.to_vec()).unwrap(),
591            "ok refs/heads/main"
592        );
593
594        let (c3, l3) = utils::read_pkt_line(&mut out);
595        assert_eq!(c3, 4);
596        assert!(l3.is_empty());
597
598        // Verify side effects
599        assert_eq!(repo_access.updates_len(), 1);
600        assert!(repo_access.post_hook_called());
601    }
602}