1use std::path::Path;
2
3use axum::extract::ws::{Message as AxumMessage, WebSocket};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7 fs::{create_dir_all, rename, write, File},
8 io::{AsyncReadExt, AsyncWriteExt},
9};
10
11use crate::{
12 fs::{RepoGuard, WsvcFsError},
13 model::{Blob, Record, Repository, Tree},
14 WsvcError,
15};
16
17#[derive(Error, Debug)]
19pub enum WsvcServerError {
20 #[error("fs error: {0}")]
21 WsvcError(#[from] WsvcError),
22 #[error("serialization error: {0}")]
23 SerializationError(#[from] serde_json::Error),
24 #[error("network error: {0}")]
25 NetworkError(#[from] axum::Error),
26 #[error("data error: {0}")]
27 DataError(String),
28}
29
30#[derive(Serialize, Deserialize, Clone, Debug)]
31pub struct RecordWithState {
32 pub record: Record,
33 pub state: i32,
35}
36
37#[derive(Serialize, Deserialize, Clone, Debug)]
38pub struct TreeWithState {
39 pub tree: Tree,
40 pub state: i32,
42}
43
44#[derive(Serialize, Deserialize, Clone, Debug)]
45pub struct BlobWithState {
46 pub blob: Blob,
47 pub state: i32,
49}
50
51async fn send_data(ws: &mut WebSocket, data: Vec<u8>) -> Result<(), WsvcServerError> {
52 let mut header_buf = [0x33u8, 0x07u8, 0u8, 0u8, 0u8, 0u8];
53 let size = data.len();
54 header_buf[2] = (size >> 24) as u8;
55 header_buf[3] = (size >> 16) as u8;
56 header_buf[4] = (size >> 8) as u8;
57 header_buf[5] = size as u8;
58 ws.send(header_buf[..].into()).await?;
59 let mut offset = 0;
61 while offset < data.len() {
62 let end = offset + 16384;
63 let end = if end > data.len() { data.len() } else { end };
64 ws.send(data[offset..end].into()).await?;
65 offset = end;
66 }
67 Ok(())
68}
69
70async fn recv_data(ws: &mut WebSocket) -> Result<Vec<u8>, WsvcServerError> {
71 if let Some(Ok(AxumMessage::Binary(msg))) = ws.recv().await {
73 let mut header_buf = [0u8; 6];
74 header_buf.copy_from_slice(&msg[..6]);
75 if header_buf[0] != 0x33 || header_buf[1] != 0x07 {
76 return Err(WsvcServerError::DataError(
77 "invalid packet header".to_owned(),
78 ));
79 }
80 let size = ((header_buf[2] as usize) << 24)
81 + ((header_buf[3] as usize) << 16)
82 + ((header_buf[4] as usize) << 8)
83 + (header_buf[5] as usize);
84 let mut data = Vec::with_capacity(size);
85 data.extend_from_slice(&msg[6..]);
86 let mut offset = data.len();
87 while offset < size {
88 if let Some(Ok(AxumMessage::Binary(msg))) = ws.recv().await {
89 data.extend_from_slice(&msg);
90 offset = data.len();
91 }
92 }
93 Ok(data)
94 } else {
95 Err(WsvcServerError::DataError(
96 "invalid packet header".to_owned(),
97 ))
98 }
99}
100
101async fn send_file(
102 ws: &mut WebSocket,
103 file_name: &str,
104 mut file: File,
105) -> Result<(), WsvcServerError> {
106 let mut header_buf = [0x09u8, 0x28u8, 0u8, 0u8];
108 let file_name_size = file_name.len();
109 if file_name_size > 16384 {
110 return Err(WsvcServerError::DataError("file name too long".to_owned()));
111 }
112 header_buf[2] = (file_name_size >> 8) as u8;
113 header_buf[3] = file_name_size as u8;
114 ws.send(header_buf[..].into()).await?;
115 let mut file_header_buf = [0x07u8, 0x15u8, 0u8, 0u8, 0u8, 0u8];
116 let mut buf = [0u8; 16384];
117 let size = file
118 .metadata()
119 .await
120 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?
121 .len() as usize;
122 file_header_buf[2] = (size >> 24) as u8;
123 file_header_buf[3] = (size >> 16) as u8;
124 file_header_buf[4] = (size >> 8) as u8;
125 file_header_buf[5] = size as u8;
126 ws.send(file_header_buf[..].into()).await?;
127 let mut offset = 0;
128 while offset != size {
129 let read_size = file
130 .read(&mut buf)
131 .await
132 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
133 ws.send(buf[..read_size].into()).await?;
134 offset += read_size;
135 }
136 Ok(())
137}
138
139async fn recv_file(
140 ws: &mut WebSocket,
141 storage_dir: impl AsRef<Path>,
142) -> Result<(), WsvcServerError> {
143 let file_name_header = ws
144 .recv()
145 .await
146 .ok_or(WsvcServerError::DataError(format!(
147 "invalid file name header: {}",
148 "none"
149 )))?
150 .map_err(WsvcServerError::NetworkError)?;
151 let mut file_name_header_buf = [0u8; 4];
152 if let AxumMessage::Binary(msg) = file_name_header {
153 file_name_header_buf.copy_from_slice(&msg[..4]);
154 } else {
155 return Err(WsvcServerError::DataError(format!(
156 "invalid file name header: {:?}",
157 file_name_header
158 )));
159 }
160 if file_name_header_buf[0] != 0x09 || file_name_header_buf[1] != 0x28 {
161 return Err(WsvcServerError::DataError(format!(
162 "invalid file name header: {:?}",
163 file_name_header_buf
164 )));
165 }
166 let file_name_size =
167 ((file_name_header_buf[2] as usize) << 8) + (file_name_header_buf[3] as usize);
168 let file_name = ws
169 .recv()
170 .await
171 .ok_or(WsvcServerError::DataError(format!(
172 "invalid file name: {}",
173 "none"
174 )))?
175 .map_err(WsvcServerError::NetworkError)?;
176 let file_name = if let AxumMessage::Binary(msg) = file_name {
177 String::from_utf8(msg[..file_name_size].to_vec())
178 .map_err(|err| WsvcServerError::DataError(err.to_string()))?
179 } else {
180 return Err(WsvcServerError::DataError(format!(
181 "invalid file name: {:?}",
182 file_name
183 )));
184 };
185 let file_path = storage_dir.as_ref().join(file_name);
186 let file_header = ws
187 .recv()
188 .await
189 .ok_or(WsvcServerError::DataError("invalid file header".to_owned()))?
190 .map_err(WsvcServerError::NetworkError)?;
191 let mut file_header_buf = [0u8; 6];
192 if let AxumMessage::Binary(msg) = file_header {
193 file_header_buf.copy_from_slice(&msg[..6]);
194 } else {
195 return Err(WsvcServerError::DataError("invalid file header".to_owned()));
196 }
197 if file_header_buf[0] != 0x07 || file_header_buf[1] != 0x15 {
198 return Err(WsvcServerError::DataError("invalid file header".to_owned()));
199 }
200 let size = ((file_header_buf[2] as usize) << 24)
201 + ((file_header_buf[3] as usize) << 16)
202 + ((file_header_buf[4] as usize) << 8)
203 + (file_header_buf[5] as usize);
204 let mut file = File::create(&file_path)
205 .await
206 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
207 let mut offset = 0;
208 while offset != size {
209 let data = ws
210 .recv()
211 .await
212 .ok_or(WsvcServerError::DataError("invalid file data".to_owned()))?
213 .map_err(WsvcServerError::NetworkError)?;
214 if let AxumMessage::Binary(data) = data {
215 offset += data.len();
216 file.write(&data)
217 .await
218 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
219 } else {
220 return Err(WsvcServerError::DataError("invalid file data".to_owned()));
221 }
222 }
223
224 Ok(())
225}
226
227async fn sync_records(
232 repo: &Repository,
233 ws: &mut WebSocket,
234) -> Result<(Vec<Record>, Vec<Record>), WsvcServerError> {
235 tracing::debug!("ROUND 1: sync records...");
238 let records = repo.get_records().await.map_err(WsvcError::FsError)?;
239 let packet_body = serde_json::to_string(&records)?;
240 tracing::trace!("send records: {:?}", records);
241 send_data(ws, packet_body.into_bytes()).await?;
242 let diff_records = recv_data(ws).await?;
243 tracing::trace!("recv diff records: {:?}", diff_records);
244 let diff_records: Vec<RecordWithState> = serde_json::from_slice(&diff_records)?;
245 let wanted_records = diff_records
246 .iter()
247 .filter(|r| r.state == 1)
248 .map(|r| r.record.clone())
249 .collect::<Vec<_>>();
250 let will_given_records = diff_records
252 .iter()
253 .filter(|r| r.state == 2)
254 .map(|r| r.record.clone())
255 .collect::<Vec<_>>();
256 Ok((wanted_records, will_given_records))
257}
258
259async fn sync_trees(
260 repo: &Repository,
261 ws: &mut WebSocket,
262 wanted_records: &[Record],
263) -> Result<(Vec<Tree>, Vec<Tree>), WsvcServerError> {
264 tracing::debug!("ROUND 2: sync trees...");
265 let mut trees = Vec::new();
266 for record in wanted_records {
267 trees.extend_from_slice(
268 &repo
269 .get_trees_of_record(&record.hash)
270 .await
271 .map_err(WsvcError::FsError)?,
272 );
273 }
274 let packet_body = serde_json::to_string(&trees)?;
275 tracing::trace!("send trees: {:?}", trees);
276 send_data(ws, packet_body.into_bytes()).await?;
277 let diff_trees = recv_data(ws).await?;
278 tracing::trace!("recv diff trees: {:?}", diff_trees);
279 let diff_trees: Vec<TreeWithState> = serde_json::from_slice(&diff_trees)?;
280 let wanted_trees = diff_trees
281 .iter()
282 .filter(|t| t.state == 1)
283 .map(|t| t.tree.clone())
284 .collect::<Vec<_>>();
285 let will_given_trees = diff_trees
286 .iter()
287 .filter(|t| t.state == 2)
288 .map(|t| t.tree.clone())
289 .collect::<Vec<_>>();
290 Ok((wanted_trees, will_given_trees))
291}
292
293async fn sync_blobs_meta(
294 repo: &Repository,
295 ws: &mut WebSocket,
296 wanted_trees: &[Tree],
297) -> Result<(Vec<Blob>, Vec<Blob>), WsvcServerError> {
298 tracing::debug!("ROUND 3: sync blobs meta...");
299 let mut blobs = Vec::new();
300 for tree in wanted_trees {
301 blobs.extend_from_slice(
302 &repo
303 .get_blobs_of_tree(&tree.hash)
304 .await
305 .map_err(WsvcError::FsError)?,
306 );
307 }
308 let packet_body = serde_json::to_string(&blobs)?;
309 tracing::trace!("send blobs meta: {:?}", blobs);
310 send_data(ws, packet_body.into_bytes()).await?;
311 let diff_blobs = recv_data(ws).await?;
312 tracing::trace!("recv diff blobs meta: {:?}", diff_blobs);
313 let diff_blobs: Vec<BlobWithState> = serde_json::from_slice(&diff_blobs)?;
314 let wanted_blobs = diff_blobs
315 .iter()
316 .filter(|b| b.state == 1)
317 .map(|b| b.blob.clone())
318 .collect::<Vec<_>>();
319 let will_given_blobs = diff_blobs
320 .iter()
321 .filter(|b| b.state == 2)
322 .map(|b| b.blob.clone())
323 .collect::<Vec<_>>();
324 Ok((wanted_blobs, will_given_blobs))
325}
326
327async fn sync_blobs(
328 repo: &Repository,
329 ws: &mut WebSocket,
330 wanted_blobs: &[Blob],
331 will_given_blobs: &[Blob],
332) -> Result<(), WsvcServerError> {
333 tracing::debug!("ROUND 4: sync blobs...");
334 let objects_dir = repo
335 .objects_dir()
336 .await
337 .map_err(|err| WsvcError::from(err))?;
338 let temp_objects_dir = repo
339 .temp_dir()
340 .await
341 .map_err(|err| WsvcError::from(err))?
342 .join("objects");
343 if !temp_objects_dir.exists() {
344 create_dir_all(&temp_objects_dir)
345 .await
346 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
347 }
348 for i in wanted_blobs {
349 let object_file = objects_dir.join(i.hash.0.to_string());
350 let file = File::open(&object_file)
351 .await
352 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
353 tracing::trace!("send blob file: {:?}", i);
354 send_file(ws, &i.hash.0.to_string(), file).await?;
355 }
356 for _ in 0..will_given_blobs.len() {
357 recv_file(ws, &temp_objects_dir).await?;
358 }
359 for i in will_given_blobs {
360 let object_file = temp_objects_dir.join(i.hash.0.to_string());
361 if !object_file.exists() {
362 return Err(WsvcServerError::DataError(format!(
363 "blob file not exists: {:?}",
364 object_file
365 )));
366 }
367 }
368 for i in will_given_blobs {
369 rename(
370 temp_objects_dir.join(i.hash.0.to_string()),
371 objects_dir.join(i.hash.0.to_string()),
372 )
373 .await
374 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
375 }
376 Ok(())
377}
378
379pub async fn sync_with(repo: &Repository, ws: &mut WebSocket) -> Result<(), WsvcServerError> {
399 let guard = RepoGuard::new(repo)
400 .await
401 .map_err(|err| WsvcError::FsError(err))?;
402 let (wanted_records, given_records) = sync_records(repo, ws).await?;
403 let (wanted_trees, given_trees) = sync_trees(repo, ws, wanted_records.as_slice()).await?;
404 let (wanted_blobs, will_given_blobs) =
405 sync_blobs_meta(repo, ws, wanted_trees.as_slice()).await?;
406 sync_blobs(
408 repo,
409 ws,
410 wanted_blobs.as_slice(),
411 will_given_blobs.as_slice(),
412 )
413 .await?;
414
415 tracing::debug!("write trees to tree database...");
417 let trees_dir = repo.trees_dir().await.map_err(WsvcError::FsError)?;
418 for tree in &given_trees {
419 write(
420 trees_dir.join(tree.hash.0.to_hex().as_str()),
421 serde_json::to_string(tree)
422 .map_err(|err| WsvcError::FsError(WsvcFsError::SerializationFailed(err)))?,
423 )
424 .await
425 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
426 }
427
428 tracing::debug!("write records to record database...");
430 let records_dir = repo.records_dir().await.map_err(WsvcError::FsError)?;
431 for record in &given_records {
432 write(
433 records_dir.join(record.hash.0.to_hex().as_str()),
434 serde_json::to_string(record)
435 .map_err(|err| WsvcError::FsError(WsvcFsError::SerializationFailed(err)))?,
436 )
437 .await
438 .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
439 }
440
441 drop(guard);
442 Ok(())
443}