data_source/
lib.rs

1use std::{collections::HashMap, io, path::Path, time::SystemTime};
2
3use log::{debug, warn};
4
5#[derive(thiserror::Error, Debug)]
6pub enum FetchError {
7    #[cfg(feature = "reqwest")]
8    #[error("reqwest err")]
9    R(#[from] reqwest::Error),
10    #[error("io err")]
11    I(#[from] io::Error),
12    #[error("time err")]
13    T(#[from] std::time::SystemTimeError),
14    #[error("size limit exceed")]
15    S,
16    #[error("no cache file")]
17    NC,
18    #[error("not found")]
19    NF,
20    #[error("not found in directories `{0:?}`")]
21    NFD(Vec<String>),
22}
23
24#[derive(Debug, Clone)]
25pub struct FileCache {
26    pub update_interval_seconds: Option<u64>,
27    pub cache_file_path: Option<String>,
28}
29
30impl FileCache {
31    pub fn read_cache_file(&self) -> Result<Vec<u8>, FetchError> {
32        let cf = self.cache_file_path.as_ref().unwrap();
33        let s: Vec<u8> = std::fs::read(cf)?;
34        Ok(s)
35    }
36
37    #[cfg(feature = "tokio")]
38    pub async fn read_cache_file_async(&self) -> Result<Vec<u8>, FetchError> {
39        let cf = self.cache_file_path.as_ref().unwrap();
40
41        let content = tokio::fs::read(cf).await?;
42        Ok(content)
43    }
44
45    pub fn write_cache_file(&self, bytes: &[u8]) -> bool {
46        let cf = self.cache_file_path.as_ref().unwrap();
47        if let Err(err) = std::fs::write(cf, bytes) {
48            warn!("Failed to write cache file: {err}");
49            false
50        } else {
51            true
52        }
53    }
54
55    #[cfg(feature = "tokio")]
56    pub async fn write_cache_file_async(&self, bytes: &[u8]) -> bool {
57        let cf = self.cache_file_path.as_ref().unwrap();
58        if let Err(err) = tokio::fs::write(cf, bytes).await {
59            warn!("Failed to write cache file: {err}");
60            false
61        } else {
62            true
63        }
64    }
65
66    /// 检查缓存文件是否超时
67    pub fn is_cache_timeout(&self) -> Result<Option<bool>, FetchError> {
68        if let Some(cf) = &self.cache_file_path {
69            if std::fs::exists(cf)? {
70                let mut expired = false;
71                if let Some(interval) = self.update_interval_seconds {
72                    let metadata = std::fs::metadata(cf)?;
73                    let last_modified = metadata.modified()?;
74                    let elapsed = SystemTime::now().duration_since(last_modified)?.as_secs();
75                    expired = elapsed > interval;
76                }
77                return Ok(Some(expired));
78            }
79            Ok(None)
80        } else {
81            Ok(None)
82        }
83    }
84}
85
86#[cfg(feature = "tokio")]
87#[async_trait::async_trait]
88pub trait AsyncSource: Send + Sync {
89    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError>;
90}
91
92pub trait SyncSource {
93    fn fetch(&self) -> Result<Vec<u8>, FetchError>;
94}
95
96#[cfg(feature = "tokio")]
97pub async fn fetch_with_cache_async(
98    fc: &FileCache,
99    s: &dyn AsyncSource,
100) -> Result<Vec<u8>, FetchError> {
101    if fc.is_cache_timeout()?.is_some_and(|timeout| !timeout) {
102        fc.read_cache_file_async().await
103    } else {
104        let d = s.fetch_async().await?;
105        if fc.cache_file_path.is_some() {
106            fc.write_cache_file_async(&d).await;
107        }
108        Ok(d)
109    }
110}
111pub fn fetch_with_cache(fc: &FileCache, s: &dyn SyncSource) -> Result<Vec<u8>, FetchError> {
112    if fc.is_cache_timeout()?.is_some_and(|timeout| !timeout) {
113        fc.read_cache_file()
114    } else {
115        let d = s.fetch()?;
116        if fc.cache_file_path.is_some() {
117            fc.write_cache_file(&d);
118        }
119        Ok(d)
120    }
121}
122
123#[cfg(feature = "tokio")]
124#[async_trait::async_trait]
125pub trait AsyncFolderSource: std::fmt::Debug {
126    async fn get_file_content_async(
127        &self,
128        file_name: &std::path::Path,
129    ) -> Result<(Vec<u8>, Option<String>), FetchError>;
130}
131
132pub trait SyncFolderSource: std::fmt::Debug {
133    fn get_file_content(
134        &self,
135        file_name: &std::path::Path,
136    ) -> Result<(Vec<u8>, Option<String>), FetchError>;
137}
138
139#[cfg(feature = "tar")]
140#[derive(Clone, Debug, Default)]
141pub struct TarFile(pub String);
142
143#[cfg(feature = "tar")]
144impl SyncFolderSource for TarFile {
145    fn get_file_content(
146        &self,
147        file_name: &std::path::Path,
148    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
149        let f = std::fs::File::open(&self.0)?;
150        get_file_from_tar_by_reader(file_name, f)
151    }
152}
153#[cfg(feature = "tokio-tar")]
154#[async_trait::async_trait]
155impl AsyncFolderSource for TarFile {
156    async fn get_file_content_async(
157        &self,
158        file_name: &std::path::Path,
159    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
160        let f = tokio::fs::File::open(&self.0).await?;
161        get_file_from_tar_by_reader_async(file_name, f).await
162    }
163}
164
165#[cfg(feature = "reqwest")]
166#[derive(Clone, Debug, Default)]
167pub struct HttpSource {
168    pub url: String,
169    pub proxy: Option<String>,
170    pub custom_request_headers: Option<Vec<(String, String)>>,
171    pub should_use_proxy: bool,
172    pub size_limit_bytes: Option<usize>,
173}
174
175#[cfg(feature = "reqwest")]
176impl HttpSource {
177    pub fn get(
178        &self,
179        c: reqwest::blocking::Client,
180    ) -> reqwest::Result<reqwest::blocking::Response> {
181        let mut rb = c.get(&self.url);
182        if let Some(h) = &self.custom_request_headers {
183            for h in h.iter() {
184                rb = rb.header(&h.0, &h.1);
185            }
186        }
187        rb.send()
188    }
189    pub fn set_proxy(
190        &self,
191        mut cb: reqwest::blocking::ClientBuilder,
192    ) -> reqwest::Result<reqwest::blocking::ClientBuilder> {
193        let ps = self.proxy.as_ref().unwrap();
194        let proxy = reqwest::Proxy::https(ps)?;
195        cb = cb.proxy(proxy);
196        let proxy = reqwest::Proxy::http(ps)?;
197        Ok(cb.proxy(proxy))
198    }
199}
200#[cfg(feature = "reqwest")]
201impl SyncSource for HttpSource {
202    fn fetch(&self) -> Result<Vec<u8>, FetchError> {
203        let mut cb = reqwest::blocking::ClientBuilder::new();
204        if self.should_use_proxy {
205            cb = self.set_proxy(cb)?;
206        }
207        let c = cb.build()?;
208        let r = self.get(c);
209        let r = match r {
210            Ok(r) => r,
211            Err(e) => {
212                if !self.should_use_proxy && self.proxy.is_some() {
213                    let mut cb = reqwest::blocking::ClientBuilder::new();
214                    cb = self.set_proxy(cb)?;
215                    let c = cb.build()?;
216                    self.get(c)?
217                } else {
218                    return Err(FetchError::R(e));
219                }
220            }
221        };
222        if let Some(sl) = self.size_limit_bytes {
223            if let Some(s) = r.content_length() {
224                if s as usize > sl {
225                    return Err(FetchError::S);
226                }
227            }
228        }
229        let b = r.bytes()?;
230        let v = b.to_vec();
231
232        Ok(v)
233    }
234}
235
236#[cfg(feature = "tokio")]
237#[cfg(feature = "reqwest")]
238impl HttpSource {
239    pub async fn get_async(&self, client: reqwest::Client) -> reqwest::Result<reqwest::Response> {
240        let mut request = client.get(&self.url);
241        if let Some(headers) = &self.custom_request_headers {
242            for (key, value) in headers {
243                request = request.header(key, value);
244            }
245        }
246        request.send().await
247    }
248
249    pub fn set_proxy_async(
250        &self,
251        client_builder: reqwest::ClientBuilder,
252    ) -> reqwest::Result<reqwest::ClientBuilder> {
253        let proxy = self.proxy.as_ref().unwrap();
254        let client_builder = client_builder.proxy(reqwest::Proxy::http(proxy)?);
255        let client_builder = client_builder.proxy(reqwest::Proxy::https(proxy)?);
256        Ok(client_builder)
257    }
258}
259
260#[cfg(feature = "tokio")]
261#[cfg(feature = "reqwest")]
262#[async_trait::async_trait]
263impl AsyncSource for HttpSource {
264    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError> {
265        let client_builder = reqwest::ClientBuilder::new();
266        let client_builder = if self.should_use_proxy {
267            self.set_proxy_async(client_builder)?
268        } else {
269            client_builder
270        };
271        let client = client_builder.build()?;
272
273        let r = self.get_async(client).await;
274        let response = match r {
275            Ok(r) => r,
276            Err(e) => {
277                if !self.should_use_proxy && self.proxy.is_some() {
278                    let mut cb = reqwest::ClientBuilder::new();
279                    cb = self.set_proxy_async(cb)?;
280                    let c = cb.build()?;
281                    self.get_async(c).await?
282                } else {
283                    return Err(FetchError::R(e));
284                }
285            }
286        };
287        if let Some(size_limit) = self.size_limit_bytes {
288            if let Some(content_length) = response.content_length() {
289                if content_length as usize > size_limit {
290                    return Err(FetchError::S);
291                }
292            }
293        }
294
295        let bytes = response.bytes().await?.to_vec();
296
297        Ok(bytes)
298    }
299}
300
301pub trait GetPath {
302    fn get_path(&self) -> Option<String> {
303        None
304    }
305}
306
307#[derive(Debug)]
308pub enum SingleFileSource {
309    #[cfg(feature = "reqwest")]
310    Http(HttpSource, FileCache),
311    FilePath(String),
312    Inline(Vec<u8>),
313}
314impl Default for SingleFileSource {
315    fn default() -> Self {
316        Self::Inline(Vec::new())
317    }
318}
319
320impl GetPath for SingleFileSource {
321    fn get_path(&self) -> Option<String> {
322        match self {
323            #[cfg(feature = "reqwest")]
324            SingleFileSource::Http(http_source, _fc) => Some(http_source.url.clone()),
325            SingleFileSource::FilePath(p) => Some(p.clone()),
326            SingleFileSource::Inline(_ec) => None,
327        }
328    }
329}
330
331#[cfg(feature = "tokio")]
332#[async_trait::async_trait]
333impl AsyncSource for SingleFileSource {
334    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError> {
335        match self {
336            #[cfg(feature = "reqwest")]
337            SingleFileSource::Http(http_source, fc) => {
338                fetch_with_cache_async(fc, http_source).await
339            }
340            SingleFileSource::FilePath(f) => {
341                let s: Vec<u8> = tokio::fs::read(f).await?;
342                Ok(s)
343            }
344            SingleFileSource::Inline(v) => Ok(v.clone()),
345        }
346    }
347}
348
349impl SyncSource for SingleFileSource {
350    fn fetch(&self) -> Result<Vec<u8>, FetchError> {
351        match self {
352            #[cfg(feature = "reqwest")]
353            SingleFileSource::Http(http_source, fc) => fetch_with_cache(fc, http_source),
354            SingleFileSource::FilePath(f) => {
355                let s: Vec<u8> = std::fs::read(f)?;
356                Ok(s)
357            }
358            SingleFileSource::Inline(v) => Ok(v.clone()),
359        }
360    }
361}
362
363/// Defines where to get the content of the requested file name.
364///
365/// 很多配置中 都要再加载其他外部文件,
366/// FileSource 限定了 查找文件的 路径 和 来源, 读取文件时只会限制在这个范围内,
367/// 这样就增加了安全性
368#[derive(Debug, Default)]
369pub enum DataSource {
370    #[default]
371    StdReadFile,
372    ///从指定的一组路径来寻找文件
373    Folders(Vec<String>),
374    /// 从一个 已放到内存中的 tar 中 寻找文件
375    #[cfg(feature = "tar")]
376    TarInMemory(Vec<u8>),
377    #[cfg(feature = "tar")]
378    TarFile(TarFile),
379
380    /// 与其它方式不同,FileMap 存储名称的映射表, 无需遍历目录
381    FileMap(HashMap<String, SingleFileSource>),
382
383    Sync(Box<dyn SyncFolderSource + Send + Sync>),
384    #[cfg(feature = "tokio")]
385    Async(Box<dyn AsyncFolderSource + Send + Sync>),
386}
387
388impl DataSource {
389    pub fn insert_current_working_dir(&mut self) -> io::Result<()> {
390        if let DataSource::Folders(ref mut v) = self {
391            v.push(std::env::current_dir()?.to_string_lossy().to_string())
392        }
393        Ok(())
394    }
395
396    pub fn read_to_string<P>(&self, file_name: P) -> Result<String, FetchError>
397    where
398        P: AsRef<std::path::Path>,
399    {
400        let r = SyncFolderSource::get_file_content(self, file_name.as_ref())?;
401        Ok(String::from_utf8_lossy(r.0.as_slice()).to_string())
402    }
403}
404#[cfg(feature = "tokio")]
405#[async_trait::async_trait]
406impl AsyncFolderSource for DataSource {
407    /// 返回读到的 数据。可能还会返回 成功找到的路径
408    async fn get_file_content_async(
409        &self,
410        file_name: &Path,
411    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
412        match self {
413            DataSource::Async(source) => source.get_file_content_async(file_name).await,
414
415            DataSource::Sync(source) => source.get_file_content(file_name),
416            #[cfg(feature = "tar")]
417            DataSource::TarInMemory(tar_binary) => {
418                get_file_from_tar_in_memory(file_name, tar_binary)
419            }
420            #[cfg(feature = "tokio-tar")]
421            DataSource::TarFile(tf) => tf.get_file_content_async(file_name).await,
422
423            DataSource::Folders(possible_addrs) => {
424                for dir in possible_addrs {
425                    let real_file_name = std::path::Path::new(dir).join(file_name);
426
427                    if real_file_name.exists() {
428                        return Ok(tokio::fs::read(&real_file_name)
429                            .await
430                            .map(|v| (v, Some(dir.to_owned())))?);
431                    }
432                }
433                Err(FetchError::NFD(possible_addrs.clone()))
434            }
435            DataSource::StdReadFile => {
436                let s: Vec<u8> = tokio::fs::read(file_name).await?;
437                Ok((s, None))
438            }
439
440            DataSource::FileMap(map) => {
441                let r = map.get(&file_name.to_string_lossy().to_string());
442
443                match r {
444                    Some(sf) => sf.fetch_async().await.map(|d| (d, sf.get_path())),
445                    None => Err(FetchError::NF),
446                }
447            }
448        }
449    }
450}
451
452impl SyncFolderSource for DataSource {
453    /// 返回读到的 数据。可能还会返回 成功找到的路径
454    fn get_file_content(&self, file_name: &Path) -> Result<(Vec<u8>, Option<String>), FetchError> {
455        match self {
456            DataSource::Sync(source) => source.get_file_content(file_name),
457
458            #[cfg(feature = "tokio")]
459            DataSource::Async(source) => {
460                tokio::runtime::Handle::current().block_on(source.get_file_content_async(file_name))
461            }
462
463            #[cfg(feature = "tar")]
464            DataSource::TarInMemory(tar_binary) => {
465                get_file_from_tar_in_memory(file_name, tar_binary)
466            }
467            #[cfg(feature = "tar")]
468            DataSource::TarFile(tf) => tf.get_file_content(file_name),
469
470            DataSource::Folders(possible_addrs) => {
471                for dir in possible_addrs {
472                    let real_file_name = std::path::Path::new(dir).join(file_name);
473
474                    if real_file_name.exists() {
475                        return Ok(
476                            std::fs::read(&real_file_name).map(|v| (v, Some(dir.to_owned())))?
477                        );
478                    }
479                }
480                Err(FetchError::NFD(possible_addrs.clone()))
481            }
482            DataSource::StdReadFile => {
483                let s: Vec<u8> = std::fs::read(file_name)?;
484                Ok((s, None))
485            }
486
487            DataSource::FileMap(map) => {
488                let r = map.get(&file_name.to_string_lossy().to_string());
489
490                match r {
491                    Some(sf) => sf.fetch().map(|d| (d, sf.get_path())),
492                    None => Err(FetchError::NF),
493                }
494            }
495        }
496    }
497}
498
499#[cfg(feature = "tokio-tar")]
500pub async fn get_file_from_tar_by_reader_async<P, R>(
501    file_name_in_tar: P,
502    reader: R,
503) -> Result<(Vec<u8>, Option<String>), FetchError>
504where
505    P: AsRef<std::path::Path>,
506    R: tokio::io::AsyncRead + Unpin,
507{
508    let mut a = tokio_tar::Archive::new(reader);
509
510    let mut es = a.entries().unwrap();
511
512    use futures::StreamExt;
513    use tokio::io::AsyncReadExt;
514
515    while let Some(file) = es.next().await {
516        let mut f = file.unwrap();
517        let p = f.path().unwrap();
518        if p.eq(file_name_in_tar.as_ref()) {
519            debug!("found {}", file_name_in_tar.as_ref().to_str().unwrap());
520            let ps = p.to_string_lossy().to_string();
521            let mut result = vec![];
522
523            f.read_to_end(&mut result).await?;
524            return Ok((result, Some(ps)));
525        }
526    }
527    Err(FetchError::NF)
528}
529#[cfg(feature = "tar")]
530pub fn get_file_from_tar_by_reader<P, R>(
531    file_name_in_tar: P,
532    reader: R,
533) -> Result<(Vec<u8>, Option<String>), FetchError>
534where
535    P: AsRef<std::path::Path>,
536    R: std::io::Read,
537{
538    let mut a = tar::Archive::new(reader);
539
540    let mut e = a
541        .entries()
542        .unwrap()
543        .find(|a| {
544            a.as_ref()
545                .is_ok_and(|b| b.path().is_ok_and(|c| c == file_name_in_tar.as_ref()))
546        })
547        .ok_or_else(|| {
548            io::Error::new(
549                io::ErrorKind::NotFound,
550                format!(
551                    "get_file_from_tar: can't find the file, {}",
552                    file_name_in_tar.as_ref().to_str().unwrap()
553                ),
554            )
555        })??;
556
557    debug!("found {}", file_name_in_tar.as_ref().to_str().unwrap());
558
559    let mut result = vec![];
560    use std::io::Read;
561    e.read_to_end(&mut result)?;
562    Ok((
563        result,
564        Some(e.path().unwrap().to_str().unwrap().to_string()),
565    ))
566}
567#[cfg(feature = "tar")]
568pub fn get_file_from_tar_in_memory<P>(
569    file_name_in_tar: P,
570    tar_binary: &Vec<u8>,
571) -> Result<(Vec<u8>, Option<String>), FetchError>
572where
573    P: AsRef<std::path::Path>,
574{
575    debug!(
576        "finding {} from tar, tar whole size is {}",
577        file_name_in_tar.as_ref().to_str().unwrap(),
578        tar_binary.len()
579    );
580    let r = std::io::Cursor::new(tar_binary);
581    get_file_from_tar_by_reader(file_name_in_tar, r)
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use std::fs::{self, File};
588    use std::io::Write;
589    use tempfile::TempDir;
590
591    #[cfg(feature = "reqwest")]
592    use reqwest::blocking::Client;
593
594    const URL: &str = "https://www.rust-lang.org";
595
596    #[cfg(feature = "tokio")]
597    #[cfg(feature = "reqwest")]
598    #[tokio::test]
599    async fn test_http_source_fetch_async() {
600        let http_source = HttpSource {
601            url: URL.to_string(),
602            should_use_proxy: false,
603            ..Default::default()
604        };
605
606        let result = http_source.fetch_async().await;
607        assert!(result.is_ok());
608        assert!(!result.unwrap().is_empty());
609    }
610
611    #[cfg(feature = "reqwest")]
612    #[test]
613    fn test_http_source_fetch() {
614        let http_source = HttpSource {
615            url: URL.to_string(),
616            should_use_proxy: false,
617            ..Default::default()
618        };
619
620        let client = Client::new();
621        let result = http_source.get(client);
622        assert!(result.is_ok());
623    }
624
625    #[test]
626    fn test_data_source_read_from_folders() {
627        let temp_dir = TempDir::new().unwrap();
628        let file_path = temp_dir.path().join("test.txt");
629
630        fs::write(&file_path, "hello world").unwrap();
631
632        let data_source = DataSource::Folders(vec![temp_dir.path().to_string_lossy().to_string()]);
633
634        let content = data_source.read_to_string("test.txt").unwrap();
635        assert_eq!(content, "hello world");
636    }
637
638    #[test]
639    fn test_data_source_read_from_file_map() {
640        let file_map = vec![(
641            "config.json".to_string(),
642            SingleFileSource::Inline(b"{\"key\": \"value\"}".to_vec()),
643        )]
644        .into_iter()
645        .collect();
646
647        let data_source = DataSource::FileMap(file_map);
648
649        let content = data_source.read_to_string("config.json").unwrap();
650        assert_eq!(content, "{\"key\": \"value\"}");
651    }
652    use std::path::PathBuf;
653    #[cfg(feature = "tar")]
654    fn gentar() -> (TempDir, PathBuf, &'static str, &'static str) {
655        let temp_dir = TempDir::new().unwrap();
656        let tar_path = temp_dir.path().join("test.tar");
657
658        let mut tar_builder = tar::Builder::new(File::create(&tar_path).unwrap());
659
660        let mut file = tempfile::NamedTempFile::new().unwrap();
661        let c = "hello tar\n";
662        write!(file, "{}", c).unwrap();
663        let file_path = file.path().to_owned();
664
665        let tfn = "test.txt";
666
667        tar_builder.append_path_with_name(&file_path, tfn).unwrap();
668        tar_builder.finish().unwrap();
669
670        (temp_dir, tar_path, tfn, c)
671    }
672
673    #[cfg(feature = "tar")]
674    #[test]
675    fn test_get_file_from_tar() {
676        let (_td, tar_path, tfn, c) = gentar(); // 不能命名为 _,
677                                                // 后面要加长,不然变量会被自动drop掉, Tempdir drop时会自动删除里面的内容
678
679        let tar_data = fs::read(&tar_path).unwrap();
680        let result = get_file_from_tar_in_memory(tfn, &tar_data);
681
682        assert!(result.is_ok());
683        let (content, path) = result.unwrap();
684        assert_eq!(String::from_utf8_lossy(&content), c);
685        assert_eq!(path.unwrap(), tfn);
686    }
687    #[cfg(feature = "tokio-tar")]
688    #[tokio::test]
689    async fn test_get_file_from_tar_async() -> Result<(), FetchError> {
690        let (_td, tar_path, tfn, c) = gentar();
691
692        let f = tokio::fs::File::open(tar_path).await?;
693        let result = get_file_from_tar_by_reader_async(tfn, f).await?;
694
695        let (content, path) = result;
696        assert_eq!(String::from_utf8_lossy(&content), c);
697        assert_eq!(path.unwrap(), tfn);
698        Ok(())
699    }
700}