Skip to main content

anni_provider/providers/
drive.rs

1use crate::{
2    common::content_range_to_range, AnniProvider, AudioInfo, AudioResourceReader, ProviderError,
3    Range, ResourceReader,
4};
5use anni_google_drive3::{
6    hyper, hyper::client::HttpConnector, hyper_rustls::HttpsConnector, oauth2, DriveHub,
7};
8use async_trait::async_trait;
9use std::borrow::Cow;
10use std::collections::{HashMap, HashSet};
11use std::num::NonZeroU8;
12use std::path::PathBuf;
13
14use self::oauth2::authenticator::Authenticator;
15use self::oauth2::authenticator_delegate::DefaultInstalledFlowDelegate;
16use crate::utils::read_duration;
17use anni_google_drive3::api::{FileList, FileListCall};
18use anni_google_drive3::hyper_rustls::HttpsConnectorBuilder;
19use anni_repo::db::RepoDatabaseRead;
20use anni_repo::library::{AlbumFolderInfo, DiscFolderInfo};
21use dashmap::DashMap;
22use futures::TryStreamExt;
23use parking_lot::Mutex;
24use std::str::FromStr;
25use tokio::sync::Semaphore;
26
27pub enum DriveAuth {
28    InstalledFlow {
29        client_id: String,
30        client_secret: String,
31        project_id: Option<String>,
32    },
33    ServiceAccount(oauth2::ServiceAccountKey),
34}
35
36impl Default for DriveAuth {
37    fn default() -> Self {
38        DriveAuth::InstalledFlow {
39            client_id: "175511611598-ot9agsmf6v3lf1jc3qbsf1vcru7saop7.apps.googleusercontent.com"
40                .to_string(),
41            client_secret: "mW1neo-JSSwzYz5Syqiiset1".to_string(),
42            project_id: Some("anni-provider".to_string()),
43        }
44    }
45}
46
47impl DriveAuth {
48    pub async fn build(
49        self,
50        token_storage: TokenStorage,
51    ) -> std::io::Result<Authenticator<HttpsConnector<HttpConnector>>> {
52        match self {
53            DriveAuth::InstalledFlow {
54                client_id,
55                client_secret,
56                project_id,
57            } => {
58                let builder = oauth2::InstalledFlowAuthenticator::builder(
59                    oauth2::ApplicationSecret {
60                        client_id,
61                        project_id,
62                        auth_uri: "https://accounts.google.com/o/oauth2/auth".to_string(),
63                        token_uri: "https://oauth2.googleapis.com/token".to_string(),
64                        auth_provider_x509_cert_url: Some(
65                            "https://www.googleapis.com/oauth2/v1/certs".to_string(),
66                        ),
67                        client_secret,
68                        redirect_uris: vec!["urn:ietf:wg:oauth:2.0:oob".to_string()],
69                        client_email: None,
70                        client_x509_cert_url: None,
71                    },
72                    oauth2::InstalledFlowReturnMethod::Interactive,
73                )
74                .flow_delegate(Box::new(DefaultInstalledFlowDelegate));
75                match token_storage {
76                    TokenStorage::Disk(path) => builder.persist_tokens_to_disk(path),
77                    TokenStorage::Custom(storage) => builder.with_storage(storage),
78                    TokenStorage::Memory => builder,
79                }
80                .build()
81                .await
82            }
83            DriveAuth::ServiceAccount(sa) => {
84                let builder = oauth2::ServiceAccountAuthenticator::builder(sa);
85                match token_storage {
86                    TokenStorage::Disk(path) => builder.persist_tokens_to_disk(path),
87                    TokenStorage::Custom(storage) => builder.with_storage(storage),
88                    TokenStorage::Memory => builder,
89                }
90                .build()
91                .await
92            }
93        }
94    }
95}
96
97pub enum TokenStorage {
98    Memory,
99    Disk(PathBuf),
100    Custom(Box<dyn oauth2::storage::TokenStorage>),
101}
102
103impl From<PathBuf> for TokenStorage {
104    fn from(p: PathBuf) -> Self {
105        Self::Disk(p)
106    }
107}
108
109pub struct DriveProviderSettings {
110    pub corpora: String,
111    pub drive_id: Option<String>,
112}
113
114impl DriveProviderSettings {
115    pub fn new(corpora: String, drive_id: Option<String>) -> Self {
116        Self { corpora, drive_id }
117    }
118}
119pub struct DriveClient {
120    hub: Box<DriveHub<HttpsConnector<HttpConnector>>>,
121    settings: DriveProviderSettings,
122    /// Semaphore for rate limiting
123    semaphore: Semaphore,
124
125    // parent_id <-> file_id
126    covers: DashMap<String, String>,
127}
128
129impl DriveClient {
130    pub async fn new(
131        auth: DriveAuth,
132        settings: DriveProviderSettings,
133        token_storage: impl Into<TokenStorage>,
134    ) -> Result<Self, ProviderError> {
135        let auth = auth.build(token_storage.into()).await?;
136        auth.token(&[
137            "https://www.googleapis.com/auth/drive.metadata.readonly",
138            "https://www.googleapis.com/auth/drive.readonly",
139        ])
140        .await?;
141        let hub = DriveHub::new(
142            hyper::Client::builder().build(
143                HttpsConnectorBuilder::new()
144                    .with_native_roots()
145                    .https_or_http()
146                    .enable_http1()
147                    .enable_http2()
148                    .build(),
149            ),
150            auth,
151        );
152        Ok(Self {
153            hub: Box::new(hub),
154            settings,
155            covers: DashMap::new(),
156            semaphore: Semaphore::new(200),
157        })
158    }
159
160    fn prepare_list(&self) -> FileListCall<HttpsConnector<HttpConnector>> {
161        let result = self
162            .hub
163            .files()
164            .list()
165            .corpora(&self.settings.corpora)
166            .supports_all_drives(true)
167            .include_items_from_all_drives(true)
168            .page_size(500);
169        match &self.settings.drive_id {
170            Some(drive_id) => result.drive_id(drive_id),
171            None => result,
172        }
173    }
174
175    async fn list_folder(&self, parent_id: &str) -> Result<FileList, ProviderError> {
176        let permit = self.semaphore.acquire().await.unwrap();
177        let (_, list) = self.prepare_list()
178            .q(&format!("mimeType = 'application/vnd.google-apps.folder' and trashed = false and '{parent_id}' in parents"))
179            .param("fields", "nextPageToken, files(id,name)")
180            .doit().await?;
181        drop(permit);
182        Ok(list)
183    }
184
185    async fn get_file(
186        &self,
187        file_id: &str,
188        range: &Range,
189    ) -> Result<(ResourceReader, Range), ProviderError> {
190        let permit = self.semaphore.acquire().await.unwrap();
191        let (resp, _) = self
192            .hub
193            .files()
194            .get(file_id)
195            .supports_all_drives(true)
196            .acknowledge_abuse(true)
197            .param("alt", "media")
198            .range(range.to_range_header())
199            .doit()
200            .await?;
201        drop(permit);
202        let content_range = resp
203            .headers()
204            .get("Content-Range")
205            .map(|v| v.to_str().unwrap().to_string());
206        let body = resp
207            .into_body()
208            .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "Error!"))
209            .into_async_read();
210        let body = tokio_util::compat::FuturesAsyncReadCompatExt::compat(body);
211        Ok((
212            Box::pin(body),
213            content_range_to_range(content_range.as_deref()),
214        ))
215    }
216
217    async fn get_cover_id_in(&self, parent_id: &str) -> Result<String, ProviderError> {
218        if self.covers.contains_key(parent_id) {
219            return self
220                .covers
221                .get(parent_id)
222                .map(|v| v.to_string())
223                .ok_or(ProviderError::FileNotFound);
224        }
225
226        let permit = self.semaphore.acquire().await.unwrap();
227        let (_, list) = self.prepare_list()
228            .q(&format!("trashed = false and mimeType = 'image/jpeg' and name = 'cover.jpg' and '{}' in parents", parent_id))
229            .param("fields", "nextPageToken, files(id,name)")
230            .doit().await?;
231        drop(permit);
232
233        let files = list.files.unwrap();
234        let file = files.get(0).ok_or(ProviderError::FileNotFound)?;
235        let id = file.id.as_ref().unwrap().to_string();
236        self.covers.insert(parent_id.to_string(), id.clone());
237        Ok(id)
238    }
239}
240
241pub struct DriveProvider {
242    /// Google Drive API Client
243    client: DriveClient,
244    /// HashMap mapping album_id and folder_id
245    folders: HashMap<String, String>,
246    /// Cache for mapping album_id and its discs if multiple discs exists
247    /// All albums with multiple discs must be in this map
248    /// If the value is None, it means the album is not cached
249    /// If the value is Some, then the value of index is the folder_id of the disc
250    discs: DashMap<String, Option<Vec<String>>>,
251    /// Cache file id
252    /// "{album_id}/cover" <-> file_id
253    /// "{album_id}/{disc_id}/cover" <-> file_id
254    /// "{album_id}/{disc_id}/track_id" <-> file_id
255    files: DashMap<String, String>,
256    /// file_id <-> (extension, filesize)
257    audios: DashMap<String, (String, usize)>,
258
259    // properties
260    strict: bool,
261    repo: Mutex<Option<RepoDatabaseRead>>,
262}
263
264impl DriveProvider {
265    pub async fn new(
266        auth: DriveAuth,
267        settings: DriveProviderSettings,
268        repo: Option<RepoDatabaseRead>,
269        token_storage: impl Into<TokenStorage>,
270    ) -> Result<Self, ProviderError> {
271        let mut this = Self {
272            client: DriveClient::new(auth, settings, token_storage).await?,
273            folders: Default::default(),
274            discs: Default::default(),
275            files: Default::default(),
276            audios: Default::default(),
277            strict: repo.is_none(),
278            repo: Mutex::new(repo),
279        };
280        this.reload().await?;
281        Ok(this)
282    }
283
284    async fn cache_discs(&self, album_id: &str) -> Result<(), ProviderError> {
285        if self.folders.contains_key(album_id)
286            && self.discs.contains_key(album_id)
287            && self.discs.get(album_id).unwrap().is_none()
288        {
289            let list = self.client.list_folder(&self.folders[album_id]).await?;
290            let mut discs: Vec<_> = list
291                .files
292                .unwrap()
293                .iter()
294                .filter_map(|file| {
295                    let file_id = file.id.as_deref().unwrap().to_string();
296                    return if self.strict {
297                        let disc_index: usize = file.name.as_ref().unwrap().parse().ok()?;
298                        Some((disc_index, file_id))
299                    } else {
300                        let DiscFolderInfo { disc_id, .. } =
301                            DiscFolderInfo::from_str(file.name.as_deref().unwrap()).ok()?;
302                        Some((disc_id, file_id))
303                    };
304                })
305                .collect();
306            discs.sort();
307            self.discs.insert(
308                album_id.to_string(),
309                Some(discs.into_iter().map(|(_, id)| id).collect()),
310            );
311        }
312
313        Ok(())
314    }
315
316    fn get_parent_folder(&self, album_id: &str, disc_id: Option<NonZeroU8>) -> Cow<str> {
317        match disc_id {
318            Some(disc_id) => {
319                if self.discs.contains_key(album_id) {
320                    Cow::Owned(
321                        self.discs.get(album_id).unwrap().as_deref().unwrap()
322                            [(disc_id.get() - 1) as usize]
323                            .clone(),
324                    )
325                } else {
326                    Cow::Borrowed(&self.folders[album_id])
327                }
328            }
329            None => Cow::Borrowed(&self.folders[album_id]),
330        }
331    }
332}
333
334#[async_trait]
335impl AnniProvider for DriveProvider {
336    async fn albums(&self) -> Result<HashSet<Cow<str>>, ProviderError> {
337        Ok(self
338            .folders
339            .keys()
340            .map(|a| Cow::Borrowed(a.as_str()))
341            .collect())
342    }
343
344    async fn get_audio(
345        &self,
346        album_id: &str,
347        disc_id: NonZeroU8,
348        track_id: NonZeroU8,
349        range: Range,
350    ) -> Result<AudioResourceReader, ProviderError> {
351        // catalog not found
352        if !self.folders.contains_key(album_id) {
353            return Err(ProviderError::FileNotFound);
354        }
355
356        let key = format!("{album_id}/{disc_id}/{track_id}");
357        if !self.files.contains_key(&key) {
358            // get folder_id
359            self.cache_discs(album_id).await?;
360            let folder_id = self.get_parent_folder(album_id, Some(disc_id));
361
362            // get audio file id
363            let permit = self.client.semaphore.acquire().await.unwrap();
364            let q = if self.strict {
365                format!("trashed = false and name = '{track_id}.flac' and '{folder_id}' in parents")
366            } else {
367                format!("trashed = false and name contains '{track_id:02}.' and '{folder_id}' in parents")
368            };
369            let (_, list) = self
370                .client
371                .prepare_list()
372                .q(&q)
373                .param("fields", "nextPageToken, files(id,name,fileExtension,size)")
374                .doit()
375                .await?;
376            drop(permit);
377
378            let files = list.files.unwrap();
379            let id = if self.strict {
380                Some(files.first().ok_or_else(|| ProviderError::FileNotFound)?)
381            } else {
382                files.iter().reduce(|a, b| {
383                    if a.name
384                        .as_ref()
385                        .unwrap()
386                        .starts_with(&format!("{track_id:02}."))
387                    {
388                        a
389                    } else {
390                        b
391                    }
392                })
393            };
394            if let Some(file) = id {
395                let id = file.id.as_ref().unwrap();
396                self.audios.insert(
397                    id.to_string(),
398                    (
399                        file.file_extension.as_ref().unwrap().to_string(),
400                        usize::from_str(file.size.as_ref().unwrap()).unwrap(),
401                    ),
402                );
403                self.files.insert(key.to_string(), id.to_string());
404            } else {
405                return Err(ProviderError::FileNotFound);
406            }
407        }
408
409        match self.files.get(&key) {
410            Some(id) => {
411                let file_id = id.value().to_string();
412                drop(id); // drop lock immediately
413                let metadata = self.audios.get(&file_id).unwrap().value().clone(); // drop lock inline
414
415                let (reader, range) = self.client.get_file(&file_id, &range).await?;
416                let (duration, reader) = read_duration(reader, range).await?;
417                Ok(AudioResourceReader {
418                    info: AudioInfo {
419                        extension: metadata.0,
420                        size: metadata.1,
421                        duration,
422                    },
423                    range,
424                    reader,
425                })
426            }
427            None => Err(ProviderError::FileNotFound),
428        }
429    }
430
431    async fn get_cover(
432        &self,
433        album_id: &str,
434        disc_id: Option<NonZeroU8>,
435    ) -> Result<ResourceReader, ProviderError> {
436        // album_id not found
437        if !self.folders.contains_key(album_id) ||
438            // disc not found
439            (disc_id.is_some() && disc_id != NonZeroU8::new(1) && !self.discs.contains_key(album_id))
440        {
441            return Err(ProviderError::FileNotFound);
442        }
443
444        let key = match disc_id {
445            Some(disc_id) => format!("{album_id}/{disc_id}/cover"),
446            None => format!("{album_id}/cover"),
447        };
448        let id = match self.files.get(&key) {
449            Some(id) => id.to_string(),
450            None => {
451                // get folder_id
452                self.cache_discs(album_id).await?;
453                let folder_id = self.get_parent_folder(album_id, disc_id);
454
455                // get cover file id
456                self.client.get_cover_id_in(&folder_id).await?
457            }
458        };
459
460        Ok(self.client.get_file(&id, &Range::FULL).await?.0)
461    }
462
463    async fn reload(&mut self) -> Result<(), ProviderError> {
464        self.folders.clear();
465        self.discs.clear();
466        self.files.clear();
467        self.audios.clear();
468
469        if let Some(repo) = &mut *self.repo.lock() {
470            repo.reload()?;
471        }
472
473        let mut page_token = String::new();
474        loop {
475            let permit = self.client.semaphore.acquire().await.unwrap();
476            let (_, list) = self
477                .client
478                .prepare_list()
479                .page_token(&page_token)
480                .q(if self.strict {
481                    "mimeType = 'application/vnd.google-apps.folder' and name != '0' and name != '1' and name != '2' and name != '3' and name != '4' and name != '5' and name != '6' and name != '7' and name != '8' and name != '9' and trashed = false"
482                } else {
483                    "mimeType = 'application/vnd.google-apps.folder' and trashed = false"
484                })
485                .param("fields", "nextPageToken, files(id,name)")
486                .page_size(1000)
487                .doit()
488                .await?;
489            drop(permit);
490            for file in list.files.unwrap() {
491                let name = file.name.unwrap();
492                if self.strict {
493                    if name.len() != 36 {
494                        continue;
495                    }
496                    self.folders.insert(name.to_string(), file.id.unwrap());
497                    self.discs.insert(name, None);
498                } else {
499                    if let Ok(AlbumFolderInfo {
500                        release_date,
501                        catalog,
502                        title,
503                        edition,
504                        disc_count,
505                    }) = AlbumFolderInfo::from_str(&name)
506                    {
507                        let album_id = self.repo.lock().as_ref().unwrap().match_album(
508                            &catalog,
509                            &release_date,
510                            disc_count as u8,
511                            &title,
512                            edition.as_deref(),
513                        )?;
514                        match album_id {
515                            Some(album_id) => {
516                                self.folders.insert(album_id.to_string(), file.id.unwrap());
517                                if disc_count > 1 {
518                                    self.discs.insert(album_id.to_string(), None);
519                                }
520                            }
521                            None => {
522                                log::warn!("Album ID not found for {}, ignoring...", catalog);
523                            }
524                        }
525                    };
526                }
527            }
528            if list.next_page_token.is_none() {
529                break;
530            } else {
531                page_token = list.next_page_token.unwrap();
532            }
533        }
534        Ok(())
535    }
536}