Skip to main content

grammers_client/client/
files.rs

1// Copyright 2020 - developers of the `grammers` project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use std::mem;
10use std::sync::Arc;
11#[cfg(feature = "fs")]
12use std::{io::SeekFrom, path::Path};
13
14use futures_util::stream::{FuturesUnordered, StreamExt as _};
15use grammers_mtsender::InvocationError;
16use grammers_tl_types as tl;
17use tokio::io::{self, AsyncRead, AsyncReadExt};
18use tokio::sync::Mutex as AsyncMutex;
19#[cfg(feature = "fs")]
20use tokio::{
21    fs,
22    io::{AsyncSeekExt, AsyncWriteExt},
23    sync::mpsc::unbounded_channel,
24};
25
26use super::Client;
27use crate::media::{Downloadable, Uploaded};
28use crate::utils::generate_random_id;
29
30pub const MIN_CHUNK_SIZE: i32 = 4 * 1024;
31pub const MAX_CHUNK_SIZE: i32 = 512 * 1024;
32const FILE_MIGRATE_ERROR: i32 = 303;
33const BIG_FILE_SIZE: usize = 10 * 1024 * 1024;
34const WORKER_COUNT: usize = 4;
35
36/// Iterator returned by [`Client::iter_download`].
37pub struct DownloadIter {
38    client: Client,
39    done: bool,
40    size: Option<usize>,
41    variant: DownloadIterVariant,
42}
43
44enum DownloadIterVariant {
45    Request(tl::functions::upload::GetFile),
46    PreDownloaded(Vec<u8>),
47    PreFailed(io::Error),
48    Empty,
49}
50
51impl DownloadIter {
52    /// Changes the chunk size, in bytes, used to make requests. Useful if you only need to get a
53    /// small part of a file. By default, `MAX_CHUNK_SIZE` is used.
54    ///
55    /// # Panics
56    ///
57    /// Panics if `size` is not divisible by `MIN_CHUNK_SIZE`, or if `size` is not in contained in
58    /// the range `MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE`.
59    pub fn chunk_size(mut self, size: i32) -> Self {
60        assert!((MIN_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&size) && size % MIN_CHUNK_SIZE == 0);
61        match &mut self.variant {
62            DownloadIterVariant::Request(request) => request.limit = size,
63            _ => {}
64        }
65        self
66    }
67
68    /// Skips `n` chunks to start downloading a different offset from the file. If you want to
69    /// skip less data, modify the `chunk_size` before calling this method, and then reset it to
70    /// any value you want.
71    pub fn skip_chunks(mut self, n: i32) -> Self {
72        match &mut self.variant {
73            DownloadIterVariant::Request(request) => {
74                request.offset += request.limit as i64 * (n as i64)
75            }
76            _ => {}
77        }
78        self
79    }
80
81    /// Fetch and return the next chunk.
82    pub async fn next(&mut self) -> Result<Option<Vec<u8>>, InvocationError> {
83        if self.done {
84            return Ok(None);
85        }
86
87        let variant = mem::replace(&mut self.variant, DownloadIterVariant::Empty);
88
89        let mut request = match variant {
90            DownloadIterVariant::Request(r) => r,
91            DownloadIterVariant::PreDownloaded(data) => {
92                self.done = true;
93                return Ok(Some(data.clone()));
94            }
95            DownloadIterVariant::PreFailed(error) => {
96                return Err(InvocationError::Io(error));
97            }
98            DownloadIterVariant::Empty => return Ok(None),
99        };
100
101        use tl::enums::upload::File;
102
103        // TODO handle maybe FILEREF_UPGRADE_NEEDED
104        let mut dc = self.client.0.session.home_dc_id()?;
105        loop {
106            break match self.client.invoke_in_dc(dc, &request).await {
107                Ok(File::File(f)) => {
108                    let reached_known_size = self
109                        .size
110                        .is_some_and(|size| request.offset as usize + f.bytes.len() >= size);
111
112                    if reached_known_size || f.bytes.len() < request.limit as usize {
113                        self.done = true;
114                        if f.bytes.is_empty() {
115                            break Ok(None);
116                        }
117                    }
118
119                    request.offset += request.limit as i64;
120                    self.variant = DownloadIterVariant::Request(request);
121
122                    Ok(Some(f.bytes))
123                }
124                Ok(File::CdnRedirect(_)) => {
125                    panic!("API returned File::CdnRedirect even though cdn_supported = false");
126                }
127                Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
128                    match self.client.copy_auth_to_dc(dc).await {
129                        Ok(_) => continue,
130                        Err(e) => Err(e),
131                    }
132                }
133                Err(InvocationError::Rpc(err)) if err.code == FILE_MIGRATE_ERROR => {
134                    dc = err.value.unwrap() as _;
135                    continue;
136                }
137                Err(e) => Err(e),
138            };
139        }
140    }
141}
142
143/// Method implementations related to uploading or downloading files.
144impl Client {
145    /// Returns a new iterator over the contents of a media document that will be downloaded.
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// # async fn f(downloadable: grammers_client::media::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
151    /// let mut file_bytes = Vec::new();
152    /// let mut download = client.iter_download(&downloadable);
153    ///
154    /// while let Some(chunk) = download.next().await? {
155    ///     file_bytes.extend(chunk);
156    /// }
157    ///
158    /// // The file is now downloaded in-memory, inside `file_bytes`!
159    /// # Ok(())
160    /// # }
161    /// ```
162    pub fn iter_download<D: Downloadable>(&self, downloadable: &D) -> DownloadIter {
163        if let Some(data) = downloadable.to_data() {
164            DownloadIter {
165                client: self.clone(),
166                done: false,
167                size: Some(data.len()),
168                variant: DownloadIterVariant::PreDownloaded(data),
169            }
170        } else if let Some(location) = downloadable.to_raw_input_location() {
171            DownloadIter {
172                client: self.clone(),
173                done: false,
174                size: downloadable.size(),
175                variant: DownloadIterVariant::Request(tl::functions::upload::GetFile {
176                    precise: false,
177                    cdn_supported: false,
178                    location,
179                    offset: 0,
180                    limit: MAX_CHUNK_SIZE,
181                }),
182            }
183        } else {
184            DownloadIter {
185                client: self.clone(),
186                done: false,
187                size: None,
188                variant: DownloadIterVariant::PreFailed(io::Error::new(
189                    io::ErrorKind::Other,
190                    "media not downloadable",
191                )),
192            }
193        }
194    }
195
196    /// Downloads a media file into the specified path.
197    ///
198    /// If the file already exists, it will be overwritten.
199    ///
200    /// This is a small wrapper around [`Client::iter_download`] for the common case of
201    /// wanting to save the file locally.
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// # async fn f(downloadable: grammers_client::media::Media, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
207    /// client.download_media(&downloadable, "/home/username/photos/holidays.jpg").await?;
208    /// # Ok(())
209    /// # }
210    /// ```
211    #[cfg(feature = "fs")]
212    pub async fn download_media<D: Downloadable, P: AsRef<Path>>(
213        &self,
214        downloadable: &D,
215        path: P,
216    ) -> Result<(), InvocationError> {
217        // Concurrent downloader
218        if let Some((location, size)) = downloadable
219            .to_raw_input_location()
220            .zip(downloadable.size())
221        {
222            if size > BIG_FILE_SIZE {
223                return self
224                    .download_media_concurrent(location, size, path, WORKER_COUNT)
225                    .await;
226            }
227        }
228
229        let mut download = self.iter_download(downloadable);
230        Ok(Client::load(path, &mut download).await?)
231    }
232
233    #[cfg(feature = "fs")]
234    async fn load<P: AsRef<Path>>(path: P, download: &mut DownloadIter) -> Result<(), io::Error> {
235        let mut file = fs::File::create(path).await?;
236        while let Some(chunk) = download.next().await.map_err(io::Error::other)? {
237            file.write_all(&chunk).await?;
238        }
239
240        Ok(())
241    }
242
243    /// Downloads a `Document` to specified path using multiple connections
244    #[cfg(feature = "fs")]
245    async fn download_media_concurrent<P: AsRef<Path>>(
246        &self,
247        location: tl::enums::InputFileLocation,
248        size: usize,
249        path: P,
250        workers: usize,
251    ) -> Result<(), InvocationError> {
252        // Allocate
253        let mut file = fs::File::create(path).await.map_err(InvocationError::Io)?;
254        file.set_len(size as u64)
255            .await
256            .map_err(InvocationError::Io)?;
257        file.seek(SeekFrom::Start(0))
258            .await
259            .map_err(InvocationError::Io)?;
260
261        // Start workers
262        let (tx, mut rx) = unbounded_channel();
263        let part_index = Arc::new(tokio::sync::Mutex::<i64>::new(0));
264        let mut tasks = vec![];
265        let home_dc_id = self.0.session.home_dc_id()?;
266        for _ in 0..workers {
267            let location = location.clone();
268            let tx = tx.clone();
269            let part_index = part_index.clone();
270            let client = self.clone();
271            let task = tokio::task::spawn(async move {
272                let mut retry_offset = None;
273                let mut dc = home_dc_id;
274                loop {
275                    // Calculate file offset
276                    let offset: i64 = {
277                        if let Some(offset) = retry_offset {
278                            retry_offset = None;
279                            offset
280                        } else {
281                            let mut i = part_index.lock().await;
282                            *i += 1;
283                            MAX_CHUNK_SIZE as i64 * (*i - 1)
284                        }
285                    };
286                    if offset >= size as i64 {
287                        break;
288                    }
289                    // Fetch from telegram
290                    let request = &tl::functions::upload::GetFile {
291                        precise: true,
292                        cdn_supported: false,
293                        location: location.clone(),
294                        offset,
295                        limit: MAX_CHUNK_SIZE,
296                    };
297                    match client.invoke_in_dc(dc, request).await {
298                        Ok(tl::enums::upload::File::File(file)) => {
299                            tx.send((offset as u64, file.bytes)).unwrap();
300                        }
301                        Ok(tl::enums::upload::File::CdnRedirect(_)) => {
302                            panic!(
303                                "API returned File::CdnRedirect even though cdn_supported = false"
304                            );
305                        }
306                        Err(InvocationError::Rpc(err)) if &err.name == "AUTH_KEY_UNREGISTERED" => {
307                            match client.copy_auth_to_dc(dc).await {
308                                Ok(_) => {
309                                    retry_offset = Some(offset);
310                                    continue;
311                                }
312                                Err(e) => return Err(e),
313                            }
314                        }
315                        Err(InvocationError::Rpc(err)) => {
316                            if err.code == FILE_MIGRATE_ERROR {
317                                dc = err.value.unwrap() as _;
318                                retry_offset = Some(offset);
319                                continue;
320                            }
321                            return Err(InvocationError::Rpc(err));
322                        }
323                        Err(e) => return Err(e),
324                    }
325                }
326                Ok::<(), InvocationError>(())
327            });
328            tasks.push(task);
329        }
330        drop(tx);
331
332        // File write loop
333        let mut pos = 0;
334        while let Some((offset, data)) = rx.recv().await {
335            if offset != pos {
336                file.seek(SeekFrom::Start(offset)).await?;
337            }
338            file.write_all(&data).await?;
339            pos = offset + data.len() as u64;
340        }
341
342        // Check if all tasks finished succesfully
343        for task in tasks {
344            let res = task.await.map_err(io::Error::other)?;
345            res?;
346        }
347        Ok(())
348    }
349
350    /// Uploads an async stream to Telegram servers.
351    ///
352    /// The file is not sent to any peer, but can be used as media when sending messages for a
353    /// certain period of time (less than a day). You can use this uploaded file multiple times.
354    ///
355    /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
356    ///
357    /// The stream size must be known beforehand. If this is not possible, you might need to
358    /// process the entire async stream to determine its size, and then use the size and the
359    /// downloaded buffer.
360    ///
361    /// The stream size may be less or equal to the actual length of the stream, but not more.
362    /// If it's less, you may continue to read from the stream after the method returns.
363    /// If it's more, the method will fail because it does not have enough data to read.
364    ///
365    /// Note that Telegram uses the file name in certain methods, for example, to make sure the
366    /// file is an image when trying to use send the file as photo media, so it is important that
367    /// the file name at least uses the right extension, even if the name is a dummy value.
368    /// If the input file name is empty, the non-empty dummy value "a" will be used instead.
369    /// Because it has no extension, you may not be able to use the file in certain methods.
370    ///
371    /// # Examples
372    ///
373    /// ```
374    /// # async fn f(peer: grammers_session::types::PeerRef, client: grammers_client::Client, some_vec: &[u8]) -> Result<(), Box<dyn std::error::Error>> {
375    /// use grammers_client::message::InputMessage;
376    ///
377    /// // In-memory `Vec<u8>` buffers can be used as async streams
378    /// let size = some_vec.len();
379    /// let mut stream = std::io::Cursor::new(some_vec);
380    /// let uploaded_file = client.upload_stream(&mut stream, size, "sleep.jpg".to_string()).await?;
381    ///
382    /// client.send_message(peer, InputMessage::new().text("Zzz...").photo(uploaded_file)).await?;
383    /// # Ok(())
384    /// # }
385    /// ```
386    ///
387    /// [`InputMessage`]: crate::message::InputMessage
388    pub async fn upload_stream<S: AsyncRead + Unpin>(
389        &self,
390        stream: &mut S,
391        size: usize,
392        name: String,
393    ) -> Result<Uploaded, io::Error> {
394        let file_id = generate_random_id();
395        let name = if name.is_empty() {
396            "a".to_string()
397        } else {
398            name
399        };
400
401        let big_file = size > BIG_FILE_SIZE;
402        let parts = PartStream::new(stream, size);
403        let total_parts = parts.total_parts();
404
405        if big_file {
406            let parts = Arc::new(parts);
407            let mut tasks = FuturesUnordered::new();
408            for _ in 0..WORKER_COUNT {
409                let handle = self.clone();
410                let parts = Arc::clone(&parts);
411                let task = async move {
412                    while let Some((part, bytes)) = parts.next_part().await? {
413                        let ok = handle
414                            .invoke(&tl::functions::upload::SaveBigFilePart {
415                                file_id,
416                                file_part: part,
417                                file_total_parts: total_parts,
418                                bytes,
419                            })
420                            .await
421                            .map_err(io::Error::other)?;
422
423                        if !ok {
424                            return Err(io::Error::new(
425                                io::ErrorKind::Other,
426                                "server failed to store uploaded data",
427                            ));
428                        }
429                    }
430                    Ok(())
431                };
432                tasks.push(task);
433            }
434
435            while let Some(res) = tasks.next().await {
436                res?;
437            }
438
439            Ok(Uploaded::from_raw(
440                tl::types::InputFileBig {
441                    id: file_id,
442                    parts: total_parts,
443                    name,
444                }
445                .into(),
446            ))
447        } else {
448            let mut md5 = md5::Context::new();
449            while let Some((part, bytes)) = parts.next_part().await? {
450                md5.consume(&bytes);
451                let ok = self
452                    .invoke(&tl::functions::upload::SaveFilePart {
453                        file_id,
454                        file_part: part,
455                        bytes,
456                    })
457                    .await
458                    .map_err(io::Error::other)?;
459
460                if !ok {
461                    return Err(io::Error::new(
462                        io::ErrorKind::Other,
463                        "server failed to store uploaded data",
464                    ));
465                }
466            }
467            Ok(Uploaded::from_raw(
468                tl::types::InputFile {
469                    id: file_id,
470                    parts: total_parts,
471                    name,
472                    md5_checksum: format!("{:x}", md5.finalize()),
473                }
474                .into(),
475            ))
476        }
477    }
478
479    /// Uploads a local file to Telegram servers.
480    ///
481    /// The file is not sent to any peer, but can be used as media when sending messages for a
482    /// certain period of time (less than a day). You can use this uploaded file multiple times.
483    ///
484    /// Refer to [`InputMessage`] to learn more uses for `uploaded_file`.
485    ///
486    /// If you need more control over the uploaded data, such as performing only a partial upload
487    /// or with a different name, use [`Client::upload_stream`] instead.
488    ///
489    /// # Examples
490    ///
491    /// ```
492    /// # async fn f(peer: grammers_session::types::PeerRef, client: grammers_client::Client) -> Result<(), Box<dyn std::error::Error>> {
493    /// use grammers_client::message::InputMessage;
494    ///
495    /// let uploaded_file = client.upload_file("/home/username/photos/holidays.jpg").await?;
496    ///
497    /// client.send_message(peer, InputMessage::new().text("Check this out!").photo(uploaded_file)).await?;
498    /// # Ok(())
499    /// # }
500    /// ```
501    ///
502    /// [`InputMessage`]: crate::message::InputMessage
503    #[cfg(feature = "fs")]
504    pub async fn upload_file<P: AsRef<Path>>(&self, path: P) -> Result<Uploaded, io::Error> {
505        let path = path.as_ref();
506
507        let mut file = fs::File::open(path).await?;
508        let size = file.seek(SeekFrom::End(0)).await? as usize;
509        file.seek(SeekFrom::Start(0)).await?;
510
511        // File name will only be `None` for `..` path, and directories cannot be uploaded as
512        // files, so it's fine to unwrap.
513        let name = path.file_name().unwrap().to_string_lossy().to_string();
514
515        self.upload_stream(&mut file, size, name).await
516    }
517}
518
519struct PartStreamInner<'a, S: AsyncRead + Unpin> {
520    stream: &'a mut S,
521    current_part: i32,
522}
523
524struct PartStream<'a, S: AsyncRead + Unpin> {
525    inner: AsyncMutex<PartStreamInner<'a, S>>,
526    total_parts: i32,
527}
528
529impl<'a, S: AsyncRead + Unpin> PartStream<'a, S> {
530    fn new(stream: &'a mut S, size: usize) -> Self {
531        let total_parts = ((size + MAX_CHUNK_SIZE as usize - 1) / MAX_CHUNK_SIZE as usize) as i32;
532        Self {
533            inner: AsyncMutex::new(PartStreamInner {
534                stream,
535                current_part: 0,
536            }),
537            total_parts,
538        }
539    }
540
541    fn total_parts(&self) -> i32 {
542        self.total_parts
543    }
544
545    async fn next_part(&self) -> Result<Option<(i32, Vec<u8>)>, io::Error> {
546        let mut lock = self.inner.lock().await;
547        if lock.current_part >= self.total_parts {
548            return Ok(None);
549        }
550        let mut read = 0;
551        let mut buffer = vec![0; MAX_CHUNK_SIZE as usize];
552
553        while read != buffer.len() {
554            let n = lock.stream.read(&mut buffer[read..]).await?;
555            if n == 0 {
556                if lock.current_part == self.total_parts - 1 {
557                    break;
558                } else {
559                    return Err(io::Error::new(
560                        io::ErrorKind::UnexpectedEof,
561                        "reached EOF before reaching the last file part",
562                    ));
563                }
564            }
565            read += n;
566        }
567
568        let bytes = if read == buffer.len() {
569            buffer
570        } else {
571            buffer[..read].to_vec()
572        };
573
574        let res = Ok(Some((lock.current_part, bytes)));
575        lock.current_part += 1;
576        res
577    }
578}