r_tftpd_proxy/
cache.rs

1use std::collections::HashMap;
2use std::mem::MaybeUninit;
3use std::os::unix::prelude::AsRawFd;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use tokio::sync::RwLock;
8
9use super::http;
10use http::Time;
11
12use crate::{ Result, Error };
13use crate::util::{pretty_dump_wrap as pretty, AsInit, CopyInit};
14
15const READ_TIMEOUT:    std::time::Duration = std::time::Duration::from_secs(30);
16const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
17const KEEPALIVE:       std::time::Duration = std::time::Duration::from_secs(300);
18
19lazy_static::lazy_static!{
20    static ref CACHE: std::sync::RwLock<CacheImpl> = std::sync::RwLock::new(CacheImpl::new());
21}
22
23#[derive(Clone, Copy, Debug, Default)]
24struct Stats {
25    pub tm:		Duration,
26}
27
28impl Stats {
29    pub async fn chunk(&mut self, response: &mut reqwest::Response) -> reqwest::Result<Option<bytes::Bytes>>
30    {
31	let start = std::time::Instant::now();
32	let chunk = response.chunk().await;
33	self.tm += start.elapsed();
34
35	chunk
36    }
37}
38
39impl crate::util::PrettyDump for Stats {
40    fn pretty_dump(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.write_fmt(format_args!("{:.3}s", self.tm.as_secs_f32()))
42    }
43}
44
45#[derive(Debug)]
46enum State {
47    None,
48
49    Error(&'static str),
50
51    Init {
52	response:	reqwest::Response,
53    },
54
55    HaveMeta {
56	response:	reqwest::Response,
57	cache_info:	http::CacheInfo,
58	file_size:	Option<u64>,
59	stats:		Stats,
60    },
61
62    Downloading {
63	response:	reqwest::Response,
64	cache_info:	http::CacheInfo,
65	file_size:	Option<u64>,
66	file:		std::fs::File,
67	file_pos:	u64,
68	stats:		Stats,
69    },
70
71    Complete {
72	cache_info:	http::CacheInfo,
73	file:		std::fs::File,
74	file_size:	u64,
75    },
76
77    Refresh {
78	response:	reqwest::Response,
79	cache_info:	http::CacheInfo,
80	file:		std::fs::File,
81	file_size:	u64,
82    },
83}
84
85impl crate::util::PrettyDump for State {
86    fn pretty_dump(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            State::None =>
89		f.write_str("no state"),
90            State::Error(e) =>
91		f.write_fmt(format_args!("error {e:?}")),
92            State::Init { response } =>
93		f.write_fmt(format_args!("INIT({})", pretty(response))),
94            State::HaveMeta { response, cache_info, file_size, stats } =>
95		f.write_fmt(format_args!("META({}, {}, {}, {})",
96					 pretty(response), pretty(cache_info),
97					 pretty(file_size), pretty(stats))),
98            State::Downloading { response, cache_info, file_size, file, file_pos, stats } =>
99		f.write_fmt(format_args!("DOWNLOADING({}, {}, {}, {}@{}, {})",
100					 pretty(response), pretty(cache_info),
101					 pretty(file_size), pretty(file),
102					 file_pos, pretty(stats))),
103            State::Complete { cache_info, file, file_size } =>
104		f.write_fmt(format_args!("COMPLETE({}, {}/{})",
105					 pretty(cache_info), pretty(file), file_size)),
106            State::Refresh { response, cache_info, file, file_size } =>
107		f.write_fmt(format_args!("REFRESH({},.{}, [{}/{}])",
108					 pretty(response), pretty(cache_info), pretty(file),
109					 file_size)),
110        }
111    }
112}
113
114impl State {
115    pub fn take(&mut self, hint: &'static str) -> Self {
116	std::mem::replace(self, State::Error(hint))
117    }
118
119    pub fn is_none(&self) -> bool {
120	matches!(self, Self::None)
121    }
122
123    pub fn is_init(&self) -> bool {
124	matches!(self, Self::Init { .. })
125    }
126
127    pub fn is_error(&self) -> bool {
128	matches!(self, Self::Error(_))
129    }
130
131    pub fn is_refresh(&self) -> bool {
132	matches!(self, Self::Refresh { .. })
133    }
134
135    pub fn is_have_meta(&self) -> bool {
136	matches!(self, Self::HaveMeta { .. })
137    }
138
139    pub fn is_downloading(&self) -> bool {
140	matches!(self, Self::Downloading { .. })
141    }
142
143    pub fn is_complete(&self) -> bool {
144	matches!(self, Self::Complete { .. })
145    }
146
147    pub fn get_file_size(&self) -> Option<u64> {
148	match self {
149	    Self::None |
150	    Self::Init { .. }	=> None,
151
152	    Self::HaveMeta { file_size, .. }	=> *file_size,
153	    Self::Downloading { file_size, .. }	=> *file_size,
154
155	    Self::Complete { file_size, .. }	=> Some(*file_size),
156	    Self::Refresh { file_size, .. }	=> Some(*file_size),
157
158	    Self::Error(hint)	=> panic!("get_file_size called in error state ({hint})"),
159	}
160    }
161
162    pub fn get_cache_info(&self) -> Option<&http::CacheInfo> {
163	match self {
164	    Self::None |
165	    Self::Error(_) |
166	    Self::Init { .. }	=> None,
167
168	    Self::HaveMeta { cache_info, .. } |
169	    Self::Downloading { cache_info, .. } |
170	    Self::Complete { cache_info, .. } |
171	    Self::Refresh { cache_info, .. }	=> Some(cache_info),
172	}
173    }
174
175    fn read_file<'a>(file: &std::fs::File, ofs: u64, buf: &'a mut [MaybeUninit<u8>], max: u64) -> Result<&'a [u8]> {
176	use nix::libc;
177
178	assert!(max > ofs);
179
180	let len = (buf.len() as u64).min(max - ofs) as usize;
181	let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
182
183	// TODO: this would be nice, but does not work because we can not get
184	// a mutable reference to 'file'
185	//file.flush()?;
186
187	let rc = unsafe { libc::pread(file.as_raw_fd(), buf_ptr, len, ofs as i64) };
188
189	if rc < 0 {
190	    return Err(std::io::Error::last_os_error().into());
191	}
192
193        assert_eq!(rc as usize, len);
194
195	Ok(unsafe { buf[..rc as usize].assume_init() })
196    }
197
198    pub fn read<'a>(&self, ofs: u64, buf: &'a mut [MaybeUninit<u8>]) -> Result<Option<&'a [u8]>> {
199	match &self {
200	    State::Downloading { file, file_pos, .. } if ofs < *file_pos	=> {
201		Self::read_file(file, ofs, buf, *file_pos)
202	    },
203
204	    State::Complete { file, file_size, .. } if ofs < *file_size		=> {
205		Self::read_file(file, ofs, buf, *file_size)
206	    }
207
208	    State::Complete { file_size, .. } if ofs == *file_size	=> Ok(&[] as &[u8]),
209
210	    State::Complete { file_size, .. } if ofs >= *file_size	=>
211		Err(Error::Internal("file out-of-bound read")),
212
213	    _	=> return Ok(None)
214	}.map(Some)
215    }
216
217    pub fn is_outdated(&self, reftm: Instant, max_lt: Duration) -> bool {
218	match self.get_cache_info() {
219	    None	=> true,
220	    Some(info)	=> info.is_outdated(reftm, max_lt),
221	}
222    }
223}
224
225#[derive(Debug)]
226pub struct EntryData {
227    pub key:		url::Url,
228    state:		State,
229    reftm:		Time,
230}
231
232impl std::fmt::Display for EntryData {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234	f.write_fmt(format_args!("{}: reftm={}, state={}", self.key,
235		    pretty(&self.reftm.local), pretty(&self.state)))
236    }
237}
238
239impl EntryData {
240    pub fn new(url: &url::Url) -> Self {
241	Self {
242	    key:		url.clone(),
243	    state:		State::None,
244	    reftm:		Time::now(),
245	}
246    }
247
248    pub fn is_complete(&self) -> bool {
249	self.state.is_complete()
250    }
251
252    pub fn is_error(&self) -> bool {
253	self.state.is_error()
254    }
255
256    pub fn is_running(&self) -> bool {
257	self.state.is_have_meta() || self.state.is_downloading()
258    }
259
260    pub fn update_localtm(&mut self) {
261	self.reftm = Time::now();
262    }
263
264    pub fn set_response(&mut self, response: reqwest::Response) {
265	self.state = match self.state.take("set_respone") {
266	    State::None |
267	    State::Error(_)	=> State::Init { response },
268
269	    State::Complete { cache_info, file, file_size } |
270	    State::Refresh { cache_info, file, file_size, .. } => State::Refresh {
271		cache_info:	cache_info,
272		file:		file,
273		file_size:	file_size,
274		response:	response,
275	    },
276
277	    s			=> panic!("unexpected state {s:?}"),
278	}
279    }
280
281    pub fn is_outdated(&self, reftm: Instant, max_lt: Duration) -> bool {
282	self.state.is_outdated(reftm, max_lt)
283    }
284
285    pub fn get_cache_info(&self) -> Option<&http::CacheInfo> {
286	self.state.get_cache_info()
287    }
288
289    pub async fn fill_meta(&mut self) -> Result<()> {
290	if !self.state.is_init() && !self.state.is_none() && !self.state.is_refresh() {
291	    return Ok(());
292	}
293
294	self.state = match self.state.take("fill_meta") {
295	    State::None			=> panic!("unexpected state"),
296
297	    State::Init{ response }	=> {
298		let hdrs = response.headers();
299
300		State::HaveMeta {
301		    cache_info:	http::CacheInfo::new(self.reftm, hdrs)?,
302		    file_size:	response.content_length(),
303		    response:	response,
304		    stats:	Stats::default(),
305		}
306	    },
307
308	    State::Refresh { file, file_size, response, cache_info }	=> {
309		let hdrs = response.headers();
310
311		State::Complete {
312		    cache_info:	cache_info.update(self.reftm, hdrs)?,
313		    file:	file,
314		    file_size:	file_size,
315		}
316	    },
317
318	    _				=> unreachable!(),
319	};
320
321	Ok(())
322    }
323
324    fn signal_complete(&self, stats: Stats) {
325	if let State::Complete { file_size, .. } = self.state {
326	    info!("downloaded {} with {} bytes in {}ms", self.key, file_size, stats.tm.as_millis());
327	}
328    }
329
330    #[instrument(level = "trace")]
331    pub async fn get_filesize(&mut self) -> Result<u64> {
332	use std::io::Write;
333
334	if let Some(sz) = self.state.get_file_size() {
335	    return Ok(sz);
336	}
337
338	match self.state.take("get_filesize") {
339	    State::HaveMeta { mut response, file_size: None, mut stats, cache_info }	=> {
340		let mut file = Cache::new_file()?;
341		let mut pos = 0;
342
343
344		while let Some(chunk) = stats.chunk(&mut response).await? {
345		    pos += chunk.len() as u64;
346		    file.write_all(&chunk)?;
347		}
348
349		self.state = State::Complete {
350		    file:	file,
351		    file_size:	pos,
352		    cache_info:	cache_info,
353		};
354
355		self.signal_complete(stats);
356
357		Ok(pos)
358	    },
359
360	    State::Downloading { mut response, mut file, file_pos, file_size: None, mut stats, cache_info } => {
361		let mut pos = file_pos;
362
363		while let Some(chunk) = stats.chunk(&mut response).await? {
364		    pos += chunk.len() as u64;
365		    file.write_all(&chunk)?;
366		}
367
368		self.state = State::Complete {
369		    file:	file,
370		    file_size:	pos,
371		    cache_info:	cache_info,
372		};
373
374		self.signal_complete(stats);
375
376		Ok(pos)
377	    }
378
379	    s		=> panic!("unexpected state: {s:?}"),
380	}
381    }
382
383    pub fn fill_request(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
384	match self.state.get_cache_info() {
385	    Some(info)	=> info.fill_request(self.reftm.mono, req),
386	    None	=> req,
387	}
388    }
389
390    pub fn matches(&self, etag: Option<&str>) -> bool {
391	let cache_info = self.state.get_cache_info();
392
393	match cache_info.and_then(|c| c.not_after) {
394	    Some(t) if t < Time::now().mono		=> return false,
395	    _						=> {},
396	}
397
398	let self_etag = match cache_info {
399	    Some(c)	=> c.etag.as_ref(),
400	    None	=> None,
401	};
402
403	match (self_etag, etag) {
404	    (Some(a), Some(b)) if a == b		=> {},
405	    (None, None)				=> {},
406	    _						=> return false,
407	}
408
409	true
410    }
411
412    pub fn invalidate(&mut self)
413    {
414	match &self.state {
415	    State::Refresh { .. }	=> self.state = State::None,
416	    State::Complete { .. }	=> self.state = State::None,
417	    _				=> {},
418	}
419    }
420
421    pub async fn read_some<'a>(&mut self, ofs: u64, buf: &'a mut [MaybeUninit<u8>]) -> Result<&'a [u8]>
422    {
423	use std::io::Write;
424
425	trace!("state={:?}, ofs={}, #buf={}", self.state, ofs, buf.len());
426
427	async fn fetch<'a>(response: &mut reqwest::Response, file: &mut std::fs::File,
428		       buf: &'a mut [MaybeUninit<u8>], stats: &mut Stats) -> Result<(&'a[u8], usize)> {
429	    match stats.chunk(response).await? {
430		Some(data)	=> {
431		    let len = buf.len().min(data.len());
432                    let buf = buf[0..len].write_copy_of_slice_x(&data.as_ref()[0..len]);
433
434		    file.write_all(&data)?;
435
436		    // TODO: it would be better to do this in State::read_file()
437		    file.flush()?;
438
439		    Ok((buf, data.len()))
440		},
441
442		None		=> Ok((&[], 0))
443	    }
444	}
445
446	if self.state.is_init() {
447	    self.fill_meta().await?;
448	}
449
450	if let Some(data) = self.state.read(ofs, buf)? {
451            // TODO: this workarounds a bogus E0499 error; revisit after polonius
452            let data = {
453                let len = data.len();
454                unsafe { buf[..len].assume_init() }
455            };
456
457	    return Ok(data);
458	}
459
460	match self.state.take("read_some") {
461	    State::HaveMeta { mut response, cache_info, file_size, mut stats }	=> {
462		let mut file = Cache::new_file()?;
463
464		let res = fetch(&mut response, &mut file, buf, &mut stats).await?;
465
466		self.state = match res {
467		    (_, 0)	=> State::Complete {
468			cache_info:	cache_info,
469			file:		file,
470			file_size:	0,
471		    },
472
473		    (_, sz)	=> State::Downloading {
474			response:	response,
475			cache_info:	cache_info,
476			file_size:	file_size,
477			file:		file,
478			file_pos:	sz as u64,
479			stats:		stats,
480		    }
481		};
482
483		self.signal_complete(stats);
484
485		Ok(res.0)
486	    },
487
488	    // catched by 'self.state.read()' above
489	    State::Downloading { file_pos, .. } if ofs < file_pos	=> unreachable!(),
490
491	    State::Downloading { mut response, cache_info, file_size, mut file, file_pos, mut stats } => {
492		let res = fetch(&mut response, &mut file, buf, &mut stats).await?;
493
494		self.state = match res {
495		    (_, 0)	=> State::Complete {
496			cache_info:	cache_info,
497			file:		file,
498			file_size:	file_pos,
499		    },
500
501		    (_, sz)	=> State::Downloading {
502			response:	response,
503			cache_info:	cache_info,
504			file_size:	file_size,
505			file:		file,
506			file_pos:	file_pos + (sz as u64),
507			stats:		stats,
508		    }
509		};
510
511		self.signal_complete(stats);
512
513		Ok(res.0)
514	    }
515
516	    s		=> panic!("unexpected state: {s:?}"),
517	}
518
519    }
520}
521
522pub type Entry = Arc<RwLock<EntryData>>;
523
524struct CacheImpl {
525    tmpdir:	std::path::PathBuf,
526    entries:	HashMap<url::Url, Entry>,
527    client:	Arc<reqwest::Client>,
528    refcnt:	u32,
529
530    abort_ch:	Option<tokio::sync::watch::Sender<()>>,
531    gc:		Option<tokio::task::JoinHandle<()>>,
532}
533
534pub enum LookupResult {
535    Found(Entry),
536    Missing,
537}
538
539impl CacheImpl {
540    fn new() -> Self {
541        let client = reqwest::Client::builder()
542            .read_timeout(READ_TIMEOUT)
543            .connect_timeout(CONNECT_TIMEOUT)
544            .tcp_keepalive(KEEPALIVE)
545            .build()
546            .unwrap();
547
548	Self {
549	    tmpdir:	std::env::temp_dir(),
550	    entries:	HashMap::new(),
551	    client:	Arc::new(client),
552	    abort_ch:	None,
553	    refcnt:	0,
554	    gc:		None,
555	}
556    }
557
558    pub fn is_empty(&self) -> bool {
559	self.entries.is_empty()
560    }
561
562    pub fn clear(&mut self) {
563	self.entries.clear();
564    }
565
566    pub fn get_client(&self) -> Arc<reqwest::Client> {
567	self.client.clone()
568    }
569
570    pub fn lookup_or_create(&mut self, key: &url::Url) -> Entry {
571	match self.entries.get(key) {
572	    Some(v)	=> v.clone(),
573	    None	=> self.create(key),
574	}
575    }
576
577    pub fn create(&mut self, key: &url::Url) -> Entry {
578	Entry::new(RwLock::new(EntryData::new(key)))
579    }
580
581    pub fn replace(&mut self, key: &url::Url, entry: &Entry) {
582	self.entries.insert(key.clone(), entry.clone());
583    }
584
585    pub fn remove(&mut self, key: &url::Url) {
586	self.entries.remove(key);
587    }
588
589    /// Removes the `num` oldest cache entries
590    pub fn gc_oldest(&mut self, mut num: usize) {
591	if num == 0 {
592	    return;
593	}
594
595	let mut tmp = Vec::with_capacity(self.entries.len());
596
597	for (key, e) in &self.entries {
598	    let entry = match e.try_read() {
599		Ok(e)	=> e,
600		_	=> continue,
601	    };
602
603	    tmp.push((key.clone(), entry.get_cache_info().map(|c| c.local_time)));
604	}
605
606	tmp.sort_by(|(_, tm_a), (_, tm_b)| tm_a.cmp(tm_b));
607
608	let mut rm_cnt = 0;
609
610	for (key, _) in tmp {
611	    if num == 0 {
612		break;
613	    }
614
615	    debug!("gc: removing old {}", key);
616	    self.entries.remove(&key);
617	    num -= 1;
618	    rm_cnt += 1;
619	}
620
621	if rm_cnt > 0 {
622	    info!("gc: removed {} old entries", rm_cnt);
623	}
624    }
625
626    /// Removes cache entries which are older than `max_lt`.
627    ///
628    /// Returns the number of remaining cache entries.
629    pub fn gc_outdated(&mut self, max_lt: Duration) -> usize {
630	let now = Instant::now();
631	let mut outdated = Vec::new();
632	let mut cnt = 0;
633
634	for (key, e) in &self.entries {
635	    match e.try_read().map(|v| v.is_outdated(now, max_lt)) {
636		Ok(true)	=> outdated.push(key.clone()),
637		_		=> cnt += 1,
638	    }
639	}
640
641	let rm_cnt = outdated.len();
642
643	for e in outdated {
644	    debug!("gc: removing outdated {}", e);
645	    self.entries.remove(&e);
646	}
647
648	if rm_cnt > 0 {
649	    info!("gc: removed {} obsolete entries", rm_cnt);
650	}
651
652	cnt
653    }
654}
655
656#[derive(Debug)]
657pub struct GcProperties {
658    pub max_elements:	usize,
659    pub max_lifetime:	Duration,
660    pub sleep:		Duration,
661}
662
663async fn gc_runner(props: GcProperties, mut abort_ch: tokio::sync::watch::Receiver<()>) {
664    const RETRY_DELAY: std::time::Duration = std::time::Duration::from_secs(1);
665
666    loop {
667	use std::sync::TryLockError;
668
669	let sleep = {
670	    let cache = CACHE.try_write();
671
672	    match cache {
673		Ok(mut cache) if !cache.is_empty()	=> {
674		    let cache_cnt = cache.gc_outdated(props.max_lifetime);
675
676		    if cache_cnt > props.max_elements {
677			cache.gc_oldest(props.max_elements - cache_cnt)
678		    }
679
680		    props.sleep
681		}
682		Ok(_)				=> props.sleep,
683		Err(TryLockError::WouldBlock)	=> RETRY_DELAY,
684		Err(e)				=> {
685		    error!("cache gc failed with {:?}", e);
686		    break;
687		}
688	    }
689	};
690
691	if tokio::time::timeout(sleep, abort_ch.changed()).await.is_ok() {
692	    debug!("cache gc runner gracefully closed");
693	    break;
694	}
695    }
696}
697
698pub struct Cache();
699
700impl Cache {
701    #[instrument(level = "trace")]
702    pub fn instanciate(tmpdir: &std::path::Path, props: GcProperties) {
703	let mut cache = CACHE.write().unwrap();
704
705	if cache.refcnt == 0 {
706	    let (tx, rx) = tokio::sync::watch::channel(());
707
708	    cache.tmpdir = tmpdir.into();
709	    cache.abort_ch = Some(tx);
710
711	    cache.gc = Some(tokio::task::spawn(gc_runner(props, rx)));
712	}
713
714	cache.refcnt += 1;
715    }
716
717    #[instrument(level = "trace")]
718    // https://github.com/rust-lang/rust-clippy/issues/6446
719    #[allow(clippy::await_holding_lock)]
720    pub async fn close() {
721	let mut cache = CACHE.write().unwrap();
722
723	assert!(cache.refcnt > 0);
724
725	cache.refcnt -= 1;
726
727	if cache.refcnt == 0 {
728	    cache.entries.clear();
729
730	    let abort_ch = cache.abort_ch.take().unwrap();
731	    let gc = cache.gc.take().unwrap();
732
733	    drop(cache);
734
735	    abort_ch.send(()).unwrap();
736	    gc.await.unwrap();
737	}
738    }
739
740    #[instrument(level = "trace", ret)]
741    pub fn lookup_or_create(key: &url::Url) -> Entry {
742	let mut cache = CACHE.write().unwrap();
743
744	cache.lookup_or_create(key)
745    }
746
747    #[instrument(level = "trace", ret)]
748    pub fn create(key: &url::Url) -> Entry {
749	let mut cache = CACHE.write().unwrap();
750
751	cache.create(key)
752    }
753
754    #[instrument(level = "trace")]
755    pub fn replace(key: &url::Url, entry: &Entry) {
756	let mut cache = CACHE.write().unwrap();
757
758	cache.replace(key, entry)
759    }
760
761    #[instrument(level = "trace")]
762    pub fn remove(key: &url::Url) {
763	let mut cache = CACHE.write().unwrap();
764
765	cache.remove(key)
766    }
767
768    pub fn get_client() -> Arc<reqwest::Client> {
769	let cache = CACHE.read().unwrap();
770
771	cache.get_client()
772    }
773
774    pub fn new_file() -> Result<std::fs::File> {
775	let cache = CACHE.read().unwrap();
776
777	Ok(tempfile::tempfile_in(&cache.tmpdir)?)
778    }
779
780    pub async fn dump() {
781	let mut entries = Vec::new();
782
783	{
784	    let cache = CACHE.read().unwrap();
785
786	    entries.reserve(cache.entries.len());
787	    entries.extend(cache.entries.values().cloned());
788	}
789
790	println!("Cache information ({} entries)", entries.len());
791
792	for e in entries {
793	    println!("{}", e.read().await);
794	}
795    }
796
797    pub async fn clear() {
798	let mut cache = CACHE.write().unwrap();
799
800	cache.clear();
801    }
802}