data_source/
lib.rs

1#[cfg(feature = "file_server")]
2pub mod file_server;
3
4use std::{collections::HashMap, io, path::Path, time::SystemTime};
5
6use log::{debug, warn};
7
8#[derive(thiserror::Error, Debug)]
9pub enum FetchError {
10    #[cfg(feature = "reqwest")]
11    #[error("reqwest err")]
12    R(#[from] reqwest::Error),
13    #[error("io err")]
14    I(#[from] io::Error),
15    #[error("time err")]
16    T(#[from] std::time::SystemTimeError),
17    #[error("size limit exceed")]
18    S,
19    #[error("no cache file")]
20    NC,
21    #[error("not found")]
22    NF,
23    #[error("not found in directories `{0:?}`")]
24    NFD(Vec<String>),
25}
26
27impl From<FetchError> for io::Error {
28    fn from(value: FetchError) -> Self {
29        match value {
30            FetchError::R(error) => io::Error::other(error),
31            FetchError::I(error) => error,
32            FetchError::T(error) => io::Error::other(error),
33            FetchError::S => io::Error::other(value.to_string()),
34            FetchError::NC => io::Error::other(value.to_string()),
35            FetchError::NF => io::Error::new(io::ErrorKind::NotFound, ""),
36            FetchError::NFD(_) => io::Error::other(value.to_string()),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct FileCache {
43    pub update_interval_seconds: Option<u64>,
44    pub cache_file_path: Option<String>,
45}
46
47impl FileCache {
48    pub fn read_cache_file(&self) -> Result<Vec<u8>, FetchError> {
49        let cf = self.cache_file_path.as_ref().unwrap();
50        let s: Vec<u8> = std::fs::read(cf)?;
51        Ok(s)
52    }
53
54    #[cfg(feature = "tokio")]
55    pub async fn read_cache_file_async(&self) -> Result<Vec<u8>, FetchError> {
56        let cf = self.cache_file_path.as_ref().unwrap();
57
58        let content = tokio::fs::read(cf).await?;
59        Ok(content)
60    }
61
62    pub fn write_cache_file(&self, bytes: &[u8]) -> bool {
63        let cf = self.cache_file_path.as_ref().unwrap();
64        if let Err(err) = std::fs::write(cf, bytes) {
65            warn!("Failed to write cache file: {err}");
66            false
67        } else {
68            true
69        }
70    }
71
72    #[cfg(feature = "tokio")]
73    pub async fn write_cache_file_async(&self, bytes: &[u8]) -> bool {
74        let cf = self.cache_file_path.as_ref().unwrap();
75        if let Err(err) = tokio::fs::write(cf, bytes).await {
76            warn!("Failed to write cache file: {err}");
77            false
78        } else {
79            true
80        }
81    }
82
83    /// 检查缓存文件是否超时
84    pub fn is_cache_timeout(&self) -> Result<Option<bool>, FetchError> {
85        if let Some(cf) = &self.cache_file_path {
86            if std::fs::exists(cf)? {
87                let mut expired = false;
88                if let Some(interval) = self.update_interval_seconds {
89                    let metadata = std::fs::metadata(cf)?;
90                    let last_modified = metadata.modified()?;
91                    let elapsed = SystemTime::now().duration_since(last_modified)?.as_secs();
92                    expired = elapsed > interval;
93                }
94                return Ok(Some(expired));
95            }
96            Ok(None)
97        } else {
98            Ok(None)
99        }
100    }
101}
102
103#[cfg(feature = "tokio")]
104#[async_trait::async_trait]
105pub trait AsyncSource: Send + Sync {
106    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError>;
107}
108
109pub trait SyncSource {
110    fn fetch(&self) -> Result<Vec<u8>, FetchError>;
111}
112
113#[cfg(feature = "tokio")]
114pub async fn fetch_with_cache_async(
115    fc: &FileCache,
116    s: &dyn AsyncSource,
117) -> Result<Vec<u8>, FetchError> {
118    if fc.is_cache_timeout()?.is_some_and(|timeout| !timeout) {
119        fc.read_cache_file_async().await
120    } else {
121        let d = s.fetch_async().await?;
122        if fc.cache_file_path.is_some() {
123            fc.write_cache_file_async(&d).await;
124        }
125        Ok(d)
126    }
127}
128pub fn fetch_with_cache(fc: &FileCache, s: &dyn SyncSource) -> Result<Vec<u8>, FetchError> {
129    if fc.is_cache_timeout()?.is_some_and(|timeout| !timeout) {
130        fc.read_cache_file()
131    } else {
132        let d = s.fetch()?;
133        if fc.cache_file_path.is_some() {
134            fc.write_cache_file(&d);
135        }
136        Ok(d)
137    }
138}
139
140#[cfg(feature = "tokio")]
141#[async_trait::async_trait]
142pub trait AsyncFolderSource: std::fmt::Debug {
143    async fn get_file_content_async(
144        &self,
145        file_name: &std::path::Path,
146    ) -> Result<(Vec<u8>, Option<String>), FetchError>;
147}
148
149pub trait SyncFolderSource: std::fmt::Debug {
150    fn get_file_content(
151        &self,
152        file_name: &std::path::Path,
153    ) -> Result<(Vec<u8>, Option<String>), FetchError>;
154}
155
156#[cfg(feature = "tar")]
157#[derive(Clone, Debug, Default)]
158pub struct TarFile(pub String);
159
160#[cfg(feature = "tar")]
161impl SyncFolderSource for TarFile {
162    fn get_file_content(
163        &self,
164        file_name: &std::path::Path,
165    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
166        let f = std::fs::File::open(&self.0)?;
167        get_file_from_tar_by_reader(file_name, f)
168    }
169}
170#[cfg(feature = "tokio-tar")]
171#[async_trait::async_trait]
172impl AsyncFolderSource for TarFile {
173    async fn get_file_content_async(
174        &self,
175        file_name: &std::path::Path,
176    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
177        let f = tokio::fs::File::open(&self.0).await?;
178        get_file_from_tar_by_reader_async(file_name, f).await
179    }
180}
181
182#[cfg(feature = "reqwest")]
183#[derive(Clone, Debug, Default)]
184pub struct HttpSource {
185    pub url: String,
186    pub proxy: Option<String>,
187    pub custom_request_headers: Option<Vec<(String, String)>>,
188    pub should_use_proxy: bool,
189    pub size_limit_bytes: Option<usize>,
190}
191
192#[cfg(feature = "reqwest")]
193impl HttpSource {
194    pub fn get(
195        &self,
196        c: reqwest::blocking::Client,
197    ) -> reqwest::Result<reqwest::blocking::Response> {
198        let mut rb = c.get(&self.url);
199        if let Some(h) = &self.custom_request_headers {
200            for h in h.iter() {
201                rb = rb.header(&h.0, &h.1);
202            }
203        }
204        rb.send()
205    }
206    pub fn set_proxy(
207        &self,
208        mut cb: reqwest::blocking::ClientBuilder,
209    ) -> reqwest::Result<reqwest::blocking::ClientBuilder> {
210        let ps = self.proxy.as_ref().unwrap();
211        let proxy = reqwest::Proxy::https(ps)?;
212        cb = cb.proxy(proxy);
213        let proxy = reqwest::Proxy::http(ps)?;
214        Ok(cb.proxy(proxy))
215    }
216}
217#[cfg(feature = "reqwest")]
218impl SyncSource for HttpSource {
219    fn fetch(&self) -> Result<Vec<u8>, FetchError> {
220        let mut cb = reqwest::blocking::ClientBuilder::new();
221        if self.should_use_proxy {
222            cb = self.set_proxy(cb)?;
223        }
224        let c = cb.build()?;
225        let r = self.get(c);
226        let r = match r {
227            Ok(r) => r,
228            Err(e) => {
229                if !self.should_use_proxy && self.proxy.is_some() {
230                    let mut cb = reqwest::blocking::ClientBuilder::new();
231                    cb = self.set_proxy(cb)?;
232                    let c = cb.build()?;
233                    self.get(c)?
234                } else {
235                    return Err(FetchError::R(e));
236                }
237            }
238        };
239        if let Some(sl) = self.size_limit_bytes {
240            if let Some(s) = r.content_length() {
241                if s as usize > sl {
242                    return Err(FetchError::S);
243                }
244            }
245        }
246        let b = r.bytes()?;
247        let v = b.to_vec();
248
249        Ok(v)
250    }
251}
252
253#[cfg(feature = "tokio")]
254#[cfg(feature = "reqwest")]
255impl HttpSource {
256    pub async fn get_async(&self, client: reqwest::Client) -> reqwest::Result<reqwest::Response> {
257        let mut request = client.get(&self.url);
258        if let Some(headers) = &self.custom_request_headers {
259            for (key, value) in headers {
260                request = request.header(key, value);
261            }
262        }
263        request.send().await
264    }
265
266    pub fn set_proxy_async(
267        &self,
268        client_builder: reqwest::ClientBuilder,
269    ) -> reqwest::Result<reqwest::ClientBuilder> {
270        let proxy = self.proxy.as_ref().unwrap();
271        let client_builder = client_builder.proxy(reqwest::Proxy::http(proxy)?);
272        let client_builder = client_builder.proxy(reqwest::Proxy::https(proxy)?);
273        Ok(client_builder)
274    }
275}
276
277#[cfg(feature = "tokio")]
278#[cfg(feature = "reqwest")]
279#[async_trait::async_trait]
280impl AsyncSource for HttpSource {
281    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError> {
282        let client_builder = reqwest::ClientBuilder::new();
283        let client_builder = if self.should_use_proxy {
284            self.set_proxy_async(client_builder)?
285        } else {
286            client_builder
287        };
288        let client = client_builder.build()?;
289
290        let r = self.get_async(client).await;
291        let response = match r {
292            Ok(r) => r,
293            Err(e) => {
294                if !self.should_use_proxy && self.proxy.is_some() {
295                    let mut cb = reqwest::ClientBuilder::new();
296                    cb = self.set_proxy_async(cb)?;
297                    let c = cb.build()?;
298                    self.get_async(c).await?
299                } else {
300                    return Err(FetchError::R(e));
301                }
302            }
303        };
304        if let Some(size_limit) = self.size_limit_bytes {
305            if let Some(content_length) = response.content_length() {
306                if content_length as usize > size_limit {
307                    return Err(FetchError::S);
308                }
309            }
310        }
311
312        let bytes = response.bytes().await?.to_vec();
313
314        Ok(bytes)
315    }
316}
317
318pub trait GetPath {
319    fn get_path(&self) -> Option<String> {
320        None
321    }
322}
323
324#[derive(Debug)]
325pub enum SingleFileSource {
326    #[cfg(feature = "reqwest")]
327    Http(HttpSource, FileCache),
328    FilePath(String),
329    Inline(Vec<u8>),
330}
331impl Default for SingleFileSource {
332    fn default() -> Self {
333        Self::Inline(Vec::new())
334    }
335}
336
337impl GetPath for SingleFileSource {
338    fn get_path(&self) -> Option<String> {
339        match self {
340            #[cfg(feature = "reqwest")]
341            SingleFileSource::Http(http_source, _fc) => Some(http_source.url.clone()),
342            SingleFileSource::FilePath(p) => Some(p.clone()),
343            SingleFileSource::Inline(_ec) => None,
344        }
345    }
346}
347
348#[cfg(feature = "tokio")]
349#[async_trait::async_trait]
350impl AsyncSource for SingleFileSource {
351    async fn fetch_async(&self) -> Result<Vec<u8>, FetchError> {
352        match self {
353            #[cfg(feature = "reqwest")]
354            SingleFileSource::Http(http_source, fc) => {
355                fetch_with_cache_async(fc, http_source).await
356            }
357            SingleFileSource::FilePath(f) => {
358                let s: Vec<u8> = tokio::fs::read(f).await?;
359                Ok(s)
360            }
361            SingleFileSource::Inline(v) => Ok(v.clone()),
362        }
363    }
364}
365
366impl SyncSource for SingleFileSource {
367    fn fetch(&self) -> Result<Vec<u8>, FetchError> {
368        match self {
369            #[cfg(feature = "reqwest")]
370            SingleFileSource::Http(http_source, fc) => fetch_with_cache(fc, http_source),
371            SingleFileSource::FilePath(f) => {
372                let s: Vec<u8> = std::fs::read(f)?;
373                Ok(s)
374            }
375            SingleFileSource::Inline(v) => Ok(v.clone()),
376        }
377    }
378}
379
380/// Defines where to get the content of the requested file name.
381///
382/// 很多配置中 都要再加载其他外部文件,
383/// FileSource 限定了 查找文件的 路径 和 来源, 读取文件时只会限制在这个范围内,
384/// 这样就增加了安全性
385#[derive(Debug, Default)]
386pub enum DataSource {
387    #[default]
388    StdReadFile,
389    ///从指定的一组路径来寻找文件
390    Folders(Vec<String>),
391    /// 从一个 已放到内存中的 tar 中 寻找文件
392    #[cfg(feature = "tar")]
393    TarInMemory(Vec<u8>),
394    #[cfg(feature = "tar")]
395    TarFile(TarFile),
396
397    /// 与其它方式不同,FileMap 存储名称的映射表, 无需遍历目录
398    FileMap(HashMap<String, SingleFileSource>),
399
400    Sync(Box<dyn SyncFolderSource + Send + Sync>),
401    #[cfg(feature = "tokio")]
402    Async(Box<dyn AsyncFolderSource + Send + Sync>),
403}
404
405impl DataSource {
406    pub fn insert_current_working_dir(&mut self) -> io::Result<()> {
407        if let DataSource::Folders(ref mut v) = self {
408            v.push(std::env::current_dir()?.to_string_lossy().to_string())
409        }
410        Ok(())
411    }
412
413    pub fn read_to_string<P>(&self, file_name: P) -> Result<String, FetchError>
414    where
415        P: AsRef<std::path::Path>,
416    {
417        let r = SyncFolderSource::get_file_content(self, file_name.as_ref())?;
418        Ok(String::from_utf8_lossy(r.0.as_slice()).to_string())
419    }
420}
421#[cfg(feature = "tokio")]
422#[async_trait::async_trait]
423impl AsyncFolderSource for DataSource {
424    /// 返回读到的 数据。可能还会返回 成功找到的路径
425    async fn get_file_content_async(
426        &self,
427        file_name: &Path,
428    ) -> Result<(Vec<u8>, Option<String>), FetchError> {
429        match self {
430            DataSource::Async(source) => source.get_file_content_async(file_name).await,
431
432            DataSource::Sync(source) => source.get_file_content(file_name),
433            #[cfg(feature = "tar")]
434            DataSource::TarInMemory(tar_binary) => {
435                get_file_from_tar_in_memory(file_name, tar_binary)
436            }
437            #[cfg(feature = "tokio-tar")]
438            DataSource::TarFile(tf) => tf.get_file_content_async(file_name).await,
439
440            DataSource::Folders(possible_addrs) => {
441                for dir in possible_addrs {
442                    let real_file_name = std::path::Path::new(dir).join(file_name);
443
444                    if real_file_name.exists() {
445                        return Ok(tokio::fs::read(&real_file_name)
446                            .await
447                            .map(|v| (v, Some(dir.to_owned())))?);
448                    }
449                }
450                Err(FetchError::NFD(possible_addrs.clone()))
451            }
452            DataSource::StdReadFile => {
453                let s: Vec<u8> = tokio::fs::read(file_name).await?;
454                Ok((s, None))
455            }
456
457            DataSource::FileMap(map) => {
458                let r = map.get(&file_name.to_string_lossy().to_string());
459
460                match r {
461                    Some(sf) => sf.fetch_async().await.map(|d| (d, sf.get_path())),
462                    None => Err(FetchError::NF),
463                }
464            }
465        }
466    }
467}
468
469impl SyncFolderSource for DataSource {
470    /// 返回读到的 数据。可能还会返回 成功找到的路径
471    fn get_file_content(&self, file_name: &Path) -> Result<(Vec<u8>, Option<String>), FetchError> {
472        match self {
473            DataSource::Sync(source) => source.get_file_content(file_name),
474
475            #[cfg(feature = "tokio")]
476            DataSource::Async(source) => {
477                tokio::runtime::Handle::current().block_on(source.get_file_content_async(file_name))
478            }
479
480            #[cfg(feature = "tar")]
481            DataSource::TarInMemory(tar_binary) => {
482                get_file_from_tar_in_memory(file_name, tar_binary)
483            }
484            #[cfg(feature = "tar")]
485            DataSource::TarFile(tf) => tf.get_file_content(file_name),
486
487            DataSource::Folders(possible_addrs) => {
488                for dir in possible_addrs {
489                    let real_file_name = std::path::Path::new(dir).join(file_name);
490
491                    if real_file_name.exists() {
492                        return Ok(
493                            std::fs::read(&real_file_name).map(|v| (v, Some(dir.to_owned())))?
494                        );
495                    }
496                }
497                Err(FetchError::NFD(possible_addrs.clone()))
498            }
499            DataSource::StdReadFile => {
500                let s: Vec<u8> = std::fs::read(file_name)?;
501                Ok((s, None))
502            }
503
504            DataSource::FileMap(map) => {
505                let r = map.get(&file_name.to_string_lossy().to_string());
506
507                match r {
508                    Some(sf) => sf.fetch().map(|d| (d, sf.get_path())),
509                    None => Err(FetchError::NF),
510                }
511            }
512        }
513    }
514}
515
516#[cfg(feature = "tokio-tar")]
517pub async fn get_file_from_tar_by_reader_async<P, R>(
518    file_name_in_tar: P,
519    reader: R,
520) -> Result<(Vec<u8>, Option<String>), FetchError>
521where
522    P: AsRef<std::path::Path>,
523    R: tokio::io::AsyncRead + Unpin,
524{
525    let mut a = tokio_tar::Archive::new(reader);
526
527    let mut es = a.entries().unwrap();
528
529    use futures::StreamExt;
530    use tokio::io::AsyncReadExt;
531
532    while let Some(file) = es.next().await {
533        let mut f = file.unwrap();
534        let p = f.path().unwrap();
535        if p.eq(file_name_in_tar.as_ref()) {
536            debug!("found {}", file_name_in_tar.as_ref().to_str().unwrap());
537            let ps = p.to_string_lossy().to_string();
538            let mut result = vec![];
539
540            f.read_to_end(&mut result).await?;
541            return Ok((result, Some(ps)));
542        }
543    }
544    Err(FetchError::NF)
545}
546#[cfg(feature = "tar")]
547pub fn get_file_from_tar_by_reader<P, R>(
548    file_name_in_tar: P,
549    reader: R,
550) -> Result<(Vec<u8>, Option<String>), FetchError>
551where
552    P: AsRef<std::path::Path>,
553    R: std::io::Read,
554{
555    let mut a = tar::Archive::new(reader);
556
557    let mut e = a
558        .entries()
559        .unwrap()
560        .find(|a| {
561            a.as_ref()
562                .is_ok_and(|b| b.path().is_ok_and(|c| c == file_name_in_tar.as_ref()))
563        })
564        .ok_or_else(|| {
565            io::Error::new(
566                io::ErrorKind::NotFound,
567                format!(
568                    "get_file_from_tar: can't find the file, {}",
569                    file_name_in_tar.as_ref().to_str().unwrap()
570                ),
571            )
572        })??;
573
574    debug!("found {}", file_name_in_tar.as_ref().to_str().unwrap());
575
576    let mut result = vec![];
577    use std::io::Read;
578    e.read_to_end(&mut result)?;
579    Ok((
580        result,
581        Some(e.path().unwrap().to_str().unwrap().to_string()),
582    ))
583}
584#[cfg(feature = "tar")]
585pub fn get_file_from_tar_in_memory<P>(
586    file_name_in_tar: P,
587    tar_binary: &Vec<u8>,
588) -> Result<(Vec<u8>, Option<String>), FetchError>
589where
590    P: AsRef<std::path::Path>,
591{
592    debug!(
593        "finding {} from tar, tar whole size is {}",
594        file_name_in_tar.as_ref().to_str().unwrap(),
595        tar_binary.len()
596    );
597    let r = std::io::Cursor::new(tar_binary);
598    get_file_from_tar_by_reader(file_name_in_tar, r)
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use std::fs::{self, File};
605    use std::io::Write;
606    use tempfile::TempDir;
607
608    #[cfg(feature = "reqwest")]
609    use reqwest::blocking::Client;
610
611    const URL: &str = "https://www.rust-lang.org";
612
613    #[cfg(feature = "tokio")]
614    #[cfg(feature = "reqwest")]
615    #[tokio::test]
616    async fn test_http_source_fetch_async() {
617        let http_source = HttpSource {
618            url: URL.to_string(),
619            should_use_proxy: false,
620            ..Default::default()
621        };
622
623        let result = http_source.fetch_async().await;
624        assert!(result.is_ok());
625        assert!(!result.unwrap().is_empty());
626    }
627
628    #[cfg(feature = "reqwest")]
629    #[test]
630    fn test_http_source_fetch() {
631        let http_source = HttpSource {
632            url: URL.to_string(),
633            should_use_proxy: false,
634            ..Default::default()
635        };
636
637        let client = Client::new();
638        let result = http_source.get(client);
639        assert!(result.is_ok());
640    }
641
642    #[test]
643    fn test_data_source_read_from_folders() {
644        let temp_dir = TempDir::new().unwrap();
645        let file_path = temp_dir.path().join("test.txt");
646
647        fs::write(&file_path, "hello world").unwrap();
648
649        let data_source = DataSource::Folders(vec![temp_dir.path().to_string_lossy().to_string()]);
650
651        let content = data_source.read_to_string("test.txt").unwrap();
652        assert_eq!(content, "hello world");
653    }
654
655    #[test]
656    fn test_data_source_read_from_file_map() {
657        let file_map = vec![(
658            "config.json".to_string(),
659            SingleFileSource::Inline(b"{\"key\": \"value\"}".to_vec()),
660        )]
661        .into_iter()
662        .collect();
663
664        let data_source = DataSource::FileMap(file_map);
665
666        let content = data_source.read_to_string("config.json").unwrap();
667        assert_eq!(content, "{\"key\": \"value\"}");
668    }
669    use std::path::PathBuf;
670    #[cfg(feature = "tar")]
671    fn gentar() -> (TempDir, PathBuf, &'static str, &'static str) {
672        let temp_dir = TempDir::new().unwrap();
673        let tar_path = temp_dir.path().join("test.tar");
674
675        let mut tar_builder = tar::Builder::new(File::create(&tar_path).unwrap());
676
677        let mut file = tempfile::NamedTempFile::new().unwrap();
678        let c = "hello tar\n";
679        write!(file, "{}", c).unwrap();
680        let file_path = file.path().to_owned();
681
682        let tfn = "test.txt";
683
684        tar_builder.append_path_with_name(&file_path, tfn).unwrap();
685        tar_builder.finish().unwrap();
686
687        (temp_dir, tar_path, tfn, c)
688    }
689
690    #[cfg(feature = "tar")]
691    #[test]
692    fn test_get_file_from_tar() {
693        let (_td, tar_path, tfn, c) = gentar(); // 不能命名为 _,
694                                                // 后面要加长,不然变量会被自动drop掉, Tempdir drop时会自动删除里面的内容
695
696        let tar_data = fs::read(&tar_path).unwrap();
697        let result = get_file_from_tar_in_memory(tfn, &tar_data);
698
699        assert!(result.is_ok());
700        let (content, path) = result.unwrap();
701        assert_eq!(String::from_utf8_lossy(&content), c);
702        assert_eq!(path.unwrap(), tfn);
703    }
704    #[cfg(feature = "tokio-tar")]
705    #[tokio::test]
706    async fn test_get_file_from_tar_async() -> Result<(), FetchError> {
707        let (_td, tar_path, tfn, c) = gentar();
708
709        let f = tokio::fs::File::open(tar_path).await?;
710        let result = get_file_from_tar_by_reader_async(tfn, f).await?;
711
712        let (content, path) = result;
713        assert_eq!(String::from_utf8_lossy(&content), c);
714        assert_eq!(path.unwrap(), tfn);
715        Ok(())
716    }
717}