1use std::mem;
10use std::sync::Arc;
11#[cfg(feature = "fs")]
12use std::{io::SeekFrom, path::Path};
13
14use futures_util::stream::{FuturesUnordered, StreamExt as _};
15use grammers_mtsender::InvocationError;
16use grammers_tl_types as tl;
17use tokio::io::{self, AsyncRead, AsyncReadExt};
18use tokio::sync::Mutex as AsyncMutex;
19#[cfg(feature = "fs")]
20use tokio::{
21 fs,
22 io::{AsyncSeekExt, AsyncWriteExt},
23 sync::mpsc::unbounded_channel,
24};
25
26use super::Client;
27use crate::media::{Downloadable, Uploaded};
28use crate::utils::generate_random_id;
29
30pub const MIN_CHUNK_SIZE: i32 = 4 * 1024;
31pub const MAX_CHUNK_SIZE: i32 = 512 * 1024;
32const FILE_MIGRATE_ERROR: i32 = 303;
33const BIG_FILE_SIZE: usize = 10 * 1024 * 1024;
34const WORKER_COUNT: usize = 4;
35
36pub struct DownloadIter {
38 client: Client,
39 done: bool,
40 size: Option<usize>,
41 variant: DownloadIterVariant,
42}
43
44enum DownloadIterVariant {
45 Request(tl::functions::upload::GetFile),
46 PreDownloaded(Vec<u8>),
47 PreFailed(io::Error),
48 Empty,
49}
50
51impl DownloadIter {
52 pub fn chunk_size(mut self, size: i32) -> Self {
60 assert!((MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&size) && size % MIN_CHUNK_SIZE == 0);
61 match &mut self.variant {
62 DownloadIterVariant::Request(request) => request.limit = size,
63 _ => {}
64 }
65 self
66 }
67
68 pub fn skip_chunks(mut self, n: i32) -> Self {
72 match &mut self.variant {
73 DownloadIterVariant::Request(request) => {
74 request.offset += request.limit as i64 * (n as i64)
75 }
76 _ => {}
77 }
78 self
79 }
80
81 pub async fn next(&mut self) -> Result<Option<Vec<u8>>, InvocationError> {
83 if self.done {
84 return Ok(None);
85 }
86
87 let variant = mem::replace(&mut self.variant, DownloadIterVariant::Empty);
88
89 let mut request = match variant {
90 DownloadIterVariant::Request(r) => r,
91 DownloadIterVariant::PreDownloaded(data) => {
92 self.done = true;
93 return Ok(Some(data.clone()));
94 }
95 DownloadIterVariant::PreFailed(error) => {
96 return Err(InvocationError::Io(error));
97 }
98 DownloadIterVariant::Empty => return Ok(None),
99 };
100
101 use tl::enums::upload::File;
102
103 let mut dc = self.client.0.session.home_dc_id()?;
105 loop {
106 break match self.client.invoke_in_dc(dc, &request).await {
107 Ok(File::File(f)) => {
108 let reached_known_size = self
109 .size
110 .is_some_and(|size| request.offset as usize + f.bytes.len() >= size);
111
112 if reached_known_size || f.bytes.len() < request.limit as usize {
113 self.done = true;
114 if f.bytes.is_empty() {
115 break Ok(None);
116 }
117 }
118
119 request.offset += request.limit as i64;
120 self.variant = DownloadIterVariant::Request(request);
121
122 Ok(Some(f.bytes))
123 }
124 Ok(File::CdnRedirect(_)) => {
125 panic!("API returned File::CdnRedirect even though cdn_supported = false");
126 }
127 Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
128 match self.client.copy_auth_to_dc(dc).await {
129 Ok(_) => continue,
130 Err(e) => Err(e),
131 }
132 }
133 Err(InvocationError::Rpc(err)) if err.code == FILE_MIGRATE_ERROR => {
134 dc = err.value.unwrap() as _;
135 continue;
136 }
137 Err(e) => Err(e),
138 };
139 }
140 }
141}
142
143impl Client {
145 pub fn iter_download<D: Downloadable>(&self, downloadable: &D) -> DownloadIter {
163 if let Some(data) = downloadable.to_data() {
164 DownloadIter {
165 client: self.clone(),
166 done: false,
167 size: Some(data.len()),
168 variant: DownloadIterVariant::PreDownloaded(data),
169 }
170 } else if let Some(location) = downloadable.to_raw_input_location() {
171 DownloadIter {
172 client: self.clone(),
173 done: false,
174 size: downloadable.size(),
175 variant: DownloadIterVariant::Request(tl::functions::upload::GetFile {
176 precise: false,
177 cdn_supported: false,
178 location,
179 offset: 0,
180 limit: MAX_CHUNK_SIZE,
181 }),
182 }
183 } else {
184 DownloadIter {
185 client: self.clone(),
186 done: false,
187 size: None,
188 variant: DownloadIterVariant::PreFailed(io::Error::new(
189 io::ErrorKind::Other,
190 "media not downloadable",
191 )),
192 }
193 }
194 }
195
196 #[cfg(feature = "fs")]
212 pub async fn download_media<D: Downloadable, P: AsRef<Path>>(
213 &self,
214 downloadable: &D,
215 path: P,
216 ) -> Result<(), InvocationError> {
217 if let Some((location, size)) = downloadable
219 .to_raw_input_location()
220 .zip(downloadable.size())
221 {
222 if size > BIG_FILE_SIZE {
223 return self
224 .download_media_concurrent(location, size, path, WORKER_COUNT)
225 .await;
226 }
227 }
228
229 let mut download = self.iter_download(downloadable);
230 Ok(Client::load(path, &mut download).await?)
231 }
232
233 #[cfg(feature = "fs")]
234 async fn load<P: AsRef<Path>>(path: P, download: &mut DownloadIter) -> Result<(), io::Error> {
235 let mut file = fs::File::create(path).await?;
236 while let Some(chunk) = download.next().await.map_err(io::Error::other)? {
237 file.write_all(&chunk).await?;
238 }
239
240 Ok(())
241 }
242
243 #[cfg(feature = "fs")]
245 async fn download_media_concurrent<P: AsRef<Path>>(
246 &self,
247 location: tl::enums::InputFileLocation,
248 size: usize,
249 path: P,
250 workers: usize,
251 ) -> Result<(), InvocationError> {
252 let mut file = fs::File::create(path).await.map_err(InvocationError::Io)?;
254 file.set_len(size as u64)
255 .await
256 .map_err(InvocationError::Io)?;
257 file.seek(SeekFrom::Start(0))
258 .await
259 .map_err(InvocationError::Io)?;
260
261 let (tx, mut rx) = unbounded_channel();
263 let part_index = Arc::new(tokio::sync::Mutex::<i64>::new(0));
264 let mut tasks = vec![];
265 let home_dc_id = self.0.session.home_dc_id()?;
266 for _ in 0..workers {
267 let location = location.clone();
268 let tx = tx.clone();
269 let part_index = part_index.clone();
270 let client = self.clone();
271 let task = tokio::task::spawn(async move {
272 let mut retry_offset = None;
273 let mut dc = home_dc_id;
274 loop {
275 let offset: i64 = {
277 if let Some(offset) = retry_offset {
278 retry_offset = None;
279 offset
280 } else {
281 let mut i = part_index.lock().await;
282 *i += 1;
283 MAX_CHUNK_SIZE as i64 * (*i - 1)
284 }
285 };
286 if offset >= size as i64 {
287 break;
288 }
289 let request = &tl::functions::upload::GetFile {
291 precise: true,
292 cdn_supported: false,
293 location: location.clone(),
294 offset,
295 limit: MAX_CHUNK_SIZE,
296 };
297 match client.invoke_in_dc(dc, request).await {
298 Ok(tl::enums::upload::File::File(file)) => {
299 tx.send((offset as u64, file.bytes)).unwrap();
300 }
301 Ok(tl::enums::upload::File::CdnRedirect(_)) => {
302 panic!(
303 "API returned File::CdnRedirect even though cdn_supported = false"
304 );
305 }
306 Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
307 match client.copy_auth_to_dc(dc).await {
308 Ok(_) => {
309 retry_offset = Some(offset);
310 continue;
311 }
312 Err(e) => return Err(e),
313 }
314 }
315 Err(InvocationError::Rpc(err)) => {
316 if err.code == FILE_MIGRATE_ERROR {
317 dc = err.value.unwrap() as _;
318 retry_offset = Some(offset);
319 continue;
320 }
321 return Err(InvocationError::Rpc(err));
322 }
323 Err(e) => return Err(e),
324 }
325 }
326 Ok::<(), InvocationError>(())
327 });
328 tasks.push(task);
329 }
330 drop(tx);
331
332 let mut pos = 0;
334 while let Some((offset, data)) = rx.recv().await {
335 if offset != pos {
336 file.seek(SeekFrom::Start(offset)).await?;
337 }
338 file.write_all(&data).await?;
339 pos = offset + data.len() as u64;
340 }
341
342 for task in tasks {
344 let res = task.await.map_err(io::Error::other)?;
345 res?;
346 }
347 Ok(())
348 }
349
350 pub async fn upload_stream<S: AsyncRead + Unpin>(
389 &self,
390 stream: &mut S,
391 size: usize,
392 name: String,
393 ) -> Result<Uploaded, io::Error> {
394 let file_id = generate_random_id();
395 let name = if name.is_empty() {
396 "a".to_string()
397 } else {
398 name
399 };
400
401 let big_file = size > BIG_FILE_SIZE;
402 let parts = PartStream::new(stream, size);
403 let total_parts = parts.total_parts();
404
405 if big_file {
406 let parts = Arc::new(parts);
407 let mut tasks = FuturesUnordered::new();
408 for _ in 0..WORKER_COUNT {
409 let handle = self.clone();
410 let parts = Arc::clone(&parts);
411 let task = async move {
412 while let Some((part, bytes)) = parts.next_part().await? {
413 let ok = handle
414 .invoke(&tl::functions::upload::SaveBigFilePart {
415 file_id,
416 file_part: part,
417 file_total_parts: total_parts,
418 bytes,
419 })
420 .await
421 .map_err(io::Error::other)?;
422
423 if !ok {
424 return Err(io::Error::new(
425 io::ErrorKind::Other,
426 "server failed to store uploaded data",
427 ));
428 }
429 }
430 Ok(())
431 };
432 tasks.push(task);
433 }
434
435 while let Some(res) = tasks.next().await {
436 res?;
437 }
438
439 Ok(Uploaded::from_raw(
440 tl::types::InputFileBig {
441 id: file_id,
442 parts: total_parts,
443 name,
444 }
445 .into(),
446 ))
447 } else {
448 let mut md5 = md5::Context::new();
449 while let Some((part, bytes)) = parts.next_part().await? {
450 md5.consume(&bytes);
451 let ok = self
452 .invoke(&tl::functions::upload::SaveFilePart {
453 file_id,
454 file_part: part,
455 bytes,
456 })
457 .await
458 .map_err(io::Error::other)?;
459
460 if !ok {
461 return Err(io::Error::new(
462 io::ErrorKind::Other,
463 "server failed to store uploaded data",
464 ));
465 }
466 }
467 Ok(Uploaded::from_raw(
468 tl::types::InputFile {
469 id: file_id,
470 parts: total_parts,
471 name,
472 md5_checksum: format!("{:x}", md5.finalize()),
473 }
474 .into(),
475 ))
476 }
477 }
478
479 #[cfg(feature = "fs")]
504 pub async fn upload_file<P: AsRef<Path>>(&self, path: P) -> Result<Uploaded, io::Error> {
505 let path = path.as_ref();
506
507 let mut file = fs::File::open(path).await?;
508 let size = file.seek(SeekFrom::End(0)).await? as usize;
509 file.seek(SeekFrom::Start(0)).await?;
510
511 let name = path.file_name().unwrap().to_string_lossy().to_string();
514
515 self.upload_stream(&mut file, size, name).await
516 }
517}
518
519struct PartStreamInner<'a, S: AsyncRead + Unpin> {
520 stream: &'a mut S,
521 current_part: i32,
522}
523
524struct PartStream<'a, S: AsyncRead + Unpin> {
525 inner: AsyncMutex<PartStreamInner<'a, S>>,
526 total_parts: i32,
527}
528
529impl<'a, S: AsyncRead + Unpin> PartStream<'a, S> {
530 fn new(stream: &'a mut S, size: usize) -> Self {
531 let total_parts = ((size + MAX_CHUNK_SIZE as usize - 1) / MAX_CHUNK_SIZE as usize) as i32;
532 Self {
533 inner: AsyncMutex::new(PartStreamInner {
534 stream,
535 current_part: 0,
536 }),
537 total_parts,
538 }
539 }
540
541 fn total_parts(&self) -> i32 {
542 self.total_parts
543 }
544
545 async fn next_part(&self) -> Result<Option<(i32, Vec<u8>)>, io::Error> {
546 let mut lock = self.inner.lock().await;
547 if lock.current_part >= self.total_parts {
548 return Ok(None);
549 }
550 let mut read = 0;
551 let mut buffer = vec![0; MAX_CHUNK_SIZE as usize];
552
553 while read != buffer.len() {
554 let n = lock.stream.read(&mut buffer[read..]).await?;
555 if n == 0 {
556 if lock.current_part == self.total_parts - 1 {
557 break;
558 } else {
559 return Err(io::Error::new(
560 io::ErrorKind::UnexpectedEof,
561 "reached EOF before reaching the last file part",
562 ));
563 }
564 }
565 read += n;
566 }
567
568 let bytes = if read == buffer.len() {
569 buffer
570 } else {
571 buffer[..read].to_vec()
572 };
573
574 let res = Ok(Some((lock.current_part, bytes)));
575 lock.current_part += 1;
576 res
577 }
578}