1use bytes::{BufMut, Bytes, BytesMut};
2use std::collections::HashMap;
3use tokio_stream::wrappers::ReceiverStream;
4
5use super::core::{AuthenticationService, RepositoryAccess};
6use super::pack::PackGenerator;
7use super::types::ProtocolError;
8use super::types::{
9 COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolStream, RECEIVE_CAP_LIST,
10 RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol, UPLOAD_CAP_LIST,
11 ZERO_ID,
12};
13use super::utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space};
14
15pub struct SmartProtocol<R, A>
20where
21 R: RepositoryAccess,
22 A: AuthenticationService,
23{
24 pub transport_protocol: TransportProtocol,
25 pub capabilities: Vec<Capability>,
26 pub side_band: Option<SideBand>,
27 pub command_list: Vec<RefCommand>,
28
29 repo_storage: R,
31 auth_service: A,
32}
33
34impl<R, A> SmartProtocol<R, A>
35where
36 R: RepositoryAccess,
37 A: AuthenticationService,
38{
39 pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
41 Self {
42 transport_protocol,
43 capabilities: Vec::new(),
44 side_band: None,
45 command_list: Vec::new(),
46 repo_storage,
47 auth_service,
48 }
49 }
50
51 pub async fn authenticate_http(
53 &self,
54 headers: &HashMap<String, String>,
55 ) -> Result<(), ProtocolError> {
56 self.auth_service.authenticate_http(headers).await
57 }
58
59 pub async fn authenticate_ssh(
61 &self,
62 username: &str,
63 public_key: &[u8],
64 ) -> Result<(), ProtocolError> {
65 self.auth_service
66 .authenticate_ssh(username, public_key)
67 .await
68 }
69
70 pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
72 self.transport_protocol = protocol;
73 }
74
75 pub async fn git_info_refs(
77 &self,
78 service_type: ServiceType,
79 ) -> Result<BytesMut, ProtocolError> {
80 let refs =
81 self.repo_storage.get_repository_refs().await.map_err(|e| {
82 ProtocolError::repository_error(format!("Failed to get refs: {}", e))
83 })?;
84
85 let head_hash = refs
87 .iter()
88 .find(|(name, _)| {
89 name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
90 })
91 .map(|(_, hash)| hash.clone())
92 .unwrap_or_else(|| "0000000000000000000000000000000000000000".to_string());
93
94 let git_refs: Vec<super::types::GitRef> = refs
95 .into_iter()
96 .map(|(name, hash)| super::types::GitRef { name, hash })
97 .collect();
98
99 let cap_list = match service_type {
101 ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}"),
102 ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}"),
103 };
104
105 let name = if head_hash == ZERO_ID {
107 "capabilities^{}"
108 } else {
109 "HEAD"
110 };
111 let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
112 let mut ref_list = vec![pkt_line];
113
114 for git_ref in git_refs {
115 let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
116 ref_list.push(pkt_line);
117 }
118
119 let pkt_line_stream =
120 build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
121 tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
122 Ok(pkt_line_stream)
123 }
124
125 pub async fn git_upload_pack(
127 &mut self,
128 upload_request: Bytes,
129 ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
130 let mut upload_request = upload_request;
131 let mut want: Vec<String> = Vec::new();
132 let mut have: Vec<String> = Vec::new();
133 let mut last_common_commit = String::new();
134
135 let mut read_first_line = false;
136 loop {
137 let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
138
139 if bytes_take == 0 {
140 break;
141 }
142
143 if pkt_line.is_empty() {
144 break;
145 }
146
147 let mut pkt_line = pkt_line;
148 let command = read_until_white_space(&mut pkt_line);
149
150 match command.as_str() {
151 "want" => {
152 let hash = read_until_white_space(&mut pkt_line);
153 want.push(hash);
154 if !read_first_line {
155 let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
156 self.parse_capabilities(&cap_str);
157 read_first_line = true;
158 }
159 }
160 "have" => {
161 let hash = read_until_white_space(&mut pkt_line);
162 have.push(hash);
163 }
164 "done" => {
165 break;
166 }
167 _ => {
168 tracing::warn!("Unknown upload-pack command: {}", command);
169 }
170 }
171 }
172
173 let mut protocol_buf = BytesMut::new();
174
175 let pack_generator = PackGenerator::new(&self.repo_storage);
177
178 if have.is_empty() {
179 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
181 let pack_stream = pack_generator.generate_full_pack(want).await?;
182 return Ok((pack_stream, protocol_buf));
183 }
184
185 for hash in &have {
187 let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
188 ProtocolError::repository_error(format!("Failed to check commit existence: {}", e))
189 })?;
190 if exists {
191 add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
192 if last_common_commit.is_empty() {
193 last_common_commit = hash.clone();
194 }
195 }
196 }
197
198 if last_common_commit.is_empty() {
199 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
201 let pack_stream = pack_generator.generate_full_pack(want).await?;
202 return Ok((pack_stream, protocol_buf));
203 }
204
205 add_pkt_line_string(
207 &mut protocol_buf,
208 format!("ACK {last_common_commit} ready\n"),
209 );
210 protocol_buf.put(&PKT_LINE_END_MARKER[..]);
211
212 add_pkt_line_string(&mut protocol_buf, format!("ACK {last_common_commit} \n"));
213
214 let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
215
216 Ok((pack_stream, protocol_buf))
217 }
218
219 pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
221 loop {
222 let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
223
224 if bytes_take == 0 {
225 break;
226 }
227
228 if pkt_line.is_empty() {
229 break;
230 }
231
232 let ref_command = self.parse_ref_command(&mut pkt_line.clone());
233 self.command_list.push(ref_command);
234 }
235 }
236
237 pub async fn git_receive_pack_stream(
239 &mut self,
240 data_stream: ProtocolStream,
241 ) -> Result<Bytes, ProtocolError> {
242 let mut pack_data = BytesMut::new();
244 let mut stream = data_stream;
245
246 while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
247 let chunk = chunk_result
248 .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {}", e)))?;
249 pack_data.extend_from_slice(&chunk);
250 }
251
252 let pack_generator = PackGenerator::new(&self.repo_storage);
254
255 let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data.freeze()).await?;
257
258 self.repo_storage
260 .handle_pack_objects(commits, trees, blobs)
261 .await
262 .map_err(|e| {
263 ProtocolError::repository_error(format!("Failed to store pack objects: {}", e))
264 })?;
265
266 let mut report_status = BytesMut::new();
268 add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
269
270 let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
271 ProtocolError::repository_error(format!("Failed to check default branch: {}", e))
272 })?;
273
274 for command in &mut self.command_list {
276 if command.ref_type == RefTypeEnum::Tag {
277 let old_hash = if command.old_hash == ZERO_ID {
280 None
281 } else {
282 Some(command.old_hash.as_str())
283 };
284 if let Err(e) = self
285 .repo_storage
286 .update_reference(&command.ref_name, old_hash, &command.new_hash)
287 .await
288 {
289 command.failed(e.to_string());
290 }
291 } else {
292 if !default_exist {
294 command.default_branch = true;
295 default_exist = true;
296 }
297 let old_hash = if command.old_hash == ZERO_ID {
299 None
300 } else {
301 Some(command.old_hash.as_str())
302 };
303 if let Err(e) = self
304 .repo_storage
305 .update_reference(&command.ref_name, old_hash, &command.new_hash)
306 .await
307 {
308 command.failed(e.to_string());
309 }
310 }
311 add_pkt_line_string(&mut report_status, command.get_status());
312 }
313
314 self.repo_storage.post_receive_hook().await.map_err(|e| {
316 ProtocolError::repository_error(format!("Post-receive hook failed: {}", e))
317 })?;
318
319 report_status.put(&PKT_LINE_END_MARKER[..]);
320 Ok(report_status.freeze())
321 }
322
323 pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
325 let mut to_bytes = BytesMut::new();
326 if self.capabilities.contains(&Capability::SideBand)
327 || self.capabilities.contains(&Capability::SideBand64k)
328 {
329 let length = length + 5;
330 to_bytes.put(Bytes::from(format!("{length:04x}")));
331 to_bytes.put_u8(SideBand::PackfileData.value());
332 to_bytes.put(from_bytes);
333 } else {
334 to_bytes.put(from_bytes);
335 }
336 to_bytes
337 }
338
339 pub fn parse_capabilities(&mut self, cap_str: &str) {
341 for cap in cap_str.split_whitespace() {
342 if let Ok(capability) = cap.parse::<Capability>() {
343 self.capabilities.push(capability);
344 }
345 }
346 }
347
348 pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
350 let old_id = read_until_white_space(pkt_line);
351 let new_id = read_until_white_space(pkt_line);
352 let ref_name = read_until_white_space(pkt_line);
353 let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
354
355 RefCommand::new(old_id, new_id, ref_name)
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::internal::metadata::{EntryMeta, MetaAttached};
363 use crate::internal::object::blob::Blob;
364 use crate::internal::object::commit::Commit;
365 use crate::internal::object::signature::{Signature, SignatureType};
366 use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
367 use crate::internal::pack::{encode::PackEncoder, entry::Entry};
368 use crate::protocol::types::{RefCommand, ZERO_ID}; use crate::protocol::utils; use async_trait::async_trait;
371 use bytes::Bytes;
372 use futures;
373 use std::sync::{
374 Arc, Mutex,
375 atomic::{AtomicBool, Ordering},
376 };
377 use tokio::sync::mpsc;
378
379 type UpdateRecord = (String, Option<String>, String);
381 type UpdateList = Vec<UpdateRecord>;
382 type SharedUpdates = Arc<Mutex<UpdateList>>;
383
384 #[derive(Clone)]
385 struct TestRepoAccess {
386 updates: SharedUpdates,
387 stored_count: Arc<Mutex<usize>>,
388 default_branch_exists: Arc<Mutex<bool>>,
389 post_called: Arc<AtomicBool>,
390 }
391
392 impl TestRepoAccess {
393 fn new() -> Self {
394 Self {
395 updates: Arc::new(Mutex::new(vec![])),
396 stored_count: Arc::new(Mutex::new(0)),
397 default_branch_exists: Arc::new(Mutex::new(false)),
398 post_called: Arc::new(AtomicBool::new(false)),
399 }
400 }
401
402 fn updates_len(&self) -> usize {
403 self.updates.lock().unwrap().len()
404 }
405
406 fn post_hook_called(&self) -> bool {
407 self.post_called.load(Ordering::SeqCst)
408 }
409 }
410
411 #[async_trait]
412 impl RepositoryAccess for TestRepoAccess {
413 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
414 Ok(vec![
415 (
416 "HEAD".to_string(),
417 "0000000000000000000000000000000000000000".to_string(),
418 ),
419 (
420 "refs/heads/main".to_string(),
421 "1111111111111111111111111111111111111111".to_string(),
422 ),
423 ])
424 }
425
426 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
427 Ok(true)
428 }
429
430 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
431 Ok(vec![])
432 }
433
434 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
435 *self.stored_count.lock().unwrap() += 1;
436 Ok(())
437 }
438
439 async fn update_reference(
440 &self,
441 ref_name: &str,
442 old_hash: Option<&str>,
443 new_hash: &str,
444 ) -> Result<(), ProtocolError> {
445 self.updates.lock().unwrap().push((
446 ref_name.to_string(),
447 old_hash.map(|s| s.to_string()),
448 new_hash.to_string(),
449 ));
450 Ok(())
451 }
452
453 async fn get_objects_for_pack(
454 &self,
455 _wants: &[String],
456 _haves: &[String],
457 ) -> Result<Vec<String>, ProtocolError> {
458 Ok(vec![])
459 }
460
461 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
462 let mut exists = self.default_branch_exists.lock().unwrap();
463 let current = *exists;
464 *exists = true; Ok(current)
466 }
467
468 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
469 self.post_called.store(true, Ordering::SeqCst);
470 Ok(())
471 }
472 }
473
474 struct TestAuth;
475
476 #[async_trait]
477 impl AuthenticationService for TestAuth {
478 async fn authenticate_http(
479 &self,
480 _headers: &std::collections::HashMap<String, String>,
481 ) -> Result<(), ProtocolError> {
482 Ok(())
483 }
484
485 async fn authenticate_ssh(
486 &self,
487 _username: &str,
488 _public_key: &[u8],
489 ) -> Result<(), ProtocolError> {
490 Ok(())
491 }
492 }
493
494 #[tokio::test]
495 async fn test_receive_pack_stream_status_report() {
496 let blob1 = Blob::from_content("hello");
498 let blob2 = Blob::from_content("world");
499
500 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
501 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
502 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
503
504 let author = Signature::new(
505 SignatureType::Author,
506 "tester".to_string(),
507 "tester@example.com".to_string(),
508 );
509 let committer = Signature::new(
510 SignatureType::Committer,
511 "tester".to_string(),
512 "tester@example.com".to_string(),
513 );
514 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
515
516 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
518 let (entry_tx, entry_rx) = mpsc::channel(1024);
519 let mut encoder = PackEncoder::new(4, 10, pack_tx);
520
521 tokio::spawn(async move {
522 if let Err(e) = encoder.encode(entry_rx).await {
523 panic!("Failed to encode pack: {}", e);
524 }
525 });
526
527 let commit_clone = commit.clone();
528 let tree_clone = tree.clone();
529 let blob1_clone = blob1.clone();
530 let blob2_clone = blob2.clone();
531 tokio::spawn(async move {
532 let _ = entry_tx
533 .send(MetaAttached {
534 inner: Entry::from(commit_clone),
535 meta: EntryMeta::new(),
536 })
537 .await;
538 let _ = entry_tx
539 .send(MetaAttached {
540 inner: Entry::from(tree_clone),
541 meta: EntryMeta::new(),
542 })
543 .await;
544 let _ = entry_tx
545 .send(MetaAttached {
546 inner: Entry::from(blob1_clone),
547 meta: EntryMeta::new(),
548 })
549 .await;
550 let _ = entry_tx
551 .send(MetaAttached {
552 inner: Entry::from(blob2_clone),
553 meta: EntryMeta::new(),
554 })
555 .await;
556 });
558
559 let mut pack_bytes: Vec<u8> = Vec::new();
560 while let Some(chunk) = pack_rx.recv().await {
561 pack_bytes.extend_from_slice(&chunk);
562 }
563
564 let repo_access = TestRepoAccess::new();
566 let auth = TestAuth;
567 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
568 smart.command_list.push(RefCommand::new(
569 ZERO_ID.to_string(),
570 commit.id.to_string(),
571 "refs/heads/main".to_string(),
572 ));
573
574 let request_stream = Box::pin(futures::stream::once(async { Ok(Bytes::from(pack_bytes)) }));
576
577 let result_bytes = smart
579 .git_receive_pack_stream(request_stream)
580 .await
581 .expect("receive-pack should succeed");
582
583 let mut out = result_bytes.clone();
585 let (_c1, l1) = utils::read_pkt_line(&mut out);
586 assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
587
588 let (_c2, l2) = utils::read_pkt_line(&mut out);
589 assert_eq!(
590 String::from_utf8(l2.to_vec()).unwrap(),
591 "ok refs/heads/main"
592 );
593
594 let (c3, l3) = utils::read_pkt_line(&mut out);
595 assert_eq!(c3, 4);
596 assert!(l3.is_empty());
597
598 assert_eq!(repo_access.updates_len(), 1);
600 assert!(repo_access.post_hook_called());
601 }
602}