cooklang_sync_client/
remote.rs

1use path_slash::PathExt as _;
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4use uuid::Uuid;
5
6use log::trace;
7
8use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
9use reqwest::StatusCode;
10use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
11
12use futures::{Stream, StreamExt};
13
14use crate::errors::SyncError;
15type Result<T, E = SyncError> = std::result::Result<T, E>;
16
17pub const REQUEST_TIMEOUT_SECS: u64 = 60;
18
19#[derive(Deserialize, Serialize, Debug)]
20pub struct ResponseFileRecord {
21    pub id: i32,
22    pub path: String,
23    pub deleted: bool,
24    pub chunk_ids: String,
25}
26
27#[derive(Debug, Deserialize, Serialize)]
28pub enum CommitResultStatus {
29    Success(i32),
30    NeedChunks(String),
31}
32
33pub struct Remote {
34    api_endpoint: String,
35    token: String,
36    uuid: String,
37    client: ClientWithMiddleware,
38}
39
40impl Remote {
41    pub fn new(api_endpoint: &str, token: &str) -> Remote {
42        let rc = reqwest::ClientBuilder::new()
43            .gzip(true)
44            .timeout(std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS))
45            .build()
46            .unwrap();
47        let client = ClientBuilder::new(rc)
48            // .with(OriginalHeadersMiddleware)
49            .build();
50
51        Self {
52            api_endpoint: api_endpoint.into(),
53            uuid: Uuid::new_v4().into(),
54            token: token.into(),
55            client,
56        }
57    }
58}
59impl Remote {
60    fn auth_headers(&self) -> HeaderMap {
61        let auth_value = format!("Bearer {}", self.token);
62
63        let mut headers = HeaderMap::new();
64        headers.insert(AUTHORIZATION, HeaderValue::from_str(&auth_value).unwrap());
65
66        headers
67    }
68
69    pub async fn upload(&self, chunk: &str, content: Vec<u8>) -> Result<()> {
70        trace!("uploading chunk {:?}", chunk);
71
72        let response = self
73            .client
74            .post(self.api_endpoint.clone() + "/chunks/" + chunk)
75            .headers(self.auth_headers())
76            .body(content)
77            .send()
78            .await?;
79
80        match response.status() {
81            StatusCode::OK => Ok(()),
82            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
83            status => Err(SyncError::Unknown(format!(
84                "Upload chunk failed with status: {}",
85                status
86            ))),
87        }
88    }
89
90    pub async fn upload_batch(&self, chunks: Vec<(String, Vec<u8>)>) -> Result<()> {
91        trace!(
92            "uploading chunks {:?}",
93            chunks.iter().map(|(c, _)| c).collect::<Vec<_>>()
94        );
95
96        // Generate a random boundary string
97        let boundary = format!("------------------------{}", Uuid::new_v4());
98        let mut headers = self.auth_headers();
99        headers.insert(
100            "content-type",
101            HeaderValue::from_str(&format!("multipart/form-data; boundary={}", &boundary)).unwrap(),
102        );
103
104        let final_boundary = format!("--{}--\r\n", &boundary).into_bytes();
105
106        // Create a stream of chunk data
107        let stream = futures::stream::iter(chunks)
108            .map(move |(chunk_id, content)| {
109                let part = format!(
110                    "--{boundary}\r\n\
111                 Content-Disposition: form-data; name=\"{chunk_id}\"\r\n\
112                 Content-Type: application/octet-stream\r\n\r\n",
113                    boundary = &boundary,
114                    chunk_id = chunk_id
115                );
116
117                let end = "\r\n".to_string();
118
119                // Combine part header, content, and end into a single stream
120                futures::stream::iter(vec![
121                    Ok::<_, SyncError>(part.into_bytes()),
122                    Ok::<_, SyncError>(content),
123                    Ok::<_, SyncError>(end.into_bytes()),
124                ])
125            })
126            .flatten();
127
128        // Add final boundary
129
130        let stream = stream.chain(futures::stream::once(async move { Ok(final_boundary) }));
131
132        let response = self
133            .client
134            .post(self.api_endpoint.clone() + "/chunks/upload")
135            .headers(headers)
136            .body(reqwest::Body::wrap_stream(stream))
137            .send()
138            .await?;
139
140        match response.status() {
141            StatusCode::OK => Ok(()),
142            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
143            status => Err(SyncError::Unknown(format!(
144                "Upload batch failed with status: {}",
145                status
146            ))),
147        }
148    }
149
150    pub async fn download(&self, chunk: &str) -> Result<Vec<u8>> {
151        trace!("downloading chunk {:?}", chunk);
152
153        let response = self
154            .client
155            .get(self.api_endpoint.clone() + "/chunks/" + chunk)
156            .headers(self.auth_headers())
157            .send()
158            .await?;
159
160        match response.status() {
161            StatusCode::OK => match response.bytes().await {
162                Ok(bytes) => Ok(bytes.to_vec()),
163                Err(_) => Err(SyncError::BodyExtractError),
164            },
165            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
166            status => Err(SyncError::Unknown(format!(
167                "Download chunk failed with status: {}",
168                status
169            ))),
170        }
171    }
172
173    pub async fn list(&self, local_jid: i32) -> Result<Vec<ResponseFileRecord>> {
174        trace!("list after {:?}", local_jid);
175
176        let jid_string = local_jid.to_string();
177
178        let response = self
179            .client
180            .get(self.api_endpoint.clone() + "/metadata/list?jid=" + &jid_string)
181            .headers(self.auth_headers())
182            .send()
183            .await?;
184
185        match response.status() {
186            StatusCode::OK => {
187                let records = response.json::<Vec<ResponseFileRecord>>().await?;
188
189                Ok(records)
190            }
191            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
192            status => Err(SyncError::Unknown(format!(
193                "List metadata failed with status: {}",
194                status
195            ))),
196        }
197    }
198
199    pub async fn poll(&self) -> Result<()> {
200        trace!("started poll");
201
202        // setting its larger than the request timeout to avoid timeouts from the server
203        let seconds = REQUEST_TIMEOUT_SECS + 10;
204
205        let seconds_string = seconds.to_string();
206
207        let response = self
208            .client
209            .get(
210                self.api_endpoint.clone()
211                    + "/metadata/poll?seconds="
212                    + &seconds_string
213                    + "&uuid="
214                    + &self.uuid,
215            )
216            .headers(self.auth_headers())
217            .send()
218            .await;
219
220        // Handle the response, ignoring timeout errors
221        match response {
222            Ok(response) => match response.status() {
223                StatusCode::OK => Ok(()),
224                StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
225                status => Err(SyncError::Unknown(format!(
226                    "Poll metadata failed with status: {}",
227                    status
228                ))),
229            },
230            Err(e) if e.is_timeout() => Ok(()), // Ignore timeout errors
231            Err(e) => Err(e.into()),
232        }
233    }
234
235    pub async fn commit(
236        &self,
237        path: &str,
238        deleted: bool,
239        chunk_ids: &str,
240    ) -> Result<CommitResultStatus> {
241        trace!("commit {:?}", path);
242
243        let path = Path::new(path);
244
245        let params = [
246            ("deleted", if deleted { "true" } else { "false" }),
247            ("chunk_ids", chunk_ids),
248            ("path", &path.to_slash().unwrap()),
249        ];
250
251        let response = self
252            .client
253            .post(self.api_endpoint.clone() + "/metadata/commit" + "?uuid=" + &self.uuid)
254            .headers(self.auth_headers())
255            .form(&params)
256            .send()
257            .await?;
258
259        match response.status() {
260            StatusCode::OK => {
261                let records = response.json::<CommitResultStatus>().await?;
262
263                Ok(records)
264            }
265            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
266            status => Err(SyncError::Unknown(format!(
267                "Commit metadata failed with status: {}",
268                status
269            ))),
270        }
271    }
272
273    pub async fn download_batch<'a>(
274        &'a self,
275        chunk_ids: Vec<&'a str>,
276    ) -> impl Stream<Item = Result<(String, Vec<u8>)>> + Unpin + 'a {
277        Box::pin(async_stream::try_stream! {
278            trace!("Starting download_batch with chunk_ids: {:?}", chunk_ids);
279
280            let params: Vec<(&str, &str)> = chunk_ids.iter().map(|&id| ("chunk_ids[]", id)).collect();
281
282            let response = self
283                .client
284                .post(self.api_endpoint.clone() + "/chunks/download")
285                .headers(self.auth_headers())
286                .form(&params)
287                .send()
288                .await?;
289            trace!("Received response with status: {:?}", response.status());
290
291            match response.status() {
292                StatusCode::OK => {
293                    let content_type = response
294                        .headers()
295                        .get("content-type")
296                        .and_then(|v| v.to_str().ok())
297                        .ok_or(SyncError::BatchDownloadError(
298                            "No content-type header".to_string(),
299                        ))?
300                        .to_string();
301
302                    let boundary = content_type
303                        .split("boundary=")
304                        .nth(1)
305                        .ok_or(SyncError::BatchDownloadError(
306                            "No boundary in content-type header".to_string(),
307                        ))?;
308
309                    let boundary_bytes = format!("--{}", boundary).into_bytes();
310
311                    let mut stream = response.bytes_stream();
312                    let mut buffer = Vec::new();
313
314                    while let Some(chunk) = stream.next().await {
315                        let chunk = chunk?;
316                        buffer.extend_from_slice(&chunk);
317
318                        // Process complete parts from buffer
319                        while let Some((part, remaining)) = extract_next_part(&buffer, &boundary_bytes)? {
320                            if let Some((chunk_id, content)) = process_part(&part)? {
321                                yield (chunk_id, content);
322                            }
323                            buffer = remaining;
324                        }
325                    }
326                }
327                StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized)?,
328                status => Err(SyncError::Unknown(format!("Download batch failed with status: {}", status)))?,
329            }
330        })
331    }
332}
333
334// Helper function to extract the next complete part from the buffer
335fn extract_next_part(buffer: &[u8], boundary: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
336    if let Some(start) = find_boundary(buffer, boundary) {
337        if let Some(next_boundary) = find_boundary(&buffer[start + boundary.len()..], boundary) {
338            let part =
339                buffer[start + boundary.len()..start + boundary.len() + next_boundary].to_vec();
340            let remaining = buffer[start + boundary.len() + next_boundary..].to_vec();
341            Ok(Some((part, remaining)))
342        } else {
343            Ok(None) // Need more data
344        }
345    } else {
346        Ok(None) // Need more data
347    }
348}
349
350// Helper function to process a single part
351fn process_part(part: &[u8]) -> Result<Option<(String, Vec<u8>)>> {
352    if let Some(headers_end) = find_double_crlf(part) {
353        let headers = std::str::from_utf8(&part[..headers_end])
354            .map_err(|_| SyncError::BatchDownloadError("Invalid headers".to_string()))?;
355
356        let chunk_id = headers
357            .lines()
358            .find(|line| line.starts_with("X-Chunk-ID:"))
359            .and_then(|line| line.split(": ").nth(1))
360            .ok_or(SyncError::BatchDownloadError(
361                "No chunk ID found".to_string(),
362            ))?
363            .trim()
364            .to_string();
365
366        // remove last 2 bytes as they are the boundary
367        let content = part[headers_end + 4..part.len() - 2].to_vec();
368        Ok(Some((chunk_id, content)))
369    } else {
370        Ok(None)
371    }
372}
373
374// Helper function to find boundary in buffer
375fn find_boundary(data: &[u8], boundary: &[u8]) -> Option<usize> {
376    data.windows(boundary.len())
377        .position(|window| window == boundary)
378}
379
380// Helper function to find double CRLF
381fn find_double_crlf(data: &[u8]) -> Option<usize> {
382    data.windows(4).position(|window| window == b"\r\n\r\n")
383}