1use bytes::{BufMut, Bytes, BytesMut};
2use std::collections::HashMap;
3use tokio_stream::wrappers::ReceiverStream;
4
5use super::core::{AuthenticationService, RepositoryAccess};
6use super::pack::PackGenerator;
7use super::types::ProtocolError;
8use super::types::{
9 COMMON_CAP_LIST, Capability, LF, NUL, PKT_LINE_END_MARKER, ProtocolStream, RECEIVE_CAP_LIST,
10 RefCommand, RefTypeEnum, SP, ServiceType, SideBand, TransportProtocol, UPLOAD_CAP_LIST,
11 ZERO_ID,
12};
13use super::utils::{add_pkt_line_string, build_smart_reply, read_pkt_line, read_until_white_space};
14
15pub struct SmartProtocol<R, A>
20where
21 R: RepositoryAccess,
22 A: AuthenticationService,
23{
24 pub transport_protocol: TransportProtocol,
25 pub capabilities: Vec<Capability>,
26 pub side_band: Option<SideBand>,
27 pub command_list: Vec<RefCommand>,
28
29 repo_storage: R,
31 auth_service: A,
32}
33
34impl<R, A> SmartProtocol<R, A>
35where
36 R: RepositoryAccess,
37 A: AuthenticationService,
38{
39 pub fn new(transport_protocol: TransportProtocol, repo_storage: R, auth_service: A) -> Self {
41 Self {
42 transport_protocol,
43 capabilities: Vec::new(),
44 side_band: None,
45 command_list: Vec::new(),
46 repo_storage,
47 auth_service,
48 }
49 }
50
51 pub async fn authenticate_http(
53 &self,
54 headers: &HashMap<String, String>,
55 ) -> Result<(), ProtocolError> {
56 self.auth_service.authenticate_http(headers).await
57 }
58
59 pub async fn authenticate_ssh(
61 &self,
62 username: &str,
63 public_key: &[u8],
64 ) -> Result<(), ProtocolError> {
65 self.auth_service
66 .authenticate_ssh(username, public_key)
67 .await
68 }
69
70 pub fn set_transport_protocol(&mut self, protocol: TransportProtocol) {
72 self.transport_protocol = protocol;
73 }
74
75 pub async fn git_info_refs(
77 &self,
78 service_type: ServiceType,
79 ) -> Result<BytesMut, ProtocolError> {
80 let refs =
81 self.repo_storage.get_repository_refs().await.map_err(|e| {
82 ProtocolError::repository_error(format!("Failed to get refs: {}", e))
83 })?;
84
85 let head_hash = refs
87 .iter()
88 .find(|(name, _)| {
89 name == "HEAD" || name.ends_with("/main") || name.ends_with("/master")
90 })
91 .map(|(_, hash)| hash.clone())
92 .unwrap_or_else(|| "0000000000000000000000000000000000000000".to_string());
93
94 let git_refs: Vec<super::types::GitRef> = refs
95 .into_iter()
96 .map(|(name, hash)| super::types::GitRef { name, hash })
97 .collect();
98
99 let cap_list = match service_type {
101 ServiceType::UploadPack => format!("{UPLOAD_CAP_LIST}{COMMON_CAP_LIST}"),
102 ServiceType::ReceivePack => format!("{RECEIVE_CAP_LIST}{COMMON_CAP_LIST}"),
103 };
104
105 let name = if head_hash == ZERO_ID {
107 "capabilities^{}"
108 } else {
109 "HEAD"
110 };
111 let pkt_line = format!("{head_hash}{SP}{name}{NUL}{cap_list}{LF}");
112 let mut ref_list = vec![pkt_line];
113
114 for git_ref in git_refs {
115 let pkt_line = format!("{}{}{}{}", git_ref.hash, SP, git_ref.name, LF);
116 ref_list.push(pkt_line);
117 }
118
119 let pkt_line_stream =
120 build_smart_reply(self.transport_protocol, &ref_list, service_type.to_string());
121 tracing::debug!("git_info_refs, return: --------> {:?}", pkt_line_stream);
122 Ok(pkt_line_stream)
123 }
124
125 pub async fn git_upload_pack(
127 &mut self,
128 upload_request: Bytes,
129 ) -> Result<(ReceiverStream<Vec<u8>>, BytesMut), ProtocolError> {
130 let mut upload_request = upload_request;
131 let mut want: Vec<String> = Vec::new();
132 let mut have: Vec<String> = Vec::new();
133 let mut last_common_commit = String::new();
134
135 let mut read_first_line = false;
136 loop {
137 let (bytes_take, pkt_line) = read_pkt_line(&mut upload_request);
138
139 if bytes_take == 0 {
140 break;
141 }
142
143 if pkt_line.is_empty() {
144 break;
145 }
146
147 let mut pkt_line = pkt_line;
148 let command = read_until_white_space(&mut pkt_line);
149
150 match command.as_str() {
151 "want" => {
152 let hash = read_until_white_space(&mut pkt_line);
153 want.push(hash);
154 if !read_first_line {
155 let cap_str = String::from_utf8_lossy(&pkt_line).to_string();
156 self.parse_capabilities(&cap_str);
157 read_first_line = true;
158 }
159 }
160 "have" => {
161 let hash = read_until_white_space(&mut pkt_line);
162 have.push(hash);
163 }
164 "done" => {
165 break;
166 }
167 _ => {
168 tracing::warn!("Unknown upload-pack command: {}", command);
169 }
170 }
171 }
172
173 let mut protocol_buf = BytesMut::new();
174
175 let pack_generator = PackGenerator::new(&self.repo_storage);
177
178 if have.is_empty() {
179 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
181 let pack_stream = pack_generator.generate_full_pack(want).await?;
182 return Ok((pack_stream, protocol_buf));
183 }
184
185 for hash in &have {
187 let exists = self.repo_storage.commit_exists(hash).await.map_err(|e| {
188 ProtocolError::repository_error(format!("Failed to check commit existence: {}", e))
189 })?;
190 if exists {
191 add_pkt_line_string(&mut protocol_buf, format!("ACK {hash} common\n"));
192 if last_common_commit.is_empty() {
193 last_common_commit = hash.clone();
194 }
195 }
196 }
197
198 if last_common_commit.is_empty() {
199 add_pkt_line_string(&mut protocol_buf, String::from("NAK\n"));
201 let pack_stream = pack_generator.generate_full_pack(want).await?;
202 return Ok((pack_stream, protocol_buf));
203 }
204
205 add_pkt_line_string(
207 &mut protocol_buf,
208 format!("ACK {last_common_commit} ready\n"),
209 );
210 protocol_buf.put(&PKT_LINE_END_MARKER[..]);
211
212 add_pkt_line_string(&mut protocol_buf, format!("ACK {last_common_commit} \n"));
213
214 let pack_stream = pack_generator.generate_incremental_pack(want, have).await?;
215
216 Ok((pack_stream, protocol_buf))
217 }
218
219 pub fn parse_receive_pack_commands(&mut self, mut protocol_bytes: Bytes) {
221 loop {
222 let (bytes_take, pkt_line) = read_pkt_line(&mut protocol_bytes);
223
224 if bytes_take == 0 {
225 break;
226 }
227
228 if pkt_line.is_empty() {
229 break;
230 }
231
232 let ref_command = self.parse_ref_command(&mut pkt_line.clone());
233 self.command_list.push(ref_command);
234 }
235 }
236
237 pub async fn git_receive_pack_stream(
239 &mut self,
240 data_stream: ProtocolStream,
241 ) -> Result<Bytes, ProtocolError> {
242 let mut pack_data = BytesMut::new();
244 let mut stream = data_stream;
245
246 while let Some(chunk_result) = futures::StreamExt::next(&mut stream).await {
247 let chunk = chunk_result
248 .map_err(|e| ProtocolError::invalid_request(&format!("Stream error: {}", e)))?;
249 pack_data.extend_from_slice(&chunk);
250 }
251
252 let pack_generator = PackGenerator::new(&self.repo_storage);
254
255 let (commits, trees, blobs) = pack_generator.unpack_stream(pack_data.freeze()).await?;
257
258 self.repo_storage
260 .handle_pack_objects(commits, trees, blobs)
261 .await
262 .map_err(|e| {
263 ProtocolError::repository_error(format!("Failed to store pack objects: {}", e))
264 })?;
265
266 let mut report_status = BytesMut::new();
268 add_pkt_line_string(&mut report_status, "unpack ok\n".to_owned());
269
270 let mut default_exist = self.repo_storage.has_default_branch().await.map_err(|e| {
271 ProtocolError::repository_error(format!("Failed to check default branch: {}", e))
272 })?;
273
274 for command in &mut self.command_list {
276 if command.ref_type == RefTypeEnum::Tag {
277 let old_hash = if command.old_hash == ZERO_ID {
280 None
281 } else {
282 Some(command.old_hash.as_str())
283 };
284 if let Err(e) = self
285 .repo_storage
286 .update_reference(&command.ref_name, old_hash, &command.new_hash)
287 .await
288 {
289 command.failed(e.to_string());
290 }
291 } else {
292 if !default_exist {
294 command.default_branch = true;
295 default_exist = true;
296 }
297 let old_hash = if command.old_hash == ZERO_ID {
299 None
300 } else {
301 Some(command.old_hash.as_str())
302 };
303 if let Err(e) = self
304 .repo_storage
305 .update_reference(&command.ref_name, old_hash, &command.new_hash)
306 .await
307 {
308 command.failed(e.to_string());
309 }
310 }
311 add_pkt_line_string(&mut report_status, command.get_status());
312 }
313
314 self.repo_storage.post_receive_hook().await.map_err(|e| {
316 ProtocolError::repository_error(format!("Post-receive hook failed: {}", e))
317 })?;
318
319 report_status.put(&PKT_LINE_END_MARKER[..]);
320 Ok(report_status.freeze())
321 }
322
323 pub fn build_side_band_format(&self, from_bytes: BytesMut, length: usize) -> BytesMut {
325 let mut to_bytes = BytesMut::new();
326 if self.capabilities.contains(&Capability::SideBand)
327 || self.capabilities.contains(&Capability::SideBand64k)
328 {
329 let length = length + 5;
330 to_bytes.put(Bytes::from(format!("{length:04x}")));
331 to_bytes.put_u8(SideBand::PackfileData.value());
332 to_bytes.put(from_bytes);
333 } else {
334 to_bytes.put(from_bytes);
335 }
336 to_bytes
337 }
338
339 pub fn parse_capabilities(&mut self, cap_str: &str) {
341 for cap in cap_str.split_whitespace() {
342 if let Ok(capability) = cap.parse::<Capability>() {
343 self.capabilities.push(capability);
344 }
345 }
346 }
347
348 pub fn parse_ref_command(&self, pkt_line: &mut Bytes) -> RefCommand {
350 let old_id = read_until_white_space(pkt_line);
351 let new_id = read_until_white_space(pkt_line);
352 let ref_name = read_until_white_space(pkt_line);
353 let _capabilities = String::from_utf8_lossy(&pkt_line[..]).to_string();
354
355 RefCommand::new(old_id, new_id, ref_name)
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::internal::object::blob::Blob;
363 use crate::internal::object::commit::Commit;
364 use crate::internal::object::signature::{Signature, SignatureType};
365 use crate::internal::object::tree::{Tree, TreeItem, TreeItemMode};
366 use crate::internal::pack::{encode::PackEncoder, entry::Entry};
367 use crate::protocol::types::{RefCommand, ZERO_ID}; use crate::protocol::utils; use async_trait::async_trait;
370 use bytes::Bytes;
371 use futures;
372 use std::sync::{
373 Arc, Mutex,
374 atomic::{AtomicBool, Ordering},
375 };
376 use tokio::sync::mpsc;
377
378 type UpdateRecord = (String, Option<String>, String);
380 type UpdateList = Vec<UpdateRecord>;
381 type SharedUpdates = Arc<Mutex<UpdateList>>;
382
383 #[derive(Clone)]
384 struct TestRepoAccess {
385 updates: SharedUpdates,
386 stored_count: Arc<Mutex<usize>>,
387 default_branch_exists: Arc<Mutex<bool>>,
388 post_called: Arc<AtomicBool>,
389 }
390
391 impl TestRepoAccess {
392 fn new() -> Self {
393 Self {
394 updates: Arc::new(Mutex::new(vec![])),
395 stored_count: Arc::new(Mutex::new(0)),
396 default_branch_exists: Arc::new(Mutex::new(false)),
397 post_called: Arc::new(AtomicBool::new(false)),
398 }
399 }
400
401 fn updates_len(&self) -> usize {
402 self.updates.lock().unwrap().len()
403 }
404
405 fn post_hook_called(&self) -> bool {
406 self.post_called.load(Ordering::SeqCst)
407 }
408 }
409
410 #[async_trait]
411 impl RepositoryAccess for TestRepoAccess {
412 async fn get_repository_refs(&self) -> Result<Vec<(String, String)>, ProtocolError> {
413 Ok(vec![
414 (
415 "HEAD".to_string(),
416 "0000000000000000000000000000000000000000".to_string(),
417 ),
418 (
419 "refs/heads/main".to_string(),
420 "1111111111111111111111111111111111111111".to_string(),
421 ),
422 ])
423 }
424
425 async fn has_object(&self, _object_hash: &str) -> Result<bool, ProtocolError> {
426 Ok(true)
427 }
428
429 async fn get_object(&self, _object_hash: &str) -> Result<Vec<u8>, ProtocolError> {
430 Ok(vec![])
431 }
432
433 async fn store_pack_data(&self, _pack_data: &[u8]) -> Result<(), ProtocolError> {
434 *self.stored_count.lock().unwrap() += 1;
435 Ok(())
436 }
437
438 async fn update_reference(
439 &self,
440 ref_name: &str,
441 old_hash: Option<&str>,
442 new_hash: &str,
443 ) -> Result<(), ProtocolError> {
444 self.updates.lock().unwrap().push((
445 ref_name.to_string(),
446 old_hash.map(|s| s.to_string()),
447 new_hash.to_string(),
448 ));
449 Ok(())
450 }
451
452 async fn get_objects_for_pack(
453 &self,
454 _wants: &[String],
455 _haves: &[String],
456 ) -> Result<Vec<String>, ProtocolError> {
457 Ok(vec![])
458 }
459
460 async fn has_default_branch(&self) -> Result<bool, ProtocolError> {
461 let mut exists = self.default_branch_exists.lock().unwrap();
462 let current = *exists;
463 *exists = true; Ok(current)
465 }
466
467 async fn post_receive_hook(&self) -> Result<(), ProtocolError> {
468 self.post_called.store(true, Ordering::SeqCst);
469 Ok(())
470 }
471 }
472
473 struct TestAuth;
474
475 #[async_trait]
476 impl AuthenticationService for TestAuth {
477 async fn authenticate_http(
478 &self,
479 _headers: &std::collections::HashMap<String, String>,
480 ) -> Result<(), ProtocolError> {
481 Ok(())
482 }
483
484 async fn authenticate_ssh(
485 &self,
486 _username: &str,
487 _public_key: &[u8],
488 ) -> Result<(), ProtocolError> {
489 Ok(())
490 }
491 }
492
493 #[tokio::test]
494 async fn test_receive_pack_stream_status_report() {
495 let blob1 = Blob::from_content("hello");
497 let blob2 = Blob::from_content("world");
498
499 let item1 = TreeItem::new(TreeItemMode::Blob, blob1.id, "hello.txt".to_string());
500 let item2 = TreeItem::new(TreeItemMode::Blob, blob2.id, "world.txt".to_string());
501 let tree = Tree::from_tree_items(vec![item1, item2]).unwrap();
502
503 let author = Signature::new(
504 SignatureType::Author,
505 "tester".to_string(),
506 "tester@example.com".to_string(),
507 );
508 let committer = Signature::new(
509 SignatureType::Committer,
510 "tester".to_string(),
511 "tester@example.com".to_string(),
512 );
513 let commit = Commit::new(author, committer, tree.id, vec![], "init commit");
514
515 let (pack_tx, mut pack_rx) = mpsc::channel(1024);
517 let (entry_tx, entry_rx) = mpsc::channel(1024);
518 let mut encoder = PackEncoder::new(4, 10, pack_tx);
519
520 tokio::spawn(async move {
521 if let Err(e) = encoder.encode(entry_rx).await {
522 panic!("Failed to encode pack: {}", e);
523 }
524 });
525
526 let commit_clone = commit.clone();
527 let tree_clone = tree.clone();
528 let blob1_clone = blob1.clone();
529 let blob2_clone = blob2.clone();
530 tokio::spawn(async move {
531 let _ = entry_tx.send(Entry::from(commit_clone)).await;
532 let _ = entry_tx.send(Entry::from(tree_clone)).await;
533 let _ = entry_tx.send(Entry::from(blob1_clone)).await;
534 let _ = entry_tx.send(Entry::from(blob2_clone)).await;
535 });
537
538 let mut pack_bytes: Vec<u8> = Vec::new();
539 while let Some(chunk) = pack_rx.recv().await {
540 pack_bytes.extend_from_slice(&chunk);
541 }
542
543 let repo_access = TestRepoAccess::new();
545 let auth = TestAuth;
546 let mut smart = SmartProtocol::new(TransportProtocol::Http, repo_access.clone(), auth);
547 smart.command_list.push(RefCommand::new(
548 ZERO_ID.to_string(),
549 commit.id.to_string(),
550 "refs/heads/main".to_string(),
551 ));
552
553 let request_stream = Box::pin(futures::stream::once(async { Ok(Bytes::from(pack_bytes)) }));
555
556 let result_bytes = smart
558 .git_receive_pack_stream(request_stream)
559 .await
560 .expect("receive-pack should succeed");
561
562 let mut out = result_bytes.clone();
564 let (_c1, l1) = utils::read_pkt_line(&mut out);
565 assert_eq!(String::from_utf8(l1.to_vec()).unwrap(), "unpack ok\n");
566
567 let (_c2, l2) = utils::read_pkt_line(&mut out);
568 assert_eq!(
569 String::from_utf8(l2.to_vec()).unwrap(),
570 "ok refs/heads/main"
571 );
572
573 let (c3, l3) = utils::read_pkt_line(&mut out);
574 assert_eq!(c3, 4);
575 assert!(l3.is_empty());
576
577 assert_eq!(repo_access.updates_len(), 1);
579 assert!(repo_access.post_hook_called());
580 }
581}