1use std::collections::HashMap;
2use std::mem::MaybeUninit;
3use std::os::unix::prelude::AsRawFd;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use tokio::sync::RwLock;
8
9use super::http;
10use http::Time;
11
12use crate::{ Result, Error };
13use crate::util::{pretty_dump_wrap as pretty, AsInit, CopyInit};
14
15const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
16const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
17const KEEPALIVE: std::time::Duration = std::time::Duration::from_secs(300);
18
19lazy_static::lazy_static!{
20 static ref CACHE: std::sync::RwLock<CacheImpl> = std::sync::RwLock::new(CacheImpl::new());
21}
22
23#[derive(Clone, Copy, Debug, Default)]
24struct Stats {
25 pub tm: Duration,
26}
27
28impl Stats {
29 pub async fn chunk(&mut self, response: &mut reqwest::Response) -> reqwest::Result<Option<bytes::Bytes>>
30 {
31 let start = std::time::Instant::now();
32 let chunk = response.chunk().await;
33 self.tm += start.elapsed();
34
35 chunk
36 }
37}
38
39impl crate::util::PrettyDump for Stats {
40 fn pretty_dump(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.write_fmt(format_args!("{:.3}s", self.tm.as_secs_f32()))
42 }
43}
44
45#[derive(Debug)]
46enum State {
47 None,
48
49 Error(&'static str),
50
51 Init {
52 response: reqwest::Response,
53 },
54
55 HaveMeta {
56 response: reqwest::Response,
57 cache_info: http::CacheInfo,
58 file_size: Option<u64>,
59 stats: Stats,
60 },
61
62 Downloading {
63 response: reqwest::Response,
64 cache_info: http::CacheInfo,
65 file_size: Option<u64>,
66 file: std::fs::File,
67 file_pos: u64,
68 stats: Stats,
69 },
70
71 Complete {
72 cache_info: http::CacheInfo,
73 file: std::fs::File,
74 file_size: u64,
75 },
76
77 Refresh {
78 response: reqwest::Response,
79 cache_info: http::CacheInfo,
80 file: std::fs::File,
81 file_size: u64,
82 },
83}
84
85impl crate::util::PrettyDump for State {
86 fn pretty_dump(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 State::None =>
89 f.write_str("no state"),
90 State::Error(e) =>
91 f.write_fmt(format_args!("error {e:?}")),
92 State::Init { response } =>
93 f.write_fmt(format_args!("INIT({})", pretty(response))),
94 State::HaveMeta { response, cache_info, file_size, stats } =>
95 f.write_fmt(format_args!("META({}, {}, {}, {})",
96 pretty(response), pretty(cache_info),
97 pretty(file_size), pretty(stats))),
98 State::Downloading { response, cache_info, file_size, file, file_pos, stats } =>
99 f.write_fmt(format_args!("DOWNLOADING({}, {}, {}, {}@{}, {})",
100 pretty(response), pretty(cache_info),
101 pretty(file_size), pretty(file),
102 file_pos, pretty(stats))),
103 State::Complete { cache_info, file, file_size } =>
104 f.write_fmt(format_args!("COMPLETE({}, {}/{})",
105 pretty(cache_info), pretty(file), file_size)),
106 State::Refresh { response, cache_info, file, file_size } =>
107 f.write_fmt(format_args!("REFRESH({},.{}, [{}/{}])",
108 pretty(response), pretty(cache_info), pretty(file),
109 file_size)),
110 }
111 }
112}
113
114impl State {
115 pub fn take(&mut self, hint: &'static str) -> Self {
116 std::mem::replace(self, State::Error(hint))
117 }
118
119 pub fn is_none(&self) -> bool {
120 matches!(self, Self::None)
121 }
122
123 pub fn is_init(&self) -> bool {
124 matches!(self, Self::Init { .. })
125 }
126
127 pub fn is_error(&self) -> bool {
128 matches!(self, Self::Error(_))
129 }
130
131 pub fn is_refresh(&self) -> bool {
132 matches!(self, Self::Refresh { .. })
133 }
134
135 pub fn is_have_meta(&self) -> bool {
136 matches!(self, Self::HaveMeta { .. })
137 }
138
139 pub fn is_downloading(&self) -> bool {
140 matches!(self, Self::Downloading { .. })
141 }
142
143 pub fn is_complete(&self) -> bool {
144 matches!(self, Self::Complete { .. })
145 }
146
147 pub fn get_file_size(&self) -> Option<u64> {
148 match self {
149 Self::None |
150 Self::Init { .. } => None,
151
152 Self::HaveMeta { file_size, .. } => *file_size,
153 Self::Downloading { file_size, .. } => *file_size,
154
155 Self::Complete { file_size, .. } => Some(*file_size),
156 Self::Refresh { file_size, .. } => Some(*file_size),
157
158 Self::Error(hint) => panic!("get_file_size called in error state ({hint})"),
159 }
160 }
161
162 pub fn get_cache_info(&self) -> Option<&http::CacheInfo> {
163 match self {
164 Self::None |
165 Self::Error(_) |
166 Self::Init { .. } => None,
167
168 Self::HaveMeta { cache_info, .. } |
169 Self::Downloading { cache_info, .. } |
170 Self::Complete { cache_info, .. } |
171 Self::Refresh { cache_info, .. } => Some(cache_info),
172 }
173 }
174
175 fn read_file<'a>(file: &std::fs::File, ofs: u64, buf: &'a mut [MaybeUninit<u8>], max: u64) -> Result<&'a [u8]> {
176 use nix::libc;
177
178 assert!(max > ofs);
179
180 let len = (buf.len() as u64).min(max - ofs) as usize;
181 let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
182
183 let rc = unsafe { libc::pread(file.as_raw_fd(), buf_ptr, len, ofs as i64) };
188
189 if rc < 0 {
190 return Err(std::io::Error::last_os_error().into());
191 }
192
193 assert_eq!(rc as usize, len);
194
195 Ok(unsafe { buf[..rc as usize].assume_init() })
196 }
197
198 pub fn read<'a>(&self, ofs: u64, buf: &'a mut [MaybeUninit<u8>]) -> Result<Option<&'a [u8]>> {
199 match &self {
200 State::Downloading { file, file_pos, .. } if ofs < *file_pos => {
201 Self::read_file(file, ofs, buf, *file_pos)
202 },
203
204 State::Complete { file, file_size, .. } if ofs < *file_size => {
205 Self::read_file(file, ofs, buf, *file_size)
206 }
207
208 State::Complete { file_size, .. } if ofs == *file_size => Ok(&[] as &[u8]),
209
210 State::Complete { file_size, .. } if ofs >= *file_size =>
211 Err(Error::Internal("file out-of-bound read")),
212
213 _ => return Ok(None)
214 }.map(Some)
215 }
216
217 pub fn is_outdated(&self, reftm: Instant, max_lt: Duration) -> bool {
218 match self.get_cache_info() {
219 None => true,
220 Some(info) => info.is_outdated(reftm, max_lt),
221 }
222 }
223}
224
225#[derive(Debug)]
226pub struct EntryData {
227 pub key: url::Url,
228 state: State,
229 reftm: Time,
230}
231
232impl std::fmt::Display for EntryData {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 f.write_fmt(format_args!("{}: reftm={}, state={}", self.key,
235 pretty(&self.reftm.local), pretty(&self.state)))
236 }
237}
238
239impl EntryData {
240 pub fn new(url: &url::Url) -> Self {
241 Self {
242 key: url.clone(),
243 state: State::None,
244 reftm: Time::now(),
245 }
246 }
247
248 pub fn is_complete(&self) -> bool {
249 self.state.is_complete()
250 }
251
252 pub fn is_error(&self) -> bool {
253 self.state.is_error()
254 }
255
256 pub fn is_running(&self) -> bool {
257 self.state.is_have_meta() || self.state.is_downloading()
258 }
259
260 pub fn update_localtm(&mut self) {
261 self.reftm = Time::now();
262 }
263
264 pub fn set_response(&mut self, response: reqwest::Response) {
265 self.state = match self.state.take("set_respone") {
266 State::None |
267 State::Error(_) => State::Init { response },
268
269 State::Complete { cache_info, file, file_size } |
270 State::Refresh { cache_info, file, file_size, .. } => State::Refresh {
271 cache_info: cache_info,
272 file: file,
273 file_size: file_size,
274 response: response,
275 },
276
277 s => panic!("unexpected state {s:?}"),
278 }
279 }
280
281 pub fn is_outdated(&self, reftm: Instant, max_lt: Duration) -> bool {
282 self.state.is_outdated(reftm, max_lt)
283 }
284
285 pub fn get_cache_info(&self) -> Option<&http::CacheInfo> {
286 self.state.get_cache_info()
287 }
288
289 pub async fn fill_meta(&mut self) -> Result<()> {
290 if !self.state.is_init() && !self.state.is_none() && !self.state.is_refresh() {
291 return Ok(());
292 }
293
294 self.state = match self.state.take("fill_meta") {
295 State::None => panic!("unexpected state"),
296
297 State::Init{ response } => {
298 let hdrs = response.headers();
299
300 State::HaveMeta {
301 cache_info: http::CacheInfo::new(self.reftm, hdrs)?,
302 file_size: response.content_length(),
303 response: response,
304 stats: Stats::default(),
305 }
306 },
307
308 State::Refresh { file, file_size, response, cache_info } => {
309 let hdrs = response.headers();
310
311 State::Complete {
312 cache_info: cache_info.update(self.reftm, hdrs)?,
313 file: file,
314 file_size: file_size,
315 }
316 },
317
318 _ => unreachable!(),
319 };
320
321 Ok(())
322 }
323
324 fn signal_complete(&self, stats: Stats) {
325 if let State::Complete { file_size, .. } = self.state {
326 info!("downloaded {} with {} bytes in {}ms", self.key, file_size, stats.tm.as_millis());
327 }
328 }
329
330 #[instrument(level = "trace")]
331 pub async fn get_filesize(&mut self) -> Result<u64> {
332 use std::io::Write;
333
334 if let Some(sz) = self.state.get_file_size() {
335 return Ok(sz);
336 }
337
338 match self.state.take("get_filesize") {
339 State::HaveMeta { mut response, file_size: None, mut stats, cache_info } => {
340 let mut file = Cache::new_file()?;
341 let mut pos = 0;
342
343
344 while let Some(chunk) = stats.chunk(&mut response).await? {
345 pos += chunk.len() as u64;
346 file.write_all(&chunk)?;
347 }
348
349 self.state = State::Complete {
350 file: file,
351 file_size: pos,
352 cache_info: cache_info,
353 };
354
355 self.signal_complete(stats);
356
357 Ok(pos)
358 },
359
360 State::Downloading { mut response, mut file, file_pos, file_size: None, mut stats, cache_info } => {
361 let mut pos = file_pos;
362
363 while let Some(chunk) = stats.chunk(&mut response).await? {
364 pos += chunk.len() as u64;
365 file.write_all(&chunk)?;
366 }
367
368 self.state = State::Complete {
369 file: file,
370 file_size: pos,
371 cache_info: cache_info,
372 };
373
374 self.signal_complete(stats);
375
376 Ok(pos)
377 }
378
379 s => panic!("unexpected state: {s:?}"),
380 }
381 }
382
383 pub fn fill_request(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
384 match self.state.get_cache_info() {
385 Some(info) => info.fill_request(self.reftm.mono, req),
386 None => req,
387 }
388 }
389
390 pub fn matches(&self, etag: Option<&str>) -> bool {
391 let cache_info = self.state.get_cache_info();
392
393 match cache_info.and_then(|c| c.not_after) {
394 Some(t) if t < Time::now().mono => return false,
395 _ => {},
396 }
397
398 let self_etag = match cache_info {
399 Some(c) => c.etag.as_ref(),
400 None => None,
401 };
402
403 match (self_etag, etag) {
404 (Some(a), Some(b)) if a == b => {},
405 (None, None) => {},
406 _ => return false,
407 }
408
409 true
410 }
411
412 pub fn invalidate(&mut self)
413 {
414 match &self.state {
415 State::Refresh { .. } => self.state = State::None,
416 State::Complete { .. } => self.state = State::None,
417 _ => {},
418 }
419 }
420
421 pub async fn read_some<'a>(&mut self, ofs: u64, buf: &'a mut [MaybeUninit<u8>]) -> Result<&'a [u8]>
422 {
423 use std::io::Write;
424
425 trace!("state={:?}, ofs={}, #buf={}", self.state, ofs, buf.len());
426
427 async fn fetch<'a>(response: &mut reqwest::Response, file: &mut std::fs::File,
428 buf: &'a mut [MaybeUninit<u8>], stats: &mut Stats) -> Result<(&'a[u8], usize)> {
429 match stats.chunk(response).await? {
430 Some(data) => {
431 let len = buf.len().min(data.len());
432 let buf = buf[0..len].write_copy_of_slice_x(&data.as_ref()[0..len]);
433
434 file.write_all(&data)?;
435
436 file.flush()?;
438
439 Ok((buf, data.len()))
440 },
441
442 None => Ok((&[], 0))
443 }
444 }
445
446 if self.state.is_init() {
447 self.fill_meta().await?;
448 }
449
450 if let Some(data) = self.state.read(ofs, buf)? {
451 let data = {
453 let len = data.len();
454 unsafe { buf[..len].assume_init() }
455 };
456
457 return Ok(data);
458 }
459
460 match self.state.take("read_some") {
461 State::HaveMeta { mut response, cache_info, file_size, mut stats } => {
462 let mut file = Cache::new_file()?;
463
464 let res = fetch(&mut response, &mut file, buf, &mut stats).await?;
465
466 self.state = match res {
467 (_, 0) => State::Complete {
468 cache_info: cache_info,
469 file: file,
470 file_size: 0,
471 },
472
473 (_, sz) => State::Downloading {
474 response: response,
475 cache_info: cache_info,
476 file_size: file_size,
477 file: file,
478 file_pos: sz as u64,
479 stats: stats,
480 }
481 };
482
483 self.signal_complete(stats);
484
485 Ok(res.0)
486 },
487
488 State::Downloading { file_pos, .. } if ofs < file_pos => unreachable!(),
490
491 State::Downloading { mut response, cache_info, file_size, mut file, file_pos, mut stats } => {
492 let res = fetch(&mut response, &mut file, buf, &mut stats).await?;
493
494 self.state = match res {
495 (_, 0) => State::Complete {
496 cache_info: cache_info,
497 file: file,
498 file_size: file_pos,
499 },
500
501 (_, sz) => State::Downloading {
502 response: response,
503 cache_info: cache_info,
504 file_size: file_size,
505 file: file,
506 file_pos: file_pos + (sz as u64),
507 stats: stats,
508 }
509 };
510
511 self.signal_complete(stats);
512
513 Ok(res.0)
514 }
515
516 s => panic!("unexpected state: {s:?}"),
517 }
518
519 }
520}
521
522pub type Entry = Arc<RwLock<EntryData>>;
523
524struct CacheImpl {
525 tmpdir: std::path::PathBuf,
526 entries: HashMap<url::Url, Entry>,
527 client: Arc<reqwest::Client>,
528 refcnt: u32,
529
530 abort_ch: Option<tokio::sync::watch::Sender<()>>,
531 gc: Option<tokio::task::JoinHandle<()>>,
532}
533
534pub enum LookupResult {
535 Found(Entry),
536 Missing,
537}
538
539impl CacheImpl {
540 fn new() -> Self {
541 let client = reqwest::Client::builder()
542 .read_timeout(READ_TIMEOUT)
543 .connect_timeout(CONNECT_TIMEOUT)
544 .tcp_keepalive(KEEPALIVE)
545 .build()
546 .unwrap();
547
548 Self {
549 tmpdir: std::env::temp_dir(),
550 entries: HashMap::new(),
551 client: Arc::new(client),
552 abort_ch: None,
553 refcnt: 0,
554 gc: None,
555 }
556 }
557
558 pub fn is_empty(&self) -> bool {
559 self.entries.is_empty()
560 }
561
562 pub fn clear(&mut self) {
563 self.entries.clear();
564 }
565
566 pub fn get_client(&self) -> Arc<reqwest::Client> {
567 self.client.clone()
568 }
569
570 pub fn lookup_or_create(&mut self, key: &url::Url) -> Entry {
571 match self.entries.get(key) {
572 Some(v) => v.clone(),
573 None => self.create(key),
574 }
575 }
576
577 pub fn create(&mut self, key: &url::Url) -> Entry {
578 Entry::new(RwLock::new(EntryData::new(key)))
579 }
580
581 pub fn replace(&mut self, key: &url::Url, entry: &Entry) {
582 self.entries.insert(key.clone(), entry.clone());
583 }
584
585 pub fn remove(&mut self, key: &url::Url) {
586 self.entries.remove(key);
587 }
588
589 pub fn gc_oldest(&mut self, mut num: usize) {
591 if num == 0 {
592 return;
593 }
594
595 let mut tmp = Vec::with_capacity(self.entries.len());
596
597 for (key, e) in &self.entries {
598 let entry = match e.try_read() {
599 Ok(e) => e,
600 _ => continue,
601 };
602
603 tmp.push((key.clone(), entry.get_cache_info().map(|c| c.local_time)));
604 }
605
606 tmp.sort_by(|(_, tm_a), (_, tm_b)| tm_a.cmp(tm_b));
607
608 let mut rm_cnt = 0;
609
610 for (key, _) in tmp {
611 if num == 0 {
612 break;
613 }
614
615 debug!("gc: removing old {}", key);
616 self.entries.remove(&key);
617 num -= 1;
618 rm_cnt += 1;
619 }
620
621 if rm_cnt > 0 {
622 info!("gc: removed {} old entries", rm_cnt);
623 }
624 }
625
626 pub fn gc_outdated(&mut self, max_lt: Duration) -> usize {
630 let now = Instant::now();
631 let mut outdated = Vec::new();
632 let mut cnt = 0;
633
634 for (key, e) in &self.entries {
635 match e.try_read().map(|v| v.is_outdated(now, max_lt)) {
636 Ok(true) => outdated.push(key.clone()),
637 _ => cnt += 1,
638 }
639 }
640
641 let rm_cnt = outdated.len();
642
643 for e in outdated {
644 debug!("gc: removing outdated {}", e);
645 self.entries.remove(&e);
646 }
647
648 if rm_cnt > 0 {
649 info!("gc: removed {} obsolete entries", rm_cnt);
650 }
651
652 cnt
653 }
654}
655
656#[derive(Debug)]
657pub struct GcProperties {
658 pub max_elements: usize,
659 pub max_lifetime: Duration,
660 pub sleep: Duration,
661}
662
663async fn gc_runner(props: GcProperties, mut abort_ch: tokio::sync::watch::Receiver<()>) {
664 const RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(1);
665
666 loop {
667 use std::sync::TryLockError;
668
669 let sleep = {
670 let cache = CACHE.try_write();
671
672 match cache {
673 Ok(mut cache) if !cache.is_empty() => {
674 let cache_cnt = cache.gc_outdated(props.max_lifetime);
675
676 if cache_cnt > props.max_elements {
677 cache.gc_oldest(props.max_elements - cache_cnt)
678 }
679
680 props.sleep
681 }
682 Ok(_) => props.sleep,
683 Err(TryLockError::WouldBlock) => RETRY_DELAY,
684 Err(e) => {
685 error!("cache gc failed with {:?}", e);
686 break;
687 }
688 }
689 };
690
691 if tokio::time::timeout(sleep, abort_ch.changed()).await.is_ok() {
692 debug!("cache gc runner gracefully closed");
693 break;
694 }
695 }
696}
697
698pub struct Cache();
699
700impl Cache {
701 #[instrument(level = "trace")]
702 pub fn instanciate(tmpdir: &std::path::Path, props: GcProperties) {
703 let mut cache = CACHE.write().unwrap();
704
705 if cache.refcnt == 0 {
706 let (tx, rx) = tokio::sync::watch::channel(());
707
708 cache.tmpdir = tmpdir.into();
709 cache.abort_ch = Some(tx);
710
711 cache.gc = Some(tokio::task::spawn(gc_runner(props, rx)));
712 }
713
714 cache.refcnt += 1;
715 }
716
717 #[instrument(level = "trace")]
718 #[allow(clippy::await_holding_lock)]
720 pub async fn close() {
721 let mut cache = CACHE.write().unwrap();
722
723 assert!(cache.refcnt > 0);
724
725 cache.refcnt -= 1;
726
727 if cache.refcnt == 0 {
728 cache.entries.clear();
729
730 let abort_ch = cache.abort_ch.take().unwrap();
731 let gc = cache.gc.take().unwrap();
732
733 drop(cache);
734
735 abort_ch.send(()).unwrap();
736 gc.await.unwrap();
737 }
738 }
739
740 #[instrument(level = "trace", ret)]
741 pub fn lookup_or_create(key: &url::Url) -> Entry {
742 let mut cache = CACHE.write().unwrap();
743
744 cache.lookup_or_create(key)
745 }
746
747 #[instrument(level = "trace", ret)]
748 pub fn create(key: &url::Url) -> Entry {
749 let mut cache = CACHE.write().unwrap();
750
751 cache.create(key)
752 }
753
754 #[instrument(level = "trace")]
755 pub fn replace(key: &url::Url, entry: &Entry) {
756 let mut cache = CACHE.write().unwrap();
757
758 cache.replace(key, entry)
759 }
760
761 #[instrument(level = "trace")]
762 pub fn remove(key: &url::Url) {
763 let mut cache = CACHE.write().unwrap();
764
765 cache.remove(key)
766 }
767
768 pub fn get_client() -> Arc<reqwest::Client> {
769 let cache = CACHE.read().unwrap();
770
771 cache.get_client()
772 }
773
774 pub fn new_file() -> Result<std::fs::File> {
775 let cache = CACHE.read().unwrap();
776
777 Ok(tempfile::tempfile_in(&cache.tmpdir)?)
778 }
779
780 pub async fn dump() {
781 let mut entries = Vec::new();
782
783 {
784 let cache = CACHE.read().unwrap();
785
786 entries.reserve(cache.entries.len());
787 entries.extend(cache.entries.values().cloned());
788 }
789
790 println!("Cache information ({} entries)", entries.len());
791
792 for e in entries {
793 println!("{}", e.read().await);
794 }
795 }
796
797 pub async fn clear() {
798 let mut cache = CACHE.write().unwrap();
799
800 cache.clear();
801 }
802}