rs_transfer/writer/
s3_writer.rs1use crate::{
2 StreamData,
3 endpoint::S3Endpoint,
4 error::Error,
5 writer::{StreamWriter, WriteJob},
6};
7use async_std::channel::Receiver;
8use async_trait::async_trait;
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
11use std::{
12 sync::mpsc::{Sender, channel},
13 thread,
14 time::Duration,
15};
16use threadpool::ThreadPool;
17
18impl S3Endpoint {
19 pub async fn start_multi_part_s3_upload(&self, path: &str) -> Result<String, Error> {
20 let client = self.connection();
21 let bucket = self.bucket().to_string();
22 let key = path.to_string().clone();
23
24 let object = client
25 .create_multipart_upload()
26 .bucket(bucket)
27 .key(key)
28 .send()
29 .await?;
30
31 object
32 .upload_id
33 .ok_or_else(|| Error::Other("Cannot retrieve upload ID from object".to_string()))
34 }
35
36 pub async fn upload_s3_part(
37 &self,
38 path: &str,
39 upload_id: &str,
40 part_number: i32,
41 data: Vec<u8>,
42 ) -> Result<CompletedPart, Error> {
43 let bucket = self.bucket().to_string();
44 let key = path.to_string();
45 let cloned_upload_id = upload_id.to_string();
46
47 let client = self.connection();
48
49 let object = client
50 .upload_part()
51 .body(ByteStream::from(data))
52 .bucket(bucket)
53 .key(key)
54 .upload_id(cloned_upload_id)
55 .part_number(part_number)
56 .send()
57 .await?;
58
59 Ok(
60 CompletedPart::builder()
61 .set_e_tag(object.e_tag)
62 .set_part_number(Some(part_number))
63 .build(),
64 )
65 }
66
67 async fn upload_s3_part_and_send(
68 &self,
69 cloned_path: &str,
70 upload_identifier: &str,
71 part_number: i32,
72 part_buffer: Vec<u8>,
73 part_sender: Sender<CompletedPart>,
74 ) -> Result<(), Error> {
75 let writer = self.clone();
76 let path = cloned_path.to_string();
77 let upload_identifier = upload_identifier.to_string();
78
79 writer
80 .upload_s3_part(&path, &upload_identifier, part_number, part_buffer)
81 .await
82 .and_then(|part_id| part_sender.send(part_id).map_err(|e| e.into()))
83 }
84
85 pub async fn complete_s3_upload(
86 &self,
87 path: &str,
88 upload_id: &str,
89 parts: Vec<CompletedPart>,
90 ) -> Result<(), Error> {
91 let bucket = self.bucket().to_string();
92 let key = path.to_string();
93 let cloned_upload_id = upload_id.to_string();
94 let multipart_upload = CompletedMultipartUpload::builder()
95 .set_parts(Some(parts))
96 .build();
97
98 let client = self.connection();
99
100 let _response = client
101 .complete_multipart_upload()
102 .bucket(bucket)
103 .key(key)
104 .upload_id(cloned_upload_id)
105 .multipart_upload(multipart_upload)
106 .send()
107 .await?;
108
109 Ok(())
110 }
111}
112
113#[async_trait]
114impl StreamWriter for S3Endpoint {
115 async fn write_stream(
116 &self,
117 path: &str,
118 receiver: Receiver<StreamData>,
119 job_and_notification: &dyn WriteJob,
120 ) -> Result<(), Error> {
121 let upload_identifier = self.start_multi_part_s3_upload(path).await?;
122
123 let mut part_number = 1;
124
125 let part_size = std::env::var("S3_WRITER_PART_SIZE")
127 .map(|buffer_size| buffer_size.parse::<usize>())
128 .unwrap_or_else(|_| Ok(10 * 1024 * 1024))
129 .unwrap_or(10 * 1024 * 1024);
130
131 let mut part_buffer: Vec<u8> = Vec::with_capacity(part_size);
132
133 let n_workers = std::env::var("S3_WRITER_WORKERS")
134 .map(|buffer_size| buffer_size.parse::<usize>())
135 .unwrap_or_else(|_| Ok(4))
136 .unwrap_or(4);
137
138 let mut n_jobs = 0;
139 let pool = ThreadPool::new(n_workers);
140
141 let mut file_size = None;
142 let mut received_bytes = 0;
143 let mut prev_percent = 0;
144
145 let (part_sender, part_receiver) = channel();
146
147 while let Ok(mut stream_data) = receiver.recv().await {
148 match stream_data {
149 StreamData::Size(size) => file_size = Some(size),
150 StreamData::Stop => break,
151 StreamData::Eof => {
152 n_jobs += 1;
153 self
154 .upload_s3_part_and_send(
155 path,
156 &upload_identifier,
157 part_number,
158 part_buffer.clone(),
159 part_sender.clone(),
160 )
161 .await?;
162
163 let mut complete_parts = part_receiver
164 .iter()
165 .take(n_jobs)
166 .collect::<Vec<CompletedPart>>();
167 complete_parts.sort_by(|part1, part2| part1.part_number.cmp(&part2.part_number));
168
169 self
170 .complete_s3_upload(path, &upload_identifier, complete_parts)
171 .await?;
172
173 break;
174 }
175 StreamData::Data(ref mut data) => {
176 received_bytes += data.len();
177 if let Some(file_size) = file_size {
178 let percent = (received_bytes as f32 / file_size as f32 * 100.0) as u8;
179
180 if percent > prev_percent {
181 prev_percent = percent;
182 job_and_notification.progress(percent)?;
183 }
184 }
185
186 part_buffer.append(data);
187
188 if part_buffer.len() > part_size {
189 while pool.queued_count() > 1 {
190 thread::sleep(Duration::from_millis(500));
191 }
192
193 self
194 .upload_s3_part_and_send(
195 path,
196 &upload_identifier,
197 part_number,
198 part_buffer.clone(),
199 part_sender.clone(),
200 )
201 .await?;
202
203 n_jobs += 1;
204
205 part_number += 1;
206 part_buffer.clear();
207 }
208 }
209 }
210 }
211 Ok(())
212 }
213}