1mod protocol;
2
3use anyhow::{Context as _, Result};
4use futures_util::{io::AsyncReadExt, AsyncWriteExt};
5use log::debug;
6use opendal::Operator;
7use protocol::*;
8use std::path::PathBuf;
9use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _};
10
11const DEFAULT_BUF_SIZE: usize = 8 * 1024 * 1024;
12
13pub struct Agent {
14 remote_op: Operator,
15 sender: tokio::sync::mpsc::Sender<String>,
16 tasks: tokio::task::JoinSet<()>,
17 root: PathBuf,
18}
19
20impl Agent {
21 pub fn new(remote_op: Operator, sender: tokio::sync::mpsc::Sender<String>) -> Self {
22 Self {
23 remote_op,
24 sender,
25 tasks: tokio::task::JoinSet::new(),
26 root: PathBuf::from(""),
27 }
28 }
29
30 pub async fn process(&mut self, request: &str) -> Result<()> {
31 debug!("request: {}", request);
32 let request: Request = serde_json::from_str(request).context("invalid request")?;
33 match request {
34 Request::Init => self.init().await,
35 Request::Upload { oid, path } => self.upload(oid, path).await,
36 Request::Download { oid } => self.download(oid).await,
37 Request::Terminate => self.terminate().await,
38 };
39 Ok(())
40 }
41
42 async fn init(&mut self) {
43 send_response(&self.sender, InitResponse::new().json()).await;
44 }
45
46 async fn upload(&mut self, oid: String, path: String) {
47 let remote_op = self.remote_op.clone();
48 let sender = self.sender.clone();
49 self.tasks.spawn(async move {
50 let status: Result<Option<String>> = async {
51 let mut reader =
52 tokio::io::BufReader::new(tokio::fs::File::open(path).await?).compat();
53 let mut writer = remote_op
54 .writer_with(&oid)
55 .chunk(DEFAULT_BUF_SIZE)
56 .await?
57 .into_futures_async_write();
58 copy_with_progress(&sender, &oid, &mut reader, &mut writer).await?;
59 writer.close().await?;
60 Ok(None)
61 }
62 .await;
63 send_response(&sender, TransferResponse::new(oid, status).json()).await;
64 });
65 }
66
67 async fn download(&mut self, oid: String) {
68 let remote_op = self.remote_op.clone();
69 let sender = self.sender.clone();
70 let path = self.root.join(lfs_object_path(&oid));
71 self.tasks.spawn(async move {
72 let status: Result<Option<String>> = async {
73 tokio::fs::create_dir_all(path.parent().unwrap()).await?;
74 let meta = remote_op.stat(&oid).await?;
75 let mut reader = remote_op
76 .reader_with(&oid)
77 .chunk(DEFAULT_BUF_SIZE)
78 .await?
79 .into_futures_async_read(0..meta.content_length())
80 .await?;
81 let mut writer =
82 tokio::io::BufWriter::new(tokio::fs::File::create(&path).await?).compat_write();
83 copy_with_progress(&sender, &oid, &mut reader, &mut writer).await?;
84 writer.close().await?;
85 Ok(Some(path.to_string_lossy().into()))
86 }
87 .await;
88 send_response(&sender, TransferResponse::new(oid, status).json()).await;
89 });
90 }
91
92 async fn terminate(&mut self) {
93 while self.tasks.join_next().await.is_some() {}
94 }
95}
96
97async fn send_response(sender: &tokio::sync::mpsc::Sender<String>, msg: String) {
98 debug!("response: {}", &msg);
99 sender.send(msg).await.unwrap();
100}
101
102async fn copy_with_progress<R, W>(
103 progress_sender: &tokio::sync::mpsc::Sender<String>,
104 oid: &str,
105 mut reader: R,
106 mut writer: W,
107) -> tokio::io::Result<usize>
108where
109 R: AsyncReadExt + Unpin,
110 W: AsyncWriteExt + Unpin,
111{
112 let mut bytes_so_far: usize = 0;
113 let mut buf = vec![0; DEFAULT_BUF_SIZE];
114
115 loop {
116 let bytes_since_last = reader.read(&mut buf).await?;
117 if bytes_since_last == 0 {
118 break;
119 }
120 writer.write_all(&buf[..bytes_since_last]).await?;
121 bytes_so_far += bytes_since_last;
122 send_response(
123 progress_sender,
124 ProgressResponse::new(oid.into(), bytes_so_far, bytes_since_last).json(),
125 )
126 .await;
127 }
128
129 Ok(bytes_so_far)
130}
131
132fn lfs_object_path(oid: &str) -> PathBuf {
133 PathBuf::from(".git/lfs/objects")
134 .join(&oid[0..2])
135 .join(&oid[2..4])
136 .join(oid)
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::io::Write as _;
143 use tempfile::NamedTempFile;
144 use tokio::sync::mpsc::error::TryRecvError;
145
146 fn agent_with_buf() -> (Agent, tokio::sync::mpsc::Receiver<String>) {
147 let remote_op = opendal::Operator::new(opendal::services::Memory::default())
148 .unwrap()
149 .finish();
150 let (tx, rx) = tokio::sync::mpsc::channel(32);
151 let agent = Agent::new(remote_op, tx);
152 (agent, rx)
153 }
154
155 #[tokio::test]
156 async fn init() {
157 let (mut agent, mut output) = agent_with_buf();
158 agent.process(r#"{"event":"init"}"#).await.unwrap();
159 assert_eq!(&output.recv().await.unwrap(), "{}");
160 assert_eq!(output.try_recv(), Err(TryRecvError::Empty));
161 }
162
163 #[tokio::test]
164 async fn upload() {
165 let (mut agent, mut output) = agent_with_buf();
166 let mut file = NamedTempFile::new().unwrap();
167 file.write_all("test".as_bytes()).unwrap();
168 agent
169 .process(
170 &serde_json::json!({
171 "event": "upload",
172 "oid": "aabbcc",
173 "path": file.path(),
174 })
175 .to_string(),
176 )
177 .await
178 .unwrap();
179 assert_eq!(
180 output.recv().await.unwrap(),
181 r#"{"event":"progress","oid":"aabbcc","bytesSoFar":4,"bytesSinceLast":4}"#
182 );
183 assert_eq!(
184 output.recv().await.unwrap(),
185 r#"{"event":"complete","oid":"aabbcc"}"#
186 );
187 assert_eq!(output.try_recv(), Err(TryRecvError::Empty));
188 assert_eq!(
189 agent.remote_op.read("aabbcc").await.unwrap().to_bytes(),
190 "test".as_bytes()
191 );
192 }
193
194 #[tokio::test]
195 async fn download() {
196 let tempdir = tempfile::tempdir().unwrap();
197 let (mut agent, mut output) = agent_with_buf();
198 agent.root = tempdir.path().into();
199 agent.remote_op.write("aabbcc", "test").await.unwrap();
200 agent
201 .process(r#"{"event":"download","oid":"aabbcc"}"#)
202 .await
203 .unwrap();
204 assert_eq!(
205 output.recv().await.unwrap(),
206 r#"{"event":"progress","oid":"aabbcc","bytesSoFar":4,"bytesSinceLast":4}"#
207 );
208 let file_name = tempdir
209 .path()
210 .join(".git/lfs/objects")
211 .join("aa")
212 .join("bb")
213 .join("aabbcc");
214 assert_eq!(
215 output.recv().await.unwrap(),
216 serde_json::json!({
217 "event": "complete",
218 "oid": "aabbcc",
219 "path": file_name,
220 })
221 .to_string()
222 );
223 assert_eq!(output.try_recv(), Err(TryRecvError::Empty));
224 assert_eq!(
225 std::fs::read_to_string(file_name).unwrap(),
226 "test".to_string()
227 );
228 }
229
230 #[tokio::test]
231 async fn terminate() {
232 let (mut agent, _) = agent_with_buf();
233 agent.process(r#"{"event":"terminate"}"#).await.unwrap();
234 }
235}