use std::io::ErrorKind;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
use tracing::{debug, error};
use crate::object_store::ObjectStore;
use crate::url::{RemoteUrl, StorageEngine};
pub mod backend;
pub(crate) mod bundle_uri;
pub(crate) mod capabilities;
pub mod fetch;
pub(crate) mod list;
pub(crate) mod option;
pub mod push;
pub mod tracing_init;
use self::fetch::{FetchedRefs, fetch_batch};
use self::option::{OptionEffect, handle_option};
use self::push::{PushOutcome, push_batch};
use self::tracing_init::ReloadHandle;
async fn write_push_outcomes<W>(
writer: &mut W,
outcomes: &[PushOutcome],
) -> Result<(), std::io::Error>
where
W: AsyncWrite + Unpin,
{
for outcome in outcomes {
writer
.write_all(outcome.to_protocol_line().as_bytes())
.await?;
}
Ok(())
}
pub(crate) fn append_source_chain<E: std::error::Error + ?Sized>(msg: &mut String, err: &E) {
let mut next = err.source();
while let Some(src) = next {
let rendered = src.to_string();
if !msg.ends_with(&rendered) {
msg.push_str(": ");
msg.push_str(&rendered);
}
next = src.source();
}
}
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("protocol I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("list failed: {0}")]
List(#[from] list::ListError),
#[error("fetch failed: {0}")]
Fetch(#[from] fetch::FetchError),
#[error("push failed: {0}")]
Push(#[from] push::PushError),
#[error("invalid command: {0:?}")]
InvalidCommand(String),
#[error("backend resolution failed: {0}")]
Backend(#[from] backend::BackendError),
#[error("bundle-uri failed: {0}")]
BundleUri(#[from] bundle_uri::BundleUriError),
}
impl ProtocolError {
#[must_use]
pub fn is_broken_pipe(&self) -> bool {
matches!(self, Self::Io(e)
if matches!(e.kind(), ErrorKind::BrokenPipe | ErrorKind::WriteZero))
}
}
#[derive(Debug, PartialEq, Eq)]
enum Command {
Capabilities,
BundleUri,
List { for_push: bool },
Option(String),
Fetch(String),
Push(String),
Empty,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Mode {
Fetch,
Push,
}
pub(crate) struct BatchCtx {
pub(crate) store: Arc<dyn ObjectStore>,
pub(crate) prefix: Option<Arc<str>>,
pub(crate) repo_dir: Arc<PathBuf>,
}
struct BatchState {
mode: Option<Mode>,
fetch_cmds: Vec<String>,
push_cmds: Vec<String>,
}
impl BatchState {
fn new() -> Self {
Self {
mode: None,
fetch_cmds: Vec::new(),
push_cmds: Vec::new(),
}
}
fn accumulate(&mut self, incoming: Mode, cmd: String) {
if self.mode != Some(incoming) {
match incoming {
Mode::Fetch => self.push_cmds.clear(),
Mode::Push => self.fetch_cmds.clear(),
}
self.mode = Some(incoming);
}
match incoming {
Mode::Fetch => {
debug_assert!(
self.push_cmds.is_empty(),
"push_cmds must be empty when accumulating a Fetch command",
);
self.fetch_cmds.push(cmd);
}
Mode::Push => {
debug_assert!(
self.fetch_cmds.is_empty(),
"fetch_cmds must be empty when accumulating a Push command",
);
self.push_cmds.push(cmd);
}
}
}
fn take_pending(&mut self) -> Option<(Mode, Vec<String>)> {
match self.mode {
Some(Mode::Fetch) if !self.fetch_cmds.is_empty() => {
self.mode = None;
Some((Mode::Fetch, std::mem::take(&mut self.fetch_cmds)))
}
Some(Mode::Push) if !self.push_cmds.is_empty() => {
self.mode = None;
Some((Mode::Push, std::mem::take(&mut self.push_cmds)))
}
_ => None,
}
}
}
fn parse_command(line: &str) -> Option<Command> {
let trimmed = line.trim_end_matches(['\r', '\n']);
if trimmed.is_empty() {
return Some(Command::Empty);
}
if trimmed == "capabilities" {
return Some(Command::Capabilities);
}
if trimmed == "bundle-uri" {
return Some(Command::BundleUri);
}
if trimmed == "list for-push" {
return Some(Command::List { for_push: true });
}
if trimmed == "list" {
return Some(Command::List { for_push: false });
}
if let Some(rest) = trimmed.strip_prefix("option ") {
return Some(Command::Option(rest.to_owned()));
}
if let Some(rest) = trimmed.strip_prefix("fetch ") {
return Some(Command::Fetch(rest.to_owned()));
}
if let Some(rest) = trimmed.strip_prefix("push ") {
return Some(Command::Push(rest.to_owned()));
}
None
}
struct FlushCtx<'a> {
batch_ctx: &'a BatchCtx,
remote: &'a RemoteUrl,
engine: StorageEngine,
zip: bool,
fetched_refs: &'a FetchedRefs,
}
async fn flush_batch<W>(
flush: &FlushCtx<'_>,
batch: &mut BatchState,
depth: &mut Option<NonZeroU32>,
writer: &mut W,
) -> Result<(), ProtocolError>
where
W: AsyncWrite + Unpin,
{
if let Some((mode, cmds)) = batch.take_pending() {
match (mode, flush.engine) {
(Mode::Fetch, StorageEngine::Bundle) => {
fetch_batch(
flush.batch_ctx,
cmds,
flush.fetched_refs.clone(),
depth.take(),
)
.await?;
}
(Mode::Fetch, StorageEngine::Packchain) => {
crate::packchain::fetch::fetch_batch(
flush.batch_ctx,
cmds,
flush.fetched_refs.clone(),
depth.take(),
)
.await?;
}
(Mode::Push, StorageEngine::Bundle) => {
let outcomes = push_batch(
flush.batch_ctx,
flush.remote.kind(),
flush.zip,
flush.engine,
cmds,
)
.await?;
write_push_outcomes(writer, &outcomes).await?;
}
(Mode::Push, StorageEngine::Packchain) => {
let outcomes =
crate::packchain::push::push_batch(flush.batch_ctx, flush.engine, cmds).await?;
write_push_outcomes(writer, &outcomes).await?;
}
}
}
writer.write_all(b"\n").await?;
writer.flush().await?;
Ok(())
}
pub async fn run<R, W>(
remote: RemoteUrl,
store: Arc<dyn ObjectStore>,
engine: StorageEngine,
reader: R,
mut writer: W,
reload: Option<ReloadHandle>,
repo_dir: PathBuf,
) -> Result<(), ProtocolError>
where
R: AsyncBufRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut lines = reader.lines();
let fetched_refs = FetchedRefs::new();
let mut batch = BatchState::new();
let mut depth: Option<NonZeroU32> = None;
let zip = remote.flags().zip;
let advertise_bundle_uri =
matches!(engine, StorageEngine::Packchain) && remote.flags().bundle_uri;
let ctx = BatchCtx {
store,
prefix: remote.prefix().map(Arc::from),
repo_dir: Arc::new(repo_dir),
};
let flush = FlushCtx {
batch_ctx: &ctx,
remote: &remote,
engine,
zip,
fetched_refs: &fetched_refs,
};
while let Some(line) = lines.next_line().await? {
debug!(cmd = %line, "received protocol command");
let Some(cmd) = parse_command(&line) else {
error!(cmd = %line, "fatal: invalid command");
return Err(ProtocolError::InvalidCommand(line));
};
match cmd {
Command::Capabilities => {
capabilities::handle_capabilities(&mut writer, advertise_bundle_uri).await?;
}
Command::BundleUri => {
let opts = bundle_uri::BundleUriOpts {
presign_ttl_seconds: remote.flags().bundle_uri_presign_ttl,
};
bundle_uri::handle_bundle_uri(
ctx.store.as_ref(),
&remote,
opts,
advertise_bundle_uri,
&mut writer,
)
.await?;
}
Command::List { for_push } => {
list::handle_list(
ctx.store.as_ref(),
ctx.prefix.as_deref(),
engine,
for_push,
&mut writer,
)
.await?;
}
Command::Option(args) => {
let effect = handle_option(&args, reload.as_ref(), &mut writer).await?;
if let OptionEffect::SetDepth(d) = effect {
depth = Some(d);
}
}
Command::Fetch(args) => batch.accumulate(Mode::Fetch, args),
Command::Push(args) => batch.accumulate(Mode::Push, args),
Command::Empty => {
flush_batch(&flush, &mut batch, &mut depth, &mut writer).await?;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_command_recognises_each_form() {
assert_eq!(parse_command("capabilities\n"), Some(Command::Capabilities));
assert_eq!(
parse_command("list\n"),
Some(Command::List { for_push: false })
);
assert_eq!(
parse_command("list for-push\n"),
Some(Command::List { for_push: true })
);
assert_eq!(
parse_command("option verbosity 2\n"),
Some(Command::Option("verbosity 2".into()))
);
assert_eq!(
parse_command("fetch deadbeef refs/heads/main\n"),
Some(Command::Fetch("deadbeef refs/heads/main".into()))
);
assert_eq!(
parse_command("push refs/heads/main:refs/heads/main\n"),
Some(Command::Push("refs/heads/main:refs/heads/main".into()))
);
assert_eq!(parse_command("\n"), Some(Command::Empty));
}
#[test]
fn parse_command_handles_crlf() {
assert_eq!(
parse_command("list\r\n"),
Some(Command::List { for_push: false })
);
assert_eq!(parse_command("\r\n"), Some(Command::Empty));
}
#[test]
fn parse_command_rejects_garbage() {
assert_eq!(parse_command("nonsense\n"), None);
assert_eq!(parse_command(" \n"), None);
assert_eq!(parse_command("list for-push\n"), None);
assert_eq!(parse_command("list \n"), None);
}
#[test]
fn parse_command_passes_strip_prefix_args_verbatim() {
assert_eq!(
parse_command("fetch abc def\n"),
Some(Command::Fetch(" abc def".into())),
);
assert_eq!(
parse_command("push +ref:ref\n"),
Some(Command::Push(" +ref:ref".into())),
);
assert_eq!(
parse_command("fetch \n"),
Some(Command::Fetch(String::new()))
);
}
#[derive(Debug, thiserror::Error)]
#[error("layer: {0}")]
struct LayerError(#[source] crate::object_store::BoxError);
#[test]
fn append_source_chain_skips_levels_already_in_display() {
let inner: crate::object_store::BoxError = Box::new(std::io::Error::other("dns failure"));
let mid: crate::object_store::BoxError = Box::new(LayerError(inner));
let top = LayerError(mid);
let mut msg = top.to_string();
assert_eq!(msg, "layer: layer: dns failure");
append_source_chain(&mut msg, &top);
assert_eq!(
msg, "layer: layer: dns failure",
"append_source_chain must not duplicate already-inlined sources",
);
}
#[test]
fn append_source_chain_appends_when_source_text_is_not_in_display() {
#[derive(Debug, thiserror::Error)]
#[error("opaque wrapper")]
struct OpaqueWrapper(#[source] crate::object_store::BoxError);
let inner: crate::object_store::BoxError = Box::new(std::io::Error::other("dns failure"));
let top = OpaqueWrapper(inner);
let mut msg = top.to_string();
assert_eq!(msg, "opaque wrapper");
append_source_chain(&mut msg, &top);
assert_eq!(msg, "opaque wrapper: dns failure");
}
#[test]
fn is_broken_pipe_matches_kinds() {
let pipe = ProtocolError::Io(std::io::Error::from(ErrorKind::BrokenPipe));
assert!(pipe.is_broken_pipe());
let write_zero = ProtocolError::Io(std::io::Error::from(ErrorKind::WriteZero));
assert!(write_zero.is_broken_pipe());
let other = ProtocolError::Io(std::io::Error::from(ErrorKind::Other));
assert!(!other.is_broken_pipe());
let not_io = ProtocolError::InvalidCommand("bad".into());
assert!(!not_io.is_broken_pipe());
}
#[test]
fn batch_state_empty_take_returns_none() {
let mut batch = BatchState::new();
assert!(batch.take_pending().is_none());
}
#[test]
fn batch_state_accumulate_and_take_round_trip() {
let mut batch = BatchState::new();
batch.accumulate(Mode::Fetch, "a".to_owned());
batch.accumulate(Mode::Fetch, "b".to_owned());
let (mode, cmds) = batch.take_pending().expect("non-empty fetch batch");
assert_eq!(mode, Mode::Fetch);
assert_eq!(cmds, ["a", "b"]);
assert!(batch.take_pending().is_none());
}
#[test]
fn batch_state_mode_switch_clears_prior_cmds() {
let mut batch = BatchState::new();
batch.accumulate(Mode::Fetch, "fetch-cmd".to_owned());
batch.accumulate(Mode::Push, "push-cmd".to_owned());
let (mode, cmds) = batch.take_pending().expect("non-empty push batch");
assert_eq!(mode, Mode::Push);
assert_eq!(cmds, ["push-cmd"]);
assert!(batch.take_pending().is_none());
}
#[test]
fn batch_state_accumulate_with_no_cmds_after_mode_set_takes_none() {
let mut batch = BatchState::new();
batch.accumulate(Mode::Fetch, "only-cmd".to_owned());
batch.take_pending(); assert!(batch.take_pending().is_none());
}
}