lfs_dal/
lib.rs

1mod protocol;
2
3use anyhow::{Context as _, Result};
4use futures_util::{io::AsyncReadExt, AsyncWriteExt};
5use log::debug;
6use opendal::Operator;
7use protocol::*;
8use std::path::PathBuf;
9use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _};
10
11const DEFAULT_BUF_SIZE: usize = 8 * 1024 * 1024;
12
13pub struct Agent {
14    remote_op: Operator,
15    sender: tokio::sync::mpsc::Sender<String>,
16    tasks: tokio::task::JoinSet<()>,
17    root: PathBuf,
18}
19
20impl Agent {
21    pub fn new(remote_op: Operator, sender: tokio::sync::mpsc::Sender<String>) -> Self {
22        Self {
23            remote_op,
24            sender,
25            tasks: tokio::task::JoinSet::new(),
26            root: PathBuf::from(""),
27        }
28    }
29
30    pub async fn process(&mut self, request: &str) -> Result<()> {
31        debug!("request: {}", request);
32        let request: Request = serde_json::from_str(request).context("invalid request")?;
33        match request {
34            Request::Init => self.init().await,
35            Request::Upload { oid, path } => self.upload(oid, path).await,
36            Request::Download { oid } => self.download(oid).await,
37            Request::Terminate => self.terminate().await,
38        };
39        Ok(())
40    }
41
42    async fn init(&mut self) {
43        send_response(&self.sender, InitResponse::new().json()).await;
44    }
45
46    async fn upload(&mut self, oid: String, path: String) {
47        let remote_op = self.remote_op.clone();
48        let sender = self.sender.clone();
49        self.tasks.spawn(async move {
50            let status: Result<Option<String>> = async {
51                let mut reader =
52                    tokio::io::BufReader::new(tokio::fs::File::open(path).await?).compat();
53                let mut writer = remote_op
54                    .writer_with(&oid)
55                    .chunk(DEFAULT_BUF_SIZE)
56                    .await?
57                    .into_futures_async_write();
58                copy_with_progress(&sender, &oid, &mut reader, &mut writer).await?;
59                writer.close().await?;
60                Ok(None)
61            }
62            .await;
63            send_response(&sender, TransferResponse::new(oid, status).json()).await;
64        });
65    }
66
67    async fn download(&mut self, oid: String) {
68        let remote_op = self.remote_op.clone();
69        let sender = self.sender.clone();
70        let path = self.root.join(lfs_object_path(&oid));
71        self.tasks.spawn(async move {
72            let status: Result<Option<String>> = async {
73                tokio::fs::create_dir_all(path.parent().unwrap()).await?;
74                let meta = remote_op.stat(&oid).await?;
75                let mut reader = remote_op
76                    .reader_with(&oid)
77                    .chunk(DEFAULT_BUF_SIZE)
78                    .await?
79                    .into_futures_async_read(0..meta.content_length())
80                    .await?;
81                let mut writer =
82                    tokio::io::BufWriter::new(tokio::fs::File::create(&path).await?).compat_write();
83                copy_with_progress(&sender, &oid, &mut reader, &mut writer).await?;
84                writer.close().await?;
85                Ok(Some(path.to_string_lossy().into()))
86            }
87            .await;
88            send_response(&sender, TransferResponse::new(oid, status).json()).await;
89        });
90    }
91
92    async fn terminate(&mut self) {
93        while self.tasks.join_next().await.is_some() {}
94    }
95}
96
97async fn send_response(sender: &tokio::sync::mpsc::Sender<String>, msg: String) {
98    debug!("response: {}", &msg);
99    sender.send(msg).await.unwrap();
100}
101
102async fn copy_with_progress<R, W>(
103    progress_sender: &tokio::sync::mpsc::Sender<String>,
104    oid: &str,
105    mut reader: R,
106    mut writer: W,
107) -> tokio::io::Result<usize>
108where
109    R: AsyncReadExt + Unpin,
110    W: AsyncWriteExt + Unpin,
111{
112    let mut bytes_so_far: usize = 0;
113    let mut buf = vec![0; DEFAULT_BUF_SIZE];
114
115    loop {
116        let bytes_since_last = reader.read(&mut buf).await?;
117        if bytes_since_last == 0 {
118            break;
119        }
120        writer.write_all(&buf[..bytes_since_last]).await?;
121        bytes_so_far += bytes_since_last;
122        send_response(
123            progress_sender,
124            ProgressResponse::new(oid.into(), bytes_so_far, bytes_since_last).json(),
125        )
126        .await;
127    }
128
129    Ok(bytes_so_far)
130}
131
132fn lfs_object_path(oid: &str) -> PathBuf {
133    PathBuf::from(".git/lfs/objects")
134        .join(&oid[0..2])
135        .join(&oid[2..4])
136        .join(oid)
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use std::io::Write as _;
143    use tempfile::NamedTempFile;
144    use tokio::sync::mpsc::error::TryRecvError;
145
146    fn agent_with_buf() -> (Agent, tokio::sync::mpsc::Receiver<String>) {
147        let remote_op = opendal::Operator::new(opendal::services::Memory::default())
148            .unwrap()
149            .finish();
150        let (tx, rx) = tokio::sync::mpsc::channel(32);
151        let agent = Agent::new(remote_op, tx);
152        (agent, rx)
153    }
154
155    #[tokio::test]
156    async fn init() {
157        let (mut agent, mut output) = agent_with_buf();
158        agent.process(r#"{"event":"init"}"#).await.unwrap();
159        assert_eq!(&output.recv().await.unwrap(), "{}");
160        assert_eq!(output.try_recv(), Err(TryRecvError::Empty));
161    }
162
163    #[tokio::test]
164    async fn upload() {
165        let (mut agent, mut output) = agent_with_buf();
166        let mut file = NamedTempFile::new().unwrap();
167        file.write_all("test".as_bytes()).unwrap();
168        agent
169            .process(
170                &serde_json::json!({
171                    "event": "upload",
172                    "oid": "aabbcc",
173                    "path": file.path(),
174                })
175                .to_string(),
176            )
177            .await
178            .unwrap();
179        assert_eq!(
180            output.recv().await.unwrap(),
181            r#"{"event":"progress","oid":"aabbcc","bytesSoFar":4,"bytesSinceLast":4}"#
182        );
183        assert_eq!(
184            output.recv().await.unwrap(),
185            r#"{"event":"complete","oid":"aabbcc"}"#
186        );
187        assert_eq!(output.try_recv(), Err(TryRecvError::Empty));
188        assert_eq!(
189            agent.remote_op.read("aabbcc").await.unwrap().to_bytes(),
190            "test".as_bytes()
191        );
192    }
193
194    #[tokio::test]
195    async fn download() {
196        let tempdir = tempfile::tempdir().unwrap();
197        let (mut agent, mut output) = agent_with_buf();
198        agent.root = tempdir.path().into();
199        agent.remote_op.write("aabbcc", "test").await.unwrap();
200        agent
201            .process(r#"{"event":"download","oid":"aabbcc"}"#)
202            .await
203            .unwrap();
204        assert_eq!(
205            output.recv().await.unwrap(),
206            r#"{"event":"progress","oid":"aabbcc","bytesSoFar":4,"bytesSinceLast":4}"#
207        );
208        let file_name = tempdir
209            .path()
210            .join(".git/lfs/objects")
211            .join("aa")
212            .join("bb")
213            .join("aabbcc");
214        assert_eq!(
215            output.recv().await.unwrap(),
216            serde_json::json!({
217                "event": "complete",
218                "oid": "aabbcc",
219                "path": file_name,
220            })
221            .to_string()
222        );
223        assert_eq!(output.try_recv(), Err(TryRecvError::Empty));
224        assert_eq!(
225            std::fs::read_to_string(file_name).unwrap(),
226            "test".to_string()
227        );
228    }
229
230    #[tokio::test]
231    async fn terminate() {
232        let (mut agent, _) = agent_with_buf();
233        agent.process(r#"{"event":"terminate"}"#).await.unwrap();
234    }
235}