Skip to main content

cooklang_sync_client/
remote.rs

1use path_slash::PathExt as _;
2use serde::{Deserialize, Serialize};
3use std::path::Path;
4use uuid::Uuid;
5
6use log::trace;
7
8use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
9use reqwest::{Client, StatusCode};
10
11use futures::{Stream, StreamExt};
12
13use crate::errors::SyncError;
14type Result<T, E = SyncError> = std::result::Result<T, E>;
15
16pub const REQUEST_TIMEOUT_SECS: u64 = 60;
17
18#[derive(Deserialize, Serialize, Debug)]
19pub struct ResponseFileRecord {
20    pub id: i32,
21    pub path: String,
22    pub deleted: bool,
23    pub chunk_ids: String,
24}
25
26#[derive(Debug, Deserialize, Serialize)]
27pub enum CommitResultStatus {
28    Success(i32),
29    NeedChunks(String),
30}
31
32pub struct Remote {
33    api_endpoint: String,
34    token: String,
35    uuid: String,
36    client: Client,
37}
38
39impl Remote {
40    pub fn new(api_endpoint: &str, token: &str) -> Remote {
41        let client = Client::builder()
42            .gzip(true)
43            .timeout(std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS))
44            .build()
45            .unwrap();
46
47        Self {
48            api_endpoint: api_endpoint.into(),
49            uuid: Uuid::new_v4().into(),
50            token: token.into(),
51            client,
52        }
53    }
54}
55impl Remote {
56    fn auth_headers(&self) -> HeaderMap {
57        let auth_value = format!("Bearer {}", self.token);
58
59        let mut headers = HeaderMap::new();
60        headers.insert(AUTHORIZATION, HeaderValue::from_str(&auth_value).unwrap());
61
62        headers
63    }
64
65    pub async fn upload(&self, chunk: &str, content: Vec<u8>) -> Result<()> {
66        trace!("uploading chunk {:?}", chunk);
67
68        let response = self
69            .client
70            .post(self.api_endpoint.clone() + "/chunks/" + chunk)
71            .headers(self.auth_headers())
72            .body(content)
73            .send()
74            .await?;
75
76        match response.status() {
77            StatusCode::OK => Ok(()),
78            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
79            status => Err(SyncError::Unknown(format!(
80                "Upload chunk failed with status: {}",
81                status
82            ))),
83        }
84    }
85
86    pub async fn upload_batch(&self, chunks: Vec<(String, Vec<u8>)>) -> Result<()> {
87        trace!(
88            "uploading chunks {:?}",
89            chunks.iter().map(|(c, _)| c).collect::<Vec<_>>()
90        );
91
92        // Generate a random boundary string
93        let boundary = format!("------------------------{}", Uuid::new_v4());
94        let mut headers = self.auth_headers();
95        headers.insert(
96            "content-type",
97            HeaderValue::from_str(&format!("multipart/form-data; boundary={}", &boundary)).unwrap(),
98        );
99
100        let final_boundary = format!("--{}--\r\n", &boundary).into_bytes();
101
102        // Create a stream of chunk data
103        let stream = futures::stream::iter(chunks)
104            .map(move |(chunk_id, content)| {
105                let part = format!(
106                    "--{boundary}\r\n\
107                 Content-Disposition: form-data; name=\"{chunk_id}\"\r\n\
108                 Content-Type: application/octet-stream\r\n\r\n",
109                    boundary = &boundary,
110                    chunk_id = chunk_id
111                );
112
113                let end = "\r\n".to_string();
114
115                // Combine part header, content, and end into a single stream
116                futures::stream::iter(vec![
117                    Ok::<_, SyncError>(part.into_bytes()),
118                    Ok::<_, SyncError>(content),
119                    Ok::<_, SyncError>(end.into_bytes()),
120                ])
121            })
122            .flatten();
123
124        // Add final boundary
125
126        let stream = stream.chain(futures::stream::once(async move { Ok(final_boundary) }));
127
128        let response = self
129            .client
130            .post(self.api_endpoint.clone() + "/chunks/upload")
131            .headers(headers)
132            .body(reqwest::Body::wrap_stream(stream))
133            .send()
134            .await?;
135
136        match response.status() {
137            StatusCode::OK => Ok(()),
138            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
139            status => Err(SyncError::Unknown(format!(
140                "Upload batch failed with status: {}",
141                status
142            ))),
143        }
144    }
145
146    pub async fn download(&self, chunk: &str) -> Result<Vec<u8>> {
147        trace!("downloading chunk {:?}", chunk);
148
149        let response = self
150            .client
151            .get(self.api_endpoint.clone() + "/chunks/" + chunk)
152            .headers(self.auth_headers())
153            .send()
154            .await?;
155
156        match response.status() {
157            StatusCode::OK => match response.bytes().await {
158                Ok(bytes) => Ok(bytes.to_vec()),
159                Err(_) => Err(SyncError::BodyExtractError),
160            },
161            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
162            status => Err(SyncError::Unknown(format!(
163                "Download chunk failed with status: {}",
164                status
165            ))),
166        }
167    }
168
169    pub async fn list(&self, local_jid: i32) -> Result<Vec<ResponseFileRecord>> {
170        trace!("list after {:?}", local_jid);
171
172        let jid_string = local_jid.to_string();
173
174        let response = self
175            .client
176            .get(self.api_endpoint.clone() + "/metadata/list?jid=" + &jid_string)
177            .headers(self.auth_headers())
178            .send()
179            .await?;
180
181        match response.status() {
182            StatusCode::OK => {
183                let records = response.json::<Vec<ResponseFileRecord>>().await?;
184
185                Ok(records)
186            }
187            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
188            status => Err(SyncError::Unknown(format!(
189                "List metadata failed with status: {}",
190                status
191            ))),
192        }
193    }
194
195    pub async fn poll(&self) -> Result<()> {
196        trace!("started poll");
197
198        // setting its larger than the request timeout to avoid timeouts from the server
199        let seconds = REQUEST_TIMEOUT_SECS + 10;
200
201        let seconds_string = seconds.to_string();
202
203        let response = self
204            .client
205            .get(
206                self.api_endpoint.clone()
207                    + "/metadata/poll?seconds="
208                    + &seconds_string
209                    + "&uuid="
210                    + &self.uuid,
211            )
212            .headers(self.auth_headers())
213            .send()
214            .await;
215
216        // Handle the response, ignoring timeout errors
217        match response {
218            Ok(response) => match response.status() {
219                StatusCode::OK => Ok(()),
220                StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
221                status => Err(SyncError::Unknown(format!(
222                    "Poll metadata failed with status: {}",
223                    status
224                ))),
225            },
226            Err(e) if e.is_timeout() => Ok(()), // Ignore timeout errors
227            Err(e) => Err(e.into()),
228        }
229    }
230
231    pub async fn commit(
232        &self,
233        path: &str,
234        deleted: bool,
235        chunk_ids: &str,
236    ) -> Result<CommitResultStatus> {
237        trace!("commit {:?}", path);
238
239        let path = Path::new(path);
240
241        let params = [
242            ("deleted", if deleted { "true" } else { "false" }),
243            ("chunk_ids", chunk_ids),
244            ("path", &path.to_slash().unwrap()),
245        ];
246
247        let response = self
248            .client
249            .post(self.api_endpoint.clone() + "/metadata/commit" + "?uuid=" + &self.uuid)
250            .headers(self.auth_headers())
251            .form(&params)
252            .send()
253            .await?;
254
255        match response.status() {
256            StatusCode::OK => {
257                let records = response.json::<CommitResultStatus>().await?;
258
259                Ok(records)
260            }
261            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
262            status => Err(SyncError::Unknown(format!(
263                "Commit metadata failed with status: {}",
264                status
265            ))),
266        }
267    }
268
269    pub async fn download_batch<'a>(
270        &'a self,
271        chunk_ids: Vec<&'a str>,
272    ) -> impl Stream<Item = Result<(String, Vec<u8>)>> + Unpin + 'a {
273        Box::pin(async_stream::try_stream! {
274            trace!("Starting download_batch with chunk_ids: {:?}", chunk_ids);
275
276            let params: Vec<(&str, &str)> = chunk_ids.iter().map(|&id| ("chunk_ids[]", id)).collect();
277
278            let response = self
279                .client
280                .post(self.api_endpoint.clone() + "/chunks/download")
281                .headers(self.auth_headers())
282                .form(&params)
283                .send()
284                .await?;
285            trace!("Received response with status: {:?}", response.status());
286
287            match response.status() {
288                StatusCode::OK => {
289                    let content_type = response
290                        .headers()
291                        .get("content-type")
292                        .and_then(|v| v.to_str().ok())
293                        .ok_or(SyncError::BatchDownloadError(
294                            "No content-type header".to_string(),
295                        ))?
296                        .to_string();
297
298                    let boundary = content_type
299                        .split("boundary=")
300                        .nth(1)
301                        .ok_or(SyncError::BatchDownloadError(
302                            "No boundary in content-type header".to_string(),
303                        ))?;
304
305                    let boundary_bytes = format!("--{}", boundary).into_bytes();
306
307                    let mut stream = response.bytes_stream();
308                    let mut buffer = Vec::new();
309
310                    while let Some(chunk) = stream.next().await {
311                        let chunk = chunk?;
312                        buffer.extend_from_slice(&chunk);
313
314                        // Process complete parts from buffer
315                        while let Some((part, remaining)) = extract_next_part(&buffer, &boundary_bytes)? {
316                            if let Some((chunk_id, content)) = process_part(&part)? {
317                                yield (chunk_id, content);
318                            }
319                            buffer = remaining;
320                        }
321                    }
322                }
323                StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized)?,
324                status => Err(SyncError::Unknown(format!("Download batch failed with status: {}", status)))?,
325            }
326        })
327    }
328}
329
330// Helper function to extract the next complete part from the buffer
331fn extract_next_part(buffer: &[u8], boundary: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
332    if let Some(start) = find_boundary(buffer, boundary) {
333        if let Some(next_boundary) = find_boundary(&buffer[start + boundary.len()..], boundary) {
334            let part =
335                buffer[start + boundary.len()..start + boundary.len() + next_boundary].to_vec();
336            let remaining = buffer[start + boundary.len() + next_boundary..].to_vec();
337            Ok(Some((part, remaining)))
338        } else {
339            Ok(None) // Need more data
340        }
341    } else {
342        Ok(None) // Need more data
343    }
344}
345
346// Helper function to process a single part
347fn process_part(part: &[u8]) -> Result<Option<(String, Vec<u8>)>> {
348    if let Some(headers_end) = find_double_crlf(part) {
349        let headers = std::str::from_utf8(&part[..headers_end])
350            .map_err(|_| SyncError::BatchDownloadError("Invalid headers".to_string()))?;
351
352        let chunk_id = headers
353            .lines()
354            .find(|line| line.starts_with("X-Chunk-ID:"))
355            .and_then(|line| line.split(": ").nth(1))
356            .ok_or(SyncError::BatchDownloadError(
357                "No chunk ID found".to_string(),
358            ))?
359            .trim()
360            .to_string();
361
362        // remove last 2 bytes as they are the boundary
363        let content = part[headers_end + 4..part.len() - 2].to_vec();
364        Ok(Some((chunk_id, content)))
365    } else {
366        Ok(None)
367    }
368}
369
370// Helper function to find boundary in buffer
371fn find_boundary(data: &[u8], boundary: &[u8]) -> Option<usize> {
372    data.windows(boundary.len())
373        .position(|window| window == boundary)
374}
375
376// Helper function to find double CRLF
377fn find_double_crlf(data: &[u8]) -> Option<usize> {
378    data.windows(4).position(|window| window == b"\r\n\r\n")
379}