grammers_client/client/
files.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::Client;
10use crate::types::{Downloadable, Uploaded};
11use crate::utils::generate_random_id;
12use futures_util::stream::{FuturesUnordered, StreamExt as _};
13use grammers_mtsender::InvocationError;
14use grammers_tl_types as tl;
15use std::mem;
16use std::sync::Arc;
17use tokio::{
18    io::{self, AsyncRead, AsyncReadExt},
19    sync::Mutex as AsyncMutex,
20};
21
22#[cfg(feature = "fs")]
23use {
24    std::{io::SeekFrom, path::Path},
25    tokio::{
26        fs,
27        io::{AsyncSeekExt, AsyncWriteExt},
28        sync::mpsc::unbounded_channel,
29    },
30};
31
32pub const MIN_CHUNK_SIZE: i32 = 4 * 1024;
33pub const MAX_CHUNK_SIZE: i32 = 512 * 1024;
34const FILE_MIGRATE_ERROR: i32 = 303;
35const BIG_FILE_SIZE: usize = 10 * 1024 * 1024;
36const WORKER_COUNT: usize = 4;
37
38pub struct DownloadIter {
39    client: Client,
40    done: bool,
41    variant: DownloadIterVariant,
42}
43
44enum DownloadIterVariant {
45    Request(tl::functions::upload::GetFile),
46    PreDownloaded(Vec<u8>),
47    PreFailed(io::Error),
48    Empty,
49}
50
51impl DownloadIter {
52    /// Changes the chunk size, in bytes, used to make requests. Useful if you only need to get a
53    /// small part of a file. By default, `MAX_CHUNK_SIZE` is used.
54    ///
55    /// # Panics
56    ///
57    /// Panics if `size` is not divisible by `MIN_CHUNK_SIZE`, or if `size` is not in contained in
58    /// the range `MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE`.
59    pub fn chunk_size(mut self, size: i32) -> Self {
60        assert!((MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&size) && size % MIN_CHUNK_SIZE == 0);
61        match &mut self.variant {
62            DownloadIterVariant::Request(request) => request.limit = size,
63            _ => {}
64        }
65        self
66    }
67
68    /// Skips `n` chunks to start downloading a different offset from the file. If you want to
69    /// skip less data, modify the `chunk_size` before calling this method, and then reset it to
70    /// any value you want.
71    pub fn skip_chunks(mut self, n: i32) -> Self {
72        match &mut self.variant {
73            DownloadIterVariant::Request(request) => {
74                request.offset += request.limit as i64 * (n as i64)
75            }
76            _ => {}
77        }
78        self
79    }
80
81    /// Fetch and return the next chunk.
82    pub async fn next(&mut self) -> Result<Option<Vec<u8>>, InvocationError> {
83        if self.done {
84            return Ok(None);
85        }
86
87        let variant = mem::replace(&mut self.variant, DownloadIterVariant::Empty);
88
89        let mut request = match variant {
90            DownloadIterVariant::Request(r) => r,
91            DownloadIterVariant::PreDownloaded(data) => {
92                self.done = true;
93                return Ok(Some(data.clone()));
94            }
95            DownloadIterVariant::PreFailed(error) => {
96                return Err(InvocationError::Io(error));
97            }
98            DownloadIterVariant::Empty => return Ok(None),
99        };
100
101        use tl::enums::upload::File;
102
103        // TODO handle maybe FILEREF_UPGRADE_NEEDED
104        let mut dc = self.client.0.session.home_dc_id();
105        loop {
106            break match self.client.invoke_in_dc(dc, &request).await {
107                Ok(File::File(f)) => {
108                    if f.bytes.len() < request.limit as usize {
109                        self.done = true;
110                        if f.bytes.is_empty() {
111                            break Ok(None);
112                        }
113                    }
114
115                    request.offset += request.limit as i64;
116                    self.variant = DownloadIterVariant::Request(request);
117
118                    Ok(Some(f.bytes))
119                }
120                Ok(File::CdnRedirect(_)) => {
121                    panic!("API returned File::CdnRedirect even though cdn_supported = false");
122                }
123                Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
124                    match self.client.copy_auth_to_dc(dc).await {
125                        Ok(_) => continue,
126                        Err(e) => Err(e),
127                    }
128                }
129                Err(InvocationError::Rpc(err)) if err.code == FILE_MIGRATE_ERROR => {
130                    dc = err.value.unwrap() as _;
131                    continue;
132                }
133                Err(e) => Err(e),
134            };
135        }
136    }
137}
138
139/// Method implementations related to uploading or downloading files.
140impl Client {
141    /// Returns a new iterator over the contents of a media document that will be downloaded.
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// # async fn f(downloadable: grammers_client::types::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
147    /// let mut file_bytes = Vec::new();
148    /// let mut download = client.iter_download(&downloadable);
149    ///
150    /// while let Some(chunk) = download.next().await? {
151    ///     file_bytes.extend(chunk);
152    /// }
153    ///
154    /// // The file is now downloaded in-memory, inside `file_bytes`!
155    /// # Ok(())
156    /// # }
157    /// ```
158    pub fn iter_download<D: Downloadable>(&self, downloadable: &D) -> DownloadIter {
159        if let Some(data) = downloadable.to_data() {
160            DownloadIter {
161                client: self.clone(),
162                done: false,
163                variant: DownloadIterVariant::PreDownloaded(data),
164            }
165        } else if let Some(location) = downloadable.to_raw_input_location() {
166            DownloadIter {
167                client: self.clone(),
168                done: false,
169                variant: DownloadIterVariant::Request(tl::functions::upload::GetFile {
170                    precise: false,
171                    cdn_supported: false,
172                    location,
173                    offset: 0,
174                    limit: MAX_CHUNK_SIZE,
175                }),
176            }
177        } else {
178            DownloadIter {
179                client: self.clone(),
180                done: false,
181                variant: DownloadIterVariant::PreFailed(io::Error::new(
182                    io::ErrorKind::Other,
183                    "media not downloadable",
184                )),
185            }
186        }
187    }
188
189    /// Downloads a media file into the specified path.
190    ///
191    /// If the file already exists, it will be overwritten.
192    ///
193    /// This is a small wrapper around [`Client::iter_download`] for the common case of
194    /// wanting to save the file locally.
195    ///
196    /// # Examples
197    ///
198    /// ```
199    /// # async fn f(downloadable: grammers_client::types::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
200    /// client.download_media(&downloadable, "/home/username/photos/holidays.jpg").await?;
201    /// # Ok(())
202    /// # }
203    /// ```
204    #[cfg(feature = "fs")]
205    pub async fn download_media<D: Downloadable, P: AsRef<Path>>(
206        &self,
207        downloadable: &D,
208        path: P,
209    ) -> Result<(), io::Error> {
210        // Concurrent downloader
211        if let Some((location, size)) = downloadable
212            .to_raw_input_location()
213            .zip(downloadable.size())
214        {
215            if size > BIG_FILE_SIZE {
216                return self
217                    .download_media_concurrent(location, size, path, WORKER_COUNT)
218                    .await;
219            }
220        }
221
222        let mut download = self.iter_download(downloadable);
223        Client::load(path, &mut download).await
224    }
225
226    #[cfg(feature = "fs")]
227    async fn load<P: AsRef<Path>>(path: P, download: &mut DownloadIter) -> Result<(), io::Error> {
228        let mut file = fs::File::create(path).await?;
229        while let Some(chunk) = download
230            .next()
231            .await
232            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
233        {
234            file.write_all(&chunk).await?;
235        }
236
237        Ok(())
238    }
239
240    /// Downloads a `Document` to specified path using multiple connections
241    #[cfg(feature = "fs")]
242    async fn download_media_concurrent<P: AsRef<Path>>(
243        &self,
244        location: tl::enums::InputFileLocation,
245        size: usize,
246        path: P,
247        workers: usize,
248    ) -> Result<(), io::Error> {
249        // Allocate
250        let mut file = fs::File::create(path).await?;
251        file.set_len(size as u64).await?;
252        file.seek(SeekFrom::Start(0)).await?;
253
254        // Start workers
255        let (tx, mut rx) = unbounded_channel();
256        let part_index = Arc::new(tokio::sync::Mutex::new(0));
257        let mut tasks = vec![];
258        let home_dc_id = self.0.session.home_dc_id();
259        for _ in 0..workers {
260            let location = location.clone();
261            let tx = tx.clone();
262            let part_index = part_index.clone();
263            let client = self.clone();
264            let task = tokio::task::spawn(async move {
265                let mut retry_offset = None;
266                let mut dc = home_dc_id;
267                loop {
268                    // Calculate file offset
269                    let offset: i64 = {
270                        if let Some(offset) = retry_offset {
271                            retry_offset = None;
272                            offset
273                        } else {
274                            let mut i = part_index.lock().await;
275                            *i += 1;
276                            (MAX_CHUNK_SIZE * (*i - 1)) as i64
277                        }
278                    };
279                    if offset > size as i64 {
280                        break;
281                    }
282                    // Fetch from telegram
283                    let request = &tl::functions::upload::GetFile {
284                        precise: true,
285                        cdn_supported: false,
286                        location: location.clone(),
287                        offset,
288                        limit: MAX_CHUNK_SIZE,
289                    };
290                    match client.invoke_in_dc(dc, request).await {
291                        Ok(tl::enums::upload::File::File(file)) => {
292                            tx.send((offset as u64, file.bytes)).unwrap();
293                        }
294                        Ok(tl::enums::upload::File::CdnRedirect(_)) => {
295                            panic!(
296                                "API returned File::CdnRedirect even though cdn_supported = false"
297                            );
298                        }
299                        Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
300                            match client.copy_auth_to_dc(dc).await {
301                                Ok(_) => {
302                                    retry_offset = Some(offset);
303                                    continue;
304                                }
305                                Err(e) => return Err(e),
306                            }
307                        }
308                        Err(InvocationError::Rpc(err)) => {
309                            if err.code == FILE_MIGRATE_ERROR {
310                                dc = err.value.unwrap() as _;
311                                retry_offset = Some(offset);
312                                continue;
313                            }
314                            return Err(InvocationError::Rpc(err));
315                        }
316                        Err(e) => return Err(e),
317                    }
318                }
319                Ok::<(), InvocationError>(())
320            });
321            tasks.push(task);
322        }
323        drop(tx);
324
325        // File write loop
326        let mut pos = 0;
327        while let Some((offset, data)) = rx.recv().await {
328            if offset != pos {
329                file.seek(SeekFrom::Start(offset)).await?;
330            }
331            file.write_all(&data).await?;
332            pos = offset + data.len() as u64;
333        }
334
335        // Check if all tasks finished succesfully
336        for task in tasks {
337            task.await?
338                .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
339        }
340        Ok(())
341    }
342
343    /// Uploads an async stream to Telegram servers.
344    ///
345    /// The file is not sent to any peer, but can be used as media when sending messages for a
346    /// certain period of time (less than a day). You can use this uploaded file multiple times.
347    ///
348    /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
349    ///
350    /// The stream size must be known beforehand. If this is not possible, you might need to
351    /// process the entire async stream to determine its size, and then use the size and the
352    /// downloaded buffer.
353    ///
354    /// The stream size may be less or equal to the actual length of the stream, but not more.
355    /// If it's less, you may continue to read from the stream after the method returns.
356    /// If it's more, the method will fail because it does not have enough data to read.
357    ///
358    /// Note that Telegram uses the file name in certain methods, for example, to make sure the
359    /// file is an image when trying to use send the file as photo media, so it is important that
360    /// the file name at least uses the right extension, even if the name is a dummy value.
361    /// If the input file name is empty, the non-empty dummy value "a" will be used instead.
362    /// Because it has no extension, you may not be able to use the file in certain methods.
363    ///
364    /// # Examples
365    ///
366    /// ```
367    /// # async fn f(peer: grammers_session::defs::PeerRef, client: grammers_client::Client, some_vec: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
368    /// use grammers_client::InputMessage;
369    ///
370    /// // In-memory `Vec<u8>` buffers can be used as async streams
371    /// let size = some_vec.len();
372    /// let mut stream = std::io::Cursor::new(some_vec);
373    /// let uploaded_file = client.upload_stream(&mut stream, size, "sleep.jpg".to_string()).await?;
374    ///
375    /// client.send_message(peer, InputMessage::new().text("Zzz...").photo(uploaded_file)).await?;
376    /// # Ok(())
377    /// # }
378    /// ```
379    ///
380    /// [`InputMessage`]: crate::types::InputMessage
381    pub async fn upload_stream<S: AsyncRead + Unpin>(
382        &self,
383        stream: &mut S,
384        size: usize,
385        name: String,
386    ) -> Result<Uploaded, io::Error> {
387        let file_id = generate_random_id();
388        let name = if name.is_empty() {
389            "a".to_string()
390        } else {
391            name
392        };
393
394        let big_file = size > BIG_FILE_SIZE;
395        let parts = PartStream::new(stream, size);
396        let total_parts = parts.total_parts();
397
398        if big_file {
399            let parts = Arc::new(parts);
400            let mut tasks = FuturesUnordered::new();
401            for _ in 0..WORKER_COUNT {
402                let handle = self.clone();
403                let parts = Arc::clone(&parts);
404                let task = async move {
405                    while let Some((part, bytes)) = parts.next_part().await? {
406                        let ok = handle
407                            .invoke(&tl::functions::upload::SaveBigFilePart {
408                                file_id,
409                                file_part: part,
410                                file_total_parts: total_parts,
411                                bytes,
412                            })
413                            .await
414                            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
415
416                        if !ok {
417                            return Err(io::Error::new(
418                                io::ErrorKind::Other,
419                                "server failed to store uploaded data",
420                            ));
421                        }
422                    }
423                    Ok(())
424                };
425                tasks.push(task);
426            }
427
428            while let Some(res) = tasks.next().await {
429                res?;
430            }
431
432            Ok(Uploaded::from_raw(
433                tl::types::InputFileBig {
434                    id: file_id,
435                    parts: total_parts,
436                    name,
437                }
438                .into(),
439            ))
440        } else {
441            let mut md5 = md5::Context::new();
442            while let Some((part, bytes)) = parts.next_part().await? {
443                md5.consume(&bytes);
444                let ok = self
445                    .invoke(&tl::functions::upload::SaveFilePart {
446                        file_id,
447                        file_part: part,
448                        bytes,
449                    })
450                    .await
451                    .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
452
453                if !ok {
454                    return Err(io::Error::new(
455                        io::ErrorKind::Other,
456                        "server failed to store uploaded data",
457                    ));
458                }
459            }
460            Ok(Uploaded::from_raw(
461                tl::types::InputFile {
462                    id: file_id,
463                    parts: total_parts,
464                    name,
465                    md5_checksum: format!("{:x}", md5.finalize()),
466                }
467                .into(),
468            ))
469        }
470    }
471
472    /// Uploads a local file to Telegram servers.
473    ///
474    /// The file is not sent to any peer, but can be used as media when sending messages for a
475    /// certain period of time (less than a day). You can use this uploaded file multiple times.
476    ///
477    /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
478    ///
479    /// If you need more control over the uploaded data, such as performing only a partial upload
480    /// or with a different name, use [`Client::upload_stream`] instead.
481    ///
482    /// # Examples
483    ///
484    /// ```
485    /// # async fn f(peer: grammers_session::defs::PeerRef, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
486    /// use grammers_client::InputMessage;
487    ///
488    /// let uploaded_file = client.upload_file("/home/username/photos/holidays.jpg").await?;
489    ///
490    /// client.send_message(peer, InputMessage::new().text("Check this out!").photo(uploaded_file)).await?;
491    /// # Ok(())
492    /// # }
493    /// ```
494    ///
495    /// [`InputMessage`]: crate::InputMessage
496    #[cfg(feature = "fs")]
497    pub async fn upload_file<P: AsRef<Path>>(&self, path: P) -> Result<Uploaded, io::Error> {
498        let path = path.as_ref();
499
500        let mut file = fs::File::open(path).await?;
501        let size = file.seek(SeekFrom::End(0)).await? as usize;
502        file.seek(SeekFrom::Start(0)).await?;
503
504        // File name will only be `None` for `..` path, and directories cannot be uploaded as
505        // files, so it's fine to unwrap.
506        let name = path.file_name().unwrap().to_string_lossy().to_string();
507
508        self.upload_stream(&mut file, size, name).await
509    }
510}
511
512struct PartStreamInner<'a, S: AsyncRead + Unpin> {
513    stream: &'a mut S,
514    current_part: i32,
515}
516
517struct PartStream<'a, S: AsyncRead + Unpin> {
518    inner: AsyncMutex<PartStreamInner<'a, S>>,
519    total_parts: i32,
520}
521
522impl<'a, S: AsyncRead + Unpin> PartStream<'a, S> {
523    fn new(stream: &'a mut S, size: usize) -> Self {
524        let total_parts = ((size + MAX_CHUNK_SIZE as usize - 1) / MAX_CHUNK_SIZE as usize) as i32;
525        Self {
526            inner: AsyncMutex::new(PartStreamInner {
527                stream,
528                current_part: 0,
529            }),
530            total_parts,
531        }
532    }
533
534    fn total_parts(&self) -> i32 {
535        self.total_parts
536    }
537
538    async fn next_part(&self) -> Result<Option<(i32, Vec<u8>)>, io::Error> {
539        let mut lock = self.inner.lock().await;
540        if lock.current_part >= self.total_parts {
541            return Ok(None);
542        }
543        let mut read = 0;
544        let mut buffer = vec![0; MAX_CHUNK_SIZE as usize];
545
546        while read != buffer.len() {
547            let n = lock.stream.read(&mut buffer[read..]).await?;
548            if n == 0 {
549                if lock.current_part == self.total_parts - 1 {
550                    break;
551                } else {
552                    return Err(io::Error::new(
553                        io::ErrorKind::UnexpectedEof,
554                        "reached EOF before reaching the last file part",
555                    ));
556                }
557            }
558            read += n;
559        }
560
561        let bytes = if read == buffer.len() {
562            buffer
563        } else {
564            buffer[..read].to_vec()
565        };
566
567        let res = Ok(Some((lock.current_part, bytes)));
568        lock.current_part += 1;
569        res
570    }
571}