Skip to main content

notify_core/
config.rs

1use std::{
2    collections::BTreeMap,
3    fs,
4    path::{Path, PathBuf},
5};
6
7use serde::{Deserialize, Serialize};
8
9use crate::{NotifyError, Result};
10
11#[derive(Debug, Clone)]
12pub struct ConfigLoad {
13    pub config: Config,
14    pub path: PathBuf,
15}
16
17impl ConfigLoad {
18    pub fn load(explicit_path: Option<&Path>) -> Result<Self> {
19        let path = discover_config_path(explicit_path)?;
20        let contents = fs::read_to_string(&path).map_err(|source| NotifyError::ConfigRead {
21            path: path.clone(),
22            source,
23        })?;
24        let config = toml::from_str(&contents).map_err(|source| NotifyError::ConfigParse {
25            path: path.clone(),
26            source,
27        })?;
28
29        Ok(Self { config, path })
30    }
31}
32
33pub fn discover_config_path(explicit_path: Option<&Path>) -> Result<PathBuf> {
34    if let Some(path) = explicit_path {
35        return Ok(path.to_path_buf());
36    }
37
38    let local = PathBuf::from("notify.toml");
39    if local.exists() {
40        return Ok(local);
41    }
42
43    if let Some(home) = dirs::home_dir() {
44        let path = home
45            .join(".config")
46            .join("agent-notify")
47            .join("config.toml");
48        if path.exists() {
49            return Ok(path);
50        }
51    }
52
53    Err(NotifyError::ConfigNotFound)
54}
55
56#[derive(Debug, Clone, Deserialize)]
57pub struct Config {
58    pub default_channel: Option<String>,
59    #[serde(default)]
60    pub channels: BTreeMap<String, ChannelConfig>,
61}
62
63impl Config {
64    pub fn resolve_channel_name<'a>(&'a self, requested: Option<&'a str>) -> Result<&'a str> {
65        let name = match requested {
66            Some(name) => name,
67            None => self
68                .default_channel
69                .as_deref()
70                .ok_or(NotifyError::DefaultChannelMissing)?,
71        };
72
73        if self.channels.contains_key(name) {
74            Ok(name)
75        } else {
76            Err(NotifyError::ChannelNotFound(name.to_string()))
77        }
78    }
79
80    pub fn channel(&self, name: &str) -> Result<&ChannelConfig> {
81        self.channels
82            .get(name)
83            .ok_or_else(|| NotifyError::ChannelNotFound(name.to_string()))
84    }
85
86    pub fn validation_issues(&self) -> Vec<CheckIssue> {
87        self.validation_issues_with(&ProcessEnv)
88    }
89
90    pub fn validation_issues_with<E: EnvSource>(&self, env: &E) -> Vec<CheckIssue> {
91        let mut issues = Vec::new();
92
93        match self.default_channel.as_deref() {
94            Some(name) if !self.channels.contains_key(name) => {
95                issues.push(CheckIssue::error(
96                    None,
97                    "DEFAULT_CHANNEL_NOT_FOUND",
98                    format!("default_channel \"{name}\" does not exist"),
99                ));
100            }
101            None => {
102                issues.push(CheckIssue::error(
103                    None,
104                    "DEFAULT_CHANNEL_MISSING",
105                    "default_channel is not configured",
106                ));
107            }
108            Some(_) => {}
109        }
110
111        for (name, channel) in &self.channels {
112            issues.extend(channel.validation_issues(name, env));
113        }
114
115        issues
116    }
117
118    pub fn channel_statuses(&self) -> Vec<ChannelStatus> {
119        self.channel_statuses_with(&ProcessEnv)
120    }
121
122    pub fn channel_statuses_with<E: EnvSource>(&self, env: &E) -> Vec<ChannelStatus> {
123        self.channels
124            .iter()
125            .map(|(name, channel)| {
126                let issues = channel.validation_issues(name, env);
127                let missing_env = issues
128                    .iter()
129                    .filter(|issue| issue.code == "MISSING_ENV")
130                    .map(|issue| issue.message.clone())
131                    .collect::<Vec<_>>();
132                let warnings = issues
133                    .iter()
134                    .filter(|issue| issue.level == IssueLevel::Warning)
135                    .map(|issue| issue.message.clone())
136                    .collect::<Vec<_>>();
137                let errors = issues
138                    .iter()
139                    .filter(|issue| issue.level == IssueLevel::Error)
140                    .map(|issue| issue.message.clone())
141                    .collect::<Vec<_>>();
142                let status = if errors.is_empty() {
143                    "ready"
144                } else if !missing_env.is_empty() {
145                    "missing"
146                } else {
147                    "error"
148                };
149
150                ChannelStatus {
151                    name: name.clone(),
152                    channel_type: channel.type_name().to_string(),
153                    status: status.to_string(),
154                    missing_env,
155                    warnings,
156                    errors,
157                }
158            })
159            .collect()
160    }
161}
162
163#[derive(Debug, Clone, Deserialize)]
164#[serde(tag = "type", rename_all = "kebab-case")]
165pub enum ChannelConfig {
166    Telegram(TelegramConfig),
167    DiscordWebhook(DiscordWebhookConfig),
168    DiscordBot(DiscordBotConfig),
169    Ntfy(NtfyConfig),
170    SlackWebhook(SlackWebhookConfig),
171    Pushover(PushoverConfig),
172    Gotify(GotifyConfig),
173    Webhook(WebhookConfig),
174    FileLog(FileLogConfig),
175}
176
177impl ChannelConfig {
178    pub fn type_name(&self) -> &'static str {
179        match self {
180            Self::Telegram(_) => "telegram",
181            Self::DiscordWebhook(_) => "discord-webhook",
182            Self::DiscordBot(_) => "discord-bot",
183            Self::Ntfy(_) => "ntfy",
184            Self::SlackWebhook(_) => "slack-webhook",
185            Self::Pushover(_) => "pushover",
186            Self::Gotify(_) => "gotify",
187            Self::Webhook(_) => "webhook",
188            Self::FileLog(_) => "file-log",
189        }
190    }
191
192    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
193        match self {
194            Self::Telegram(config) => config.validation_issues(channel, env),
195            Self::DiscordWebhook(config) => config.validation_issues(channel, env),
196            Self::DiscordBot(config) => config.validation_issues(channel, env),
197            Self::Ntfy(config) => config.validation_issues(channel, env),
198            Self::SlackWebhook(config) => config.validation_issues(channel, env),
199            Self::Pushover(config) => config.validation_issues(channel, env),
200            Self::Gotify(config) => config.validation_issues(channel, env),
201            Self::Webhook(config) => config.validation_issues(channel, env),
202            Self::FileLog(config) => config.validation_issues(channel),
203        }
204    }
205}
206
207#[derive(Debug, Clone, Deserialize)]
208pub struct TelegramConfig {
209    pub bot_token: Option<String>,
210    pub bot_token_env: Option<String>,
211    pub chat_id: Option<String>,
212    pub chat_id_env: Option<String>,
213    pub parse_mode: Option<String>,
214}
215
216impl TelegramConfig {
217    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
218        let mut issues = Vec::new();
219        validate_secret_pair(
220            channel,
221            "bot_token",
222            self.bot_token.as_deref(),
223            self.bot_token_env.as_deref(),
224            true,
225            env,
226            &mut issues,
227        );
228        validate_secret_pair(
229            channel,
230            "chat_id",
231            self.chat_id.as_deref(),
232            self.chat_id_env.as_deref(),
233            true,
234            env,
235            &mut issues,
236        );
237
238        if let Some(parse_mode) = self.parse_mode.as_deref()
239            && !matches!(parse_mode, "plain" | "html" | "markdown-v2")
240        {
241            issues.push(CheckIssue::error(
242                Some(channel),
243                "INVALID_FIELD",
244                format!("channel \"{channel}\" has invalid parse_mode \"{parse_mode}\""),
245            ));
246        }
247
248        issues
249    }
250}
251
252#[derive(Debug, Clone, Deserialize)]
253pub struct DiscordWebhookConfig {
254    pub webhook_url: Option<String>,
255    pub webhook_url_env: Option<String>,
256    pub username: Option<String>,
257    pub avatar_url: Option<String>,
258    pub allow_mentions: Option<bool>,
259}
260
261impl DiscordWebhookConfig {
262    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
263        let mut issues = Vec::new();
264        validate_secret_pair(
265            channel,
266            "webhook_url",
267            self.webhook_url.as_deref(),
268            self.webhook_url_env.as_deref(),
269            true,
270            env,
271            &mut issues,
272        );
273        issues
274    }
275}
276
277#[derive(Debug, Clone, Deserialize)]
278pub struct DiscordBotConfig {
279    pub bot_token: Option<String>,
280    pub bot_token_env: Option<String>,
281    pub channel_id: Option<String>,
282    pub channel_id_env: Option<String>,
283    pub allow_mentions: Option<bool>,
284}
285
286impl DiscordBotConfig {
287    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
288        let mut issues = Vec::new();
289        validate_secret_pair(
290            channel,
291            "bot_token",
292            self.bot_token.as_deref(),
293            self.bot_token_env.as_deref(),
294            true,
295            env,
296            &mut issues,
297        );
298        validate_secret_pair(
299            channel,
300            "channel_id",
301            self.channel_id.as_deref(),
302            self.channel_id_env.as_deref(),
303            true,
304            env,
305            &mut issues,
306        );
307        issues
308    }
309}
310
311#[derive(Debug, Clone, Deserialize)]
312pub struct NtfyConfig {
313    pub server: Option<String>,
314    pub topic: Option<String>,
315    pub topic_env: Option<String>,
316    pub token: Option<String>,
317    pub token_env: Option<String>,
318}
319
320impl NtfyConfig {
321    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
322        let mut issues = Vec::new();
323        validate_secret_pair(
324            channel,
325            "topic",
326            self.topic.as_deref(),
327            self.topic_env.as_deref(),
328            true,
329            env,
330            &mut issues,
331        );
332        validate_secret_pair(
333            channel,
334            "token",
335            self.token.as_deref(),
336            self.token_env.as_deref(),
337            false,
338            env,
339            &mut issues,
340        );
341        issues
342    }
343}
344
345#[derive(Debug, Clone, Deserialize)]
346pub struct SlackWebhookConfig {
347    pub webhook_url: Option<String>,
348    pub webhook_url_env: Option<String>,
349    pub username: Option<String>,
350    pub icon_emoji: Option<String>,
351    pub icon_url: Option<String>,
352    pub allow_mentions: Option<bool>,
353}
354
355impl SlackWebhookConfig {
356    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
357        let mut issues = Vec::new();
358        validate_secret_pair(
359            channel,
360            "webhook_url",
361            self.webhook_url.as_deref(),
362            self.webhook_url_env.as_deref(),
363            true,
364            env,
365            &mut issues,
366        );
367        issues
368    }
369}
370
371#[derive(Debug, Clone, Deserialize)]
372pub struct PushoverConfig {
373    pub token: Option<String>,
374    pub token_env: Option<String>,
375    pub user: Option<String>,
376    pub user_env: Option<String>,
377    pub device: Option<String>,
378    pub sound: Option<String>,
379}
380
381impl PushoverConfig {
382    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
383        let mut issues = Vec::new();
384        validate_secret_pair(
385            channel,
386            "token",
387            self.token.as_deref(),
388            self.token_env.as_deref(),
389            true,
390            env,
391            &mut issues,
392        );
393        validate_secret_pair(
394            channel,
395            "user",
396            self.user.as_deref(),
397            self.user_env.as_deref(),
398            true,
399            env,
400            &mut issues,
401        );
402        issues
403    }
404}
405
406#[derive(Debug, Clone, Deserialize)]
407pub struct GotifyConfig {
408    pub server: String,
409    pub token: Option<String>,
410    pub token_env: Option<String>,
411    pub priority: Option<i64>,
412}
413
414impl GotifyConfig {
415    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
416        let mut issues = Vec::new();
417        if self.server.trim().is_empty() {
418            issues.push(CheckIssue::error(
419                Some(channel),
420                "MISSING_FIELD",
421                format!("channel \"{channel}\" is missing server"),
422            ));
423        }
424        validate_secret_pair(
425            channel,
426            "token",
427            self.token.as_deref(),
428            self.token_env.as_deref(),
429            true,
430            env,
431            &mut issues,
432        );
433        issues
434    }
435}
436
437#[derive(Debug, Clone, Deserialize)]
438pub struct WebhookConfig {
439    pub url: Option<String>,
440    pub url_env: Option<String>,
441    pub auth_header: Option<String>,
442    pub auth_header_env: Option<String>,
443    pub timeout_seconds: Option<u64>,
444}
445
446impl WebhookConfig {
447    fn validation_issues<E: EnvSource>(&self, channel: &str, env: &E) -> Vec<CheckIssue> {
448        let mut issues = Vec::new();
449        validate_secret_pair(
450            channel,
451            "url",
452            self.url.as_deref(),
453            self.url_env.as_deref(),
454            true,
455            env,
456            &mut issues,
457        );
458        validate_secret_pair(
459            channel,
460            "auth_header",
461            self.auth_header.as_deref(),
462            self.auth_header_env.as_deref(),
463            false,
464            env,
465            &mut issues,
466        );
467        if matches!(self.timeout_seconds, Some(0)) {
468            issues.push(CheckIssue::error(
469                Some(channel),
470                "INVALID_FIELD",
471                format!("channel \"{channel}\" timeout_seconds must be greater than 0"),
472            ));
473        }
474        issues
475    }
476}
477
478#[derive(Debug, Clone, Deserialize)]
479pub struct FileLogConfig {
480    pub path: PathBuf,
481}
482
483impl FileLogConfig {
484    fn validation_issues(&self, channel: &str) -> Vec<CheckIssue> {
485        if self.path.as_os_str().is_empty() {
486            vec![CheckIssue::error(
487                Some(channel),
488                "MISSING_FIELD",
489                format!("channel \"{channel}\" is missing path"),
490            )]
491        } else {
492            Vec::new()
493        }
494    }
495}
496
497#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
498#[serde(rename_all = "lowercase")]
499pub enum IssueLevel {
500    Error,
501    Warning,
502}
503
504#[derive(Debug, Clone, Serialize)]
505pub struct CheckIssue {
506    pub level: IssueLevel,
507    pub channel: Option<String>,
508    pub code: String,
509    pub message: String,
510}
511
512impl CheckIssue {
513    pub fn error(channel: Option<&str>, code: &str, message: impl Into<String>) -> Self {
514        Self {
515            level: IssueLevel::Error,
516            channel: channel.map(ToOwned::to_owned),
517            code: code.to_string(),
518            message: message.into(),
519        }
520    }
521
522    pub fn warning(channel: Option<&str>, code: &str, message: impl Into<String>) -> Self {
523        Self {
524            level: IssueLevel::Warning,
525            channel: channel.map(ToOwned::to_owned),
526            code: code.to_string(),
527            message: message.into(),
528        }
529    }
530
531    pub fn is_error(&self) -> bool {
532        self.level == IssueLevel::Error
533    }
534}
535
536#[derive(Debug, Clone, Serialize)]
537pub struct ChannelStatus {
538    pub name: String,
539    #[serde(rename = "type")]
540    pub channel_type: String,
541    pub status: String,
542    pub missing_env: Vec<String>,
543    pub warnings: Vec<String>,
544    pub errors: Vec<String>,
545}
546
547pub trait EnvSource {
548    fn exists(&self, name: &str) -> bool;
549}
550
551#[derive(Debug, Clone, Copy)]
552pub struct ProcessEnv;
553
554impl EnvSource for ProcessEnv {
555    fn exists(&self, name: &str) -> bool {
556        std::env::var_os(name).is_some()
557    }
558}
559
560fn validate_secret_pair<E: EnvSource>(
561    channel: &str,
562    field: &str,
563    inline: Option<&str>,
564    env_name: Option<&str>,
565    required: bool,
566    env: &E,
567    issues: &mut Vec<CheckIssue>,
568) {
569    match (inline, env_name) {
570        (Some(_), Some(_)) => issues.push(CheckIssue::error(
571            Some(channel),
572            "SECRET_CONFLICT",
573            format!("channel \"{channel}\" {field} and {field}_env cannot be set at the same time"),
574        )),
575        (Some(_), None) => issues.push(CheckIssue::warning(
576            Some(channel),
577            "INLINE_SECRET",
578            format!("channel \"{channel}\" uses inline {field}"),
579        )),
580        (None, Some(env_name)) if !env.exists(env_name) => issues.push(CheckIssue::error(
581            Some(channel),
582            "MISSING_ENV",
583            env_name.to_string(),
584        )),
585        (None, None) if required => issues.push(CheckIssue::error(
586            Some(channel),
587            "MISSING_FIELD",
588            format!("channel \"{channel}\" is missing {field} or {field}_env"),
589        )),
590        _ => {}
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use std::{collections::BTreeSet, fs};
597
598    use tempfile::tempdir;
599
600    use super::*;
601
602    struct MapEnv(BTreeSet<String>);
603
604    impl EnvSource for MapEnv {
605        fn exists(&self, name: &str) -> bool {
606            self.0.contains(name)
607        }
608    }
609
610    #[test]
611    fn loads_config_from_explicit_path() {
612        let dir = tempdir().unwrap();
613        let path = dir.path().join("notify.toml");
614        fs::write(
615            &path,
616            r#"
617default_channel = "local"
618
619[channels.local]
620type = "file-log"
621path = "./notify-log"
622"#,
623        )
624        .unwrap();
625
626        let loaded = ConfigLoad::load(Some(&path)).unwrap();
627
628        assert_eq!(loaded.path, path);
629        assert_eq!(loaded.config.default_channel.as_deref(), Some("local"));
630        assert!(matches!(
631            loaded.config.channels.get("local"),
632            Some(ChannelConfig::FileLog(_))
633        ));
634    }
635
636    #[test]
637    fn detects_default_channel_missing() {
638        let config: Config = toml::from_str(
639            r#"
640[channels.local]
641type = "file-log"
642path = "./notify-log"
643"#,
644        )
645        .unwrap();
646
647        let issues = config.validation_issues_with(&MapEnv(BTreeSet::new()));
648
649        assert!(
650            issues
651                .iter()
652                .any(|issue| issue.code == "DEFAULT_CHANNEL_MISSING")
653        );
654    }
655
656    #[test]
657    fn detects_default_channel_not_found() {
658        let config: Config = toml::from_str(
659            r#"
660default_channel = "missing"
661
662[channels.local]
663type = "file-log"
664path = "./notify-log"
665"#,
666        )
667        .unwrap();
668
669        let issues = config.validation_issues_with(&MapEnv(BTreeSet::new()));
670
671        assert!(issues.iter().any(|issue| {
672            issue.code == "DEFAULT_CHANNEL_NOT_FOUND" && issue.level == IssueLevel::Error
673        }));
674    }
675
676    #[test]
677    fn detects_secret_conflict_and_missing_env() {
678        let config: Config = toml::from_str(
679            r#"
680default_channel = "team"
681
682[channels.team]
683type = "discord-webhook"
684webhook_url = "https://example.com"
685webhook_url_env = "NOTIFY_DISCORD_WEBHOOK_URL"
686
687[channels.phone]
688type = "ntfy"
689topic_env = "NOTIFY_NTFY_TOPIC"
690"#,
691        )
692        .unwrap();
693
694        let issues = config.validation_issues_with(&MapEnv(BTreeSet::new()));
695
696        assert!(issues.iter().any(|issue| issue.code == "SECRET_CONFLICT"));
697        assert!(issues.iter().any(|issue| issue.code == "MISSING_ENV"));
698    }
699
700    #[test]
701    fn detects_invalid_telegram_parse_mode() {
702        let config: Config = toml::from_str(
703            r#"
704default_channel = "personal"
705
706[channels.personal]
707type = "telegram"
708bot_token = "token"
709chat_id = "chat"
710parse_mode = "markdown"
711"#,
712        )
713        .unwrap();
714
715        let issues = config.validation_issues_with(&MapEnv(BTreeSet::new()));
716
717        assert!(issues.iter().any(|issue| {
718            issue.code == "INVALID_FIELD"
719                && issue.level == IssueLevel::Error
720                && issue.channel.as_deref() == Some("personal")
721        }));
722    }
723
724    #[test]
725    fn loads_added_http_channel_types() {
726        let config: Config = toml::from_str(
727            r#"
728default_channel = "slack"
729
730[channels.slack]
731type = "slack-webhook"
732webhook_url_env = "NOTIFY_SLACK_WEBHOOK_URL"
733username = "Agent Notify"
734
735[channels.mobile]
736type = "pushover"
737token_env = "NOTIFY_PUSHOVER_TOKEN"
738user_env = "NOTIFY_PUSHOVER_USER"
739device = "phone"
740
741[channels.self_hosted]
742type = "gotify"
743server = "https://gotify.example.com"
744token_env = "NOTIFY_GOTIFY_TOKEN"
745priority = 7
746"#,
747        )
748        .unwrap();
749
750        assert!(matches!(
751            config.channels.get("slack"),
752            Some(ChannelConfig::SlackWebhook(_))
753        ));
754        assert!(matches!(
755            config.channels.get("mobile"),
756            Some(ChannelConfig::Pushover(_))
757        ));
758        assert!(matches!(
759            config.channels.get("self_hosted"),
760            Some(ChannelConfig::Gotify(_))
761        ));
762        assert_eq!(config.channels["slack"].type_name(), "slack-webhook");
763        assert_eq!(config.channels["mobile"].type_name(), "pushover");
764        assert_eq!(config.channels["self_hosted"].type_name(), "gotify");
765    }
766
767    #[test]
768    fn validates_added_http_channel_secrets() {
769        let config: Config = toml::from_str(
770            r#"
771default_channel = "slack"
772
773[channels.slack]
774type = "slack-webhook"
775webhook_url = "https://hooks.slack.com/services/test"
776webhook_url_env = "NOTIFY_SLACK_WEBHOOK_URL"
777
778[channels.mobile]
779type = "pushover"
780token = "app-token"
781user_env = "NOTIFY_PUSHOVER_USER"
782
783[channels.self_hosted]
784type = "gotify"
785server = "https://gotify.example.com"
786token_env = "NOTIFY_GOTIFY_TOKEN"
787"#,
788        )
789        .unwrap();
790
791        let issues = config.validation_issues_with(&MapEnv(BTreeSet::new()));
792
793        assert_issue(&issues, "slack", "SECRET_CONFLICT", IssueLevel::Error);
794        assert_issue(&issues, "mobile", "INLINE_SECRET", IssueLevel::Warning);
795        assert_issue(&issues, "mobile", "MISSING_ENV", IssueLevel::Error);
796        assert_issue(&issues, "self_hosted", "MISSING_ENV", IssueLevel::Error);
797    }
798
799    #[test]
800    fn rejects_unsupported_type_during_deserialize() {
801        let result = toml::from_str::<Config>(
802            r#"
803default_channel = "mail"
804
805[channels.mail]
806type = "email"
807"#,
808        );
809
810        assert!(result.is_err());
811    }
812
813    fn assert_issue(issues: &[CheckIssue], channel: &str, code: &str, level: IssueLevel) {
814        assert!(
815            issues.iter().any(|issue| {
816                issue.channel.as_deref() == Some(channel)
817                    && issue.code == code
818                    && issue.level == level
819            }),
820            "missing {level:?} issue {code} for channel {channel}: {issues:#?}"
821        );
822    }
823}