Skip to main content

grammers_client/client/
files.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::mem;
10use std::sync::Arc;
11#[cfg(feature = "fs")]
12use std::{io::SeekFrom, path::Path};
13
14use futures_util::stream::{FuturesUnordered, StreamExt as _};
15use grammers_mtsender::InvocationError;
16use grammers_tl_types as tl;
17use tokio::io::{self, AsyncRead, AsyncReadExt};
18use tokio::sync::Mutex as AsyncMutex;
19#[cfg(feature = "fs")]
20use tokio::{
21    fs,
22    io::{AsyncSeekExt, AsyncWriteExt},
23    sync::mpsc::unbounded_channel,
24};
25
26use super::Client;
27use crate::media::{Downloadable, Uploaded};
28use crate::utils::generate_random_id;
29
30pub const MIN_CHUNK_SIZE: i32 = 4 * 1024;
31pub const MAX_CHUNK_SIZE: i32 = 512 * 1024;
32const FILE_MIGRATE_ERROR: i32 = 303;
33const BIG_FILE_SIZE: usize = 10 * 1024 * 1024;
34const WORKER_COUNT: usize = 4;
35
36/// Iterator returned by [`Client::iter_download`].
37pub struct DownloadIter {
38    client: Client,
39    done: bool,
40    variant: DownloadIterVariant,
41}
42
43enum DownloadIterVariant {
44    Request(tl::functions::upload::GetFile),
45    PreDownloaded(Vec<u8>),
46    PreFailed(io::Error),
47    Empty,
48}
49
50impl DownloadIter {
51    /// Changes the chunk size, in bytes, used to make requests. Useful if you only need to get a
52    /// small part of a file. By default, `MAX_CHUNK_SIZE` is used.
53    ///
54    /// # Panics
55    ///
56    /// Panics if `size` is not divisible by `MIN_CHUNK_SIZE`, or if `size` is not in contained in
57    /// the range `MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE`.
58    pub fn chunk_size(mut self, size: i32) -> Self {
59        assert!((MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&size) && size % MIN_CHUNK_SIZE == 0);
60        match &mut self.variant {
61            DownloadIterVariant::Request(request) => request.limit = size,
62            _ => {}
63        }
64        self
65    }
66
67    /// Skips `n` chunks to start downloading a different offset from the file. If you want to
68    /// skip less data, modify the `chunk_size` before calling this method, and then reset it to
69    /// any value you want.
70    pub fn skip_chunks(mut self, n: i32) -> Self {
71        match &mut self.variant {
72            DownloadIterVariant::Request(request) => {
73                request.offset += request.limit as i64 * (n as i64)
74            }
75            _ => {}
76        }
77        self
78    }
79
80    /// Fetch and return the next chunk.
81    pub async fn next(&mut self) -> Result<Option<Vec<u8>>, InvocationError> {
82        if self.done {
83            return Ok(None);
84        }
85
86        let variant = mem::replace(&mut self.variant, DownloadIterVariant::Empty);
87
88        let mut request = match variant {
89            DownloadIterVariant::Request(r) => r,
90            DownloadIterVariant::PreDownloaded(data) => {
91                self.done = true;
92                return Ok(Some(data.clone()));
93            }
94            DownloadIterVariant::PreFailed(error) => {
95                return Err(InvocationError::Io(error));
96            }
97            DownloadIterVariant::Empty => return Ok(None),
98        };
99
100        use tl::enums::upload::File;
101
102        // TODO handle maybe FILEREF_UPGRADE_NEEDED
103        let mut dc = self.client.0.session.home_dc_id();
104        loop {
105            break match self.client.invoke_in_dc(dc, &request).await {
106                Ok(File::File(f)) => {
107                    if f.bytes.len() < request.limit as usize {
108                        self.done = true;
109                        if f.bytes.is_empty() {
110                            break Ok(None);
111                        }
112                    }
113
114                    request.offset += request.limit as i64;
115                    self.variant = DownloadIterVariant::Request(request);
116
117                    Ok(Some(f.bytes))
118                }
119                Ok(File::CdnRedirect(_)) => {
120                    panic!("API returned File::CdnRedirect even though cdn_supported = false");
121                }
122                Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
123                    match self.client.copy_auth_to_dc(dc).await {
124                        Ok(_) => continue,
125                        Err(e) => Err(e),
126                    }
127                }
128                Err(InvocationError::Rpc(err)) if err.code == FILE_MIGRATE_ERROR => {
129                    dc = err.value.unwrap() as _;
130                    continue;
131                }
132                Err(e) => Err(e),
133            };
134        }
135    }
136}
137
138/// Method implementations related to uploading or downloading files.
139impl Client {
140    /// Returns a new iterator over the contents of a media document that will be downloaded.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// # async fn f(downloadable: grammers_client::media::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
146    /// let mut file_bytes = Vec::new();
147    /// let mut download = client.iter_download(&downloadable);
148    ///
149    /// while let Some(chunk) = download.next().await? {
150    ///     file_bytes.extend(chunk);
151    /// }
152    ///
153    /// // The file is now downloaded in-memory, inside `file_bytes`!
154    /// # Ok(())
155    /// # }
156    /// ```
157    pub fn iter_download<D: Downloadable>(&self, downloadable: &D) -> DownloadIter {
158        if let Some(data) = downloadable.to_data() {
159            DownloadIter {
160                client: self.clone(),
161                done: false,
162                variant: DownloadIterVariant::PreDownloaded(data),
163            }
164        } else if let Some(location) = downloadable.to_raw_input_location() {
165            DownloadIter {
166                client: self.clone(),
167                done: false,
168                variant: DownloadIterVariant::Request(tl::functions::upload::GetFile {
169                    precise: false,
170                    cdn_supported: false,
171                    location,
172                    offset: 0,
173                    limit: MAX_CHUNK_SIZE,
174                }),
175            }
176        } else {
177            DownloadIter {
178                client: self.clone(),
179                done: false,
180                variant: DownloadIterVariant::PreFailed(io::Error::new(
181                    io::ErrorKind::Other,
182                    "media not downloadable",
183                )),
184            }
185        }
186    }
187
188    /// Downloads a media file into the specified path.
189    ///
190    /// If the file already exists, it will be overwritten.
191    ///
192    /// This is a small wrapper around [`Client::iter_download`] for the common case of
193    /// wanting to save the file locally.
194    ///
195    /// # Examples
196    ///
197    /// ```
198    /// # async fn f(downloadable: grammers_client::media::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
199    /// client.download_media(&downloadable, "/home/username/photos/holidays.jpg").await?;
200    /// # Ok(())
201    /// # }
202    /// ```
203    #[cfg(feature = "fs")]
204    pub async fn download_media<D: Downloadable, P: AsRef<Path>>(
205        &self,
206        downloadable: &D,
207        path: P,
208    ) -> Result<(), io::Error> {
209        // Concurrent downloader
210        if let Some((location, size)) = downloadable
211            .to_raw_input_location()
212            .zip(downloadable.size())
213        {
214            if size > BIG_FILE_SIZE {
215                return self
216                    .download_media_concurrent(location, size, path, WORKER_COUNT)
217                    .await;
218            }
219        }
220
221        let mut download = self.iter_download(downloadable);
222        Client::load(path, &mut download).await
223    }
224
225    #[cfg(feature = "fs")]
226    async fn load<P: AsRef<Path>>(path: P, download: &mut DownloadIter) -> Result<(), io::Error> {
227        let mut file = fs::File::create(path).await?;
228        while let Some(chunk) = download
229            .next()
230            .await
231            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
232        {
233            file.write_all(&chunk).await?;
234        }
235
236        Ok(())
237    }
238
239    /// Downloads a `Document` to specified path using multiple connections
240    #[cfg(feature = "fs")]
241    async fn download_media_concurrent<P: AsRef<Path>>(
242        &self,
243        location: tl::enums::InputFileLocation,
244        size: usize,
245        path: P,
246        workers: usize,
247    ) -> Result<(), io::Error> {
248        // Allocate
249        let mut file = fs::File::create(path).await?;
250        file.set_len(size as u64).await?;
251        file.seek(SeekFrom::Start(0)).await?;
252
253        // Start workers
254        let (tx, mut rx) = unbounded_channel();
255        let part_index = Arc::new(tokio::sync::Mutex::new(0));
256        let mut tasks = vec![];
257        let home_dc_id = self.0.session.home_dc_id();
258        for _ in 0..workers {
259            let location = location.clone();
260            let tx = tx.clone();
261            let part_index = part_index.clone();
262            let client = self.clone();
263            let task = tokio::task::spawn(async move {
264                let mut retry_offset = None;
265                let mut dc = home_dc_id;
266                loop {
267                    // Calculate file offset
268                    let offset: i64 = {
269                        if let Some(offset) = retry_offset {
270                            retry_offset = None;
271                            offset
272                        } else {
273                            let mut i = part_index.lock().await;
274                            *i += 1;
275                            (MAX_CHUNK_SIZE * (*i - 1)) as i64
276                        }
277                    };
278                    if offset > size as i64 {
279                        break;
280                    }
281                    // Fetch from telegram
282                    let request = &tl::functions::upload::GetFile {
283                        precise: true,
284                        cdn_supported: false,
285                        location: location.clone(),
286                        offset,
287                        limit: MAX_CHUNK_SIZE,
288                    };
289                    match client.invoke_in_dc(dc, request).await {
290                        Ok(tl::enums::upload::File::File(file)) => {
291                            tx.send((offset as u64, file.bytes)).unwrap();
292                        }
293                        Ok(tl::enums::upload::File::CdnRedirect(_)) => {
294                            panic!(
295                                "API returned File::CdnRedirect even though cdn_supported = false"
296                            );
297                        }
298                        Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
299                            match client.copy_auth_to_dc(dc).await {
300                                Ok(_) => {
301                                    retry_offset = Some(offset);
302                                    continue;
303                                }
304                                Err(e) => return Err(e),
305                            }
306                        }
307                        Err(InvocationError::Rpc(err)) => {
308                            if err.code == FILE_MIGRATE_ERROR {
309                                dc = err.value.unwrap() as _;
310                                retry_offset = Some(offset);
311                                continue;
312                            }
313                            return Err(InvocationError::Rpc(err));
314                        }
315                        Err(e) => return Err(e),
316                    }
317                }
318                Ok::<(), InvocationError>(())
319            });
320            tasks.push(task);
321        }
322        drop(tx);
323
324        // File write loop
325        let mut pos = 0;
326        while let Some((offset, data)) = rx.recv().await {
327            if offset != pos {
328                file.seek(SeekFrom::Start(offset)).await?;
329            }
330            file.write_all(&data).await?;
331            pos = offset + data.len() as u64;
332        }
333
334        // Check if all tasks finished succesfully
335        for task in tasks {
336            task.await?
337                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
338        }
339        Ok(())
340    }
341
342    /// Uploads an async stream to Telegram servers.
343    ///
344    /// The file is not sent to any peer, but can be used as media when sending messages for a
345    /// certain period of time (less than a day). You can use this uploaded file multiple times.
346    ///
347    /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
348    ///
349    /// The stream size must be known beforehand. If this is not possible, you might need to
350    /// process the entire async stream to determine its size, and then use the size and the
351    /// downloaded buffer.
352    ///
353    /// The stream size may be less or equal to the actual length of the stream, but not more.
354    /// If it's less, you may continue to read from the stream after the method returns.
355    /// If it's more, the method will fail because it does not have enough data to read.
356    ///
357    /// Note that Telegram uses the file name in certain methods, for example, to make sure the
358    /// file is an image when trying to use send the file as photo media, so it is important that
359    /// the file name at least uses the right extension, even if the name is a dummy value.
360    /// If the input file name is empty, the non-empty dummy value "a" will be used instead.
361    /// Because it has no extension, you may not be able to use the file in certain methods.
362    ///
363    /// # Examples
364    ///
365    /// ```
366    /// # async fn f(peer: grammers_session::types::PeerRef, client: grammers_client::Client, some_vec: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
367    /// use grammers_client::message::InputMessage;
368    ///
369    /// // In-memory `Vec<u8>` buffers can be used as async streams
370    /// let size = some_vec.len();
371    /// let mut stream = std::io::Cursor::new(some_vec);
372    /// let uploaded_file = client.upload_stream(&mut stream, size, "sleep.jpg".to_string()).await?;
373    ///
374    /// client.send_message(peer, InputMessage::new().text("Zzz...").photo(uploaded_file)).await?;
375    /// # Ok(())
376    /// # }
377    /// ```
378    ///
379    /// [`InputMessage`]: crate::message::InputMessage
380    pub async fn upload_stream<S: AsyncRead + Unpin>(
381        &self,
382        stream: &mut S,
383        size: usize,
384        name: String,
385    ) -> Result<Uploaded, io::Error> {
386        let file_id = generate_random_id();
387        let name = if name.is_empty() {
388            "a".to_string()
389        } else {
390            name
391        };
392
393        let big_file = size > BIG_FILE_SIZE;
394        let parts = PartStream::new(stream, size);
395        let total_parts = parts.total_parts();
396
397        if big_file {
398            let parts = Arc::new(parts);
399            let mut tasks = FuturesUnordered::new();
400            for _ in 0..WORKER_COUNT {
401                let handle = self.clone();
402                let parts = Arc::clone(&parts);
403                let task = async move {
404                    while let Some((part, bytes)) = parts.next_part().await? {
405                        let ok = handle
406                            .invoke(&tl::functions::upload::SaveBigFilePart {
407                                file_id,
408                                file_part: part,
409                                file_total_parts: total_parts,
410                                bytes,
411                            })
412                            .await
413                            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
414
415                        if !ok {
416                            return Err(io::Error::new(
417                                io::ErrorKind::Other,
418                                "server failed to store uploaded data",
419                            ));
420                        }
421                    }
422                    Ok(())
423                };
424                tasks.push(task);
425            }
426
427            while let Some(res) = tasks.next().await {
428                res?;
429            }
430
431            Ok(Uploaded::from_raw(
432                tl::types::InputFileBig {
433                    id: file_id,
434                    parts: total_parts,
435                    name,
436                }
437                .into(),
438            ))
439        } else {
440            let mut md5 = md5::Context::new();
441            while let Some((part, bytes)) = parts.next_part().await? {
442                md5.consume(&bytes);
443                let ok = self
444                    .invoke(&tl::functions::upload::SaveFilePart {
445                        file_id,
446                        file_part: part,
447                        bytes,
448                    })
449                    .await
450                    .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
451
452                if !ok {
453                    return Err(io::Error::new(
454                        io::ErrorKind::Other,
455                        "server failed to store uploaded data",
456                    ));
457                }
458            }
459            Ok(Uploaded::from_raw(
460                tl::types::InputFile {
461                    id: file_id,
462                    parts: total_parts,
463                    name,
464                    md5_checksum: format!("{:x}", md5.finalize()),
465                }
466                .into(),
467            ))
468        }
469    }
470
471    /// Uploads a local file to Telegram servers.
472    ///
473    /// The file is not sent to any peer, but can be used as media when sending messages for a
474    /// certain period of time (less than a day). You can use this uploaded file multiple times.
475    ///
476    /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
477    ///
478    /// If you need more control over the uploaded data, such as performing only a partial upload
479    /// or with a different name, use [`Client::upload_stream`] instead.
480    ///
481    /// # Examples
482    ///
483    /// ```
484    /// # async fn f(peer: grammers_session::types::PeerRef, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
485    /// use grammers_client::message::InputMessage;
486    ///
487    /// let uploaded_file = client.upload_file("/home/username/photos/holidays.jpg").await?;
488    ///
489    /// client.send_message(peer, InputMessage::new().text("Check this out!").photo(uploaded_file)).await?;
490    /// # Ok(())
491    /// # }
492    /// ```
493    ///
494    /// [`InputMessage`]: crate::message::InputMessage
495    #[cfg(feature = "fs")]
496    pub async fn upload_file<P: AsRef<Path>>(&self, path: P) -> Result<Uploaded, io::Error> {
497        let path = path.as_ref();
498
499        let mut file = fs::File::open(path).await?;
500        let size = file.seek(SeekFrom::End(0)).await? as usize;
501        file.seek(SeekFrom::Start(0)).await?;
502
503        // File name will only be `None` for `..` path, and directories cannot be uploaded as
504        // files, so it's fine to unwrap.
505        let name = path.file_name().unwrap().to_string_lossy().to_string();
506
507        self.upload_stream(&mut file, size, name).await
508    }
509}
510
511struct PartStreamInner<'a, S: AsyncRead + Unpin> {
512    stream: &'a mut S,
513    current_part: i32,
514}
515
516struct PartStream<'a, S: AsyncRead + Unpin> {
517    inner: AsyncMutex<PartStreamInner<'a, S>>,
518    total_parts: i32,
519}
520
521impl<'a, S: AsyncRead + Unpin> PartStream<'a, S> {
522    fn new(stream: &'a mut S, size: usize) -> Self {
523        let total_parts = ((size + MAX_CHUNK_SIZE as usize - 1) / MAX_CHUNK_SIZE as usize) as i32;
524        Self {
525            inner: AsyncMutex::new(PartStreamInner {
526                stream,
527                current_part: 0,
528            }),
529            total_parts,
530        }
531    }
532
533    fn total_parts(&self) -> i32 {
534        self.total_parts
535    }
536
537    async fn next_part(&self) -> Result<Option<(i32, Vec<u8>)>, io::Error> {
538        let mut lock = self.inner.lock().await;
539        if lock.current_part >= self.total_parts {
540            return Ok(None);
541        }
542        let mut read = 0;
543        let mut buffer = vec![0; MAX_CHUNK_SIZE as usize];
544
545        while read != buffer.len() {
546            let n = lock.stream.read(&mut buffer[read..]).await?;
547            if n == 0 {
548                if lock.current_part == self.total_parts - 1 {
549                    break;
550                } else {
551                    return Err(io::Error::new(
552                        io::ErrorKind::UnexpectedEof,
553                        "reached EOF before reaching the last file part",
554                    ));
555                }
556            }
557            read += n;
558        }
559
560        let bytes = if read == buffer.len() {
561            buffer
562        } else {
563            buffer[..read].to_vec()
564        };
565
566        let res = Ok(Some((lock.current_part, bytes)));
567        lock.current_part += 1;
568        res
569    }
570}