Skip to main content

anni_provider/
cache.rs

1use crate::{AnniProvider, AudioInfo, AudioResourceReader, ProviderError, Range, ResourceReader};
2use anni_common::models::{RawTrackIdentifier, TrackIdentifier};
3use async_trait::async_trait;
4use dashmap::DashMap;
5use lru::LruCache;
6use parking_lot::RwLock;
7use std::borrow::{Borrow, Cow};
8use std::collections::HashSet;
9use std::future::Future;
10use std::num::NonZeroU8;
11use std::path::{Path, PathBuf};
12use std::pin::Pin;
13use std::sync::Arc;
14use std::task::{Context, Poll};
15use tokio::fs::File;
16use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf};
17use tokio::sync::Mutex;
18use tokio::time::Duration;
19
20pub struct CacheProvider<T>
21where
22    T: AnniProvider + Send,
23{
24    inner: T,
25    pool: Arc<CachePool>,
26}
27
28impl<T> CacheProvider<T>
29where
30    T: AnniProvider + Send,
31{
32    pub fn new(inner: T, pool: Arc<CachePool>) -> Self {
33        Self { inner, pool }
34    }
35
36    pub async fn invalidate(&self, album_id: &str, disc_id: NonZeroU8, track_id: NonZeroU8) {
37        let key = RawTrackIdentifier::new(album_id, disc_id, track_id);
38        self.pool.remove(&key).await;
39    }
40}
41
42#[async_trait]
43impl<T> AnniProvider for CacheProvider<T>
44where
45    T: AnniProvider + Send,
46{
47    async fn albums(&self) -> Result<HashSet<Cow<str>>, ProviderError> {
48        self.inner.albums().await
49    }
50
51    async fn get_audio_info(
52        &self,
53        album_id: &str,
54        disc_id: NonZeroU8,
55        track_id: NonZeroU8,
56    ) -> Result<AudioInfo, ProviderError> {
57        if let Some(info) = self
58            .pool
59            .get_cached_audio_info(album_id, disc_id, track_id)
60            .await
61        {
62            Ok(info)
63        } else {
64            self.inner.get_audio_info(album_id, disc_id, track_id).await
65        }
66    }
67
68    async fn get_audio(
69        &self,
70        album_id: &str,
71        disc_id: NonZeroU8,
72        track_id: NonZeroU8,
73        range: Range,
74    ) -> Result<AudioResourceReader, ProviderError> {
75        self.pool
76            .fetch_audio(
77                album_id,
78                disc_id,
79                track_id,
80                range,
81                self.inner.get_audio(
82                    album_id,
83                    disc_id,
84                    track_id,
85                    Range::FULL, // cache does not pass range to the underlying provider
86                ),
87            )
88            .await
89    }
90
91    async fn get_cover(
92        &self,
93        album_id: &str,
94        disc_id: Option<NonZeroU8>,
95    ) -> Result<ResourceReader, ProviderError> {
96        // TODO: cache cover
97        self.inner.get_cover(album_id, disc_id).await
98    }
99
100    async fn reload(&mut self) -> Result<(), ProviderError> {
101        // reload the inner provider
102        self.inner.reload().await
103    }
104}
105
106pub struct CachePool {
107    /// Root of cache folder
108    root: PathBuf,
109    /// Maximum space used by cache
110    max_size: Option<usize>,
111    cache: DashMap<TrackIdentifier, Arc<CacheItem>>,
112    // https://github.com/xacrimon/dashmap/issues/189
113    // TODO: Use LFU instead of LRU
114    last_used: Mutex<LruCache<TrackIdentifier, Arc<Mutex<u8>>>>,
115}
116
117impl CachePool {
118    pub fn new<P>(root: P, max_size: Option<usize>) -> Self
119    where
120        P: AsRef<Path>,
121    {
122        Self {
123            root: PathBuf::from(root.as_ref()),
124            max_size,
125            cache: Default::default(),
126            last_used: Mutex::new(LruCache::unbounded()),
127        }
128    }
129
130    async fn fetch_audio(
131        &self,
132        album_id: &str,
133        disc_id: NonZeroU8,
134        track_id: NonZeroU8,
135        range: Range,
136        on_miss: impl Future<Output = Result<AudioResourceReader, ProviderError>>,
137    ) -> Result<AudioResourceReader, ProviderError> {
138        let key = RawTrackIdentifier::new(album_id, disc_id, track_id);
139        let item = if !self.has_cache(album_id, disc_id, track_id).await {
140            // on miss, set state to cached first
141            let mutex = Arc::new(Mutex::new(0));
142            let handle = mutex.clone().lock_owned().await;
143            // TODO: The following procedure may fail, but the entry would still be added to `last_used`
144            //       We should remove it from `last_used` and fail all correspoding requests.
145            self.last_used.lock().await.put(key.to_owned(), mutex);
146
147            // get data, return directly if it's a partial request
148            let result = on_miss.await?;
149
150            // prepare for new item
151            let mut path = self.root.join(key.album_id.as_ref());
152            tokio::fs::create_dir_all(&path).await?;
153            path.push(format!("{}_{}", key.disc_id.get(), key.track_id.get()));
154            let mut file = File::create(&path).await?;
155
156            let AudioResourceReader {
157                info, mut reader, ..
158            } = result;
159            let item = Arc::new(CacheItem::new(path, info, false));
160
161            // remove old item if space is full
162            if let Some(max_size) = self.max_size {
163                if self.space_used() > max_size {
164                    // get the first item of BTreeMap
165                    let mut write = self.last_used.lock().await;
166                    let key = write.pop_lru().unwrap();
167                    // remove it from cache map
168                    // drop would do the removal
169                    self.remove(&key.0.borrow()).await;
170                }
171            }
172
173            // write to map
174            self.cache.insert(key.to_owned(), item.clone());
175            // item is set to cached, release lock
176            drop(handle);
177
178            // cache
179            let item_spawn = item.clone();
180            tokio::spawn(async move {
181                let actual_size = tokio::io::copy(&mut reader, &mut file).await.unwrap() as usize;
182                if item_spawn.size() != actual_size {
183                    // TODO: should not happen, throw error here
184                    item_spawn.set_size(actual_size);
185                }
186                item_spawn.set_cached(true);
187            });
188            item
189        } else {
190            // resource requested, but not added to cache map yet
191            if !self.cache.contains_key(&key) {
192                // await cache mutex
193                let mutex = {
194                    let mut map = self.last_used.lock().await;
195                    map.get(&key).unwrap().clone()
196                };
197                let _ = mutex.lock().await;
198            }
199            // update last_used time
200            self.last_used.lock().await.get(&key).unwrap();
201            self.cache.get(&key).unwrap().clone()
202        };
203
204        Ok(item
205            .to_audio_resource_reader(File::open(&item.path).await?, range)
206            .await)
207    }
208
209    async fn remove<'a>(&self, key: &RawTrackIdentifier<'a>) {
210        self.cache.remove(key).map(|r| r.1.set_cached(false));
211        self.last_used.lock().await.pop(key);
212    }
213
214    async fn get_cached_audio_info(
215        &self,
216        album_id: &str,
217        disc_id: NonZeroU8,
218        track_id: NonZeroU8,
219    ) -> Option<AudioInfo> {
220        self.cache
221            .get(&RawTrackIdentifier::new(album_id, disc_id, track_id))
222            .and_then(|r| {
223                if *r.cached.read() {
224                    Some(AudioInfo {
225                        extension: r.ext.clone(),
226                        size: *r.size.read(),
227                        duration: r.duration,
228                    })
229                } else {
230                    None
231                }
232            })
233    }
234
235    async fn has_cache(&self, album_id: &str, disc_id: NonZeroU8, track_id: NonZeroU8) -> bool {
236        self.last_used
237            .lock()
238            .await
239            .contains(&RawTrackIdentifier::new(album_id, disc_id, track_id))
240    }
241
242    fn space_used(&self) -> usize {
243        self.cache
244            .iter()
245            .map(|i| i.size())
246            .reduce(|a, b| a + b)
247            .unwrap_or(0)
248    }
249}
250
251struct CacheItem {
252    ext: String,
253    path: PathBuf,
254    size: RwLock<usize>,
255    duration: u64,
256    cached: RwLock<bool>,
257}
258
259impl CacheItem {
260    fn new(path: PathBuf, info: AudioInfo, cached: bool) -> Self {
261        let AudioInfo {
262            extension: ext,
263            duration,
264            size,
265        } = info;
266        CacheItem {
267            path,
268            ext,
269            size: RwLock::new(size),
270            duration,
271            cached: RwLock::new(cached),
272        }
273    }
274
275    fn size(&self) -> usize {
276        *self.size.read()
277    }
278
279    fn set_size(&self, size: usize) {
280        *self.size.write() = size;
281    }
282
283    fn cached(&self) -> bool {
284        *self.cached.read()
285    }
286
287    fn set_cached(&self, cached: bool) {
288        *self.cached.write() = cached
289    }
290}
291
292#[async_trait::async_trait]
293trait CacheReader {
294    fn to_reader(&self, file: File) -> CacheItemReader;
295
296    async fn to_audio_resource_reader(&self, file: File, range: Range) -> AudioResourceReader;
297}
298
299#[async_trait::async_trait]
300impl CacheReader for Arc<CacheItem> {
301    fn to_reader(&self, file: File) -> CacheItemReader {
302        CacheItemReader {
303            item: self.clone(),
304            file: Box::pin(file),
305            filled: 0,
306            timer: None,
307        }
308    }
309
310    async fn to_audio_resource_reader(&self, file: File, range: Range) -> AudioResourceReader {
311        let mut reader = self.to_reader(file);
312        if range.start > 0 {
313            let reader = &mut reader;
314            let _ = tokio::io::copy(&mut reader.take(range.start), &mut tokio::io::sink()).await;
315        }
316        let length = range.length();
317        let reader: ResourceReader = match length {
318            Some(length) => Box::pin(reader.take(length)),
319            None => Box::pin(reader),
320        };
321
322        AudioResourceReader {
323            info: AudioInfo {
324                extension: self.ext.clone(),
325                size: self.size(),
326                duration: self.duration,
327            },
328            range,
329            reader,
330        }
331    }
332}
333
334impl Drop for CacheItem {
335    fn drop(&mut self) {
336        // not cached, means:
337        // a. file not fully cached and program reaches program termination
338        // b. manually set cached to false
339        if !self.cached() {
340            if let Err(e) = std::fs::remove_file(&self.path) {
341                log::error!("Failed to drop CacheItem: {}", e);
342            }
343        }
344    }
345}
346
347struct CacheItemReader {
348    item: Arc<CacheItem>,
349    file: Pin<Box<File>>,
350    filled: usize,
351
352    timer: Option<Pin<Box<dyn Future<Output = ()> + Send>>>,
353}
354
355impl AsyncRead for CacheItemReader {
356    fn poll_read(
357        mut self: Pin<&mut Self>,
358        cx: &mut Context<'_>,
359        buf: &mut ReadBuf<'_>,
360    ) -> Poll<std::io::Result<()>> {
361        // Wait mode
362        if self.timer.is_some() {
363            let task = self.timer.as_mut().unwrap();
364            // poll the saved timer
365            let result = task.as_mut().poll(cx);
366            match result {
367                // timer ready, stop waiting
368                Poll::Ready(_) => self.timer = None,
369                // timer pending, wait
370                Poll::Pending => return Poll::Pending,
371            }
372        }
373
374        // Read mode
375        // save filled buf length before poll_read
376        let before = buf.filled().len();
377        let result = self.file.as_mut().poll_read(cx, buf);
378        match result {
379            Poll::Ready(result) => {
380                match result {
381                    Ok(_) => {
382                        let now = buf.filled().len();
383                        if before != now {
384                            self.filled += now - before;
385                            Poll::Ready(Ok(()))
386                        } else if self.item.cached() {
387                            if self.filled != self.item.size() {
388                                // caching finished just now
389                                // wake immediately to finish the last part
390                                cx.waker().wake_by_ref();
391                                Poll::Pending
392                            } else {
393                                // EOF
394                                Poll::Ready(Ok(()))
395                            }
396                        } else {
397                            // not done, wait for more data
398                            // set up timer to wait
399                            self.timer =
400                                Some(Box::pin(tokio::time::sleep(Duration::from_millis(100))));
401                            // wait immediately to poll the timer
402                            cx.waker().wake_by_ref();
403                            Poll::Pending
404                        }
405                    }
406                    // poll error
407                    Err(e) => Poll::Ready(Err(e)),
408                }
409            }
410            // wait
411            Poll::Pending => Poll::Pending,
412        }
413    }
414}