url_bot_rs/
config.rs

1/*
2 * Application configuration
3 *
4 */
5use std::{
6    fs::{self, File},
7    io::Write,
8    path::{Path, PathBuf},
9    collections::BTreeMap,
10};
11use irc::client::data::Config as IrcConfig;
12use failure::{Error, bail};
13use directories::{BaseDirs, ProjectDirs};
14use serde_derive::{Serialize, Deserialize};
15use log::info;
16
17use crate::{
18    VERSION,
19    plugins::PluginConfig,
20    http::{Retriever, RetrieverBuilder},
21};
22
23#[derive(Serialize, Deserialize, Clone)]
24pub struct Network {
25    pub name: String,
26    pub enable: bool,
27}
28
29impl Default for Network {
30    fn default() -> Self {
31        Self {
32            name: "default".into(),
33            enable: true,
34        }
35    }
36}
37
38#[derive(Default, Serialize, Deserialize, Clone)]
39#[serde(default)]
40pub struct Features {
41    pub report_metadata: bool,
42    pub report_mime: bool,
43    pub mask_highlights: bool,
44    pub send_notice: bool,
45    pub history: bool,
46    pub cross_channel_history: bool,
47    pub invite: bool,
48    pub autosave: bool,
49    pub send_errors_to_poster: bool,
50    pub reply_with_errors: bool,
51    pub partial_urls: bool,
52    pub nick_response: bool,
53    pub reconnect: bool,
54}
55
56#[macro_export]
57macro_rules! feat {
58    ($rtd:expr, $name:ident) => {
59        $rtd.conf.features.$name
60    };
61}
62
63
64#[derive(Serialize, Deserialize, Clone, PartialEq)]
65#[serde(rename_all = "kebab-case")]
66pub enum DbType {
67    InMemory,
68    Sqlite,
69}
70
71impl Default for DbType {
72    fn default() -> Self {
73        Self::InMemory
74    }
75}
76
77#[derive(Serialize, Deserialize, Default, Clone)]
78#[serde(default)]
79pub struct Database {
80    #[serde(rename = "type")]
81    pub db_type: DbType,
82    pub path: Option<String>,
83}
84
85#[derive(Serialize, Deserialize, Clone)]
86#[serde(default)]
87pub struct Parameters {
88    pub url_limit: u8,
89    pub status_channels: Vec<String>,
90    pub nick_response_str: String,
91    pub reconnect_timeout: u64,
92    pub ignore_nicks: Vec<String>,
93}
94
95impl Default for Parameters {
96    fn default() -> Self {
97        Self {
98            url_limit: 10,
99            status_channels: vec![],
100            nick_response_str: "".to_string(),
101            reconnect_timeout: 10,
102            ignore_nicks: vec![],
103        }
104    }
105}
106
107#[macro_export]
108macro_rules! param {
109    ($rtd:expr, $name:ident) => {
110        $rtd.conf.params.$name
111    };
112}
113
114#[derive(Serialize, Deserialize, Clone)]
115#[serde(default)]
116pub struct Http {
117    pub timeout_s: u64,
118    pub max_redirections: u8,
119    pub max_retries: u8,
120    pub retry_delay_s: u64,
121    pub accept_lang: String,
122    pub user_agent: Option<String>,
123}
124
125impl Default for Http {
126    fn default() -> Self {
127        Self {
128            timeout_s: 10,
129            max_redirections: 10,
130            max_retries: 3,
131            retry_delay_s: 5,
132            accept_lang: "en".to_string(),
133            user_agent: None,
134        }
135    }
136}
137
138#[macro_export]
139macro_rules! http {
140    ($rtd:expr, $name:ident) => {
141        $rtd.conf.http_params.$name
142    };
143}
144
145#[derive(Serialize, Deserialize, Clone)]
146pub struct Conf {
147    #[serde(default)]
148    pub plugins: PluginConfig,
149    #[serde(default)]
150    pub network: Network,
151    #[serde(default)]
152    pub features: Features,
153    #[serde(default, rename = "parameters")]
154    pub params: Parameters,
155    #[serde(default, rename = "http")]
156    pub http_params: Http,
157    #[serde(default)]
158    pub database: Database,
159    #[serde(rename = "connection")]
160    pub client: IrcConfig,
161    #[serde(skip)]
162    pub path: Option<PathBuf>,
163}
164
165impl Conf {
166    /// load configuration TOML from a file
167    pub fn load(path: impl AsRef<Path>) -> Result<Self, Error> {
168        let conf = fs::read_to_string(path.as_ref())?;
169        let mut conf: Conf = toml::de::from_str(&conf)?;
170        // insert the path the config was loaded from
171        conf.path = Some(path.as_ref().to_path_buf());
172        Ok(conf)
173    }
174
175    /// write configuration to a file
176    pub fn write(&self, path: impl AsRef<Path>) -> Result<(), Error> {
177        let mut file = File::create(path)?;
178        file.write_all(toml::ser::to_string(&self)?.as_bytes())?;
179        Ok(())
180    }
181
182    /// add an IRC channel to the list of channels in the configuration
183    pub fn add_channel(&mut self, name: String) {
184        if let Some(ref mut c) = self.client.channels {
185            if !c.contains(&name) {
186                c.push(name);
187            }
188        }
189    }
190
191    /// remove an IRC channel from the list of channels in the configuration
192    pub fn remove_channel(&mut self, name: &str) {
193        if let Some(ref mut c) = self.client.channels {
194            if let Some(index) = c.iter().position(|c| c == name) {
195                c.remove(index);
196            }
197        }
198    }
199}
200
201impl Default for Conf {
202    fn default() -> Self {
203        Self {
204            plugins: PluginConfig::default(),
205            network: Network::default(),
206            features: Features::default(),
207            params: Parameters::default(),
208            http_params: Http::default(),
209            database: Database::default(),
210            client: IrcConfig {
211                nickname: Some("url-bot-rs".to_string()),
212                alt_nicks: Some(vec!["url-bot-rs_".to_string()]),
213                nick_password: Some("".to_string()),
214                username: Some("url-bot-rs".to_string()),
215                realname: Some("url-bot-rs".to_string()),
216                server: Some("127.0.0.1".to_string()),
217                port: Some(6667),
218                password: Some("".to_string()),
219                use_ssl: Some(false),
220                channels: Some(vec!["#url-bot-rs".to_string()]),
221                user_info: Some("Feed me URLs.".to_string()),
222                ..IrcConfig::default()
223            },
224            path: None,
225        }
226    }
227}
228
229#[derive(Serialize, Deserialize, Clone)]
230pub struct ConfSet {
231    #[serde(flatten)]
232    pub configs: BTreeMap<String, Conf>,
233}
234
235impl ConfSet {
236    /// load configuration TOML from a file
237    pub fn load(path: impl AsRef<Path>) -> Result<Self, Error> {
238        let conf_string = fs::read_to_string(path.as_ref())?;
239        let mut conf_set: ConfSet = toml::de::from_str(&conf_string)?;
240
241        // populate path field of all configs
242        conf_set.configs
243            .iter_mut()
244            .for_each(|(_, c)| c.path = Some(path.as_ref().to_path_buf()));
245
246        Ok(conf_set)
247    }
248
249    /// write configuration to a file
250    pub fn write(&self, path: impl AsRef<Path>) -> Result<(), Error> {
251        let mut file = File::create(path)?;
252        file.write_all(toml::ser::to_string(&self)?.as_bytes())?;
253        Ok(())
254    }
255}
256
257/// Run-time configuration data.
258#[derive(Default, Clone)]
259pub struct Rtd {
260    /// paths
261    pub paths: Paths,
262    /// configuration file data
263    pub conf: Conf,
264    /// HTTP client
265    client: Option<Retriever>,
266}
267
268#[derive(Default, Clone)]
269pub struct Paths {
270    pub db: Option<PathBuf>,
271}
272
273impl Rtd {
274    pub fn new() -> Self {
275        Rtd::default()
276    }
277
278    /// Set the configuration
279    pub fn conf(mut self, c: Conf) -> Self {
280        self.conf = c;
281        self
282    }
283
284    pub fn db(mut self, path: Option<&PathBuf>) -> Self {
285        self.paths.db = path.map(|p| expand_tilde(p));
286        self
287    }
288
289    pub fn init_http_client(mut self) -> Result<Self, Error> {
290        let conf = &self.conf.http_params;
291
292        let mut builder = RetrieverBuilder::new()
293            .retry(conf.max_retries.into(), conf.retry_delay_s)
294            .timeout(conf.timeout_s)
295            .accept_lang(&conf.accept_lang)
296            .redirect_limit(conf.max_redirections.into());
297
298        if let Some(ref user_agent) = conf.user_agent {
299            builder = builder.user_agent(user_agent);
300        };
301
302        self.client = Some(builder.build()?);
303
304        Ok(self)
305    }
306
307    pub fn get_client(&self) -> Result<&Retriever, Error> {
308        let client = match self.client.as_ref() {
309            None => bail!("HTTP client not initialised"),
310            Some(c) => c,
311        };
312
313        Ok(client)
314    }
315
316    /// Load the configuration file and return an Rtd.
317    pub fn load(mut self) -> Result<Self, Error> {
318        // get a database path
319        self.paths.db = self.get_db_path().map(|p| expand_tilde(&p));
320
321        if let Some(dp) = &self.paths.db {
322            ensure_parent_dir(dp)?;
323        }
324
325        // set url-bot-rs version number in the irc client configuration
326        self.conf.client.version = Some(VERSION.to_string());
327
328        Ok(self)
329    }
330
331    fn get_db_path(&mut self) -> Option<PathBuf> {
332        if self.conf.features.history {
333            match self.conf.database.db_type {
334                DbType::InMemory => None,
335                DbType::Sqlite => self.get_sqlite_path(),
336            }
337        } else {
338            None
339        }
340    }
341
342    fn get_sqlite_path(&self) -> Option<PathBuf> {
343        let mut path = self.conf.database.path.as_ref()
344            .filter(|p| !p.is_empty())
345            .map(PathBuf::from);
346
347        if self.paths.db.is_some() && path.is_none() {
348            path = self.paths.db.clone();
349        };
350
351        if path.is_none() {
352            // generate and use a default database path
353            let dirs = ProjectDirs::from("org", "", "url-bot-rs").unwrap();
354            let db = format!("history.{}.db", self.conf.network.name);
355            let db = dirs.data_local_dir().join(&db);
356            path = Some(db);
357        };
358
359        path
360    }
361}
362
363pub fn ensure_parent_dir(file: &Path) -> Result<bool, Error> {
364    let without_path = file.components().count() == 1;
365
366    match file.parent() {
367        Some(dir) if !without_path => {
368            let create = !dir.exists();
369            if create {
370                info!(
371                    "directory `{}` doesn't exist, creating it", dir.display()
372                );
373                fs::create_dir_all(dir)?;
374            }
375            Ok(create)
376        },
377        _ => Ok(false),
378    }
379}
380
381fn expand_tilde(path: &Path) -> PathBuf {
382    match (BaseDirs::new(), path.strip_prefix("~")) {
383        (Some(bd), Ok(stripped)) => bd.home_dir().join(stripped),
384        _ => path.to_owned(),
385    }
386}
387
388/// non-recursively search for configuration files in a directory
389pub fn find_configs_in_dir(dir: &Path) -> Result<impl Iterator<Item = PathBuf>, Error> {
390    Ok(fs::read_dir(dir)?
391        .flatten()
392        .map(|e| e.path())
393        .filter(|e| !e.is_dir() && (Conf::load(e).is_ok() || ConfSet::load(e).is_ok()))
394        .take(32))
395}
396
397/// Take a vector of paths to either configurations, or configuration sets,
398/// and return a vector of configurations
399pub fn load_flattened_configs(paths: Vec<PathBuf>) -> Vec<Conf> {
400    let mut configs: Vec<Conf> = paths.iter()
401        .filter_map(|p| Conf::load(p).ok())
402        .collect();
403
404    let mut set_configs: Vec<Conf> = paths.into_iter()
405        .filter_map(|p| ConfSet::load(p).ok())
406        .flat_map(|s| s.configs.values().cloned().collect::<Vec<Conf>>())
407        .collect();
408
409    configs.append(&mut set_configs);
410    configs
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use tempfile::tempdir;
417    use std::env;
418    use std::iter;
419    use std::panic;
420
421    #[test]
422    /// test that the example configuration file parses without error
423    fn load_example_configs() {
424        Conf::load(&PathBuf::from("example.config.toml")).unwrap();
425        ConfSet::load(&PathBuf::from("example.multi.config.toml")).unwrap();
426    }
427
428    #[test]
429    fn load_write_default() {
430        let tmp_dir = tempdir().unwrap();
431        let cfg_path = tmp_dir.path().join("config.toml");
432
433        let conf = Conf::default();
434        conf.write(&cfg_path).unwrap();
435
436        let example = fs::read_to_string("example.config.toml").unwrap();
437        let written = fs::read_to_string(cfg_path).unwrap();
438
439        example.lines()
440            .zip(written.lines())
441            .for_each(|(a, b)| assert_eq!(a, b));
442    }
443
444    fn get_test_confset() -> ConfSet {
445        let mut confset = ConfSet {
446            configs: BTreeMap::new()
447        };
448
449        let mut conf = Conf::default();
450        conf.network.name = String::from("foo");
451        confset.configs.insert("foo".to_string(), conf.clone());
452        conf.network.name = String::from("bar");
453        confset.configs.insert("bar".to_string(), conf);
454
455        confset
456    }
457
458    #[test]
459    fn load_write_default_set() {
460        let tmp_dir = tempdir().unwrap();
461        let cfg_path = tmp_dir.path().join("config.toml");
462
463        let conf = get_test_confset();
464        conf.write(&cfg_path).unwrap();
465
466        let example = fs::read_to_string("example.multi.config.toml").unwrap();
467        let written = fs::read_to_string(cfg_path).unwrap();
468
469        example.lines()
470            .zip(written.lines())
471            .for_each(|(a, b)| assert_eq!(a, b));
472    }
473
474    #[test]
475    fn sqlite_path_explicit() {
476        let tmp_dir = tempdir().unwrap();
477        let cfg_path = tmp_dir.path().join("config.toml");
478        let db_path = tmp_dir.path().join("test.db");
479
480        let mut cfg = Conf::default();
481        cfg.features.history = true;
482        cfg.database.db_type = DbType::Sqlite;
483        cfg.write(&cfg_path).unwrap();
484
485        let conf = Conf::load(&cfg_path).unwrap();
486        let rtd = Rtd::new()
487            .conf(conf)
488            .db(Some(&db_path))
489            .load()
490            .unwrap();
491
492        assert_eq!(rtd.paths.db, Some(db_path));
493    }
494
495    #[test]
496    fn sqlite_path_config_overrides_explicit() {
497        let tmp_dir = tempdir().unwrap();
498        let cfg_path = tmp_dir.path().join("config.toml");
499        let db_path = tmp_dir.path().join("test.db");
500        let db_path_cfg = tmp_dir.path().join("cfg.test.db");
501
502        let mut cfg = Conf::default();
503        cfg.features.history = true;
504        cfg.database.db_type = DbType::Sqlite;
505        cfg.database.path = Some(db_path_cfg.to_str().unwrap().to_string());
506        cfg.write(&cfg_path).unwrap();
507
508        let conf = Conf::load(&cfg_path).unwrap();
509        let rtd = Rtd::new()
510            .conf(conf)
511            .db(Some(&db_path))
512            .load()
513            .unwrap();
514
515        assert_eq!(rtd.paths.db, Some(db_path_cfg));
516    }
517
518    #[test]
519    fn sqlite_path_empty_config_path_does_not_override_explicit() {
520        let tmp_dir = tempdir().unwrap();
521        let cfg_path = tmp_dir.path().join("config.toml");
522        let db_path = tmp_dir.path().join("cfg.test.db");
523
524        let mut cfg = Conf::default();
525        cfg.features.history = true;
526        cfg.database.db_type = DbType::Sqlite;
527        cfg.database.path = Some("".to_string());
528        cfg.write(&cfg_path).unwrap();
529
530        let conf = Conf::load(&cfg_path).unwrap();
531        let rtd = Rtd::new()
532            .conf(conf)
533            .db(Some(&db_path))
534            .load()
535            .unwrap();
536
537        assert_eq!(rtd.paths.db, Some(db_path));
538    }
539
540    #[test]
541    fn sqlite_path_default() {
542        let tmp_dir = tempdir().unwrap();
543        let cfg_path = tmp_dir.path().join("config.toml");
544
545        let mut cfg = Conf::default();
546        cfg.features.history = true;
547        cfg.database.db_type = DbType::Sqlite;
548        cfg.write(&cfg_path).unwrap();
549
550        let conf = Conf::load(&cfg_path).unwrap();
551        let rtd = Rtd::new()
552            .conf(conf)
553            .load()
554            .unwrap();
555
556        let dirs = ProjectDirs::from("org", "", "url-bot-rs").unwrap();
557        let default = dirs.data_local_dir().join("history.default.db");
558        println!("database path: {}", default.to_str().unwrap());
559        assert_eq!(rtd.paths.db, Some(default));
560
561        cfg.network.name = "test_net".to_string();
562        cfg.write(&cfg_path).unwrap();
563
564        let conf = Conf::load(&cfg_path).unwrap();
565        let rtd = Rtd::new()
566            .conf(conf)
567            .load()
568            .unwrap();
569
570        let default = dirs.data_local_dir().join("history.test_net.db");
571        println!("database path: {}", default.to_str().unwrap());
572        assert_eq!(rtd.paths.db, Some(default.clone()));
573
574        cfg.database.path = Some("".to_string());
575        cfg.write(&cfg_path).unwrap();
576
577        let conf = Conf::load(&cfg_path).unwrap();
578        let rtd = Rtd::new()
579            .conf(conf)
580            .load()
581            .unwrap();
582
583        assert_eq!(rtd.paths.db, Some(default));
584    }
585
586    #[test]
587    fn sqlite_path_config_overrides_default() {
588        let tmp_dir = tempdir().unwrap();
589        let cfg_path = tmp_dir.path().join("config.toml");
590        let db_path_cfg = tmp_dir.path().join("cfg.test.db");
591
592        let mut cfg = Conf::default();
593        cfg.features.history = true;
594        cfg.database.db_type = DbType::Sqlite;
595        cfg.database.path = Some(db_path_cfg.to_str().unwrap().to_string());
596        cfg.write(&cfg_path).unwrap();
597
598        let conf = Conf::load(&cfg_path).unwrap();
599        let rtd = Rtd::new()
600            .conf(conf)
601            .load()
602            .unwrap();
603
604        assert_eq!(rtd.paths.db, Some(db_path_cfg));
605    }
606
607    #[test]
608    fn test_ensure_parent() {
609        let tmp_dir = tempdir().unwrap();
610        let tmp_path = tmp_dir.path().join("test/test.file");
611
612        assert_eq!(ensure_parent_dir(&tmp_path).unwrap(), true);
613        assert_eq!(ensure_parent_dir(&tmp_path).unwrap(), false);
614        assert_eq!(ensure_parent_dir(&tmp_path).unwrap(), false);
615    }
616
617    #[test]
618    /// CWD should always exist, so don't try to create it
619    fn test_ensure_parent_file_in_cwd() {
620        assert_eq!(ensure_parent_dir(Path::new("test.f")).unwrap(), false);
621        assert_eq!(ensure_parent_dir(Path::new("./test.f")).unwrap(), false);
622    }
623
624    #[test]
625    fn test_ensure_parent_relative() {
626        let tmp_dir = tempdir().unwrap();
627        let test_dir = tmp_dir.path().join("subdir");
628        println!("creating temp path: {}", test_dir.display());
629        fs::create_dir_all(&test_dir).unwrap();
630
631        let cwd = env::current_dir().unwrap();
632        env::set_current_dir(test_dir).unwrap();
633
634        let result = panic::catch_unwind(|| {
635            assert_eq!(ensure_parent_dir(Path::new("../dir/file")).unwrap(), true);
636            assert_eq!(ensure_parent_dir(Path::new("../dir/file")).unwrap(), false);
637            assert_eq!(ensure_parent_dir(Path::new("./dir/file")).unwrap(), true);
638            assert_eq!(ensure_parent_dir(Path::new("./dir/file")).unwrap(), false);
639            assert_eq!(ensure_parent_dir(Path::new("dir2/file")).unwrap(), true);
640            assert_eq!(ensure_parent_dir(Path::new("dir2/file")).unwrap(), false);
641            assert_eq!(ensure_parent_dir(Path::new("./dir3/file")).unwrap(), true);
642            assert_eq!(ensure_parent_dir(Path::new("dir3/file2")).unwrap(), false);
643        });
644
645        env::set_current_dir(cwd).unwrap();
646        assert!(result.is_ok());
647    }
648
649    fn print_diff(example: &str, default: &str) {
650        // print diff (on failure)
651        println!("Configuration diff (- example, + default):");
652        for diff in diff::lines(&example, &default) {
653            match diff {
654                diff::Result::Left(l) => println!("-{}", l),
655                diff::Result::Both(l, _) => println!(" {}", l),
656                diff::Result::Right(r) => println!("+{}", r)
657            }
658        }
659    }
660
661    #[test]
662    /// test that the example configuration matches default values
663    fn example_conf_data_matches_generated_default_values() {
664        let example = fs::read_to_string("example.config.toml").unwrap();
665        let default = toml::ser::to_string(&Conf::default()).unwrap();
666
667        print_diff(&example, &default);
668
669        default.lines()
670            .zip(example.lines())
671            .for_each(|(a, b)| assert_eq!(a, b));
672    }
673
674    #[test]
675    /// test that the example configuration matches default values
676    fn example_conf_data_matches_generated_expected_values() {
677        // construct the example
678        let confset = get_test_confset();
679
680        let example = fs::read_to_string("example.multi.config.toml").unwrap();
681        let default = toml::ser::to_string(&confset).unwrap();
682
683        print_diff(&example, &default);
684
685        default.lines()
686            .zip(example.lines())
687            .for_each(|(a, b)| assert_eq!(a, b));
688    }
689
690    #[test]
691    fn conf_add_remove_channel() {
692        let mut rtd = Rtd::default();
693        check_channels(&rtd, "#url-bot-rs", 1);
694
695        rtd.conf.add_channel("#cheese".to_string());
696        check_channels(&rtd, "#cheese", 2);
697
698        rtd.conf.add_channel("#cheese-2".to_string());
699        check_channels(&rtd, "#cheese-2", 3);
700
701        rtd.conf.remove_channel(&"#cheese-2".to_string());
702        let c = rtd.conf.client.channels.clone().unwrap();
703
704        assert!(!c.contains(&"#cheese-2".to_string()));
705        assert_eq!(2, c.len());
706    }
707
708    fn check_channels(rtd: &Rtd, contains: &str, len: usize) {
709        let c = rtd.conf.client.channels.clone().unwrap();
710        println!("{:?}", c);
711
712        assert!(c.contains(&contains.to_string()));
713        assert_eq!(len, c.len());
714    }
715
716    #[test]
717    fn test_expand_tilde() {
718        let homedir: PathBuf = BaseDirs::new()
719            .unwrap()
720            .home_dir()
721            .to_owned();
722
723        assert_eq!(expand_tilde(&PathBuf::from("/")),
724            PathBuf::from("/"));
725        assert_eq!(expand_tilde(&PathBuf::from("/abc/~def/ghi/")),
726            PathBuf::from("/abc/~def/ghi/"));
727        assert_eq!(expand_tilde(&PathBuf::from("~/")),
728            PathBuf::from(format!("{}/", homedir.to_str().unwrap())));
729        assert_eq!(expand_tilde(&PathBuf::from("~/ac/df/gi/")),
730            PathBuf::from(format!("{}/ac/df/gi/", homedir.to_str().unwrap())));
731    }
732
733    fn write_n_configs(n: usize, dir: &Path) {
734        iter::repeat(dir)
735            .take(n)
736            .enumerate()
737            .map(|(i, p)| p.join(i.to_string() + ".conf"))
738            .for_each(|p| Conf::default().write(p).unwrap());
739    }
740
741    #[test]
742    fn test_find_configs_in_dir() {
743        let tmp_dir = tempdir().unwrap();
744        let cfg_dir = tmp_dir.path();
745
746        assert_eq!(find_configs_in_dir(cfg_dir).unwrap().count(), 0);
747
748        write_n_configs(10, cfg_dir);
749        assert_eq!(find_configs_in_dir(cfg_dir).unwrap().count(), 10);
750
751        let mut f = File::create(cfg_dir.join("fake.conf")).unwrap();
752        f.write_all(b"not a config").unwrap();
753        assert_eq!(find_configs_in_dir(cfg_dir).unwrap().count(), 10);
754
755        let mut f = File::create(cfg_dir.join("fake.toml")).unwrap();
756        f.write_all(b"[this]\nis = \"valid toml\"").unwrap();
757        assert_eq!(find_configs_in_dir(cfg_dir).unwrap().count(), 10);
758
759        fs::create_dir(cfg_dir.join("fake.dir")).unwrap();
760        assert_eq!(find_configs_in_dir(cfg_dir).unwrap().count(), 10);
761
762        write_n_configs(33, cfg_dir);
763        assert_eq!(find_configs_in_dir(cfg_dir).unwrap().count(), 32);
764    }
765
766    #[test]
767    fn test_do_not_promiscuously_load_any_toml() {
768        let tmp_dir = tempdir().unwrap();
769        let cfg_path = tmp_dir.path().join("fake.toml");
770        let mut f = File::create(&cfg_path).unwrap();
771        f.write_all(b"[this]\nis = \"valid toml\"").unwrap();
772
773        assert!(Conf::load(&cfg_path).is_err());
774    }
775
776    #[test]
777    fn test_allow_loading_with_missing_optional_fields() {
778        let tmp_dir = tempdir().unwrap();
779        let cfg_path = tmp_dir.path().join("fake.toml");
780        let mut f = File::create(&cfg_path).unwrap();
781        f.write_all(b"[connection]\n[features]\n[parameters]\n").unwrap();
782
783        Conf::load(&cfg_path).unwrap();
784    }
785
786    #[test]
787    fn test_macros() {
788        let mut rtd = Rtd::default();
789        assert_eq!(10, param!(rtd, url_limit));
790        assert_eq!(10, http!(rtd, max_redirections));
791        assert!(!feat!(rtd, reconnect));
792
793        rtd.conf.params.url_limit = 100;
794        assert_eq!(100, param!(rtd, url_limit));
795
796        rtd.conf.http_params.max_redirections = 100;
797        assert_eq!(100, http!(rtd, max_redirections));
798
799        rtd.conf.features.reconnect = true;
800        assert!(feat!(rtd, reconnect));
801    }
802
803    #[test]
804    fn test_load_flattened_configs() {
805        let tmp_dir = tempdir().unwrap();
806        let mut paths: Vec<PathBuf> = vec![];
807
808        // make 10 normal configuration files
809        for c in 0..10 {
810            let path = tmp_dir.path().join(format!("conf_{}.toml", c));
811            Conf::default().write(&path).unwrap();
812            paths.push(path);
813        }
814
815        // make 10 configuration sets
816        let set = get_test_confset();
817        for c in 0..10 {
818            let path = tmp_dir.path().join(format!("conf_multi_{}.toml", c));
819            set.write(&path).unwrap();
820            paths.push(path);
821        }
822
823        let res = load_flattened_configs(paths);
824        assert_eq!(res.iter().count(), 30);
825    }
826}