use std::fmt;
use std::future::Future;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc;
use tracing::{debug, warn};
use crate::keys;
use crate::lfs::oid::LfsOid;
use crate::lfs::protocol::{CompleteEvent, ErrorPayload, ProgressEvent};
use crate::object_store::{GetOpts, ObjectStore, ObjectStoreError, ProgressSink, PutOpts};
pub(crate) const ERR_CODE_GENERIC: u32 = 2;
pub(crate) const ERR_CODE_INIT: u32 = 32;
pub(crate) struct Agent {
store: Arc<dyn ObjectStore>,
prefix: String,
tmp_dir: PathBuf,
}
#[derive(Debug, Error)]
pub enum AgentError {
#[error("LFS protocol I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("LFS event serialization failed: {0}")]
Serialize(#[from] serde_json::Error),
}
impl Agent {
pub(crate) fn new(
store: Arc<dyn ObjectStore>,
prefix: Option<String>,
tmp_dir: PathBuf,
) -> Self {
Self {
store,
prefix: prefix.unwrap_or_default(),
tmp_dir,
}
}
fn key(&self, oid: &LfsOid) -> String {
keys::join(Some(&self.prefix), &format!("lfs/{oid}"))
}
pub(crate) async fn upload<W: AsyncWrite + Unpin>(
&self,
oid: &LfsOid,
_size: u64,
path: &Path,
writer: &mut W,
) -> Result<(), AgentError> {
let oid_for_progress = oid.as_str().to_owned();
let result = with_progress_stream(writer, oid_for_progress, |sink| async move {
self.try_upload(oid, path, sink).await
})
.await?;
match result {
Ok(()) => write_complete(writer, oid.as_str(), None, None).await,
Err(OpError { message }) => {
write_complete(writer, oid.as_str(), None, Some(&message)).await
}
}
}
pub(crate) async fn download<W: AsyncWrite + Unpin>(
&self,
oid: &LfsOid,
_size: u64,
writer: &mut W,
) -> Result<(), AgentError> {
let oid_for_progress = oid.as_str().to_owned();
let result = with_progress_stream(writer, oid_for_progress, |sink| async move {
self.try_download(oid, sink).await
})
.await?;
match result {
Ok(dest_str) => write_complete(writer, oid.as_str(), Some(&dest_str), None).await,
Err(OpError { message }) => {
write_complete(writer, oid.as_str(), None, Some(&message)).await
}
}
}
async fn try_upload(
&self,
oid: &LfsOid,
path: &Path,
progress: ProgressSink,
) -> Result<(), OpError> {
let key = self.key(oid);
debug!(oid = %oid, key = %key, "lfs upload");
match self.store.head(&key).await {
Ok(_) => {
debug!(oid = %oid, "object already present; skipping upload");
return Ok(());
}
Err(ObjectStoreError::NotFound(_)) => {}
Err(e) => {
warn!(oid = %oid, error = %e, "head failed during upload");
return Err(OpError::with_cause(&e));
}
}
let opts = PutOpts {
progress: Some(progress),
..PutOpts::default()
};
self.store.put_path(&key, path, opts).await.map_err(|e| {
warn!(oid = %oid, error = %e, "upload failed");
OpError::with_cause(&e)
})?;
Ok(())
}
async fn try_download(&self, oid: &LfsOid, progress: ProgressSink) -> Result<String, OpError> {
let key = self.key(oid);
let dest = self.tmp_dir.join(oid.as_str());
debug!(oid = %oid, key = %key, dest = %dest.display(), "lfs download");
if let Some(parent) = dest.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
warn!(oid = %oid, error = %e, "create_dir_all failed");
OpError::with_cause(&e)
})?;
}
let opts = GetOpts {
progress: Some(progress),
};
self.store
.get_to_file(&key, &dest, opts)
.await
.map_err(|e| {
warn!(oid = %oid, error = %e, "download failed");
OpError::with_cause(&e)
})?;
let dest_str = dest
.to_str()
.map(str::to_owned)
.ok_or_else(|| OpError::with_cause(&"download destination is not valid UTF-8"))?;
Ok(dest_str)
}
}
async fn with_progress_stream<W, F, Fut, R>(
writer: &mut W,
oid_raw: String,
op: F,
) -> Result<R, AgentError>
where
W: AsyncWrite + Unpin,
F: FnOnce(ProgressSink) -> Fut,
Fut: Future<Output = R>,
{
let (tx, mut rx) = mpsc::unbounded_channel::<u64>();
let sink = ProgressSink::new(move |amount| {
let _ = tx.send(amount);
});
let op_fut = op(sink);
tokio::pin!(op_fut);
let mut bytes_so_far: u64 = 0;
let result = loop {
tokio::select! {
biased;
Some(amount) = rx.recv() => {
if amount == 0 {
continue;
}
bytes_so_far = bytes_so_far.saturating_add(amount);
write_progress(writer, oid_raw.as_str(), bytes_so_far, amount).await?;
}
done = &mut op_fut => break done,
}
};
while let Ok(amount) = rx.try_recv() {
if amount == 0 {
continue;
}
bytes_so_far = bytes_so_far.saturating_add(amount);
write_progress(writer, oid_raw.as_str(), bytes_so_far, amount).await?;
}
Ok(result)
}
struct OpError {
message: String,
}
impl OpError {
fn with_cause(cause: &dyn fmt::Display) -> Self {
Self {
message: cause.to_string(),
}
}
}
pub(crate) async fn write_event<W, E>(writer: &mut W, evt: &E) -> Result<(), AgentError>
where
W: AsyncWrite + Unpin,
E: serde::Serialize,
{
let line = serde_json::to_string(evt)?;
writer.write_all(line.as_bytes()).await?;
writer.write_all(b"\n").await?;
writer.flush().await?;
Ok(())
}
async fn write_progress<W: AsyncWrite + Unpin>(
writer: &mut W,
oid: &str,
bytes_so_far: u64,
bytes_since_last: u64,
) -> Result<(), AgentError> {
write_event(
writer,
&ProgressEvent {
event: "progress",
oid,
bytes_so_far,
bytes_since_last,
},
)
.await
}
async fn write_complete<W: AsyncWrite + Unpin>(
writer: &mut W,
oid: &str,
path: Option<&str>,
error_message: Option<&str>,
) -> Result<(), AgentError> {
write_event(
writer,
&CompleteEvent {
event: "complete",
oid,
path,
error: error_message.map(|message| ErrorPayload {
code: ERR_CODE_GENERIC,
message,
}),
},
)
.await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::object_store::mock::MockStore;
use bytes::Bytes;
use std::str::FromStr;
use tempfile::TempDir;
fn good_oid() -> LfsOid {
LfsOid::from_str("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
.expect("hard-coded oid is valid")
}
fn agent(store: MockStore, prefix: Option<&str>, tmp: &TempDir) -> Agent {
Agent::new(
Arc::new(store),
prefix.map(str::to_owned),
tmp.path().to_owned(),
)
}
#[tokio::test]
async fn upload_skips_when_present() {
let store = MockStore::new();
let oid = good_oid();
store.insert(format!("repo/lfs/{oid}"), Bytes::from_static(b"hello"));
let tmp = TempDir::new().unwrap();
let a = agent(store.clone(), Some("repo"), &tmp);
let src = tmp.path().join("body");
tokio::fs::write(&src, b"hello").await.unwrap();
let mut out = Vec::new();
a.upload(&oid, 5, &src, &mut out).await.expect("upload");
let got = String::from_utf8(out).unwrap();
assert_eq!(
got,
format!("{{\"event\":\"complete\",\"oid\":\"{oid}\"}}\n")
);
}
#[tokio::test]
async fn upload_streams_when_absent_and_emits_progress_then_complete() {
let store = MockStore::new();
let tmp = TempDir::new().unwrap();
let oid = good_oid();
let a = agent(store.clone(), Some("repo"), &tmp);
let src = tmp.path().join("body");
let body = b"the quick brown fox";
tokio::fs::write(&src, body).await.unwrap();
let mut out = Vec::new();
a.upload(&oid, body.len() as u64, &src, &mut out)
.await
.expect("upload");
let got = String::from_utf8(out).unwrap();
let lines: Vec<&str> = got.lines().collect();
assert_eq!(lines.len(), 2, "expected progress + complete: {got}");
assert!(lines[0].contains("\"event\":\"progress\""));
assert!(lines[0].contains(&format!("\"oid\":\"{oid}\"")));
assert!(lines[0].contains(&format!("\"bytesSoFar\":{}", body.len())));
assert_eq!(
lines[1],
format!("{{\"event\":\"complete\",\"oid\":\"{oid}\"}}")
);
assert!(store.contains(&format!("repo/lfs/{oid}")));
}
#[tokio::test]
async fn download_writes_file_and_emits_progress_then_complete() {
let store = MockStore::new();
let oid = good_oid();
let body = b"payload bytes";
store.insert(format!("repo/lfs/{oid}"), Bytes::from_static(body));
let tmp = TempDir::new().unwrap();
let a = agent(store, Some("repo"), &tmp);
let mut out = Vec::new();
a.download(&oid, body.len() as u64, &mut out)
.await
.expect("download");
let got = String::from_utf8(out).unwrap();
let lines: Vec<&str> = got.lines().collect();
assert_eq!(lines.len(), 2, "expected progress + complete: {got}");
assert!(lines[0].contains("\"event\":\"progress\""));
let dest = tmp.path().join(oid.as_str());
let dest_str = dest.to_str().unwrap();
assert!(
lines[1].contains(&format!("\"path\":\"{dest_str}\"")),
"complete should include path: {got}"
);
let read = tokio::fs::read(&dest).await.unwrap();
assert_eq!(read, body);
}
#[tokio::test]
async fn download_emits_error_on_missing_object() {
let store = MockStore::new();
let oid = good_oid();
let tmp = TempDir::new().unwrap();
let a = agent(store, Some("repo"), &tmp);
let mut out = Vec::new();
a.download(&oid, 0, &mut out).await.expect("dispatch ok");
let got = String::from_utf8(out).unwrap();
assert!(got.contains("\"error\""));
assert!(got.contains(&format!("\"oid\":\"{oid}\"")));
}
#[tokio::test]
async fn empty_prefix_yields_top_level_lfs_key() {
let store = MockStore::new();
let tmp = TempDir::new().unwrap();
let oid = good_oid();
let a = agent(store.clone(), None, &tmp);
let src = tmp.path().join("body");
tokio::fs::write(&src, b"x").await.unwrap();
let mut out = Vec::new();
a.upload(&oid, 1, &src, &mut out).await.expect("upload");
assert!(store.contains(&format!("lfs/{oid}")));
}
fn parse_progress(line: &str) -> (u64, u64) {
let so_far = line
.split("\"bytesSoFar\":")
.nth(1)
.and_then(|tail| tail.split([',', '}']).next())
.and_then(|n| n.parse().ok())
.unwrap_or_else(|| panic!("bytesSoFar missing: {line}"));
let since = line
.split("\"bytesSinceLast\":")
.nth(1)
.and_then(|tail| tail.split([',', '}']).next())
.and_then(|n| n.parse().ok())
.unwrap_or_else(|| panic!("bytesSinceLast missing: {line}"));
(so_far, since)
}
#[tokio::test]
async fn upload_emits_chunked_progress_for_multipart_body() {
let store = MockStore::new();
store.set_progress_chunk_size(Some(8));
let oid = good_oid();
let tmp = TempDir::new().unwrap();
let a = agent(store.clone(), Some("repo"), &tmp);
let src = tmp.path().join("body");
let body = b"abcdefghijklmnopqrstuvwxyz0123456789"; tokio::fs::write(&src, body).await.unwrap();
let mut out = Vec::new();
a.upload(&oid, body.len() as u64, &src, &mut out)
.await
.expect("upload");
let got = String::from_utf8(out).unwrap();
let progress_lines: Vec<&str> = got
.lines()
.filter(|l| l.contains("\"event\":\"progress\""))
.collect();
assert!(
progress_lines.len() >= 2,
"expected ≥ 2 progress events for a body of {} bytes at chunk=8: {got}",
body.len()
);
let mut last_so_far = 0u64;
for line in &progress_lines {
let (so_far, since) = parse_progress(line);
assert!(
so_far >= last_so_far,
"bytesSoFar must be monotonic non-decreasing: {got}"
);
assert!(since > 0, "bytesSinceLast must be positive: {line}");
last_so_far = so_far;
}
assert_eq!(
last_so_far,
body.len() as u64,
"final bytesSoFar must equal size: {got}"
);
assert!(store.contains(&format!("repo/lfs/{oid}")));
}
#[tokio::test]
async fn download_emits_chunked_progress_for_multipart_body() {
let store = MockStore::new();
store.set_progress_chunk_size(Some(4));
let oid = good_oid();
let body: Vec<u8> = (0u8..=20).collect(); store.insert(format!("repo/lfs/{oid}"), Bytes::from(body.clone()));
let tmp = TempDir::new().unwrap();
let a = agent(store, Some("repo"), &tmp);
let mut out = Vec::new();
a.download(&oid, body.len() as u64, &mut out)
.await
.expect("download");
let got = String::from_utf8(out).unwrap();
let progress_lines: Vec<&str> = got
.lines()
.filter(|l| l.contains("\"event\":\"progress\""))
.collect();
assert!(
progress_lines.len() >= 2,
"expected ≥ 2 progress events: {got}"
);
let mut last_so_far = 0u64;
for line in &progress_lines {
let (so_far, _since) = parse_progress(line);
assert!(so_far >= last_so_far, "monotonic: {got}");
last_so_far = so_far;
}
assert_eq!(last_so_far, body.len() as u64, "final equals size: {got}");
}
}