Skip to main content

cooklang_sync_client/
remote.rs

1use futures::{Stream, StreamExt};
2use path_slash::PathExt as _;
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use uuid::Uuid;
6
7use log::trace;
8
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
10use reqwest::{Client, StatusCode};
11
12use crate::errors::SyncError;
13
14/// User-Agent sent on every request, e.g. "cooklang-sync-client/0.4.11".
15/// Lets the server identify the client version in logs when diagnosing
16/// misbehaving clients.
17const CLIENT_USER_AGENT: &str = concat!("cooklang-sync-client/", env!("CARGO_PKG_VERSION"));
18/// Duplicates the version from User-Agent into a dedicated header so log
19/// pipelines can extract the client version without parsing User-Agent.
20const CLIENT_VERSION_HEADER: &str = "x-client-version";
21type Result<T, E = SyncError> = std::result::Result<T, E>;
22
23pub const REQUEST_TIMEOUT_SECS: u64 = 60;
24
25#[derive(Deserialize, Serialize, Debug)]
26pub struct ResponseFileRecord {
27    pub id: i32,
28    pub path: String,
29    pub deleted: bool,
30    pub chunk_ids: String,
31}
32
33#[derive(Debug, Deserialize, Serialize)]
34pub enum CommitResultStatus {
35    Success(i32),
36    NeedChunks(String),
37}
38
39pub struct Remote {
40    api_endpoint: String,
41    token: String,
42    uuid: String,
43    client: Client,
44}
45
46impl Remote {
47    pub fn new(api_endpoint: &str, token: &str) -> Remote {
48        let mut default_headers = HeaderMap::new();
49        default_headers.insert(USER_AGENT, HeaderValue::from_static(CLIENT_USER_AGENT));
50        default_headers.insert(
51            CLIENT_VERSION_HEADER,
52            HeaderValue::from_static(env!("CARGO_PKG_VERSION")),
53        );
54
55        let client = Client::builder()
56            .gzip(true)
57            .timeout(std::time::Duration::from_secs(REQUEST_TIMEOUT_SECS))
58            .default_headers(default_headers)
59            .build()
60            .unwrap();
61
62        Self {
63            api_endpoint: api_endpoint.into(),
64            uuid: Uuid::new_v4().into(),
65            token: token.into(),
66            client,
67        }
68    }
69}
70impl Remote {
71    fn auth_headers(&self) -> HeaderMap {
72        let auth_value = format!("Bearer {}", self.token);
73
74        let mut headers = HeaderMap::new();
75        headers.insert(AUTHORIZATION, HeaderValue::from_str(&auth_value).unwrap());
76
77        headers
78    }
79
80    pub async fn upload(&self, chunk: &str, content: Vec<u8>) -> Result<()> {
81        trace!("uploading chunk {:?}", chunk);
82
83        let response = self
84            .client
85            .post(self.api_endpoint.clone() + "/chunks/" + chunk)
86            .headers(self.auth_headers())
87            .body(content)
88            .send()
89            .await?;
90
91        match response.status() {
92            StatusCode::OK => Ok(()),
93            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
94            status => Err(SyncError::Unknown(format!(
95                "Upload chunk failed with status: {}",
96                status
97            ))),
98        }
99    }
100
101    pub async fn upload_batch(&self, chunks: Vec<(String, Vec<u8>)>) -> Result<()> {
102        trace!(
103            "uploading chunks {:?}",
104            chunks.iter().map(|(c, _)| c).collect::<Vec<_>>()
105        );
106
107        // Generate a random boundary string
108        let boundary = format!("------------------------{}", Uuid::new_v4());
109        let mut headers = self.auth_headers();
110        headers.insert(
111            "content-type",
112            HeaderValue::from_str(&format!("multipart/form-data; boundary={}", &boundary)).unwrap(),
113        );
114
115        let final_boundary = format!("--{}--\r\n", &boundary).into_bytes();
116
117        // Create a stream of chunk data
118        let stream = futures::stream::iter(chunks)
119            .map(move |(chunk_id, content)| {
120                let part = format!(
121                    "--{boundary}\r\n\
122                 Content-Disposition: form-data; name=\"{chunk_id}\"\r\n\
123                 Content-Type: application/octet-stream\r\n\r\n",
124                    boundary = &boundary,
125                    chunk_id = chunk_id
126                );
127
128                let end = "\r\n".to_string();
129
130                // Combine part header, content, and end into a single stream
131                futures::stream::iter(vec![
132                    Ok::<_, SyncError>(part.into_bytes()),
133                    Ok::<_, SyncError>(content),
134                    Ok::<_, SyncError>(end.into_bytes()),
135                ])
136            })
137            .flatten();
138
139        // Add final boundary
140
141        let stream = stream.chain(futures::stream::once(async move { Ok(final_boundary) }));
142
143        let response = self
144            .client
145            .post(self.api_endpoint.clone() + "/chunks/upload")
146            .headers(headers)
147            .body(reqwest::Body::wrap_stream(stream))
148            .send()
149            .await?;
150
151        match response.status() {
152            StatusCode::OK => Ok(()),
153            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
154            status => Err(SyncError::Unknown(format!(
155                "Upload batch failed with status: {}",
156                status
157            ))),
158        }
159    }
160
161    pub async fn download(&self, chunk: &str) -> Result<Vec<u8>> {
162        trace!("downloading chunk {:?}", chunk);
163
164        let response = self
165            .client
166            .get(self.api_endpoint.clone() + "/chunks/" + chunk)
167            .headers(self.auth_headers())
168            .send()
169            .await?;
170
171        match response.status() {
172            StatusCode::OK => match response.bytes().await {
173                Ok(bytes) => Ok(bytes.to_vec()),
174                Err(_) => Err(SyncError::BodyExtractError),
175            },
176            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
177            status => Err(SyncError::Unknown(format!(
178                "Download chunk failed with status: {}",
179                status
180            ))),
181        }
182    }
183
184    pub async fn list(&self, local_jid: i32) -> Result<Vec<ResponseFileRecord>> {
185        trace!("list after {:?}", local_jid);
186
187        let jid_string = local_jid.to_string();
188
189        let response = self
190            .client
191            .get(self.api_endpoint.clone() + "/metadata/list?jid=" + &jid_string)
192            .headers(self.auth_headers())
193            .send()
194            .await?;
195
196        match response.status() {
197            StatusCode::OK => {
198                let records = response.json::<Vec<ResponseFileRecord>>().await?;
199
200                Ok(records)
201            }
202            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
203            status => Err(SyncError::Unknown(format!(
204                "List metadata failed with status: {}",
205                status
206            ))),
207        }
208    }
209
210    pub async fn poll(&self) -> Result<()> {
211        trace!("started poll");
212
213        // setting its larger than the request timeout to avoid timeouts from the server
214        let seconds = REQUEST_TIMEOUT_SECS + 10;
215
216        let seconds_string = seconds.to_string();
217
218        let response = self
219            .client
220            .get(
221                self.api_endpoint.clone()
222                    + "/metadata/poll?seconds="
223                    + &seconds_string
224                    + "&uuid="
225                    + &self.uuid,
226            )
227            .headers(self.auth_headers())
228            .send()
229            .await;
230
231        // Handle the response, ignoring timeout errors
232        match response {
233            Ok(response) => match response.status() {
234                StatusCode::OK => Ok(()),
235                StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
236                status => Err(SyncError::Unknown(format!(
237                    "Poll metadata failed with status: {}",
238                    status
239                ))),
240            },
241            Err(e) if e.is_timeout() => Ok(()), // Ignore timeout errors
242            Err(e) => Err(e.into()),
243        }
244    }
245
246    pub async fn commit(
247        &self,
248        path: &str,
249        deleted: bool,
250        chunk_ids: &str,
251    ) -> Result<CommitResultStatus> {
252        trace!("commit {:?}", path);
253
254        let path = Path::new(path);
255
256        let params = [
257            ("deleted", if deleted { "true" } else { "false" }),
258            ("chunk_ids", chunk_ids),
259            ("path", &path.to_slash().unwrap()),
260        ];
261
262        let response = self
263            .client
264            .post(self.api_endpoint.clone() + "/metadata/commit" + "?uuid=" + &self.uuid)
265            .headers(self.auth_headers())
266            .form(&params)
267            .send()
268            .await?;
269
270        match response.status() {
271            StatusCode::OK => {
272                let records = response.json::<CommitResultStatus>().await?;
273
274                Ok(records)
275            }
276            StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized),
277            status => Err(SyncError::Unknown(format!(
278                "Commit metadata failed with status: {}",
279                status
280            ))),
281        }
282    }
283
284    pub async fn download_batch<'a>(
285        &'a self,
286        chunk_ids: Vec<&'a str>,
287    ) -> impl Stream<Item = Result<(String, Vec<u8>)>> + Unpin + 'a {
288        Box::pin(async_stream::try_stream! {
289            trace!("Starting download_batch with chunk_ids: {:?}", chunk_ids);
290
291            let params: Vec<(&str, &str)> = chunk_ids.iter().map(|&id| ("chunk_ids[]", id)).collect();
292
293            let response = self
294                .client
295                .post(self.api_endpoint.clone() + "/chunks/download")
296                .headers(self.auth_headers())
297                .form(&params)
298                .send()
299                .await?;
300            trace!("Received response with status: {:?}", response.status());
301
302            match response.status() {
303                StatusCode::OK => {
304                    let content_type = response
305                        .headers()
306                        .get("content-type")
307                        .and_then(|v| v.to_str().ok())
308                        .ok_or(SyncError::BatchDownloadError(
309                            "No content-type header".to_string(),
310                        ))?
311                        .to_string();
312
313                    let boundary = content_type
314                        .split("boundary=")
315                        .nth(1)
316                        .ok_or(SyncError::BatchDownloadError(
317                            "No boundary in content-type header".to_string(),
318                        ))?;
319
320                    let boundary_bytes = format!("--{}", boundary).into_bytes();
321
322                    let mut stream = response.bytes_stream();
323                    let mut buffer = Vec::new();
324
325                    while let Some(chunk) = stream.next().await {
326                        let chunk = chunk?;
327                        buffer.extend_from_slice(&chunk);
328
329                        // Process complete parts from buffer
330                        while let Some((part, remaining)) = extract_next_part(&buffer, &boundary_bytes)? {
331                            if let Some((chunk_id, content)) = process_part(&part)? {
332                                yield (chunk_id, content);
333                            }
334                            buffer = remaining;
335                        }
336                    }
337                }
338                StatusCode::UNAUTHORIZED => Err(SyncError::Unauthorized)?,
339                status => Err(SyncError::Unknown(format!("Download batch failed with status: {}", status)))?,
340            }
341        })
342    }
343}
344
345// Helper function to extract the next complete part from the buffer
346fn extract_next_part(buffer: &[u8], boundary: &[u8]) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
347    if let Some(start) = find_boundary(buffer, boundary) {
348        if let Some(next_boundary) = find_boundary(&buffer[start + boundary.len()..], boundary) {
349            let part =
350                buffer[start + boundary.len()..start + boundary.len() + next_boundary].to_vec();
351            let remaining = buffer[start + boundary.len() + next_boundary..].to_vec();
352            Ok(Some((part, remaining)))
353        } else {
354            Ok(None) // Need more data
355        }
356    } else {
357        Ok(None) // Need more data
358    }
359}
360
361// Helper function to process a single part
362fn process_part(part: &[u8]) -> Result<Option<(String, Vec<u8>)>> {
363    if let Some(headers_end) = find_double_crlf(part) {
364        let headers = std::str::from_utf8(&part[..headers_end])
365            .map_err(|_| SyncError::BatchDownloadError("Invalid headers".to_string()))?;
366
367        let chunk_id = headers
368            .lines()
369            .find(|line| line.starts_with("X-Chunk-ID:"))
370            .and_then(|line| line.split(": ").nth(1))
371            .ok_or(SyncError::BatchDownloadError(
372                "No chunk ID found".to_string(),
373            ))?
374            .trim()
375            .to_string();
376
377        // remove last 2 bytes as they are the boundary
378        let content = part[headers_end + 4..part.len() - 2].to_vec();
379        Ok(Some((chunk_id, content)))
380    } else {
381        Ok(None)
382    }
383}
384
385// Helper function to find boundary in buffer
386fn find_boundary(data: &[u8], boundary: &[u8]) -> Option<usize> {
387    data.windows(boundary.len())
388        .position(|window| window == boundary)
389}
390
391// Helper function to find double CRLF
392fn find_double_crlf(data: &[u8]) -> Option<usize> {
393    data.windows(4).position(|window| window == b"\r\n\r\n")
394}