1use std::io::ErrorKind;
12use std::num::NonZeroU32;
13use std::path::PathBuf;
14use std::sync::Arc;
15
16use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
17use tracing::{debug, error};
18
19use crate::object_store::ObjectStore;
20use crate::url::{RemoteUrl, StorageEngine};
21
22pub mod backend;
23pub(crate) mod bundle_uri;
24pub(crate) mod capabilities;
25pub mod fetch;
26pub(crate) mod list;
27pub(crate) mod option;
28pub mod push;
29pub mod tracing_init;
30
31use self::fetch::{FetchedRefs, fetch_batch};
32use self::option::{OptionEffect, handle_option};
33use self::push::{PushOutcome, push_batch};
34use self::tracing_init::ReloadHandle;
35
36async fn write_push_outcomes<W>(
42 writer: &mut W,
43 outcomes: &[PushOutcome],
44) -> Result<(), std::io::Error>
45where
46 W: AsyncWrite + Unpin,
47{
48 for outcome in outcomes {
49 writer
50 .write_all(outcome.to_protocol_line().as_bytes())
51 .await?;
52 }
53 Ok(())
54}
55
56pub(crate) fn append_source_chain<E: std::error::Error + ?Sized>(msg: &mut String, err: &E) {
81 let mut next = err.source();
82 while let Some(src) = next {
83 let rendered = src.to_string();
87 if !msg.ends_with(&rendered) {
88 msg.push_str(": ");
89 msg.push_str(&rendered);
90 }
91 next = src.source();
92 }
93}
94
95#[derive(Debug, thiserror::Error)]
97pub enum ProtocolError {
98 #[error("protocol I/O error: {0}")]
100 Io(#[from] std::io::Error),
101
102 #[error("list failed: {0}")]
104 List(#[from] list::ListError),
105
106 #[error("fetch failed: {0}")]
108 Fetch(#[from] fetch::FetchError),
109
110 #[error("push failed: {0}")]
112 Push(#[from] push::PushError),
113
114 #[error("invalid command: {0:?}")]
116 InvalidCommand(String),
117
118 #[error("backend resolution failed: {0}")]
120 Backend(#[from] backend::BackendError),
121
122 #[error("bundle-uri failed: {0}")]
124 BundleUri(#[from] bundle_uri::BundleUriError),
125}
126
127impl ProtocolError {
128 #[must_use]
132 pub fn is_broken_pipe(&self) -> bool {
133 matches!(self, Self::Io(e)
134 if matches!(e.kind(), ErrorKind::BrokenPipe | ErrorKind::WriteZero))
135 }
136}
137
138#[derive(Debug, PartialEq, Eq)]
140enum Command {
141 Capabilities,
142 BundleUri,
143 List { for_push: bool },
144 Option(String),
145 Fetch(String),
146 Push(String),
147 Empty,
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
155enum Mode {
156 Fetch,
157 Push,
158}
159
160pub(crate) struct BatchCtx {
166 pub(crate) store: Arc<dyn ObjectStore>,
167 pub(crate) prefix: Option<Arc<str>>,
169 pub(crate) repo_dir: Arc<PathBuf>,
170}
171
172struct BatchState {
179 mode: Option<Mode>,
180 fetch_cmds: Vec<String>,
181 push_cmds: Vec<String>,
182}
183
184impl BatchState {
185 fn new() -> Self {
186 Self {
187 mode: None,
188 fetch_cmds: Vec::new(),
189 push_cmds: Vec::new(),
190 }
191 }
192
193 fn accumulate(&mut self, incoming: Mode, cmd: String) {
196 if self.mode != Some(incoming) {
197 match incoming {
198 Mode::Fetch => self.push_cmds.clear(),
199 Mode::Push => self.fetch_cmds.clear(),
200 }
201 self.mode = Some(incoming);
202 }
203 match incoming {
204 Mode::Fetch => {
205 debug_assert!(
210 self.push_cmds.is_empty(),
211 "push_cmds must be empty when accumulating a Fetch command",
212 );
213 self.fetch_cmds.push(cmd);
214 }
215 Mode::Push => {
216 debug_assert!(
217 self.fetch_cmds.is_empty(),
218 "fetch_cmds must be empty when accumulating a Push command",
219 );
220 self.push_cmds.push(cmd);
221 }
222 }
223 }
224
225 fn take_pending(&mut self) -> Option<(Mode, Vec<String>)> {
231 match self.mode {
232 Some(Mode::Fetch) if !self.fetch_cmds.is_empty() => {
233 self.mode = None;
234 Some((Mode::Fetch, std::mem::take(&mut self.fetch_cmds)))
235 }
236 Some(Mode::Push) if !self.push_cmds.is_empty() => {
237 self.mode = None;
238 Some((Mode::Push, std::mem::take(&mut self.push_cmds)))
239 }
240 _ => None,
241 }
242 }
243}
244
245fn parse_command(line: &str) -> Option<Command> {
246 let trimmed = line.trim_end_matches(['\r', '\n']);
247 if trimmed.is_empty() {
248 return Some(Command::Empty);
249 }
250 if trimmed == "capabilities" {
251 return Some(Command::Capabilities);
252 }
253 if trimmed == "bundle-uri" {
254 return Some(Command::BundleUri);
255 }
256 if trimmed == "list for-push" {
258 return Some(Command::List { for_push: true });
259 }
260 if trimmed == "list" {
261 return Some(Command::List { for_push: false });
262 }
263 if let Some(rest) = trimmed.strip_prefix("option ") {
264 return Some(Command::Option(rest.to_owned()));
265 }
266 if let Some(rest) = trimmed.strip_prefix("fetch ") {
267 return Some(Command::Fetch(rest.to_owned()));
268 }
269 if let Some(rest) = trimmed.strip_prefix("push ") {
270 return Some(Command::Push(rest.to_owned()));
271 }
272 None
273}
274
275struct FlushCtx<'a> {
279 batch_ctx: &'a BatchCtx,
280 remote: &'a RemoteUrl,
281 engine: StorageEngine,
282 zip: bool,
283 fetched_refs: &'a FetchedRefs,
284}
285
286async fn flush_batch<W>(
291 flush: &FlushCtx<'_>,
292 batch: &mut BatchState,
293 depth: &mut Option<NonZeroU32>,
294 writer: &mut W,
295) -> Result<(), ProtocolError>
296where
297 W: AsyncWrite + Unpin,
298{
299 if let Some((mode, cmds)) = batch.take_pending() {
300 match (mode, flush.engine) {
301 (Mode::Fetch, StorageEngine::Bundle) => {
302 fetch_batch(
303 flush.batch_ctx,
304 cmds,
305 flush.fetched_refs.clone(),
306 depth.take(),
307 )
308 .await?;
309 }
310 (Mode::Fetch, StorageEngine::Packchain) => {
311 crate::packchain::fetch::fetch_batch(
312 flush.batch_ctx,
313 cmds,
314 flush.fetched_refs.clone(),
315 depth.take(),
316 )
317 .await?;
318 }
319 (Mode::Push, StorageEngine::Bundle) => {
320 let outcomes = push_batch(
321 flush.batch_ctx,
322 flush.remote.kind(),
323 flush.zip,
324 flush.engine,
325 cmds,
326 )
327 .await?;
328 write_push_outcomes(writer, &outcomes).await?;
329 }
330 (Mode::Push, StorageEngine::Packchain) => {
331 let outcomes =
332 crate::packchain::push::push_batch(flush.batch_ctx, flush.engine, cmds).await?;
333 write_push_outcomes(writer, &outcomes).await?;
334 }
335 }
336 }
337 writer.write_all(b"\n").await?;
338 writer.flush().await?;
339 Ok(())
340}
341
342pub async fn run<R, W>(
360 remote: RemoteUrl,
361 store: Arc<dyn ObjectStore>,
362 engine: StorageEngine,
363 reader: R,
364 mut writer: W,
365 reload: Option<ReloadHandle>,
366 repo_dir: PathBuf,
367) -> Result<(), ProtocolError>
368where
369 R: AsyncBufRead + Unpin,
370 W: AsyncWrite + Unpin,
371{
372 let mut lines = reader.lines();
378 let fetched_refs = FetchedRefs::new();
379 let mut batch = BatchState::new();
380 let mut depth: Option<NonZeroU32> = None;
385 let zip = remote.flags().zip;
386 let advertise_bundle_uri =
393 matches!(engine, StorageEngine::Packchain) && remote.flags().bundle_uri;
394 let ctx = BatchCtx {
395 store,
396 prefix: remote.prefix().map(Arc::from),
397 repo_dir: Arc::new(repo_dir),
398 };
399 let flush = FlushCtx {
400 batch_ctx: &ctx,
401 remote: &remote,
402 engine,
403 zip,
404 fetched_refs: &fetched_refs,
405 };
406
407 while let Some(line) = lines.next_line().await? {
408 debug!(cmd = %line, "received protocol command");
409 let Some(cmd) = parse_command(&line) else {
410 error!(cmd = %line, "fatal: invalid command");
411 return Err(ProtocolError::InvalidCommand(line));
412 };
413 match cmd {
414 Command::Capabilities => {
415 capabilities::handle_capabilities(&mut writer, advertise_bundle_uri).await?;
416 }
417 Command::BundleUri => {
418 let opts = bundle_uri::BundleUriOpts {
419 presign_ttl_seconds: remote.flags().bundle_uri_presign_ttl,
420 };
421 bundle_uri::handle_bundle_uri(
422 ctx.store.as_ref(),
423 &remote,
424 opts,
425 advertise_bundle_uri,
426 &mut writer,
427 )
428 .await?;
429 }
430 Command::List { for_push } => {
431 list::handle_list(
432 ctx.store.as_ref(),
433 ctx.prefix.as_deref(),
434 engine,
435 for_push,
436 &mut writer,
437 )
438 .await?;
439 }
440 Command::Option(args) => {
441 let effect = handle_option(&args, reload.as_ref(), &mut writer).await?;
442 if let OptionEffect::SetDepth(d) = effect {
443 depth = Some(d);
444 }
445 }
446 Command::Fetch(args) => batch.accumulate(Mode::Fetch, args),
447 Command::Push(args) => batch.accumulate(Mode::Push, args),
448 Command::Empty => {
449 flush_batch(&flush, &mut batch, &mut depth, &mut writer).await?;
450 }
451 }
452 }
453 Ok(())
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn parse_command_recognises_each_form() {
462 assert_eq!(parse_command("capabilities\n"), Some(Command::Capabilities));
463 assert_eq!(
464 parse_command("list\n"),
465 Some(Command::List { for_push: false })
466 );
467 assert_eq!(
468 parse_command("list for-push\n"),
469 Some(Command::List { for_push: true })
470 );
471 assert_eq!(
472 parse_command("option verbosity 2\n"),
473 Some(Command::Option("verbosity 2".into()))
474 );
475 assert_eq!(
476 parse_command("fetch deadbeef refs/heads/main\n"),
477 Some(Command::Fetch("deadbeef refs/heads/main".into()))
478 );
479 assert_eq!(
480 parse_command("push refs/heads/main:refs/heads/main\n"),
481 Some(Command::Push("refs/heads/main:refs/heads/main".into()))
482 );
483 assert_eq!(parse_command("\n"), Some(Command::Empty));
484 }
485
486 #[test]
487 fn parse_command_handles_crlf() {
488 assert_eq!(
489 parse_command("list\r\n"),
490 Some(Command::List { for_push: false })
491 );
492 assert_eq!(parse_command("\r\n"), Some(Command::Empty));
493 }
494
495 #[test]
496 fn parse_command_rejects_garbage() {
497 assert_eq!(parse_command("nonsense\n"), None);
498 assert_eq!(parse_command(" \n"), None);
501 assert_eq!(parse_command("list for-push\n"), None);
506 assert_eq!(parse_command("list \n"), None);
508 }
509
510 #[test]
518 fn parse_command_passes_strip_prefix_args_verbatim() {
519 assert_eq!(
524 parse_command("fetch abc def\n"),
525 Some(Command::Fetch(" abc def".into())),
526 );
527 assert_eq!(
528 parse_command("push +ref:ref\n"),
529 Some(Command::Push(" +ref:ref".into())),
530 );
531 assert_eq!(
534 parse_command("fetch \n"),
535 Some(Command::Fetch(String::new()))
536 );
537 }
538
539 #[derive(Debug, thiserror::Error)]
545 #[error("layer: {0}")]
546 struct LayerError(#[source] crate::object_store::BoxError);
547
548 #[test]
549 fn append_source_chain_skips_levels_already_in_display() {
550 let inner: crate::object_store::BoxError = Box::new(std::io::Error::other("dns failure"));
555 let mid: crate::object_store::BoxError = Box::new(LayerError(inner));
556 let top = LayerError(mid);
557
558 let mut msg = top.to_string();
559 assert_eq!(msg, "layer: layer: dns failure");
562
563 append_source_chain(&mut msg, &top);
564 assert_eq!(
567 msg, "layer: layer: dns failure",
568 "append_source_chain must not duplicate already-inlined sources",
569 );
570 }
571
572 #[test]
573 fn append_source_chain_appends_when_source_text_is_not_in_display() {
574 #[derive(Debug, thiserror::Error)]
577 #[error("opaque wrapper")]
578 struct OpaqueWrapper(#[source] crate::object_store::BoxError);
579
580 let inner: crate::object_store::BoxError = Box::new(std::io::Error::other("dns failure"));
581 let top = OpaqueWrapper(inner);
582
583 let mut msg = top.to_string();
584 assert_eq!(msg, "opaque wrapper");
585 append_source_chain(&mut msg, &top);
586 assert_eq!(msg, "opaque wrapper: dns failure");
587 }
588
589 #[test]
590 fn is_broken_pipe_matches_kinds() {
591 let pipe = ProtocolError::Io(std::io::Error::from(ErrorKind::BrokenPipe));
592 assert!(pipe.is_broken_pipe());
593 let write_zero = ProtocolError::Io(std::io::Error::from(ErrorKind::WriteZero));
594 assert!(write_zero.is_broken_pipe());
595 let other = ProtocolError::Io(std::io::Error::from(ErrorKind::Other));
596 assert!(!other.is_broken_pipe());
597 let not_io = ProtocolError::InvalidCommand("bad".into());
598 assert!(!not_io.is_broken_pipe());
599 }
600
601 #[test]
604 fn batch_state_empty_take_returns_none() {
605 let mut batch = BatchState::new();
606 assert!(batch.take_pending().is_none());
607 }
608
609 #[test]
610 fn batch_state_accumulate_and_take_round_trip() {
611 let mut batch = BatchState::new();
612 batch.accumulate(Mode::Fetch, "a".to_owned());
613 batch.accumulate(Mode::Fetch, "b".to_owned());
614 let (mode, cmds) = batch.take_pending().expect("non-empty fetch batch");
615 assert_eq!(mode, Mode::Fetch);
616 assert_eq!(cmds, ["a", "b"]);
617 assert!(batch.take_pending().is_none());
619 }
620
621 #[test]
622 fn batch_state_mode_switch_clears_prior_cmds() {
623 let mut batch = BatchState::new();
624 batch.accumulate(Mode::Fetch, "fetch-cmd".to_owned());
626 batch.accumulate(Mode::Push, "push-cmd".to_owned());
627 let (mode, cmds) = batch.take_pending().expect("non-empty push batch");
629 assert_eq!(mode, Mode::Push);
630 assert_eq!(cmds, ["push-cmd"]);
631 assert!(batch.take_pending().is_none());
632 }
633
634 #[test]
635 fn batch_state_accumulate_with_no_cmds_after_mode_set_takes_none() {
636 let mut batch = BatchState::new();
641 batch.accumulate(Mode::Fetch, "only-cmd".to_owned());
642 batch.take_pending(); assert!(batch.take_pending().is_none());
645 }
646}