1use bytes::{BufMut, Bytes, BytesMut};
2use std::collections::HashMap;
3use tokio_stream::wrappers::ReceiverStream;
4
5use super::core::{AuthenticationService, RepositoryAccess};
6use super::pack::PackGenerator;
7use super::types::ProtocolError;
8use super::types::{
9 COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolStream, RECEIVE_CAP_LIST,
10 RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol, UPLOAD_CAP_LIST,
11};
12use super::utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space};
13use crate::hash::{HashKind, ObjectHash, get_hash_kind};
14pub struct SmartProtocol<R, A>
19where
20 R: RepositoryAccess,
21 A: AuthenticationService,
22{
23 pub transport_protocol: TransportProtocol,
24 pub capabilities: Vec<Capability>,
25 pub side_band: Option<SideBand>,
26 pub command_list: Vec<RefCommand>,
27 pub wire_hash_kind: HashKind,
28 pub local_hash_kind: HashKind,
29 pub zero_id: String,
30 repo_storage: R,
32 auth_service: A,
33}
34
35impl<R, A> SmartProtocol<R, A>
36where
37 R: RepositoryAccess,
38 A: AuthenticationService,
39{
40 pub fn set_wire_hash_kind(&mut self, kind: HashKind) {
41 self.wire_hash_kind = kind;
42 self.zero_id = ObjectHash::zero_str(kind);
43 }
44
45 pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
47 Self {
48 transport_protocol,
49 capabilities: Vec::new(),
50 side_band: None,
51 command_list: Vec::new(),
52 repo_storage,
53 auth_service,
54 wire_hash_kind: HashKind::default(), local_hash_kind: get_hash_kind(),
56 zero_id: ObjectHash::zero_str(HashKind::default()),
57 }
58 }
59
60 pub async fn authenticate_http(
62 &self,
63 headers: &HashMap<String, String>,
64 ) -> Result<(), ProtocolError> {
65 self.auth_service.authenticate_http(headers).await
66 }
67
68 pub async fn authenticate_ssh(
70 &self,
71 username: &str,
72 public_key: &[u8],
73 ) -> Result<(), ProtocolError> {
74 self.auth_service
75 .authenticate_ssh(username, public_key)
76 .await
77 }
78
79 pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
81 self.transport_protocol = protocol;
82 }
83
84 pub async fn git_info_refs(
86 &self,
87 service_type: ServiceType,
88 ) -> Result<BytesMut, ProtocolError> {
89 let refs =
90 self.repo_storage.get_repository_refs().await.map_err(|e| {
91 ProtocolError::repository_error(format!("Failed to get refs: {}", e))
92 })?;
93 let hex_len = self.wire_hash_kind.hex_len();
94 for (name, h) in &refs {
95 if h.len() != hex_len {
96 return Err(ProtocolError::invalid_request(&format!(
97 "Hash length mismatch for ref {}: expected {}, got {}",
98 name,
99 hex_len,
100 h.len()
101 )));
102 }
103 } let head_hash = refs
106 .iter()
107 .find(|(name, _)| {
108 name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
109 })
110 .map(|(_, hash)| hash.clone())
111 .unwrap_or_else(|| self.zero_id.clone());
112
113 let git_refs: Vec<super::types::GitRef> = refs
114 .into_iter()
115 .map(|(name, hash)| super::types::GitRef { name, hash })
116 .collect();
117 let format_cap = match self.wire_hash_kind {
119 HashKind::Sha1 => " object-format=sha1",
120 HashKind::Sha256 => " object-format=sha256",
121 };
122 let cap_list = match service_type {
124 ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
125 ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}{format_cap}"),
126 };
127
128 let name = if head_hash == self.zero_id {
130 "capabilities^{}"
131 } else {
132 "HEAD"
133 };
134 let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
135 let mut ref_list = vec![pkt_line];
136
137 for git_ref in git_refs {
138 let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
139 ref_list.push(pkt_line);
140 }
141
142 let pkt_line_stream =
143 build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
144 tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
145 Ok(pkt_line_stream)
146 }
147
148 pub async fn git_upload_pack(
150 &mut self,
151 upload_request: Bytes,
152 ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
153 let mut upload_request = upload_request;
154 let mut want: Vec<String> = Vec::new();
155 let mut have: Vec<String> = Vec::new();
156 let mut last_common_commit = String::new();
157
158 let mut read_first_line = false;
159 loop {
160 let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
161
162 if bytes_take == 0 {
163 break;
164 }
165
166 if pkt_line.is_empty() {
167 break;
168 }
169
170 let mut pkt_line = pkt_line;
171 let command = read_until_white_space(&mut pkt_line);
172
173 match command.as_str() {
174 "want" => {
175 let hash = read_until_white_space(&mut pkt_line);
176 want.push(hash);
177 if !read_first_line {
178 let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
179 self.parse_capabilities(&cap_str);
180 read_first_line = true;
181 }
182 }
183 "have" => {
184 let hash = read_until_white_space(&mut pkt_line);
185 have.push(hash);
186 }
187 "done" => {
188 break;
189 }
190 _ => {
191 tracing::warn!("Unknown upload-pack command: {}", command);
192 }
193 }
194 }
195
196 let mut protocol_buf = BytesMut::new();
197
198 let pack_generator = PackGenerator::new(&self.repo_storage);
200
201 if have.is_empty() {
202 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
204 let pack_stream = pack_generator.generate_full_pack(want).await?;
205 return Ok((pack_stream, protocol_buf));
206 }
207
208 for hash in &have {
210 let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
211 ProtocolError::repository_error(format!("Failed to check commit existence: {}", e))
212 })?;
213 if exists {
214 add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
215 if last_common_commit.is_empty() {
216 last_common_commit = hash.clone();
217 }
218 }
219 }
220
221 if last_common_commit.is_empty() {
222 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
224 let pack_stream = pack_generator.generate_full_pack(want).await?;
225 return Ok((pack_stream, protocol_buf));
226 }
227
228 add_pkt_line_string(
230 &mut protocol_buf,
231 format!("ACK {last_common_commit} ready\n"),
232 );
233 protocol_buf.put(&PKT_LINE_END_MARKER[..]);
234
235 add_pkt_line_string(&mut protocol_buf, format!("ACK {last_common_commit} \n"));
236
237 let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
238
239 Ok((pack_stream, protocol_buf))
240 }
241
242 pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
244 loop {
245 let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
246
247 if bytes_take == 0 {
248 break;
249 }
250
251 if pkt_line.is_empty() {
252 break;
253 }
254
255 let ref_command = self.parse_ref_command(&mut pkt_line.clone());
256 self.command_list.push(ref_command);
257 }
258 }
259
260 pub async fn git_receive_pack_stream(
262 &mut self,
263 data_stream: ProtocolStream,
264 ) -> Result<Bytes, ProtocolError> {
265 let mut pack_data = BytesMut::new();
267 let mut stream = data_stream;
268
269 while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
270 let chunk = chunk_result
271 .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {}", e)))?;
272 pack_data.extend_from_slice(&chunk);
273 }
274
275 let pack_generator = PackGenerator::new(&self.repo_storage);
277
278 let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data.freeze()).await?;
280
281 self.repo_storage
283 .handle_pack_objects(commits, trees, blobs)
284 .await
285 .map_err(|e| {
286 ProtocolError::repository_error(format!("Failed to store pack objects: {}", e))
287 })?;
288
289 let mut report_status = BytesMut::new();
291 add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
292
293 let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
294 ProtocolError::repository_error(format!("Failed to check default branch: {}", e))
295 })?;
296
297 for command in &mut self.command_list {
299 if command.ref_type == RefTypeEnum::Tag {
300 let old_hash = if command.old_hash == self.zero_id {
303 None
304 } else {
305 Some(command.old_hash.as_str())
306 };
307 if let Err(e) = self
308 .repo_storage
309 .update_reference(&command.ref_name, old_hash, &command.new_hash)
310 .await
311 {
312 command.failed(e.to_string());
313 }
314 } else {
315 if !default_exist {
317 command.default_branch = true;
318 default_exist = true;
319 }
320 let old_hash = if command.old_hash == self.zero_id {
322 None
323 } else {
324 Some(command.old_hash.as_str())
325 };
326 if let Err(e) = self
327 .repo_storage
328 .update_reference(&command.ref_name, old_hash, &command.new_hash)
329 .await
330 {
331 command.failed(e.to_string());
332 }
333 }
334 add_pkt_line_string(&mut report_status, command.get_status());
335 }
336
337 self.repo_storage.post_receive_hook().await.map_err(|e| {
339 ProtocolError::repository_error(format!("Post-receive hook failed: {}", e))
340 })?;
341
342 report_status.put(&PKT_LINE_END_MARKER[..]);
343 Ok(report_status.freeze())
344 }
345
346 pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
348 let mut to_bytes = BytesMut::new();
349 if self.capabilities.contains(&Capability::SideBand)
350 || self.capabilities.contains(&Capability::SideBand64k)
351 {
352 let length = length + 5;
353 to_bytes.put(Bytes::from(format!("{length:04x}")));
354 to_bytes.put_u8(SideBand::PackfileData.value());
355 to_bytes.put(from_bytes);
356 } else {
357 to_bytes.put(from_bytes);
358 }
359 to_bytes
360 }
361
362 pub fn parse_capabilities(&mut self, cap_str: &str) {
364 for cap in cap_str.split_whitespace() {
365 if let Some(fmt) = cap.strip_prefix("object-format=") {
366 match fmt {
367 "sha1" => self.set_wire_hash_kind(HashKind::Sha1),
368 "sha256" => self.set_wire_hash_kind(HashKind::Sha256),
369 _ => {
370 tracing::warn!("Unknown object-format capability: {}", fmt);
371 }
372 }
373 continue;
374 }
375 if let Ok(capability) = cap.parse::<Capability>() {
376 self.capabilities.push(capability);
377 }
378 }
379 }
380
381 pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
383 let old_id = read_until_white_space(pkt_line);
384 let new_id = read_until_white_space(pkt_line);
385 let ref_name = read_until_white_space(pkt_line);
386 let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
387
388 RefCommand::new(old_id, new_id, ref_name)
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::hash::{HashKind, set_hash_kind_for_test};
396 use crate::internal::metadata::{EntryMeta, MetaAttached};
397 use crate::internal::object::blob::Blob;
398 use crate::internal::object::commit::Commit;
399 use crate::internal::object::signature::{Signature, SignatureType};
400 use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
401 use crate::internal::pack::{encode::PackEncoder, entry::Entry};
402 use crate::protocol::types::RefCommand; use crate::protocol::utils; use async_trait::async_trait;
405 use bytes::Bytes;
406 use futures;
407 use std::sync::{
408 Arc, Mutex,
409 atomic::{AtomicBool, Ordering},
410 };
411 use tokio::sync::mpsc;
412
413 type UpdateRecord = (String, Option<String>, String);
415 type UpdateList = Vec<UpdateRecord>;
416 type SharedUpdates = Arc<Mutex<UpdateList>>;
417
418 #[derive(Clone)]
419 struct TestRepoAccess {
420 updates: SharedUpdates,
421 stored_count: Arc<Mutex<usize>>,
422 default_branch_exists: Arc<Mutex<bool>>,
423 post_called: Arc<AtomicBool>,
424 }
425
426 impl TestRepoAccess {
427 fn new() -> Self {
428 Self {
429 updates: Arc::new(Mutex::new(vec![])),
430 stored_count: Arc::new(Mutex::new(0)),
431 default_branch_exists: Arc::new(Mutex::new(false)),
432 post_called: Arc::new(AtomicBool::new(false)),
433 }
434 }
435
436 fn updates_len(&self) -> usize {
437 self.updates.lock().unwrap().len()
438 }
439
440 fn post_hook_called(&self) -> bool {
441 self.post_called.load(Ordering::SeqCst)
442 }
443 }
444
445 #[async_trait]
446 impl RepositoryAccess for TestRepoAccess {
447 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
448 Ok(vec![
449 (
450 "HEAD".to_string(),
451 "0000000000000000000000000000000000000000".to_string(),
452 ),
453 (
454 "refs/heads/main".to_string(),
455 "1111111111111111111111111111111111111111".to_string(),
456 ),
457 ])
458 }
459
460 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
461 Ok(true)
462 }
463
464 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
465 Ok(vec![])
466 }
467
468 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
469 *self.stored_count.lock().unwrap() += 1;
470 Ok(())
471 }
472
473 async fn update_reference(
474 &self,
475 ref_name: &str,
476 old_hash: Option<&str>,
477 new_hash: &str,
478 ) -> Result<(), ProtocolError> {
479 self.updates.lock().unwrap().push((
480 ref_name.to_string(),
481 old_hash.map(|s| s.to_string()),
482 new_hash.to_string(),
483 ));
484 Ok(())
485 }
486
487 async fn get_objects_for_pack(
488 &self,
489 _wants: &[String],
490 _haves: &[String],
491 ) -> Result<Vec<String>, ProtocolError> {
492 Ok(vec![])
493 }
494
495 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
496 let mut exists = self.default_branch_exists.lock().unwrap();
497 let current = *exists;
498 *exists = true; Ok(current)
500 }
501
502 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
503 self.post_called.store(true, Ordering::SeqCst);
504 Ok(())
505 }
506 }
507
508 struct TestAuth;
509
510 #[async_trait]
511 impl AuthenticationService for TestAuth {
512 async fn authenticate_http(
513 &self,
514 _headers: &std::collections::HashMap<String, String>,
515 ) -> Result<(), ProtocolError> {
516 Ok(())
517 }
518
519 async fn authenticate_ssh(
520 &self,
521 _username: &str,
522 _public_key: &[u8],
523 ) -> Result<(), ProtocolError> {
524 Ok(())
525 }
526 }
527
528 #[tokio::test]
529 async fn test_receive_pack_stream_status_report() {
530 let _guard = set_hash_kind_for_test(HashKind::Sha1);
531 let blob1 = Blob::from_content("hello");
533 let blob2 = Blob::from_content("world");
534
535 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
536 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
537 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
538
539 let author = Signature::new(
540 SignatureType::Author,
541 "tester".to_string(),
542 "tester@example.com".to_string(),
543 );
544 let committer = Signature::new(
545 SignatureType::Committer,
546 "tester".to_string(),
547 "tester@example.com".to_string(),
548 );
549 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
550
551 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
553 let (entry_tx, entry_rx) = mpsc::channel(1024);
554 let mut encoder = PackEncoder::new(4, 10, pack_tx);
555
556 tokio::spawn(async move {
557 if let Err(e) = encoder.encode(entry_rx).await {
558 panic!("Failed to encode pack: {}", e);
559 }
560 });
561
562 let commit_clone = commit.clone();
563 let tree_clone = tree.clone();
564 let blob1_clone = blob1.clone();
565 let blob2_clone = blob2.clone();
566 tokio::spawn(async move {
567 let _ = entry_tx
568 .send(MetaAttached {
569 inner: Entry::from(commit_clone),
570 meta: EntryMeta::new(),
571 })
572 .await;
573 let _ = entry_tx
574 .send(MetaAttached {
575 inner: Entry::from(tree_clone),
576 meta: EntryMeta::new(),
577 })
578 .await;
579 let _ = entry_tx
580 .send(MetaAttached {
581 inner: Entry::from(blob1_clone),
582 meta: EntryMeta::new(),
583 })
584 .await;
585 let _ = entry_tx
586 .send(MetaAttached {
587 inner: Entry::from(blob2_clone),
588 meta: EntryMeta::new(),
589 })
590 .await;
591 });
593
594 let mut pack_bytes: Vec<u8> = Vec::new();
595 while let Some(chunk) = pack_rx.recv().await {
596 pack_bytes.extend_from_slice(&chunk);
597 }
598
599 let repo_access = TestRepoAccess::new();
601 let auth = TestAuth;
602 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
603 smart.set_wire_hash_kind(HashKind::Sha1);
604 smart.command_list.push(RefCommand::new(
605 smart.zero_id.to_string(),
606 commit.id.to_string(),
607 "refs/heads/main".to_string(),
608 ));
609
610 let request_stream = Box::pin(futures::stream::once(async { Ok(Bytes::from(pack_bytes)) }));
612
613 let result_bytes = smart
615 .git_receive_pack_stream(request_stream)
616 .await
617 .expect("receive-pack should succeed");
618
619 let mut out = result_bytes.clone();
621 let (_c1, l1) = utils::read_pkt_line(&mut out);
622 assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
623
624 let (_c2, l2) = utils::read_pkt_line(&mut out);
625 assert_eq!(
626 String::from_utf8(l2.to_vec()).unwrap(),
627 "ok refs/heads/main"
628 );
629
630 let (c3, l3) = utils::read_pkt_line(&mut out);
631 assert_eq!(c3, 4);
632 assert!(l3.is_empty());
633
634 assert_eq!(repo_access.updates_len(), 1);
636 assert!(repo_access.post_hook_called());
637 }
638
639 #[tokio::test]
640 async fn info_refs_rejects_sha256_with_sha1_refs() {
641 let _guard = set_hash_kind_for_test(HashKind::Sha1); let repo_access = TestRepoAccess::new(); let auth = TestAuth;
644 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access, auth);
645 smart.set_wire_hash_kind(HashKind::Sha256); let res = smart.git_info_refs(ServiceType::UploadPack).await;
648 assert!(res.is_err(), "expected failure when hash lengths mismatch");
649
650 smart.set_wire_hash_kind(HashKind::Sha1);
651
652 let res = smart.git_info_refs(ServiceType::UploadPack).await;
653 assert!(
654 res.is_ok(),
655 "expected SHA1 refs to be accepted when wire is SHA1"
656 );
657 }
658}