git_internal/protocol/
smart.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use std::collections::HashMap;
3use tokio_stream::wrappers::ReceiverStream;
4
5use super::core::{AuthenticationService, RepositoryAccess};
6use super::pack::PackGenerator;
7use super::types::ProtocolError;
8use super::types::{
9    COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolStream, RECEIVE_CAP_LIST,
10    RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol, UPLOAD_CAP_LIST,
11    ZERO_ID,
12};
13use super::utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space};
14
15/// Smart Git Protocol implementation
16///
17/// This struct handles the Git smart protocol operations for both HTTP and SSH transports.
18/// It uses trait abstractions to decouple from specific business logic implementations.
19pub struct SmartProtocol<R, A>
20where
21    R: RepositoryAccess,
22    A: AuthenticationService,
23{
24    pub transport_protocol: TransportProtocol,
25    pub capabilities: Vec<Capability>,
26    pub side_band: Option<SideBand>,
27    pub command_list: Vec<RefCommand>,
28
29    // Trait-based dependencies
30    repo_storage: R,
31    auth_service: A,
32}
33
34impl<R, A> SmartProtocol<R, A>
35where
36    R: RepositoryAccess,
37    A: AuthenticationService,
38{
39    /// Create a new SmartProtocol instance
40    pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
41        Self {
42            transport_protocol,
43            capabilities: Vec::new(),
44            side_band: None,
45            command_list: Vec::new(),
46            repo_storage,
47            auth_service,
48        }
49    }
50
51    /// Authenticate an HTTP request using the injected auth service
52    pub async fn authenticate_http(
53        &self,
54        headers: &HashMap<String, String>,
55    ) -> Result<(), ProtocolError> {
56        self.auth_service.authenticate_http(headers).await
57    }
58
59    /// Authenticate an SSH session using username and public key
60    pub async fn authenticate_ssh(
61        &self,
62        username: &str,
63        public_key: &[u8],
64    ) -> Result<(), ProtocolError> {
65        self.auth_service
66            .authenticate_ssh(username, public_key)
67            .await
68    }
69
70    /// Set transport protocol (Http, Ssh, etc.)
71    pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
72        self.transport_protocol = protocol;
73    }
74
75    /// Get git info refs for the repository, with explicit service type
76    pub async fn git_info_refs(
77        &self,
78        service_type: ServiceType,
79    ) -> Result<BytesMut, ProtocolError> {
80        let refs =
81            self.repo_storage.get_repository_refs().await.map_err(|e| {
82                ProtocolError::repository_error(format!("Failed to get refs: {}", e))
83            })?;
84
85        // Convert to the expected format (head_hash, git_refs)
86        let head_hash = refs
87            .iter()
88            .find(|(name, _)| {
89                name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
90            })
91            .map(|(_, hash)| hash.clone())
92            .unwrap_or_else(|| "0000000000000000000000000000000000000000".to_string());
93
94        let git_refs: Vec<super::types::GitRef> = refs
95            .into_iter()
96            .map(|(name, hash)| super::types::GitRef { name, hash })
97            .collect();
98
99        // Determine capabilities based on service type
100        let cap_list = match service_type {
101            ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}"),
102            ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}"),
103        };
104
105        // The stream MUST include capability declarations behind a NUL on the first ref.
106        let name = if head_hash == ZERO_ID {
107            "capabilities^{}"
108        } else {
109            "HEAD"
110        };
111        let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
112        let mut ref_list = vec![pkt_line];
113
114        for git_ref in git_refs {
115            let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
116            ref_list.push(pkt_line);
117        }
118
119        let pkt_line_stream =
120            build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
121        tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
122        Ok(pkt_line_stream)
123    }
124
125    /// Handle git-upload-pack request
126    pub async fn git_upload_pack(
127        &mut self,
128        upload_request: Bytes,
129    ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
130        let mut upload_request = upload_request;
131        let mut want: Vec<String> = Vec::new();
132        let mut have: Vec<String> = Vec::new();
133        let mut last_common_commit = String::new();
134
135        let mut read_first_line = false;
136        loop {
137            let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
138
139            if bytes_take == 0 {
140                break;
141            }
142
143            if pkt_line.is_empty() {
144                break;
145            }
146
147            let mut pkt_line = pkt_line;
148            let command = read_until_white_space(&mut pkt_line);
149
150            match command.as_str() {
151                "want" => {
152                    let hash = read_until_white_space(&mut pkt_line);
153                    want.push(hash);
154                    if !read_first_line {
155                        let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
156                        self.parse_capabilities(&cap_str);
157                        read_first_line = true;
158                    }
159                }
160                "have" => {
161                    let hash = read_until_white_space(&mut pkt_line);
162                    have.push(hash);
163                }
164                "done" => {
165                    break;
166                }
167                _ => {
168                    tracing::warn!("Unknown upload-pack command: {}", command);
169                }
170            }
171        }
172
173        let mut protocol_buf = BytesMut::new();
174
175        // Create pack generator for this operation
176        let pack_generator = PackGenerator::new(&self.repo_storage);
177
178        if have.is_empty() {
179            // Full pack
180            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
181            let pack_stream = pack_generator.generate_full_pack(want).await?;
182            return Ok((pack_stream, protocol_buf));
183        }
184
185        // Check for common commits
186        for hash in &have {
187            let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
188                ProtocolError::repository_error(format!("Failed to check commit existence: {}", e))
189            })?;
190            if exists {
191                add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
192                if last_common_commit.is_empty() {
193                    last_common_commit = hash.clone();
194                }
195            }
196        }
197
198        if last_common_commit.is_empty() {
199            // No common commits found
200            add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
201            let pack_stream = pack_generator.generate_full_pack(want).await?;
202            return Ok((pack_stream, protocol_buf));
203        }
204
205        // Generate incremental pack
206        add_pkt_line_string(
207            &mut protocol_buf,
208            format!("ACK {last_common_commit} ready\n"),
209        );
210        protocol_buf.put(&PKT_LINE_END_MARKER[..]);
211
212        add_pkt_line_string(&mut protocol_buf, format!("ACK {last_common_commit} \n"));
213
214        let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
215
216        Ok((pack_stream, protocol_buf))
217    }
218
219    /// Parse receive pack commands from protocol bytes
220    pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
221        loop {
222            let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
223
224            if bytes_take == 0 {
225                break;
226            }
227
228            if pkt_line.is_empty() {
229                break;
230            }
231
232            let ref_command = self.parse_ref_command(&mut pkt_line.clone());
233            self.command_list.push(ref_command);
234        }
235    }
236
237    /// Handle git receive-pack operation (push)
238    pub async fn git_receive_pack_stream(
239        &mut self,
240        data_stream: ProtocolStream,
241    ) -> Result<Bytes, ProtocolError> {
242        // Collect all pack data from stream
243        let mut pack_data = BytesMut::new();
244        let mut stream = data_stream;
245
246        while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
247            let chunk = chunk_result
248                .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {}", e)))?;
249            pack_data.extend_from_slice(&chunk);
250        }
251
252        // Create pack generator for unpacking
253        let pack_generator = PackGenerator::new(&self.repo_storage);
254
255        // Unpack the received data
256        let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data.freeze()).await?;
257
258        // Store the unpacked objects via the repository access trait
259        self.repo_storage
260            .handle_pack_objects(commits, trees, blobs)
261            .await
262            .map_err(|e| {
263                ProtocolError::repository_error(format!("Failed to store pack objects: {}", e))
264            })?;
265
266        // Build status report
267        let mut report_status = BytesMut::new();
268        add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
269
270        let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
271            ProtocolError::repository_error(format!("Failed to check default branch: {}", e))
272        })?;
273
274        // Update refs with proper error handling
275        for command in &mut self.command_list {
276            if command.ref_type == RefTypeEnum::Tag {
277                // Just update if refs type is tag
278                // Convert ZERO_ID to None for old hash
279                let old_hash = if command.old_hash == ZERO_ID {
280                    None
281                } else {
282                    Some(command.old_hash.as_str())
283                };
284                if let Err(e) = self
285                    .repo_storage
286                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
287                    .await
288                {
289                    command.failed(e.to_string());
290                }
291            } else {
292                // Handle default branch setting for the first branch
293                if !default_exist {
294                    command.default_branch = true;
295                    default_exist = true;
296                }
297                // Convert ZERO_ID to None for old hash
298                let old_hash = if command.old_hash == ZERO_ID {
299                    None
300                } else {
301                    Some(command.old_hash.as_str())
302                };
303                if let Err(e) = self
304                    .repo_storage
305                    .update_reference(&command.ref_name, old_hash, &command.new_hash)
306                    .await
307                {
308                    command.failed(e.to_string());
309                }
310            }
311            add_pkt_line_string(&mut report_status, command.get_status());
312        }
313
314        // Post-receive hook
315        self.repo_storage.post_receive_hook().await.map_err(|e| {
316            ProtocolError::repository_error(format!("Post-receive hook failed: {}", e))
317        })?;
318
319        report_status.put(&PKT_LINE_END_MARKER[..]);
320        Ok(report_status.freeze())
321    }
322
323    /// Builds the packet data in the sideband format if the SideBand/64k capability is enabled.
324    pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
325        let mut to_bytes = BytesMut::new();
326        if self.capabilities.contains(&Capability::SideBand)
327            || self.capabilities.contains(&Capability::SideBand64k)
328        {
329            let length = length + 5;
330            to_bytes.put(Bytes::from(format!("{length:04x}")));
331            to_bytes.put_u8(SideBand::PackfileData.value());
332            to_bytes.put(from_bytes);
333        } else {
334            to_bytes.put(from_bytes);
335        }
336        to_bytes
337    }
338
339    /// Parse capabilities from capability string
340    pub fn parse_capabilities(&mut self, cap_str: &str) {
341        for cap in cap_str.split_whitespace() {
342            if let Ok(capability) = cap.parse::<Capability>() {
343                self.capabilities.push(capability);
344            }
345        }
346    }
347
348    /// Parse a reference command from packet line
349    pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
350        let old_id = read_until_white_space(pkt_line);
351        let new_id = read_until_white_space(pkt_line);
352        let ref_name = read_until_white_space(pkt_line);
353        let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
354
355        RefCommand::new(old_id, new_id, ref_name)
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use crate::internal::object::blob::Blob;
363    use crate::internal::object::commit::Commit;
364    use crate::internal::object::signature::{Signature, SignatureType};
365    use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
366    use crate::internal::pack::{encode::PackEncoder, entry::Entry};
367    use crate::protocol::types::{RefCommand, ZERO_ID}; // import sibling types
368    use crate::protocol::utils; // import sibling module
369    use async_trait::async_trait;
370    use bytes::Bytes;
371    use futures;
372    use std::sync::{
373        Arc, Mutex,
374        atomic::{AtomicBool, Ordering},
375    };
376    use tokio::sync::mpsc;
377
378    // Simplify complex type via aliases to satisfy clippy::type_complexity
379    type UpdateRecord = (String, Option<String>, String);
380    type UpdateList = Vec<UpdateRecord>;
381    type SharedUpdates = Arc<Mutex<UpdateList>>;
382
383    #[derive(Clone)]
384    struct TestRepoAccess {
385        updates: SharedUpdates,
386        stored_count: Arc<Mutex<usize>>,
387        default_branch_exists: Arc<Mutex<bool>>,
388        post_called: Arc<AtomicBool>,
389    }
390
391    impl TestRepoAccess {
392        fn new() -> Self {
393            Self {
394                updates: Arc::new(Mutex::new(vec![])),
395                stored_count: Arc::new(Mutex::new(0)),
396                default_branch_exists: Arc::new(Mutex::new(false)),
397                post_called: Arc::new(AtomicBool::new(false)),
398            }
399        }
400
401        fn updates_len(&self) -> usize {
402            self.updates.lock().unwrap().len()
403        }
404
405        fn post_hook_called(&self) -> bool {
406            self.post_called.load(Ordering::SeqCst)
407        }
408    }
409
410    #[async_trait]
411    impl RepositoryAccess for TestRepoAccess {
412        async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
413            Ok(vec![
414                (
415                    "HEAD".to_string(),
416                    "0000000000000000000000000000000000000000".to_string(),
417                ),
418                (
419                    "refs/heads/main".to_string(),
420                    "1111111111111111111111111111111111111111".to_string(),
421                ),
422            ])
423        }
424
425        async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
426            Ok(true)
427        }
428
429        async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
430            Ok(vec![])
431        }
432
433        async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
434            *self.stored_count.lock().unwrap() += 1;
435            Ok(())
436        }
437
438        async fn update_reference(
439            &self,
440            ref_name: &str,
441            old_hash: Option<&str>,
442            new_hash: &str,
443        ) -> Result<(), ProtocolError> {
444            self.updates.lock().unwrap().push((
445                ref_name.to_string(),
446                old_hash.map(|s| s.to_string()),
447                new_hash.to_string(),
448            ));
449            Ok(())
450        }
451
452        async fn get_objects_for_pack(
453            &self,
454            _wants: &[String],
455            _haves: &[String],
456        ) -> Result<Vec<String>, ProtocolError> {
457            Ok(vec![])
458        }
459
460        async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
461            let mut exists = self.default_branch_exists.lock().unwrap();
462            let current = *exists;
463            *exists = true; // flip to true after first check
464            Ok(current)
465        }
466
467        async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
468            self.post_called.store(true, Ordering::SeqCst);
469            Ok(())
470        }
471    }
472
473    struct TestAuth;
474
475    #[async_trait]
476    impl AuthenticationService for TestAuth {
477        async fn authenticate_http(
478            &self,
479            _headers: &std::collections::HashMap<String, String>,
480        ) -> Result<(), ProtocolError> {
481            Ok(())
482        }
483
484        async fn authenticate_ssh(
485            &self,
486            _username: &str,
487            _public_key: &[u8],
488        ) -> Result<(), ProtocolError> {
489            Ok(())
490        }
491    }
492
493    #[tokio::test]
494    async fn test_receive_pack_stream_status_report() {
495        // Build simple objects
496        let blob1 = Blob::from_content("hello");
497        let blob2 = Blob::from_content("world");
498
499        let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
500        let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
501        let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
502
503        let author = Signature::new(
504            SignatureType::Author,
505            "tester".to_string(),
506            "tester@example.com".to_string(),
507        );
508        let committer = Signature::new(
509            SignatureType::Committer,
510            "tester".to_string(),
511            "tester@example.com".to_string(),
512        );
513        let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
514
515        // Encode pack bytes via PackEncoder
516        let (pack_tx, mut pack_rx) = mpsc::channel(1024);
517        let (entry_tx, entry_rx) = mpsc::channel(1024);
518        let mut encoder = PackEncoder::new(4, 10, pack_tx);
519
520        tokio::spawn(async move {
521            if let Err(e) = encoder.encode(entry_rx).await {
522                panic!("Failed to encode pack: {}", e);
523            }
524        });
525
526        let commit_clone = commit.clone();
527        let tree_clone = tree.clone();
528        let blob1_clone = blob1.clone();
529        let blob2_clone = blob2.clone();
530        tokio::spawn(async move {
531            let _ = entry_tx.send(Entry::from(commit_clone)).await;
532            let _ = entry_tx.send(Entry::from(tree_clone)).await;
533            let _ = entry_tx.send(Entry::from(blob1_clone)).await;
534            let _ = entry_tx.send(Entry::from(blob2_clone)).await;
535            // sender drop indicates end
536        });
537
538        let mut pack_bytes: Vec<u8> = Vec::new();
539        while let Some(chunk) = pack_rx.recv().await {
540            pack_bytes.extend_from_slice(&chunk);
541        }
542
543        // Prepare protocol and command
544        let repo_access = TestRepoAccess::new();
545        let auth = TestAuth;
546        let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
547        smart.command_list.push(RefCommand::new(
548            ZERO_ID.to_string(),
549            commit.id.to_string(),
550            "refs/heads/main".to_string(),
551        ));
552
553        // Create request stream
554        let request_stream = Box::pin(futures::stream::once(async { Ok(Bytes::from(pack_bytes)) }));
555
556        // Execute receive-pack
557        let result_bytes = smart
558            .git_receive_pack_stream(request_stream)
559            .await
560            .expect("receive-pack should succeed");
561
562        // Verify pkt-lines
563        let mut out = result_bytes.clone();
564        let (_c1, l1) = utils::read_pkt_line(&mut out);
565        assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
566
567        let (_c2, l2) = utils::read_pkt_line(&mut out);
568        assert_eq!(
569            String::from_utf8(l2.to_vec()).unwrap(),
570            "ok refs/heads/main"
571        );
572
573        let (c3, l3) = utils::read_pkt_line(&mut out);
574        assert_eq!(c3, 4);
575        assert!(l3.is_empty());
576
577        // Verify side effects
578        assert_eq!(repo_access.updates_len(), 1);
579        assert!(repo_access.post_hook_called());
580    }
581}