rs_transfer/writer/
s3_writer.rs

1use crate::{
2  StreamData,
3  endpoint::S3Endpoint,
4  error::Error,
5  writer::{StreamWriter, WriteJob},
6};
7use async_std::channel::Receiver;
8use async_trait::async_trait;
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
11use std::{
12  sync::mpsc::{Sender, channel},
13  thread,
14  time::Duration,
15};
16use threadpool::ThreadPool;
17
18impl S3Endpoint {
19  pub async fn start_multi_part_s3_upload(&self, path: &str) -> Result<String, Error> {
20    let client = self.connection();
21    let bucket = self.bucket().to_string();
22    let key = path.to_string().clone();
23
24    let object = client
25      .create_multipart_upload()
26      .bucket(bucket)
27      .key(key)
28      .send()
29      .await?;
30
31    object
32      .upload_id
33      .ok_or_else(|| Error::Other("Cannot retrieve upload ID from object".to_string()))
34  }
35
36  pub async fn upload_s3_part(
37    &self,
38    path: &str,
39    upload_id: &str,
40    part_number: i32,
41    data: Vec<u8>,
42  ) -> Result<CompletedPart, Error> {
43    let bucket = self.bucket().to_string();
44    let key = path.to_string();
45    let cloned_upload_id = upload_id.to_string();
46
47    let client = self.connection();
48
49    let object = client
50      .upload_part()
51      .body(ByteStream::from(data))
52      .bucket(bucket)
53      .key(key)
54      .upload_id(cloned_upload_id)
55      .part_number(part_number)
56      .send()
57      .await?;
58
59    Ok(
60      CompletedPart::builder()
61        .set_e_tag(object.e_tag)
62        .set_part_number(Some(part_number))
63        .build(),
64    )
65  }
66
67  async fn upload_s3_part_and_send(
68    &self,
69    cloned_path: &str,
70    upload_identifier: &str,
71    part_number: i32,
72    part_buffer: Vec<u8>,
73    part_sender: Sender<CompletedPart>,
74  ) -> Result<(), Error> {
75    let writer = self.clone();
76    let path = cloned_path.to_string();
77    let upload_identifier = upload_identifier.to_string();
78
79    writer
80      .upload_s3_part(&path, &upload_identifier, part_number, part_buffer)
81      .await
82      .and_then(|part_id| part_sender.send(part_id).map_err(|e| e.into()))
83  }
84
85  pub async fn complete_s3_upload(
86    &self,
87    path: &str,
88    upload_id: &str,
89    parts: Vec<CompletedPart>,
90  ) -> Result<(), Error> {
91    let bucket = self.bucket().to_string();
92    let key = path.to_string();
93    let cloned_upload_id = upload_id.to_string();
94    let multipart_upload = CompletedMultipartUpload::builder()
95      .set_parts(Some(parts))
96      .build();
97
98    let client = self.connection();
99
100    let _response = client
101      .complete_multipart_upload()
102      .bucket(bucket)
103      .key(key)
104      .upload_id(cloned_upload_id)
105      .multipart_upload(multipart_upload)
106      .send()
107      .await?;
108
109    Ok(())
110  }
111}
112
113#[async_trait]
114impl StreamWriter for S3Endpoint {
115  async fn write_stream(
116    &self,
117    path: &str,
118    receiver: Receiver<StreamData>,
119    job_and_notification: &dyn WriteJob,
120  ) -> Result<(), Error> {
121    let upload_identifier = self.start_multi_part_s3_upload(path).await?;
122
123    let mut part_number = 1;
124
125    // limited to 10000 parts
126    let part_size = std::env::var("S3_WRITER_PART_SIZE")
127      .map(|buffer_size| buffer_size.parse::<usize>())
128      .unwrap_or_else(|_| Ok(10 * 1024 * 1024))
129      .unwrap_or(10 * 1024 * 1024);
130
131    let mut part_buffer: Vec<u8> = Vec::with_capacity(part_size);
132
133    let n_workers = std::env::var("S3_WRITER_WORKERS")
134      .map(|buffer_size| buffer_size.parse::<usize>())
135      .unwrap_or_else(|_| Ok(4))
136      .unwrap_or(4);
137
138    let mut n_jobs = 0;
139    let pool = ThreadPool::new(n_workers);
140
141    let mut file_size = None;
142    let mut received_bytes = 0;
143    let mut prev_percent = 0;
144
145    let (part_sender, part_receiver) = channel();
146
147    while let Ok(mut stream_data) = receiver.recv().await {
148      match stream_data {
149        StreamData::Size(size) => file_size = Some(size),
150        StreamData::Stop => break,
151        StreamData::Eof => {
152          n_jobs += 1;
153          self
154            .upload_s3_part_and_send(
155              path,
156              &upload_identifier,
157              part_number,
158              part_buffer.clone(),
159              part_sender.clone(),
160            )
161            .await?;
162
163          let mut complete_parts = part_receiver
164            .iter()
165            .take(n_jobs)
166            .collect::<Vec<CompletedPart>>();
167          complete_parts.sort_by(|part1, part2| part1.part_number.cmp(&part2.part_number));
168
169          self
170            .complete_s3_upload(path, &upload_identifier, complete_parts)
171            .await?;
172
173          break;
174        }
175        StreamData::Data(ref mut data) => {
176          received_bytes += data.len();
177          if let Some(file_size) = file_size {
178            let percent = (received_bytes as f32 / file_size as f32 * 100.0) as u8;
179
180            if percent > prev_percent {
181              prev_percent = percent;
182              job_and_notification.progress(percent)?;
183            }
184          }
185
186          part_buffer.append(data);
187
188          if part_buffer.len() > part_size {
189            while pool.queued_count() > 1 {
190              thread::sleep(Duration::from_millis(500));
191            }
192
193            self
194              .upload_s3_part_and_send(
195                path,
196                &upload_identifier,
197                part_number,
198                part_buffer.clone(),
199                part_sender.clone(),
200              )
201              .await?;
202
203            n_jobs += 1;
204
205            part_number += 1;
206            part_buffer.clear();
207          }
208        }
209      }
210    }
211    Ok(())
212  }
213}