1use bytes::{Bytes, BytesMut};
2use candid::{CandidType, Principal};
3use ic_agent::Agent;
4use ic_oss_types::{bucket::*, file::*, folder::*, format_error};
5use serde::{Deserialize, Serialize};
6use serde_bytes::{ByteArray, ByteBuf};
7use sha3::{Digest, Sha3_256};
8use std::{collections::BTreeSet, sync::Arc};
9use tokio::io::AsyncRead;
10use tokio::sync::{mpsc, RwLock, Semaphore};
11use tokio_stream::StreamExt;
12use tokio_util::codec::{Decoder, FramedRead};
13
14use crate::agent::{query_call, update_call};
15
16#[derive(Clone)]
17pub struct Client {
18 concurrency: u8,
19 agent: Arc<Agent>,
20 bucket: Principal,
21 set_readonly: bool,
22 access_token: Option<ByteBuf>,
23}
24
25#[derive(CandidType, Clone, Debug, Default, Deserialize, Serialize)]
26pub struct UploadFileChunksResult {
27 pub id: u32,
28 pub filled: u64,
29 pub uploaded_chunks: BTreeSet<u32>,
30 pub error: Option<String>, }
32
33#[derive(Clone, Debug, Default, Deserialize, Serialize)]
34pub struct Progress {
35 pub filled: u64,
36 pub size: Option<u64>, pub chunk_index: u32,
38 pub concurrency: u8,
39}
40
41impl Client {
42 pub fn new(agent: Arc<Agent>, bucket: Principal) -> Client {
43 Client {
44 concurrency: 16,
45 agent,
46 bucket,
47 set_readonly: false,
48 access_token: None,
49 }
50 }
51
52 pub fn set_concurrency(&mut self, concurrency: u8) {
53 if concurrency > 0 && concurrency <= 64 {
54 self.concurrency = concurrency;
55 }
56 }
57
58 pub fn set_readonly(&mut self, readonly: bool) {
59 self.set_readonly = readonly;
60 }
61
62 pub async fn admin_set_managers(&self, args: BTreeSet<Principal>) -> Result<(), String> {
64 update_call(&self.agent, &self.bucket, "admin_set_managers", (args,)).await?
65 }
66
67 pub async fn admin_set_auditors(&self, args: BTreeSet<Principal>) -> Result<(), String> {
69 update_call(&self.agent, &self.bucket, "admin_set_auditors", (args,)).await?
70 }
71
72 pub async fn admin_update_bucket(&self, args: UpdateBucketInput) -> Result<(), String> {
74 update_call(&self.agent, &self.bucket, "admin_update_bucket", (args,)).await?
75 }
76
77 pub async fn get_bucket_info(&self) -> Result<BucketInfo, String> {
78 query_call(
79 &self.agent,
80 &self.bucket,
81 "get_bucket_info",
82 (&self.access_token,),
83 )
84 .await?
85 }
86
87 pub async fn get_file_info(&self, id: u32) -> Result<FileInfo, String> {
88 query_call(
89 &self.agent,
90 &self.bucket,
91 "get_file_info",
92 (id, &self.access_token),
93 )
94 .await?
95 }
96
97 pub async fn get_file_info_by_hash(&self, hash: ByteArray<32>) -> Result<FileInfo, String> {
98 query_call(
99 &self.agent,
100 &self.bucket,
101 "get_file_info_by_hash",
102 (hash, &self.access_token),
103 )
104 .await?
105 }
106
107 pub async fn get_file_ancestors(&self, id: u32) -> Result<Vec<FolderName>, String> {
108 query_call(
109 &self.agent,
110 &self.bucket,
111 "get_file_ancestors",
112 (id, &self.access_token),
113 )
114 .await?
115 }
116
117 pub async fn get_file_chunks(
118 &self,
119 id: u32,
120 index: u32,
121 take: Option<u32>,
122 ) -> Result<Vec<FileChunk>, String> {
123 query_call(
124 &self.agent,
125 &self.bucket,
126 "get_file_chunks",
127 (id, index, take, &self.access_token),
128 )
129 .await?
130 }
131
132 pub async fn list_files(
133 &self,
134 parent: u32,
135 prev: Option<u32>,
136 take: Option<u32>,
137 ) -> Result<Vec<FileInfo>, String> {
138 query_call(
139 &self.agent,
140 &self.bucket,
141 "list_files",
142 (parent, prev, take, &self.access_token),
143 )
144 .await?
145 }
146
147 pub async fn get_folder_info(&self, id: u32) -> Result<FolderInfo, String> {
148 query_call(
149 &self.agent,
150 &self.bucket,
151 "get_folder_info",
152 (id, &self.access_token),
153 )
154 .await?
155 }
156
157 pub async fn get_folder_ancestors(&self, id: u32) -> Result<Vec<FolderName>, String> {
158 query_call(
159 &self.agent,
160 &self.bucket,
161 "get_folder_ancestors",
162 (id, &self.access_token),
163 )
164 .await?
165 }
166
167 pub async fn list_folders(
168 &self,
169 parent: u32,
170 prev: Option<u32>,
171 take: Option<u32>,
172 ) -> Result<Vec<FolderInfo>, String> {
173 query_call(
174 &self.agent,
175 &self.bucket,
176 "list_folders",
177 (parent, prev, take, &self.access_token),
178 )
179 .await?
180 }
181
182 pub async fn create_file(&self, file: CreateFileInput) -> Result<CreateFileOutput, String> {
183 update_call(
184 &self.agent,
185 &self.bucket,
186 "create_file",
187 (file, &self.access_token),
188 )
189 .await?
190 }
191
192 pub async fn update_file_chunk(
193 &self,
194 input: UpdateFileChunkInput,
195 ) -> Result<UpdateFileChunkOutput, String> {
196 update_call(
197 &self.agent,
198 &self.bucket,
199 "update_file_chunk",
200 (input, &self.access_token),
201 )
202 .await?
203 }
204
205 pub async fn update_file_info(
206 &self,
207 input: UpdateFileInput,
208 ) -> Result<UpdateFileOutput, String> {
209 update_call(
210 &self.agent,
211 &self.bucket,
212 "update_file_info",
213 (input, &self.access_token),
214 )
215 .await?
216 }
217
218 pub async fn move_file(&self, input: MoveInput) -> Result<UpdateFileOutput, String> {
219 update_call(
220 &self.agent,
221 &self.bucket,
222 "move_file",
223 (input, &self.access_token),
224 )
225 .await?
226 }
227
228 pub async fn delete_file(&self, id: u32) -> Result<bool, String> {
229 update_call(
230 &self.agent,
231 &self.bucket,
232 "delete_file",
233 (id, &self.access_token),
234 )
235 .await?
236 }
237
238 pub async fn batch_delete_subfiles(
239 &self,
240 parent: u32,
241 ids: BTreeSet<u32>,
242 ) -> Result<Vec<u32>, String> {
243 update_call(
244 &self.agent,
245 &self.bucket,
246 "batch_delete_subfiles",
247 (parent, ids, &self.access_token),
248 )
249 .await?
250 }
251
252 pub async fn create_folder(
253 &self,
254 input: CreateFolderInput,
255 ) -> Result<CreateFolderOutput, String> {
256 update_call(
257 &self.agent,
258 &self.bucket,
259 "create_folder",
260 (input, &self.access_token),
261 )
262 .await?
263 }
264
265 pub async fn update_folder_info(
266 &self,
267 input: UpdateFolderInput,
268 ) -> Result<UpdateFolderOutput, String> {
269 update_call(
270 &self.agent,
271 &self.bucket,
272 "update_folder_info",
273 (input, &self.access_token),
274 )
275 .await?
276 }
277
278 pub async fn move_folder(&self, input: MoveInput) -> Result<UpdateFolderOutput, String> {
279 update_call(
280 &self.agent,
281 &self.bucket,
282 "move_folder",
283 (input, &self.access_token),
284 )
285 .await?
286 }
287
288 pub async fn delete_folder(&self, id: u32) -> Result<bool, String> {
289 update_call(
290 &self.agent,
291 &self.bucket,
292 "delete_folder",
293 (id, &self.access_token),
294 )
295 .await?
296 }
297
298 pub async fn upload<T, F>(
299 &self,
300 stream: T,
301 mut file: CreateFileInput,
302 on_progress: F,
303 ) -> Result<UploadFileChunksResult, String>
304 where
305 T: AsyncRead,
306 F: Fn(Progress),
307 {
308 if let Some(size) = file.size {
309 if size <= MAX_FILE_SIZE_PER_CALL {
310 let content = try_read_all(stream, size as u32).await?;
312 if file.hash.is_none() {
313 let mut hasher = Sha3_256::new();
314 hasher.update(&content);
315 let hash: [u8; 32] = hasher.finalize().into();
316 file.hash = Some(hash.into());
317 }
318 file.content = Some(ByteBuf::from(content.to_vec()));
319 file.status = if self.set_readonly { Some(1) } else { None };
320 let res = self.create_file(file).await?;
321
322 on_progress(Progress {
323 filled: size,
324 size: Some(size),
325 chunk_index: 0,
326 concurrency: 1,
327 });
328 return Ok(UploadFileChunksResult {
329 id: res.id,
330 filled: size,
331 uploaded_chunks: BTreeSet::new(),
332 error: None,
333 });
334 }
335 }
336
337 let hash = file.hash;
339 let size = file.size;
340 let res = self.create_file(file).await?;
341 let res = self
342 .upload_chunks(stream, res.id, size, hash, &BTreeSet::new(), on_progress)
343 .await;
344 Ok(res)
345 }
346
347 pub async fn upload_chunks<T, F>(
348 &self,
349 stream: T,
350 id: u32,
351 size: Option<u64>,
352 hash: Option<ByteArray<32>>,
353 exclude_chunks: &BTreeSet<u32>,
354 on_progress: F,
355 ) -> UploadFileChunksResult
356 where
357 T: AsyncRead,
358 F: Fn(Progress),
359 {
360 let bucket = self.bucket;
362 let has_hash = hash.is_some();
363 let mut frames = Box::pin(FramedRead::new(stream, ChunksCodec::new(CHUNK_SIZE)));
364 let (tx, mut rx) = mpsc::channel::<Result<Progress, String>>(self.concurrency as usize);
365 let output = Arc::new(RwLock::new(UploadFileChunksResult {
366 id,
367 filled: 0,
368 uploaded_chunks: exclude_chunks.clone(),
369 error: None,
370 }));
371
372 let uploading_loop = async {
373 let mut index = 0;
374 let mut hasher = Sha3_256::new();
375 let semaphore = Arc::new(Semaphore::new(self.concurrency as usize));
376
377 loop {
378 let access_token = self.access_token.clone();
379 let tx1 = tx.clone();
380 let output = output.clone();
381 let permit = semaphore
382 .clone()
383 .acquire_owned()
384 .await
385 .map_err(format_error)?;
386 let concurrency = (self.concurrency as usize - semaphore.available_permits()) as u8;
387
388 match frames.next().await {
389 None => {
390 drop(tx);
391 semaphore.close();
392 return Ok(Into::<[u8; 32]>::into(hasher.finalize()));
393 }
394 Some(Err(err)) => {
395 drop(tx);
396 semaphore.close();
397 return Err(err.to_string());
398 }
399 Some(Ok(chunk)) => {
400 let chunk_index = index;
401 index += 1;
402 let chunk_len = chunk.len() as u32;
403
404 if !has_hash {
405 hasher.update(&chunk);
406 }
407
408 if exclude_chunks.contains(&chunk_index) {
409 let mut r = output.write().await;
410 r.filled += chunk_len as u64;
411 on_progress(Progress {
412 filled: r.filled,
413 size,
414 chunk_index,
415 concurrency: 0,
416 });
417 drop(permit);
418 continue;
419 }
420
421 let agent = self.agent.clone();
422 tokio::spawn(async move {
423 let res = async {
424 let out: Result<UpdateFileChunkOutput, String> = update_call(
425 &agent,
426 &bucket,
427 "update_file_chunk",
428 (
429 UpdateFileChunkInput {
430 id,
431 chunk_index,
432 content: ByteBuf::from(chunk.to_vec()),
433 },
434 &access_token,
435 ),
436 )
437 .await?;
438 let out = out?;
439 Ok(Progress {
440 filled: out.filled,
441 size,
442 chunk_index,
443 concurrency,
444 })
445 }
446 .await;
447
448 if res.is_ok() {
449 let mut r = output.write().await;
450 r.filled += chunk_len as u64;
451 r.uploaded_chunks.insert(chunk_index);
452 drop(permit);
453 }
454 let _ = tx1.send(res).await;
455 });
456 }
457 }
458 }
459 };
460
461 let uploading_result = async {
462 while let Some(res) = rx.recv().await {
463 match res {
464 Ok(progress) => {
465 on_progress(progress);
466 }
467 Err(err) => return Err(err),
468 }
469 }
470
471 Ok(())
472 };
473
474 let result = async {
475 let (hash_new, _) = futures::future::try_join(uploading_loop, uploading_result).await?;
476
477 let _ = self
479 .update_file_info(UpdateFileInput {
480 id,
481 hash: Some(hash.unwrap_or(hash_new.into())),
482 status: if self.set_readonly { Some(1) } else { None },
483 size,
484 ..Default::default()
485 })
486 .await?;
487 Ok::<(), String>(())
488 }
489 .await;
490
491 let mut output = output.read().await.to_owned();
492 if let Err(err) = result {
493 output.error = Some(err);
494 }
495
496 output
497 }
498}
499
500#[derive(Copy, Clone, Debug)]
501pub struct ChunksCodec(u32);
502
503impl ChunksCodec {
504 pub fn new(len: u32) -> ChunksCodec {
505 ChunksCodec(len)
506 }
507}
508
509impl Decoder for ChunksCodec {
510 type Item = Bytes;
511 type Error = tokio::io::Error;
512
513 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
514 if buf.len() >= self.0 as usize {
515 Ok(Some(BytesMut::freeze(buf.split_to(self.0 as usize))))
516 } else {
517 Ok(None)
518 }
519 }
520
521 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
522 if buf.is_empty() {
523 Ok(None)
524 } else {
525 let len = buf.len();
526 Ok(Some(BytesMut::freeze(buf.split_to(len))))
527 }
528 }
529}
530
531async fn try_read_all<T: AsyncRead>(stream: T, size: u32) -> Result<Bytes, String> {
532 let mut frames = Box::pin(FramedRead::new(stream, ChunksCodec::new(size)));
533
534 let res = frames.next().await.ok_or("no bytes to read".to_string())?;
535 if frames.next().await.is_some() {
536 return Err("too many bytes to read".to_string());
537 }
538 let res = res.map_err(format_error)?;
539 if res.len() != size as usize {
540 return Err("insufficient bytes to read".to_string());
541 }
542 Ok(res)
543}