wsvc/server/
mod.rs

1use std::path::Path;
2
3use axum::extract::ws::{Message as AxumMessage, WebSocket};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    fs::{create_dir_all, rename, write, File},
8    io::{AsyncReadExt, AsyncWriteExt},
9};
10
11use crate::{
12    fs::{RepoGuard, WsvcFsError},
13    model::{Blob, Record, Repository, Tree},
14    WsvcError,
15};
16
17/// `WsvcServerError` stand for server error.
18#[derive(Error, Debug)]
19pub enum WsvcServerError {
20    #[error("fs error: {0}")]
21    WsvcError(#[from] WsvcError),
22    #[error("serialization error: {0}")]
23    SerializationError(#[from] serde_json::Error),
24    #[error("network error: {0}")]
25    NetworkError(#[from] axum::Error),
26    #[error("data error: {0}")]
27    DataError(String),
28}
29
30#[derive(Serialize, Deserialize, Clone, Debug)]
31pub struct RecordWithState {
32    pub record: Record,
33    /// 0: same, 1: wanted, 2: will-give
34    pub state: i32,
35}
36
37#[derive(Serialize, Deserialize, Clone, Debug)]
38pub struct TreeWithState {
39    pub tree: Tree,
40    /// 0: same, 1: wanted, 2: will-give
41    pub state: i32,
42}
43
44#[derive(Serialize, Deserialize, Clone, Debug)]
45pub struct BlobWithState {
46    pub blob: Blob,
47    /// 0: same, 1: wanted, 2: will-give
48    pub state: i32,
49}
50
51async fn send_data(ws: &mut WebSocket, data: Vec<u8>) -> Result<(), WsvcServerError> {
52    let mut header_buf = [0x33u8, 0x07u8, 0u8, 0u8, 0u8, 0u8];
53    let size = data.len();
54    header_buf[2] = (size >> 24) as u8;
55    header_buf[3] = (size >> 16) as u8;
56    header_buf[4] = (size >> 8) as u8;
57    header_buf[5] = size as u8;
58    ws.send(header_buf[..].into()).await?;
59    // split data into 16384 bytes
60    let mut offset = 0;
61    while offset < data.len() {
62        let end = offset + 16384;
63        let end = if end > data.len() { data.len() } else { end };
64        ws.send(data[offset..end].into()).await?;
65        offset = end;
66    }
67    Ok(())
68}
69
70async fn recv_data(ws: &mut WebSocket) -> Result<Vec<u8>, WsvcServerError> {
71    // match header and get size
72    if let Some(Ok(AxumMessage::Binary(msg))) = ws.recv().await {
73        let mut header_buf = [0u8; 6];
74        header_buf.copy_from_slice(&msg[..6]);
75        if header_buf[0] != 0x33 || header_buf[1] != 0x07 {
76            return Err(WsvcServerError::DataError(
77                "invalid packet header".to_owned(),
78            ));
79        }
80        let size = ((header_buf[2] as usize) << 24)
81            + ((header_buf[3] as usize) << 16)
82            + ((header_buf[4] as usize) << 8)
83            + (header_buf[5] as usize);
84        let mut data = Vec::with_capacity(size);
85        data.extend_from_slice(&msg[6..]);
86        let mut offset = data.len();
87        while offset < size {
88            if let Some(Ok(AxumMessage::Binary(msg))) = ws.recv().await {
89                data.extend_from_slice(&msg);
90                offset = data.len();
91            }
92        }
93        Ok(data)
94    } else {
95        Err(WsvcServerError::DataError(
96            "invalid packet header".to_owned(),
97        ))
98    }
99}
100
101async fn send_file(
102    ws: &mut WebSocket,
103    file_name: &str,
104    mut file: File,
105) -> Result<(), WsvcServerError> {
106    // file name packet header: 0x09 0x28 [size], 9.28 is Kamisato Ayaka's birthday
107    let mut header_buf = [0x09u8, 0x28u8, 0u8, 0u8];
108    let file_name_size = file_name.len();
109    if file_name_size > 16384 {
110        return Err(WsvcServerError::DataError("file name too long".to_owned()));
111    }
112    header_buf[2] = (file_name_size >> 8) as u8;
113    header_buf[3] = file_name_size as u8;
114    ws.send(header_buf[..].into()).await?;
115    let mut file_header_buf = [0x07u8, 0x15u8, 0u8, 0u8, 0u8, 0u8];
116    let mut buf = [0u8; 16384];
117    let size = file
118        .metadata()
119        .await
120        .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?
121        .len() as usize;
122    file_header_buf[2] = (size >> 24) as u8;
123    file_header_buf[3] = (size >> 16) as u8;
124    file_header_buf[4] = (size >> 8) as u8;
125    file_header_buf[5] = size as u8;
126    ws.send(file_header_buf[..].into()).await?;
127    let mut offset = 0;
128    while offset != size {
129        let read_size = file
130            .read(&mut buf)
131            .await
132            .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
133        ws.send(buf[..read_size].into()).await?;
134        offset += read_size;
135    }
136    Ok(())
137}
138
139async fn recv_file(
140    ws: &mut WebSocket,
141    storage_dir: impl AsRef<Path>,
142) -> Result<(), WsvcServerError> {
143    let file_name_header = ws
144        .recv()
145        .await
146        .ok_or(WsvcServerError::DataError(format!(
147            "invalid file name header: {}",
148            "none"
149        )))?
150        .map_err(WsvcServerError::NetworkError)?;
151    let mut file_name_header_buf = [0u8; 4];
152    if let AxumMessage::Binary(msg) = file_name_header {
153        file_name_header_buf.copy_from_slice(&msg[..4]);
154    } else {
155        return Err(WsvcServerError::DataError(format!(
156            "invalid file name header: {:?}",
157            file_name_header
158        )));
159    }
160    if file_name_header_buf[0] != 0x09 || file_name_header_buf[1] != 0x28 {
161        return Err(WsvcServerError::DataError(format!(
162            "invalid file name header: {:?}",
163            file_name_header_buf
164        )));
165    }
166    let file_name_size =
167        ((file_name_header_buf[2] as usize) << 8) + (file_name_header_buf[3] as usize);
168    let file_name = ws
169        .recv()
170        .await
171        .ok_or(WsvcServerError::DataError(format!(
172            "invalid file name: {}",
173            "none"
174        )))?
175        .map_err(WsvcServerError::NetworkError)?;
176    let file_name = if let AxumMessage::Binary(msg) = file_name {
177        String::from_utf8(msg[..file_name_size].to_vec())
178            .map_err(|err| WsvcServerError::DataError(err.to_string()))?
179    } else {
180        return Err(WsvcServerError::DataError(format!(
181            "invalid file name: {:?}",
182            file_name
183        )));
184    };
185    let file_path = storage_dir.as_ref().join(file_name);
186    let file_header = ws
187        .recv()
188        .await
189        .ok_or(WsvcServerError::DataError("invalid file header".to_owned()))?
190        .map_err(WsvcServerError::NetworkError)?;
191    let mut file_header_buf = [0u8; 6];
192    if let AxumMessage::Binary(msg) = file_header {
193        file_header_buf.copy_from_slice(&msg[..6]);
194    } else {
195        return Err(WsvcServerError::DataError("invalid file header".to_owned()));
196    }
197    if file_header_buf[0] != 0x07 || file_header_buf[1] != 0x15 {
198        return Err(WsvcServerError::DataError("invalid file header".to_owned()));
199    }
200    let size = ((file_header_buf[2] as usize) << 24)
201        + ((file_header_buf[3] as usize) << 16)
202        + ((file_header_buf[4] as usize) << 8)
203        + (file_header_buf[5] as usize);
204    let mut file = File::create(&file_path)
205        .await
206        .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
207    let mut offset = 0;
208    while offset != size {
209        let data = ws
210            .recv()
211            .await
212            .ok_or(WsvcServerError::DataError("invalid file data".to_owned()))?
213            .map_err(WsvcServerError::NetworkError)?;
214        if let AxumMessage::Binary(data) = data {
215            offset += data.len();
216            file.write(&data)
217                .await
218                .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
219        } else {
220            return Err(WsvcServerError::DataError("invalid file data".to_owned()));
221        }
222    }
223
224    Ok(())
225}
226
227/// `sync_records` syncs records with client.
228///
229/// ## returns
230/// (wanted_records, will_given_records)
231async fn sync_records(
232    repo: &Repository,
233    ws: &mut WebSocket,
234) -> Result<(Vec<Record>, Vec<Record>), WsvcServerError> {
235    // packet header: 0x33 0x07 [size]
236    // the first round for server, pack all record and send it to client
237    tracing::debug!("ROUND 1: sync records...");
238    let records = repo.get_records().await.map_err(WsvcError::FsError)?;
239    let packet_body = serde_json::to_string(&records)?;
240    tracing::trace!("send records: {:?}", records);
241    send_data(ws, packet_body.into_bytes()).await?;
242    let diff_records = recv_data(ws).await?;
243    tracing::trace!("recv diff records: {:?}", diff_records);
244    let diff_records: Vec<RecordWithState> = serde_json::from_slice(&diff_records)?;
245    let wanted_records = diff_records
246        .iter()
247        .filter(|r| r.state == 1)
248        .map(|r| r.record.clone())
249        .collect::<Vec<_>>();
250    // do not store records until trees and blobs are synced.
251    let will_given_records = diff_records
252        .iter()
253        .filter(|r| r.state == 2)
254        .map(|r| r.record.clone())
255        .collect::<Vec<_>>();
256    Ok((wanted_records, will_given_records))
257}
258
259async fn sync_trees(
260    repo: &Repository,
261    ws: &mut WebSocket,
262    wanted_records: &[Record],
263) -> Result<(Vec<Tree>, Vec<Tree>), WsvcServerError> {
264    tracing::debug!("ROUND 2: sync trees...");
265    let mut trees = Vec::new();
266    for record in wanted_records {
267        trees.extend_from_slice(
268            &repo
269                .get_trees_of_record(&record.hash)
270                .await
271                .map_err(WsvcError::FsError)?,
272        );
273    }
274    let packet_body = serde_json::to_string(&trees)?;
275    tracing::trace!("send trees: {:?}", trees);
276    send_data(ws, packet_body.into_bytes()).await?;
277    let diff_trees = recv_data(ws).await?;
278    tracing::trace!("recv diff trees: {:?}", diff_trees);
279    let diff_trees: Vec<TreeWithState> = serde_json::from_slice(&diff_trees)?;
280    let wanted_trees = diff_trees
281        .iter()
282        .filter(|t| t.state == 1)
283        .map(|t| t.tree.clone())
284        .collect::<Vec<_>>();
285    let will_given_trees = diff_trees
286        .iter()
287        .filter(|t| t.state == 2)
288        .map(|t| t.tree.clone())
289        .collect::<Vec<_>>();
290    Ok((wanted_trees, will_given_trees))
291}
292
293async fn sync_blobs_meta(
294    repo: &Repository,
295    ws: &mut WebSocket,
296    wanted_trees: &[Tree],
297) -> Result<(Vec<Blob>, Vec<Blob>), WsvcServerError> {
298    tracing::debug!("ROUND 3: sync blobs meta...");
299    let mut blobs = Vec::new();
300    for tree in wanted_trees {
301        blobs.extend_from_slice(
302            &repo
303                .get_blobs_of_tree(&tree.hash)
304                .await
305                .map_err(WsvcError::FsError)?,
306        );
307    }
308    let packet_body = serde_json::to_string(&blobs)?;
309    tracing::trace!("send blobs meta: {:?}", blobs);
310    send_data(ws, packet_body.into_bytes()).await?;
311    let diff_blobs = recv_data(ws).await?;
312    tracing::trace!("recv diff blobs meta: {:?}", diff_blobs);
313    let diff_blobs: Vec<BlobWithState> = serde_json::from_slice(&diff_blobs)?;
314    let wanted_blobs = diff_blobs
315        .iter()
316        .filter(|b| b.state == 1)
317        .map(|b| b.blob.clone())
318        .collect::<Vec<_>>();
319    let will_given_blobs = diff_blobs
320        .iter()
321        .filter(|b| b.state == 2)
322        .map(|b| b.blob.clone())
323        .collect::<Vec<_>>();
324    Ok((wanted_blobs, will_given_blobs))
325}
326
327async fn sync_blobs(
328    repo: &Repository,
329    ws: &mut WebSocket,
330    wanted_blobs: &[Blob],
331    will_given_blobs: &[Blob],
332) -> Result<(), WsvcServerError> {
333    tracing::debug!("ROUND 4: sync blobs...");
334    let objects_dir = repo
335        .objects_dir()
336        .await
337        .map_err(|err| WsvcError::from(err))?;
338    let temp_objects_dir = repo
339        .temp_dir()
340        .await
341        .map_err(|err| WsvcError::from(err))?
342        .join("objects");
343    if !temp_objects_dir.exists() {
344        create_dir_all(&temp_objects_dir)
345            .await
346            .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
347    }
348    for i in wanted_blobs {
349        let object_file = objects_dir.join(i.hash.0.to_string());
350        let file = File::open(&object_file)
351            .await
352            .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
353        tracing::trace!("send blob file: {:?}", i);
354        send_file(ws, &i.hash.0.to_string(), file).await?;
355    }
356    for _ in 0..will_given_blobs.len() {
357        recv_file(ws, &temp_objects_dir).await?;
358    }
359    for i in will_given_blobs {
360        let object_file = temp_objects_dir.join(i.hash.0.to_string());
361        if !object_file.exists() {
362            return Err(WsvcServerError::DataError(format!(
363                "blob file not exists: {:?}",
364                object_file
365            )));
366        }
367    }
368    for i in will_given_blobs {
369        rename(
370            temp_objects_dir.join(i.hash.0.to_string()),
371            objects_dir.join(i.hash.0.to_string()),
372        )
373        .await
374        .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
375    }
376    Ok(())
377}
378
379/// `sync_with` syncs repository with client.
380///
381/// - round 1: sync records. server send all records to client, client get records,
382///     and diff its own records with server's records, then send diff records to server.
383///     then server got the `wanted_records` and `will_given_records`
384/// - round 2: sync trees. server send all trees of `wanted_records` recursively to client,
385///     client get trees, and diff its own trees with server's trees, then send diff trees to server.
386/// - round 3: sync blobs list. server send all blobs meta of diff tree to client,
387///     client get blobs meta, and diff its own blobs meta with server's blobs meta, then send diff blobs meta to server.
388/// - round 4: sync blobs. server send all blobs of diff blobs meta to client,
389///     client send all blobs of diff blobs meta to server.
390/// - end process: server store all trees and blobs, then store all records.
391///
392/// when failed, both server and client should cleanup all temp files.
393///
394/// ## arguments
395///
396/// * `repo` - repository to sync with.
397/// * `ws` - websocket connection from axum.
398pub async fn sync_with(repo: &Repository, ws: &mut WebSocket) -> Result<(), WsvcServerError> {
399    let guard = RepoGuard::new(repo)
400        .await
401        .map_err(|err| WsvcError::FsError(err))?;
402    let (wanted_records, given_records) = sync_records(repo, ws).await?;
403    let (wanted_trees, given_trees) = sync_trees(repo, ws, wanted_records.as_slice()).await?;
404    let (wanted_blobs, will_given_blobs) =
405        sync_blobs_meta(repo, ws, wanted_trees.as_slice()).await?;
406    // now all wanted trees and blobs are ready in server's and client's memory, now we should sync blob files.
407    sync_blobs(
408        repo,
409        ws,
410        wanted_blobs.as_slice(),
411        will_given_blobs.as_slice(),
412    )
413    .await?;
414
415    // store trees
416    tracing::debug!("write trees to tree database...");
417    let trees_dir = repo.trees_dir().await.map_err(WsvcError::FsError)?;
418    for tree in &given_trees {
419        write(
420            trees_dir.join(tree.hash.0.to_hex().as_str()),
421            serde_json::to_string(tree)
422                .map_err(|err| WsvcError::FsError(WsvcFsError::SerializationFailed(err)))?,
423        )
424        .await
425        .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
426    }
427
428    // store records
429    tracing::debug!("write records to record database...");
430    let records_dir = repo.records_dir().await.map_err(WsvcError::FsError)?;
431    for record in &given_records {
432        write(
433            records_dir.join(record.hash.0.to_hex().as_str()),
434            serde_json::to_string(record)
435                .map_err(|err| WsvcError::FsError(WsvcFsError::SerializationFailed(err)))?,
436        )
437        .await
438        .map_err(|err| WsvcError::FsError(WsvcFsError::Os(err)))?;
439    }
440
441    drop(guard);
442    Ok(())
443}