grammers_client/client/files.rs
1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::Client;
10use crate::types::{Downloadable, Uploaded};
11use crate::utils::generate_random_id;
12use futures_util::stream::{FuturesUnordered, StreamExt as _};
13use grammers_mtsender::InvocationError;
14use grammers_tl_types as tl;
15use std::mem;
16use std::sync::Arc;
17use tokio::{
18 io::{self, AsyncRead, AsyncReadExt},
19 sync::Mutex as AsyncMutex,
20};
21
22#[cfg(feature = "fs")]
23use {
24 std::{io::SeekFrom, path::Path},
25 tokio::{
26 fs,
27 io::{AsyncSeekExt, AsyncWriteExt},
28 sync::mpsc::unbounded_channel,
29 },
30};
31
32pub const MIN_CHUNK_SIZE: i32 = 4 * 1024;
33pub const MAX_CHUNK_SIZE: i32 = 512 * 1024;
34const FILE_MIGRATE_ERROR: i32 = 303;
35const BIG_FILE_SIZE: usize = 10 * 1024 * 1024;
36const WORKER_COUNT: usize = 4;
37
38pub struct DownloadIter {
39 client: Client,
40 done: bool,
41 variant: DownloadIterVariant,
42}
43
44enum DownloadIterVariant {
45 Request(tl::functions::upload::GetFile),
46 PreDownloaded(Vec<u8>),
47 PreFailed(io::Error),
48 Empty,
49}
50
51impl DownloadIter {
52 /// Changes the chunk size, in bytes, used to make requests. Useful if you only need to get a
53 /// small part of a file. By default, `MAX_CHUNK_SIZE` is used.
54 ///
55 /// # Panics
56 ///
57 /// Panics if `size` is not divisible by `MIN_CHUNK_SIZE`, or if `size` is not in contained in
58 /// the range `MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE`.
59 pub fn chunk_size(mut self, size: i32) -> Self {
60 assert!((MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&size) && size % MIN_CHUNK_SIZE == 0);
61 match &mut self.variant {
62 DownloadIterVariant::Request(request) => request.limit = size,
63 _ => {}
64 }
65 self
66 }
67
68 /// Skips `n` chunks to start downloading a different offset from the file. If you want to
69 /// skip less data, modify the `chunk_size` before calling this method, and then reset it to
70 /// any value you want.
71 pub fn skip_chunks(mut self, n: i32) -> Self {
72 match &mut self.variant {
73 DownloadIterVariant::Request(request) => {
74 request.offset += request.limit as i64 * (n as i64)
75 }
76 _ => {}
77 }
78 self
79 }
80
81 /// Fetch and return the next chunk.
82 pub async fn next(&mut self) -> Result<Option<Vec<u8>>, InvocationError> {
83 if self.done {
84 return Ok(None);
85 }
86
87 let variant = mem::replace(&mut self.variant, DownloadIterVariant::Empty);
88
89 let mut request = match variant {
90 DownloadIterVariant::Request(r) => r,
91 DownloadIterVariant::PreDownloaded(data) => {
92 self.done = true;
93 return Ok(Some(data.clone()));
94 }
95 DownloadIterVariant::PreFailed(error) => {
96 return Err(InvocationError::Io(error));
97 }
98 DownloadIterVariant::Empty => return Ok(None),
99 };
100
101 use tl::enums::upload::File;
102
103 // TODO handle maybe FILEREF_UPGRADE_NEEDED
104 let mut dc = self.client.0.session.home_dc_id();
105 loop {
106 break match self.client.invoke_in_dc(dc, &request).await {
107 Ok(File::File(f)) => {
108 if f.bytes.len() < request.limit as usize {
109 self.done = true;
110 if f.bytes.is_empty() {
111 break Ok(None);
112 }
113 }
114
115 request.offset += request.limit as i64;
116 self.variant = DownloadIterVariant::Request(request);
117
118 Ok(Some(f.bytes))
119 }
120 Ok(File::CdnRedirect(_)) => {
121 panic!("API returned File::CdnRedirect even though cdn_supported = false");
122 }
123 Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
124 match self.client.copy_auth_to_dc(dc).await {
125 Ok(_) => continue,
126 Err(e) => Err(e),
127 }
128 }
129 Err(InvocationError::Rpc(err)) if err.code == FILE_MIGRATE_ERROR => {
130 dc = err.value.unwrap() as _;
131 continue;
132 }
133 Err(e) => Err(e),
134 };
135 }
136 }
137}
138
139/// Method implementations related to uploading or downloading files.
140impl Client {
141 /// Returns a new iterator over the contents of a media document that will be downloaded.
142 ///
143 /// # Examples
144 ///
145 /// ```
146 /// # async fn f(downloadable: grammers_client::types::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
147 /// let mut file_bytes = Vec::new();
148 /// let mut download = client.iter_download(&downloadable);
149 ///
150 /// while let Some(chunk) = download.next().await? {
151 /// file_bytes.extend(chunk);
152 /// }
153 ///
154 /// // The file is now downloaded in-memory, inside `file_bytes`!
155 /// # Ok(())
156 /// # }
157 /// ```
158 pub fn iter_download<D: Downloadable>(&self, downloadable: &D) -> DownloadIter {
159 if let Some(data) = downloadable.to_data() {
160 DownloadIter {
161 client: self.clone(),
162 done: false,
163 variant: DownloadIterVariant::PreDownloaded(data),
164 }
165 } else if let Some(location) = downloadable.to_raw_input_location() {
166 DownloadIter {
167 client: self.clone(),
168 done: false,
169 variant: DownloadIterVariant::Request(tl::functions::upload::GetFile {
170 precise: false,
171 cdn_supported: false,
172 location,
173 offset: 0,
174 limit: MAX_CHUNK_SIZE,
175 }),
176 }
177 } else {
178 DownloadIter {
179 client: self.clone(),
180 done: false,
181 variant: DownloadIterVariant::PreFailed(io::Error::new(
182 io::ErrorKind::Other,
183 "media not downloadable",
184 )),
185 }
186 }
187 }
188
189 /// Downloads a media file into the specified path.
190 ///
191 /// If the file already exists, it will be overwritten.
192 ///
193 /// This is a small wrapper around [`Client::iter_download`] for the common case of
194 /// wanting to save the file locally.
195 ///
196 /// # Examples
197 ///
198 /// ```
199 /// # async fn f(downloadable: grammers_client::types::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
200 /// client.download_media(&downloadable, "/home/username/photos/holidays.jpg").await?;
201 /// # Ok(())
202 /// # }
203 /// ```
204 #[cfg(feature = "fs")]
205 pub async fn download_media<D: Downloadable, P: AsRef<Path>>(
206 &self,
207 downloadable: &D,
208 path: P,
209 ) -> Result<(), io::Error> {
210 // Concurrent downloader
211 if let Some((location, size)) = downloadable
212 .to_raw_input_location()
213 .zip(downloadable.size())
214 {
215 if size > BIG_FILE_SIZE {
216 return self
217 .download_media_concurrent(location, size, path, WORKER_COUNT)
218 .await;
219 }
220 }
221
222 let mut download = self.iter_download(downloadable);
223 Client::load(path, &mut download).await
224 }
225
226 #[cfg(feature = "fs")]
227 async fn load<P: AsRef<Path>>(path: P, download: &mut DownloadIter) -> Result<(), io::Error> {
228 let mut file = fs::File::create(path).await?;
229 while let Some(chunk) = download
230 .next()
231 .await
232 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
233 {
234 file.write_all(&chunk).await?;
235 }
236
237 Ok(())
238 }
239
240 /// Downloads a `Document` to specified path using multiple connections
241 #[cfg(feature = "fs")]
242 async fn download_media_concurrent<P: AsRef<Path>>(
243 &self,
244 location: tl::enums::InputFileLocation,
245 size: usize,
246 path: P,
247 workers: usize,
248 ) -> Result<(), io::Error> {
249 // Allocate
250 let mut file = fs::File::create(path).await?;
251 file.set_len(size as u64).await?;
252 file.seek(SeekFrom::Start(0)).await?;
253
254 // Start workers
255 let (tx, mut rx) = unbounded_channel();
256 let part_index = Arc::new(tokio::sync::Mutex::new(0));
257 let mut tasks = vec![];
258 let home_dc_id = self.0.session.home_dc_id();
259 for _ in 0..workers {
260 let location = location.clone();
261 let tx = tx.clone();
262 let part_index = part_index.clone();
263 let client = self.clone();
264 let task = tokio::task::spawn(async move {
265 let mut retry_offset = None;
266 let mut dc = home_dc_id;
267 loop {
268 // Calculate file offset
269 let offset: i64 = {
270 if let Some(offset) = retry_offset {
271 retry_offset = None;
272 offset
273 } else {
274 let mut i = part_index.lock().await;
275 *i += 1;
276 (MAX_CHUNK_SIZE * (*i - 1)) as i64
277 }
278 };
279 if offset > size as i64 {
280 break;
281 }
282 // Fetch from telegram
283 let request = &tl::functions::upload::GetFile {
284 precise: true,
285 cdn_supported: false,
286 location: location.clone(),
287 offset,
288 limit: MAX_CHUNK_SIZE,
289 };
290 match client.invoke_in_dc(dc, request).await {
291 Ok(tl::enums::upload::File::File(file)) => {
292 tx.send((offset as u64, file.bytes)).unwrap();
293 }
294 Ok(tl::enums::upload::File::CdnRedirect(_)) => {
295 panic!(
296 "API returned File::CdnRedirect even though cdn_supported = false"
297 );
298 }
299 Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
300 match client.copy_auth_to_dc(dc).await {
301 Ok(_) => {
302 retry_offset = Some(offset);
303 continue;
304 }
305 Err(e) => return Err(e),
306 }
307 }
308 Err(InvocationError::Rpc(err)) => {
309 if err.code == FILE_MIGRATE_ERROR {
310 dc = err.value.unwrap() as _;
311 retry_offset = Some(offset);
312 continue;
313 }
314 return Err(InvocationError::Rpc(err));
315 }
316 Err(e) => return Err(e),
317 }
318 }
319 Ok::<(), InvocationError>(())
320 });
321 tasks.push(task);
322 }
323 drop(tx);
324
325 // File write loop
326 let mut pos = 0;
327 while let Some((offset, data)) = rx.recv().await {
328 if offset != pos {
329 file.seek(SeekFrom::Start(offset)).await?;
330 }
331 file.write_all(&data).await?;
332 pos = offset + data.len() as u64;
333 }
334
335 // Check if all tasks finished succesfully
336 for task in tasks {
337 task.await?
338 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
339 }
340 Ok(())
341 }
342
343 /// Uploads an async stream to Telegram servers.
344 ///
345 /// The file is not sent to any peer, but can be used as media when sending messages for a
346 /// certain period of time (less than a day). You can use this uploaded file multiple times.
347 ///
348 /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
349 ///
350 /// The stream size must be known beforehand. If this is not possible, you might need to
351 /// process the entire async stream to determine its size, and then use the size and the
352 /// downloaded buffer.
353 ///
354 /// The stream size may be less or equal to the actual length of the stream, but not more.
355 /// If it's less, you may continue to read from the stream after the method returns.
356 /// If it's more, the method will fail because it does not have enough data to read.
357 ///
358 /// Note that Telegram uses the file name in certain methods, for example, to make sure the
359 /// file is an image when trying to use send the file as photo media, so it is important that
360 /// the file name at least uses the right extension, even if the name is a dummy value.
361 /// If the input file name is empty, the non-empty dummy value "a" will be used instead.
362 /// Because it has no extension, you may not be able to use the file in certain methods.
363 ///
364 /// # Examples
365 ///
366 /// ```
367 /// # async fn f(peer: grammers_session::defs::PeerRef, client: grammers_client::Client, some_vec: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
368 /// use grammers_client::InputMessage;
369 ///
370 /// // In-memory `Vec<u8>` buffers can be used as async streams
371 /// let size = some_vec.len();
372 /// let mut stream = std::io::Cursor::new(some_vec);
373 /// let uploaded_file = client.upload_stream(&mut stream, size, "sleep.jpg".to_string()).await?;
374 ///
375 /// client.send_message(peer, InputMessage::new().text("Zzz...").photo(uploaded_file)).await?;
376 /// # Ok(())
377 /// # }
378 /// ```
379 ///
380 /// [`InputMessage`]: crate::types::InputMessage
381 pub async fn upload_stream<S: AsyncRead + Unpin>(
382 &self,
383 stream: &mut S,
384 size: usize,
385 name: String,
386 ) -> Result<Uploaded, io::Error> {
387 let file_id = generate_random_id();
388 let name = if name.is_empty() {
389 "a".to_string()
390 } else {
391 name
392 };
393
394 let big_file = size > BIG_FILE_SIZE;
395 let parts = PartStream::new(stream, size);
396 let total_parts = parts.total_parts();
397
398 if big_file {
399 let parts = Arc::new(parts);
400 let mut tasks = FuturesUnordered::new();
401 for _ in 0..WORKER_COUNT {
402 let handle = self.clone();
403 let parts = Arc::clone(&parts);
404 let task = async move {
405 while let Some((part, bytes)) = parts.next_part().await? {
406 let ok = handle
407 .invoke(&tl::functions::upload::SaveBigFilePart {
408 file_id,
409 file_part: part,
410 file_total_parts: total_parts,
411 bytes,
412 })
413 .await
414 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
415
416 if !ok {
417 return Err(io::Error::new(
418 io::ErrorKind::Other,
419 "server failed to store uploaded data",
420 ));
421 }
422 }
423 Ok(())
424 };
425 tasks.push(task);
426 }
427
428 while let Some(res) = tasks.next().await {
429 res?;
430 }
431
432 Ok(Uploaded::from_raw(
433 tl::types::InputFileBig {
434 id: file_id,
435 parts: total_parts,
436 name,
437 }
438 .into(),
439 ))
440 } else {
441 let mut md5 = md5::Context::new();
442 while let Some((part, bytes)) = parts.next_part().await? {
443 md5.consume(&bytes);
444 let ok = self
445 .invoke(&tl::functions::upload::SaveFilePart {
446 file_id,
447 file_part: part,
448 bytes,
449 })
450 .await
451 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
452
453 if !ok {
454 return Err(io::Error::new(
455 io::ErrorKind::Other,
456 "server failed to store uploaded data",
457 ));
458 }
459 }
460 Ok(Uploaded::from_raw(
461 tl::types::InputFile {
462 id: file_id,
463 parts: total_parts,
464 name,
465 md5_checksum: format!("{:x}", md5.finalize()),
466 }
467 .into(),
468 ))
469 }
470 }
471
472 /// Uploads a local file to Telegram servers.
473 ///
474 /// The file is not sent to any peer, but can be used as media when sending messages for a
475 /// certain period of time (less than a day). You can use this uploaded file multiple times.
476 ///
477 /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
478 ///
479 /// If you need more control over the uploaded data, such as performing only a partial upload
480 /// or with a different name, use [`Client::upload_stream`] instead.
481 ///
482 /// # Examples
483 ///
484 /// ```
485 /// # async fn f(peer: grammers_session::defs::PeerRef, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
486 /// use grammers_client::InputMessage;
487 ///
488 /// let uploaded_file = client.upload_file("/home/username/photos/holidays.jpg").await?;
489 ///
490 /// client.send_message(peer, InputMessage::new().text("Check this out!").photo(uploaded_file)).await?;
491 /// # Ok(())
492 /// # }
493 /// ```
494 ///
495 /// [`InputMessage`]: crate::InputMessage
496 #[cfg(feature = "fs")]
497 pub async fn upload_file<P: AsRef<Path>>(&self, path: P) -> Result<Uploaded, io::Error> {
498 let path = path.as_ref();
499
500 let mut file = fs::File::open(path).await?;
501 let size = file.seek(SeekFrom::End(0)).await? as usize;
502 file.seek(SeekFrom::Start(0)).await?;
503
504 // File name will only be `None` for `..` path, and directories cannot be uploaded as
505 // files, so it's fine to unwrap.
506 let name = path.file_name().unwrap().to_string_lossy().to_string();
507
508 self.upload_stream(&mut file, size, name).await
509 }
510}
511
512struct PartStreamInner<'a, S: AsyncRead + Unpin> {
513 stream: &'a mut S,
514 current_part: i32,
515}
516
517struct PartStream<'a, S: AsyncRead + Unpin> {
518 inner: AsyncMutex<PartStreamInner<'a, S>>,
519 total_parts: i32,
520}
521
522impl<'a, S: AsyncRead + Unpin> PartStream<'a, S> {
523 fn new(stream: &'a mut S, size: usize) -> Self {
524 let total_parts = ((size + MAX_CHUNK_SIZE as usize - 1) / MAX_CHUNK_SIZE as usize) as i32;
525 Self {
526 inner: AsyncMutex::new(PartStreamInner {
527 stream,
528 current_part: 0,
529 }),
530 total_parts,
531 }
532 }
533
534 fn total_parts(&self) -> i32 {
535 self.total_parts
536 }
537
538 async fn next_part(&self) -> Result<Option<(i32, Vec<u8>)>, io::Error> {
539 let mut lock = self.inner.lock().await;
540 if lock.current_part >= self.total_parts {
541 return Ok(None);
542 }
543 let mut read = 0;
544 let mut buffer = vec![0; MAX_CHUNK_SIZE as usize];
545
546 while read != buffer.len() {
547 let n = lock.stream.read(&mut buffer[read..]).await?;
548 if n == 0 {
549 if lock.current_part == self.total_parts - 1 {
550 break;
551 } else {
552 return Err(io::Error::new(
553 io::ErrorKind::UnexpectedEof,
554 "reached EOF before reaching the last file part",
555 ));
556 }
557 }
558 read += n;
559 }
560
561 let bytes = if read == buffer.len() {
562 buffer
563 } else {
564 buffer[..read].to_vec()
565 };
566
567 let res = Ok(Some((lock.current_part, bytes)));
568 lock.current_part += 1;
569 res
570 }
571}