data_source/
lib.rs

1use std::{collections::HashMap, io, path::Path, time::SystemTime};
2
3use log::{debug, warn};
4
5#[derive(thiserror::Error, Debug)]
6pub enum FetchError {
7    #[cfg(feature = "reqwest")]
8    #[error("reqwest err")]
9    R(#[from] reqwest::Error),
10    #[error("io err")]
11    I(#[from] io::Error),
12    #[error("time err")]
13    T(#[from] std::time::SystemTimeError),
14    #[error("size limit exceed")]
15    S,
16    #[error("no cache file")]
17    NC,
18    #[error("not found")]
19    NF,
20    #[error("not found in directories `{0:?}`")]
21    NFD(Vec<String>),
22}
23
24impl From<FetchError> for io::Error {
25    fn from(value: FetchError) -> Self {
26        match value {
27            FetchError::R(error) => io::Error::other(error),
28            FetchError::I(error) => error,
29            FetchError::T(error) => io::Error::other(error),
30            FetchError::S => io::Error::other(value.to_string()),
31            FetchError::NC => io::Error::other(value.to_string()),
32            FetchError::NF => io::Error::new(io::ErrorKind::NotFound, ""),
33            FetchError::NFD(_) => io::Error::other(value.to_string()),
34        }
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct FileCache {
40    pub update_interval_seconds: Option<u64>,
41    pub cache_file_path: Option<String>,
42}
43
44impl FileCache {
45    pub fn read_cache_file(&self) -> Result<Vec<u8>, FetchError> {
46        let cf = self.cache_file_path.as_ref().unwrap();
47        let s: Vec<u8> = std::fs::read(cf)?;
48        Ok(s)
49    }
50
51    #[cfg(feature = "tokio")]
52    pub async fn read_cache_file_async(&self) -> Result<Vec<u8>, FetchError> {
53        let cf = self.cache_file_path.as_ref().unwrap();
54
55        let content = tokio::fs::read(cf).await?;
56        Ok(content)
57    }
58
59    pub fn write_cache_file(&self, bytes: &[u8]) -> bool {
60        let cf = self.cache_file_path.as_ref().unwrap();
61        if let Err(err) = std::fs::write(cf, bytes) {
62            warn!("Failed to write cache file: {err}");
63            false
64        } else {
65            true
66        }
67    }
68
69    #[cfg(feature = "tokio")]
70    pub async fn write_cache_file_async(&self, bytes: &[u8]) -> bool {
71        let cf = self.cache_file_path.as_ref().unwrap();
72        if let Err(err) = tokio::fs::write(cf, bytes).await {
73            warn!("Failed to write cache file: {err}");
74            false
75        } else {
76            true
77        }
78    }
79
80    /// 检查缓存文件是否超时
81    pub fn is_cache_timeout(&self) -> Result<Option<bool>, FetchError> {
82        if let Some(cf) = &self.cache_file_path {
83            if std::fs::exists(cf)? {
84                let mut expired = false;
85                if let Some(interval) = self.update_interval_seconds {
86                    let metadata = std::fs::metadata(cf)?;
87                    let last_modified = metadata.modified()?;
88                    let elapsed = SystemTime::now().duration_since(last_modified)?.as_secs();
89                    expired = elapsed > interval;
90                }
91                return Ok(Some(expired));
92            }
93            Ok(None)
94        } else {
95            Ok(None)
96        }
97    }
98}
99
100#[cfg(feature = "tokio")]
101#[async_trait::async_trait]
102pub trait AsyncSource: Send + Sync {
103    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError>;
104}
105
106pub trait SyncSource {
107    fn fetch(&self) -> Result<Vec<u8>, FetchError>;
108}
109
110#[cfg(feature = "tokio")]
111pub async fn fetch_with_cache_async(
112    fc: &FileCache,
113    s: &dyn AsyncSource,
114) -> Result<Vec<u8>, FetchError> {
115    if fc.is_cache_timeout()?.is_some_and(|timeout| !timeout) {
116        fc.read_cache_file_async().await
117    } else {
118        let d = s.fetch_async().await?;
119        if fc.cache_file_path.is_some() {
120            fc.write_cache_file_async(&d).await;
121        }
122        Ok(d)
123    }
124}
125pub fn fetch_with_cache(fc: &FileCache, s: &dyn SyncSource) -> Result<Vec<u8>, FetchError> {
126    if fc.is_cache_timeout()?.is_some_and(|timeout| !timeout) {
127        fc.read_cache_file()
128    } else {
129        let d = s.fetch()?;
130        if fc.cache_file_path.is_some() {
131            fc.write_cache_file(&d);
132        }
133        Ok(d)
134    }
135}
136
137#[cfg(feature = "tokio")]
138#[async_trait::async_trait]
139pub trait AsyncFolderSource: std::fmt::Debug {
140    async fn get_file_content_async(
141        &self,
142        file_name: &std::path::Path,
143    ) -> Result<(Vec<u8>, Option<String>), FetchError>;
144}
145
146pub trait SyncFolderSource: std::fmt::Debug {
147    fn get_file_content(
148        &self,
149        file_name: &std::path::Path,
150    ) -> Result<(Vec<u8>, Option<String>), FetchError>;
151}
152
153#[cfg(feature = "tar")]
154#[derive(Clone, Debug, Default)]
155pub struct TarFile(pub String);
156
157#[cfg(feature = "tar")]
158impl SyncFolderSource for TarFile {
159    fn get_file_content(
160        &self,
161        file_name: &std::path::Path,
162    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
163        let f = std::fs::File::open(&self.0)?;
164        get_file_from_tar_by_reader(file_name, f)
165    }
166}
167#[cfg(feature = "tokio-tar")]
168#[async_trait::async_trait]
169impl AsyncFolderSource for TarFile {
170    async fn get_file_content_async(
171        &self,
172        file_name: &std::path::Path,
173    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
174        let f = tokio::fs::File::open(&self.0).await?;
175        get_file_from_tar_by_reader_async(file_name, f).await
176    }
177}
178
179#[cfg(feature = "reqwest")]
180#[derive(Clone, Debug, Default)]
181pub struct HttpSource {
182    pub url: String,
183    pub proxy: Option<String>,
184    pub custom_request_headers: Option<Vec<(String, String)>>,
185    pub should_use_proxy: bool,
186    pub size_limit_bytes: Option<usize>,
187}
188
189#[cfg(feature = "reqwest")]
190impl HttpSource {
191    pub fn get(
192        &self,
193        c: reqwest::blocking::Client,
194    ) -> reqwest::Result<reqwest::blocking::Response> {
195        let mut rb = c.get(&self.url);
196        if let Some(h) = &self.custom_request_headers {
197            for h in h.iter() {
198                rb = rb.header(&h.0, &h.1);
199            }
200        }
201        rb.send()
202    }
203    pub fn set_proxy(
204        &self,
205        mut cb: reqwest::blocking::ClientBuilder,
206    ) -> reqwest::Result<reqwest::blocking::ClientBuilder> {
207        let ps = self.proxy.as_ref().unwrap();
208        let proxy = reqwest::Proxy::https(ps)?;
209        cb = cb.proxy(proxy);
210        let proxy = reqwest::Proxy::http(ps)?;
211        Ok(cb.proxy(proxy))
212    }
213}
214#[cfg(feature = "reqwest")]
215impl SyncSource for HttpSource {
216    fn fetch(&self) -> Result<Vec<u8>, FetchError> {
217        let mut cb = reqwest::blocking::ClientBuilder::new();
218        if self.should_use_proxy {
219            cb = self.set_proxy(cb)?;
220        }
221        let c = cb.build()?;
222        let r = self.get(c);
223        let r = match r {
224            Ok(r) => r,
225            Err(e) => {
226                if !self.should_use_proxy && self.proxy.is_some() {
227                    let mut cb = reqwest::blocking::ClientBuilder::new();
228                    cb = self.set_proxy(cb)?;
229                    let c = cb.build()?;
230                    self.get(c)?
231                } else {
232                    return Err(FetchError::R(e));
233                }
234            }
235        };
236        if let Some(sl) = self.size_limit_bytes {
237            if let Some(s) = r.content_length() {
238                if s as usize > sl {
239                    return Err(FetchError::S);
240                }
241            }
242        }
243        let b = r.bytes()?;
244        let v = b.to_vec();
245
246        Ok(v)
247    }
248}
249
250#[cfg(feature = "tokio")]
251#[cfg(feature = "reqwest")]
252impl HttpSource {
253    pub async fn get_async(&self, client: reqwest::Client) -> reqwest::Result<reqwest::Response> {
254        let mut request = client.get(&self.url);
255        if let Some(headers) = &self.custom_request_headers {
256            for (key, value) in headers {
257                request = request.header(key, value);
258            }
259        }
260        request.send().await
261    }
262
263    pub fn set_proxy_async(
264        &self,
265        client_builder: reqwest::ClientBuilder,
266    ) -> reqwest::Result<reqwest::ClientBuilder> {
267        let proxy = self.proxy.as_ref().unwrap();
268        let client_builder = client_builder.proxy(reqwest::Proxy::http(proxy)?);
269        let client_builder = client_builder.proxy(reqwest::Proxy::https(proxy)?);
270        Ok(client_builder)
271    }
272}
273
274#[cfg(feature = "tokio")]
275#[cfg(feature = "reqwest")]
276#[async_trait::async_trait]
277impl AsyncSource for HttpSource {
278    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError> {
279        let client_builder = reqwest::ClientBuilder::new();
280        let client_builder = if self.should_use_proxy {
281            self.set_proxy_async(client_builder)?
282        } else {
283            client_builder
284        };
285        let client = client_builder.build()?;
286
287        let r = self.get_async(client).await;
288        let response = match r {
289            Ok(r) => r,
290            Err(e) => {
291                if !self.should_use_proxy && self.proxy.is_some() {
292                    let mut cb = reqwest::ClientBuilder::new();
293                    cb = self.set_proxy_async(cb)?;
294                    let c = cb.build()?;
295                    self.get_async(c).await?
296                } else {
297                    return Err(FetchError::R(e));
298                }
299            }
300        };
301        if let Some(size_limit) = self.size_limit_bytes {
302            if let Some(content_length) = response.content_length() {
303                if content_length as usize > size_limit {
304                    return Err(FetchError::S);
305                }
306            }
307        }
308
309        let bytes = response.bytes().await?.to_vec();
310
311        Ok(bytes)
312    }
313}
314
315pub trait GetPath {
316    fn get_path(&self) -> Option<String> {
317        None
318    }
319}
320
321#[derive(Debug)]
322pub enum SingleFileSource {
323    #[cfg(feature = "reqwest")]
324    Http(HttpSource, FileCache),
325    FilePath(String),
326    Inline(Vec<u8>),
327}
328impl Default for SingleFileSource {
329    fn default() -> Self {
330        Self::Inline(Vec::new())
331    }
332}
333
334impl GetPath for SingleFileSource {
335    fn get_path(&self) -> Option<String> {
336        match self {
337            #[cfg(feature = "reqwest")]
338            SingleFileSource::Http(http_source, _fc) => Some(http_source.url.clone()),
339            SingleFileSource::FilePath(p) => Some(p.clone()),
340            SingleFileSource::Inline(_ec) => None,
341        }
342    }
343}
344
345#[cfg(feature = "tokio")]
346#[async_trait::async_trait]
347impl AsyncSource for SingleFileSource {
348    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError> {
349        match self {
350            #[cfg(feature = "reqwest")]
351            SingleFileSource::Http(http_source, fc) => {
352                fetch_with_cache_async(fc, http_source).await
353            }
354            SingleFileSource::FilePath(f) => {
355                let s: Vec<u8> = tokio::fs::read(f).await?;
356                Ok(s)
357            }
358            SingleFileSource::Inline(v) => Ok(v.clone()),
359        }
360    }
361}
362
363impl SyncSource for SingleFileSource {
364    fn fetch(&self) -> Result<Vec<u8>, FetchError> {
365        match self {
366            #[cfg(feature = "reqwest")]
367            SingleFileSource::Http(http_source, fc) => fetch_with_cache(fc, http_source),
368            SingleFileSource::FilePath(f) => {
369                let s: Vec<u8> = std::fs::read(f)?;
370                Ok(s)
371            }
372            SingleFileSource::Inline(v) => Ok(v.clone()),
373        }
374    }
375}
376
377/// Defines where to get the content of the requested file name.
378///
379/// 很多配置中 都要再加载其他外部文件,
380/// FileSource 限定了 查找文件的 路径 和 来源, 读取文件时只会限制在这个范围内,
381/// 这样就增加了安全性
382#[derive(Debug, Default)]
383pub enum DataSource {
384    #[default]
385    StdReadFile,
386    ///从指定的一组路径来寻找文件
387    Folders(Vec<String>),
388    /// 从一个 已放到内存中的 tar 中 寻找文件
389    #[cfg(feature = "tar")]
390    TarInMemory(Vec<u8>),
391    #[cfg(feature = "tar")]
392    TarFile(TarFile),
393
394    /// 与其它方式不同,FileMap 存储名称的映射表, 无需遍历目录
395    FileMap(HashMap<String, SingleFileSource>),
396
397    Sync(Box<dyn SyncFolderSource + Send + Sync>),
398    #[cfg(feature = "tokio")]
399    Async(Box<dyn AsyncFolderSource + Send + Sync>),
400}
401
402impl DataSource {
403    pub fn insert_current_working_dir(&mut self) -> io::Result<()> {
404        if let DataSource::Folders(ref mut v) = self {
405            v.push(std::env::current_dir()?.to_string_lossy().to_string())
406        }
407        Ok(())
408    }
409
410    pub fn read_to_string<P>(&self, file_name: P) -> Result<String, FetchError>
411    where
412        P: AsRef<std::path::Path>,
413    {
414        let r = SyncFolderSource::get_file_content(self, file_name.as_ref())?;
415        Ok(String::from_utf8_lossy(r.0.as_slice()).to_string())
416    }
417}
418#[cfg(feature = "tokio")]
419#[async_trait::async_trait]
420impl AsyncFolderSource for DataSource {
421    /// 返回读到的 数据。可能还会返回 成功找到的路径
422    async fn get_file_content_async(
423        &self,
424        file_name: &Path,
425    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
426        match self {
427            DataSource::Async(source) => source.get_file_content_async(file_name).await,
428
429            DataSource::Sync(source) => source.get_file_content(file_name),
430            #[cfg(feature = "tar")]
431            DataSource::TarInMemory(tar_binary) => {
432                get_file_from_tar_in_memory(file_name, tar_binary)
433            }
434            #[cfg(feature = "tokio-tar")]
435            DataSource::TarFile(tf) => tf.get_file_content_async(file_name).await,
436
437            DataSource::Folders(possible_addrs) => {
438                for dir in possible_addrs {
439                    let real_file_name = std::path::Path::new(dir).join(file_name);
440
441                    if real_file_name.exists() {
442                        return Ok(tokio::fs::read(&real_file_name)
443                            .await
444                            .map(|v| (v, Some(dir.to_owned())))?);
445                    }
446                }
447                Err(FetchError::NFD(possible_addrs.clone()))
448            }
449            DataSource::StdReadFile => {
450                let s: Vec<u8> = tokio::fs::read(file_name).await?;
451                Ok((s, None))
452            }
453
454            DataSource::FileMap(map) => {
455                let r = map.get(&file_name.to_string_lossy().to_string());
456
457                match r {
458                    Some(sf) => sf.fetch_async().await.map(|d| (d, sf.get_path())),
459                    None => Err(FetchError::NF),
460                }
461            }
462        }
463    }
464}
465
466impl SyncFolderSource for DataSource {
467    /// 返回读到的 数据。可能还会返回 成功找到的路径
468    fn get_file_content(&self, file_name: &Path) -> Result<(Vec<u8>, Option<String>), FetchError> {
469        match self {
470            DataSource::Sync(source) => source.get_file_content(file_name),
471
472            #[cfg(feature = "tokio")]
473            DataSource::Async(source) => {
474                tokio::runtime::Handle::current().block_on(source.get_file_content_async(file_name))
475            }
476
477            #[cfg(feature = "tar")]
478            DataSource::TarInMemory(tar_binary) => {
479                get_file_from_tar_in_memory(file_name, tar_binary)
480            }
481            #[cfg(feature = "tar")]
482            DataSource::TarFile(tf) => tf.get_file_content(file_name),
483
484            DataSource::Folders(possible_addrs) => {
485                for dir in possible_addrs {
486                    let real_file_name = std::path::Path::new(dir).join(file_name);
487
488                    if real_file_name.exists() {
489                        return Ok(
490                            std::fs::read(&real_file_name).map(|v| (v, Some(dir.to_owned())))?
491                        );
492                    }
493                }
494                Err(FetchError::NFD(possible_addrs.clone()))
495            }
496            DataSource::StdReadFile => {
497                let s: Vec<u8> = std::fs::read(file_name)?;
498                Ok((s, None))
499            }
500
501            DataSource::FileMap(map) => {
502                let r = map.get(&file_name.to_string_lossy().to_string());
503
504                match r {
505                    Some(sf) => sf.fetch().map(|d| (d, sf.get_path())),
506                    None => Err(FetchError::NF),
507                }
508            }
509        }
510    }
511}
512
513#[cfg(feature = "tokio-tar")]
514pub async fn get_file_from_tar_by_reader_async<P, R>(
515    file_name_in_tar: P,
516    reader: R,
517) -> Result<(Vec<u8>, Option<String>), FetchError>
518where
519    P: AsRef<std::path::Path>,
520    R: tokio::io::AsyncRead + Unpin,
521{
522    let mut a = tokio_tar::Archive::new(reader);
523
524    let mut es = a.entries().unwrap();
525
526    use futures::StreamExt;
527    use tokio::io::AsyncReadExt;
528
529    while let Some(file) = es.next().await {
530        let mut f = file.unwrap();
531        let p = f.path().unwrap();
532        if p.eq(file_name_in_tar.as_ref()) {
533            debug!("found {}", file_name_in_tar.as_ref().to_str().unwrap());
534            let ps = p.to_string_lossy().to_string();
535            let mut result = vec![];
536
537            f.read_to_end(&mut result).await?;
538            return Ok((result, Some(ps)));
539        }
540    }
541    Err(FetchError::NF)
542}
543#[cfg(feature = "tar")]
544pub fn get_file_from_tar_by_reader<P, R>(
545    file_name_in_tar: P,
546    reader: R,
547) -> Result<(Vec<u8>, Option<String>), FetchError>
548where
549    P: AsRef<std::path::Path>,
550    R: std::io::Read,
551{
552    let mut a = tar::Archive::new(reader);
553
554    let mut e = a
555        .entries()
556        .unwrap()
557        .find(|a| {
558            a.as_ref()
559                .is_ok_and(|b| b.path().is_ok_and(|c| c == file_name_in_tar.as_ref()))
560        })
561        .ok_or_else(|| {
562            io::Error::new(
563                io::ErrorKind::NotFound,
564                format!(
565                    "get_file_from_tar: can't find the file, {}",
566                    file_name_in_tar.as_ref().to_str().unwrap()
567                ),
568            )
569        })??;
570
571    debug!("found {}", file_name_in_tar.as_ref().to_str().unwrap());
572
573    let mut result = vec![];
574    use std::io::Read;
575    e.read_to_end(&mut result)?;
576    Ok((
577        result,
578        Some(e.path().unwrap().to_str().unwrap().to_string()),
579    ))
580}
581#[cfg(feature = "tar")]
582pub fn get_file_from_tar_in_memory<P>(
583    file_name_in_tar: P,
584    tar_binary: &Vec<u8>,
585) -> Result<(Vec<u8>, Option<String>), FetchError>
586where
587    P: AsRef<std::path::Path>,
588{
589    debug!(
590        "finding {} from tar, tar whole size is {}",
591        file_name_in_tar.as_ref().to_str().unwrap(),
592        tar_binary.len()
593    );
594    let r = std::io::Cursor::new(tar_binary);
595    get_file_from_tar_by_reader(file_name_in_tar, r)
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601    use std::fs::{self, File};
602    use std::io::Write;
603    use tempfile::TempDir;
604
605    #[cfg(feature = "reqwest")]
606    use reqwest::blocking::Client;
607
608    const URL: &str = "https://www.rust-lang.org";
609
610    #[cfg(feature = "tokio")]
611    #[cfg(feature = "reqwest")]
612    #[tokio::test]
613    async fn test_http_source_fetch_async() {
614        let http_source = HttpSource {
615            url: URL.to_string(),
616            should_use_proxy: false,
617            ..Default::default()
618        };
619
620        let result = http_source.fetch_async().await;
621        assert!(result.is_ok());
622        assert!(!result.unwrap().is_empty());
623    }
624
625    #[cfg(feature = "reqwest")]
626    #[test]
627    fn test_http_source_fetch() {
628        let http_source = HttpSource {
629            url: URL.to_string(),
630            should_use_proxy: false,
631            ..Default::default()
632        };
633
634        let client = Client::new();
635        let result = http_source.get(client);
636        assert!(result.is_ok());
637    }
638
639    #[test]
640    fn test_data_source_read_from_folders() {
641        let temp_dir = TempDir::new().unwrap();
642        let file_path = temp_dir.path().join("test.txt");
643
644        fs::write(&file_path, "hello world").unwrap();
645
646        let data_source = DataSource::Folders(vec![temp_dir.path().to_string_lossy().to_string()]);
647
648        let content = data_source.read_to_string("test.txt").unwrap();
649        assert_eq!(content, "hello world");
650    }
651
652    #[test]
653    fn test_data_source_read_from_file_map() {
654        let file_map = vec![(
655            "config.json".to_string(),
656            SingleFileSource::Inline(b"{\"key\": \"value\"}".to_vec()),
657        )]
658        .into_iter()
659        .collect();
660
661        let data_source = DataSource::FileMap(file_map);
662
663        let content = data_source.read_to_string("config.json").unwrap();
664        assert_eq!(content, "{\"key\": \"value\"}");
665    }
666    use std::path::PathBuf;
667    #[cfg(feature = "tar")]
668    fn gentar() -> (TempDir, PathBuf, &'static str, &'static str) {
669        let temp_dir = TempDir::new().unwrap();
670        let tar_path = temp_dir.path().join("test.tar");
671
672        let mut tar_builder = tar::Builder::new(File::create(&tar_path).unwrap());
673
674        let mut file = tempfile::NamedTempFile::new().unwrap();
675        let c = "hello tar\n";
676        write!(file, "{}", c).unwrap();
677        let file_path = file.path().to_owned();
678
679        let tfn = "test.txt";
680
681        tar_builder.append_path_with_name(&file_path, tfn).unwrap();
682        tar_builder.finish().unwrap();
683
684        (temp_dir, tar_path, tfn, c)
685    }
686
687    #[cfg(feature = "tar")]
688    #[test]
689    fn test_get_file_from_tar() {
690        let (_td, tar_path, tfn, c) = gentar(); // 不能命名为 _,
691                                                // 后面要加长,不然变量会被自动drop掉, Tempdir drop时会自动删除里面的内容
692
693        let tar_data = fs::read(&tar_path).unwrap();
694        let result = get_file_from_tar_in_memory(tfn, &tar_data);
695
696        assert!(result.is_ok());
697        let (content, path) = result.unwrap();
698        assert_eq!(String::from_utf8_lossy(&content), c);
699        assert_eq!(path.unwrap(), tfn);
700    }
701    #[cfg(feature = "tokio-tar")]
702    #[tokio::test]
703    async fn test_get_file_from_tar_async() -> Result<(), FetchError> {
704        let (_td, tar_path, tfn, c) = gentar();
705
706        let f = tokio::fs::File::open(tar_path).await?;
707        let result = get_file_from_tar_by_reader_async(tfn, f).await?;
708
709        let (content, path) = result;
710        assert_eq!(String::from_utf8_lossy(&content), c);
711        assert_eq!(path.unwrap(), tfn);
712        Ok(())
713    }
714}