1use std::{
2 collections::HashMap, io::Write, ops::ControlFlow, path::PathBuf, process::Stdio, sync::Arc,
3 time::Duration,
4};
5
6use chrono::{DateTime, Utc};
7use miette::{Context as _, IntoDiagnostic, Result, miette};
8use tera::Context as TeraCtx;
9use tokio::io::AsyncReadExt as _;
10use tokio_postgres::types::ToSql;
11use tracing::{debug, error, info, instrument, warn};
12
13use crate::{
14 EmailConfig, LogError, events::EventType, targets::ExternalTarget, templates::build_context,
15};
16
17fn enabled() -> bool {
18 true
19}
20
21fn default_interval() -> String {
22 "1 minute".to_string()
23}
24
25#[derive(serde::Deserialize, Debug, Clone)]
26#[serde(rename_all = "kebab-case")]
27pub struct NumericalThreshold {
28 pub field: String,
29 pub alert_at: f64,
30 pub clear_at: Option<f64>,
31}
32
33#[derive(serde::Deserialize, Debug, Clone)]
34#[serde(rename_all = "kebab-case")]
35#[serde(untagged)]
36pub enum WhenChanged {
37 Boolean(bool),
38 Detailed(WhenChangedConfig),
39}
40
41impl Default for WhenChanged {
42 fn default() -> Self {
43 WhenChanged::Boolean(false)
44 }
45}
46
47#[derive(serde::Deserialize, Debug, Clone)]
48#[serde(rename_all = "kebab-case")]
49pub struct WhenChangedConfig {
50 #[serde(default)]
51 pub except: Vec<String>,
52 #[serde(default)]
53 pub only: Vec<String>,
54}
55
56#[derive(serde::Deserialize, Debug, Clone)]
57#[serde(rename_all = "kebab-case")]
58#[serde(untagged)]
59pub enum AlwaysSend {
60 Boolean(bool),
61 Timed(AlwaysSendConfig),
62}
63
64impl Default for AlwaysSend {
65 fn default() -> Self {
66 AlwaysSend::Boolean(false)
67 }
68}
69
70#[derive(serde::Deserialize, Debug, Clone)]
71#[serde(rename_all = "kebab-case")]
72pub struct AlwaysSendConfig {
73 pub after: String,
74 #[serde(skip)]
75 pub after_duration: Duration,
76}
77
78#[derive(serde::Deserialize, Debug, Default, Clone)]
79#[serde(rename_all = "kebab-case")]
80pub struct AlertDefinition {
81 #[serde(default, skip)]
82 pub file: PathBuf,
83
84 #[serde(default = "enabled")]
85 pub enabled: bool,
86
87 #[serde(default = "default_interval")]
88 pub interval: String,
89
90 #[serde(skip)]
91 pub interval_duration: Duration,
92
93 #[serde(default)]
94 pub always_send: AlwaysSend,
95
96 #[serde(default)]
97 pub when_changed: WhenChanged,
98
99 #[serde(default)]
100 pub send: Vec<crate::targets::SendTarget>,
101
102 #[serde(flatten)]
103 pub source: TicketSource,
104}
105
106#[derive(serde::Deserialize, Debug, Default, Clone)]
107#[serde(untagged, deny_unknown_fields)]
108pub enum TicketSource {
109 Sql {
110 sql: String,
111 #[serde(default)]
112 numerical: Vec<NumericalThreshold>,
113 },
114 Shell {
115 shell: String,
116 run: String,
117 },
118 Event {
119 event: EventType,
120 },
121
122 #[default]
123 None,
124}
125
126impl AlertDefinition {
127 pub fn normalise(
128 mut self,
129 external_targets: &HashMap<String, Vec<ExternalTarget>>,
130 ) -> Result<(Self, Vec<crate::targets::ResolvedTarget>)> {
131 self.interval_duration = parse_interval(&self.interval)
133 .wrap_err_with(|| format!("failed to parse interval: {}", self.interval))?;
134
135 if let AlwaysSend::Timed(ref mut config) = self.always_send {
137 config.after_duration = parse_interval(&config.after)
138 .wrap_err_with(|| format!("failed to parse always-send after: {}", config.after))?;
139 }
140
141 for (idx, target) in self.send.iter().enumerate() {
144 crate::templates::load_templates(target.subject(), target.template()).wrap_err_with(
145 || {
146 format!(
147 "validating templates for send target #{} (id: {})",
148 idx + 1,
149 target.id()
150 )
151 },
152 )?;
153 }
154
155 let resolved = self
156 .send
157 .iter()
158 .flat_map(|target| {
159 let resolved_targets = target.resolve_external(external_targets);
160 if resolved_targets.is_empty() {
161 error!(
162 file=?self.file,
163 id = %target.id(),
164 available_targets=?external_targets.keys().collect::<Vec<_>>(),
165 "external target not found"
166 );
167 }
168 resolved_targets
169 })
170 .collect();
171
172 self.send.clear(); Ok((self, resolved))
174 }
175
176 #[instrument(skip(self, pool, not_before, context))]
177 pub async fn read_sources(
178 &self,
179 pool: &bestool_postgres::pool::PgPool,
180 not_before: DateTime<Utc>,
181 context: &mut TeraCtx,
182 was_triggered: bool,
183 ) -> Result<ControlFlow<(), ()>> {
184 match &self.source {
185 TicketSource::None => {
186 debug!(?self.file, "no source, skipping");
187 return Ok(ControlFlow::Break(()));
188 }
189 TicketSource::Event { .. } => {
190 debug!(?self.file, "event source, skipping normal execution");
192 return Ok(ControlFlow::Break(()));
193 }
194 TicketSource::Sql { sql, numerical } => {
195 let client = pool
196 .get()
197 .await
198 .map_err(|e| miette!("getting connection from pool: {e}"))?;
199 let statement = client.prepare(sql).await.into_diagnostic()?;
200
201 let interval = bestool_postgres::pg_interval::Interval(self.interval_duration);
202 let all_params: Vec<&(dyn ToSql + Sync)> = vec![¬_before, &interval];
203
204 let rows = client
205 .query(&statement, &all_params[..statement.params().len()])
206 .await
207 .into_diagnostic()
208 .wrap_err("querying database")?;
209
210 if rows.is_empty() {
211 debug!(?self.file, "no rows returned, skipping");
212 return Ok(ControlFlow::Break(()));
213 }
214
215 let context_rows = rows_to_value_map(&rows);
216
217 if !numerical.is_empty() {
219 let triggered =
220 check_numerical_thresholds(&context_rows, numerical, was_triggered)?;
221 if !triggered {
222 debug!(?self.file, "numerical thresholds not met, skipping");
223 return Ok(ControlFlow::Break(()));
224 }
225 }
226
227 info!(?self.file, rows=%rows.len(), "alert triggered");
228 context.insert("rows", &context_rows);
229 }
230 TicketSource::Shell { shell, run } => {
231 let mut script = tempfile::Builder::new().tempfile().into_diagnostic()?;
232 write!(script.as_file_mut(), "{run}").into_diagnostic()?;
233
234 let mut shell = tokio::process::Command::new(shell)
235 .arg(script.path())
236 .stdin(Stdio::null())
237 .stdout(Stdio::piped())
238 .spawn()
239 .into_diagnostic()?;
240
241 let mut output = Vec::new();
242 let mut stdout = shell
243 .stdout
244 .take()
245 .ok_or_else(|| miette!("getting the child stdout handle"))?;
246 let output_future =
247 futures::future::try_join(shell.wait(), stdout.read_to_end(&mut output));
248
249 let Ok(res) = tokio::time::timeout(self.interval_duration, output_future).await
250 else {
251 warn!(?self.file, "the script timed out, skipping");
252 shell.kill().await.into_diagnostic()?;
253 return Ok(ControlFlow::Break(()));
254 };
255
256 let (status, output_size) = res.into_diagnostic().wrap_err("running the shell")?;
257
258 if status.success() {
259 debug!(?self.file, "the script succeeded, skipping");
260 return Ok(ControlFlow::Break(()));
261 }
262 info!(?self.file, ?status, ?output_size, "alert triggered");
263
264 context.insert("output", &String::from_utf8_lossy(&output));
265 }
266 }
267 Ok(ControlFlow::Continue(()))
268 }
269
270 pub async fn execute(
271 &self,
272 ctx: Arc<InternalContext>,
273 email: Option<&EmailConfig>,
274 dry_run: bool,
275 resolved_targets: &[crate::targets::ResolvedTarget],
276 ) -> Result<()> {
277 info!(?self.file, "executing alert");
278
279 let now = chrono::Utc::now();
280 let not_before = now - self.interval_duration;
281 info!(?now, ?not_before, interval=?self.interval_duration, "date range for alert");
282
283 let mut tera_ctx = build_context(self, now);
284 if self
285 .read_sources(&ctx.pg_pool, not_before, &mut tera_ctx, false)
286 .await?
287 .is_break()
288 {
289 return Ok(());
290 }
291
292 for target in resolved_targets {
293 if let Err(err) = target.send(self, &mut tera_ctx, email, dry_run).await {
294 error!("sending: {}", LogError(&err));
295 }
296 }
297
298 Ok(())
299 }
300}
301
302#[derive(Debug, Clone)]
303pub struct InternalContext {
304 pub pg_pool: bestool_postgres::pool::PgPool,
305}
306
307fn rows_to_value_map(
308 rows: &[tokio_postgres::Row],
309) -> Vec<serde_json::Map<String, serde_json::Value>> {
310 rows.iter()
311 .map(|row| {
312 let mut map = serde_json::Map::new();
313 for (idx, column) in row.columns().iter().enumerate() {
314 let value = bestool_postgres::stringify::postgres_to_json_value(row, idx);
315 map.insert(column.name().to_string(), value);
316 }
317 map
318 })
319 .collect()
320}
321
322fn check_numerical_thresholds(
323 rows: &[serde_json::Map<String, serde_json::Value>],
324 thresholds: &[NumericalThreshold],
325 was_triggered: bool,
326) -> Result<bool> {
327 for threshold in thresholds {
328 for row in rows {
329 let value = match row.get(&threshold.field) {
330 Some(serde_json::Value::Number(n)) => n
331 .as_f64()
332 .ok_or_else(|| miette!("field '{}' is not a valid number", threshold.field))?,
333 Some(_) => {
334 return Err(miette!(
335 "field '{}' exists but is not a number",
336 threshold.field
337 ));
338 }
339 None => {
340 return Err(miette!(
341 "field '{}' not found in query results",
342 threshold.field
343 ));
344 }
345 };
346
347 let is_inverted = threshold
349 .clear_at
350 .is_some_and(|clear| clear > threshold.alert_at);
351
352 if was_triggered {
353 if let Some(clear_at) = threshold.clear_at {
355 let should_clear = if is_inverted {
356 value >= clear_at
358 } else {
359 value <= clear_at
361 };
362
363 if should_clear {
364 continue;
366 } else {
367 return Ok(true);
369 }
370 } else {
371 let still_triggered = if is_inverted {
373 value <= threshold.alert_at
374 } else {
375 value >= threshold.alert_at
376 };
377
378 if still_triggered {
379 return Ok(true);
380 }
381 }
382 } else {
383 let should_trigger = if is_inverted {
385 value <= threshold.alert_at
387 } else {
388 value >= threshold.alert_at
390 };
391
392 if should_trigger {
393 return Ok(true);
394 }
395 }
396 }
397 }
398
399 Ok(false)
400}
401
402fn parse_interval(s: &str) -> Result<Duration> {
403 let s = s.trim();
404
405 if let Ok(secs) = s.parse::<u64>() {
407 return Ok(Duration::from_secs(secs));
408 }
409
410 let parts: Vec<&str> = s.split_whitespace().collect();
412 if parts.len() != 2 {
413 return Err(miette!(
414 "interval must be in format '<number> <unit>' or just '<seconds>'"
415 ));
416 }
417
418 let value: u64 = parts[0]
419 .parse()
420 .into_diagnostic()
421 .wrap_err("interval value must be a number")?;
422 let unit = parts[1].to_lowercase();
423
424 let duration = match unit.as_str() {
425 "second" | "seconds" | "s" | "sec" | "secs" => Duration::from_secs(value),
426 "minute" | "minutes" | "m" | "min" | "mins" => Duration::from_secs(value * 60),
427 "hour" | "hours" | "h" | "hr" | "hrs" => Duration::from_secs(value * 3600),
428 "day" | "days" | "d" => Duration::from_secs(value * 86400),
429 _ => {
430 return Err(miette!(
431 "unknown interval unit: {}, expected: seconds, minutes, hours, or days",
432 unit
433 ));
434 }
435 };
436
437 Ok(duration)
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_alert_with_event_source() {
446 let yaml = r#"
447event: source-error
448send:
449 - id: test-target
450 subject: Test
451 template: Test template
452"#;
453 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
454 assert!(matches!(alert.source, TicketSource::Event { .. }));
455 if let TicketSource::Event { event } = alert.source {
456 assert_eq!(event, EventType::SourceError);
457 }
458 }
459
460 #[test]
461 fn test_parse_interval() {
462 assert_eq!(parse_interval("60").unwrap(), Duration::from_secs(60));
463 assert_eq!(parse_interval("1 minute").unwrap(), Duration::from_secs(60));
464 assert_eq!(
465 parse_interval("5 minutes").unwrap(),
466 Duration::from_secs(300)
467 );
468 assert_eq!(
469 parse_interval("2 hours").unwrap(),
470 Duration::from_secs(7200)
471 );
472 assert_eq!(parse_interval("1 day").unwrap(), Duration::from_secs(86400));
473 assert_eq!(
474 parse_interval("30 seconds").unwrap(),
475 Duration::from_secs(30)
476 );
477 }
478
479 #[test]
480 fn test_default_interval() {
481 let yaml = r#"
482sql: "SELECT 1"
483send:
484 - id: test
485 subject: Test
486 template: Test
487"#;
488 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
489 assert_eq!(alert.interval, "1 minute");
490 }
491
492 #[test]
493 fn test_default_always_send() {
494 let yaml = r#"
495sql: "SELECT 1"
496send:
497 - id: test
498 subject: Test
499 template: Test
500"#;
501 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
502 assert!(matches!(alert.always_send, AlwaysSend::Boolean(false)));
503 }
504
505 #[test]
506 fn test_always_send_true() {
507 let yaml = r#"
508always-send: true
509sql: "SELECT 1"
510send:
511 - id: test
512 subject: Test
513 template: Test
514"#;
515 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
516 assert!(matches!(alert.always_send, AlwaysSend::Boolean(true)));
517 }
518
519 #[test]
520 fn test_always_send_timed() {
521 let yaml = r#"
522always-send:
523 after: 8h
524sql: "SELECT 1"
525send:
526 - id: test
527 subject: Test
528 template: Test
529"#;
530 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
531 match alert.always_send {
532 AlwaysSend::Timed(config) => {
533 assert_eq!(config.after, "8h");
534 }
535 _ => panic!("Expected AlwaysSend::Timed"),
536 }
537 }
538
539 #[test]
540 fn test_always_send_timed_normalised() {
541 let yaml = r#"
542always-send:
543 after: 8 hours
544sql: "SELECT 1"
545send:
546 - id: test
547 subject: Test
548 template: Test
549"#;
550 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
551 let external_targets = std::collections::HashMap::new();
552 let (normalised, _) = alert.normalise(&external_targets).unwrap();
553
554 match normalised.always_send {
555 AlwaysSend::Timed(config) => {
556 assert_eq!(config.after, "8 hours");
557 assert_eq!(
558 config.after_duration,
559 std::time::Duration::from_secs(8 * 3600)
560 );
561 }
562 _ => panic!("Expected AlwaysSend::Timed"),
563 }
564 }
565
566 #[test]
567 fn test_numerical_threshold_normal() {
568 let yaml = r#"
569sql: "SELECT 1"
570numerical:
571 - field: count
572 alert-at: 100
573 clear-at: 50
574send:
575 - id: test
576 subject: Test
577 template: Test
578"#;
579 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
580 if let TicketSource::Sql { numerical, .. } = &alert.source {
581 assert_eq!(numerical.len(), 1);
582 assert_eq!(numerical[0].field, "count");
583 assert_eq!(numerical[0].alert_at, 100.0);
584 assert_eq!(numerical[0].clear_at, Some(50.0));
585 } else {
586 panic!("Expected Sql source");
587 }
588 }
589
590 #[test]
591 fn test_numerical_threshold_inverted() {
592 let yaml = r#"
593sql: "SELECT 1"
594numerical:
595 - field: free_space
596 alert-at: 10
597 clear-at: 50
598send:
599 - id: test
600 subject: Test
601 template: Test
602"#;
603 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
604 if let TicketSource::Sql { numerical, .. } = &alert.source {
605 assert_eq!(numerical.len(), 1);
606 assert_eq!(numerical[0].field, "free_space");
607 assert_eq!(numerical[0].alert_at, 10.0);
608 assert_eq!(numerical[0].clear_at, Some(50.0));
609 } else {
610 panic!("Expected Sql source");
611 }
612 }
613
614 #[test]
615 fn test_numerical_threshold_no_clear() {
616 let yaml = r#"
617sql: "SELECT 1"
618numerical:
619 - field: errors
620 alert-at: 5
621send:
622 - id: test
623 subject: Test
624 template: Test
625"#;
626 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
627 if let TicketSource::Sql { numerical, .. } = &alert.source {
628 assert_eq!(numerical.len(), 1);
629 assert_eq!(numerical[0].field, "errors");
630 assert_eq!(numerical[0].alert_at, 5.0);
631 assert_eq!(numerical[0].clear_at, None);
632 } else {
633 panic!("Expected Sql source");
634 }
635 }
636
637 #[test]
638 fn test_check_numerical_thresholds_normal_trigger() {
639 let mut row = serde_json::Map::new();
640 row.insert("count".to_string(), serde_json::Value::Number(150.into()));
641 let rows = vec![row];
642
643 let threshold = NumericalThreshold {
644 field: "count".to_string(),
645 alert_at: 100.0,
646 clear_at: Some(50.0),
647 };
648
649 let result =
651 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), false).unwrap();
652 assert!(result);
653
654 let result =
656 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), true).unwrap();
657 assert!(result);
658
659 let mut row = serde_json::Map::new();
661 row.insert("count".to_string(), serde_json::Value::Number(30.into()));
662 let rows = vec![row];
663 let result = check_numerical_thresholds(&rows, &[threshold], true).unwrap();
664 assert!(!result);
665 }
666
667 #[test]
668 fn test_check_numerical_thresholds_inverted_trigger() {
669 let mut row = serde_json::Map::new();
670 row.insert(
671 "free_space".to_string(),
672 serde_json::Value::Number(5.into()),
673 );
674 let rows = vec![row];
675
676 let threshold = NumericalThreshold {
677 field: "free_space".to_string(),
678 alert_at: 10.0,
679 clear_at: Some(50.0), };
681
682 let result =
684 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), false).unwrap();
685 assert!(result);
686
687 let result =
689 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), true).unwrap();
690 assert!(result);
691
692 let mut row = serde_json::Map::new();
694 row.insert(
695 "free_space".to_string(),
696 serde_json::Value::Number(60.into()),
697 );
698 let rows = vec![row];
699 let result = check_numerical_thresholds(&rows, &[threshold], true).unwrap();
700 assert!(!result);
701 }
702
703 #[test]
704 fn test_check_numerical_thresholds_no_clear_at() {
705 let threshold = NumericalThreshold {
706 field: "errors".to_string(),
707 alert_at: 5.0,
708 clear_at: None,
709 };
710
711 let mut row = serde_json::Map::new();
713 row.insert("errors".to_string(), serde_json::Value::Number(10.into()));
714 let rows = vec![row];
715 let result =
716 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), false).unwrap();
717 assert!(result);
718
719 let result =
721 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), true).unwrap();
722 assert!(result);
723
724 let mut row = serde_json::Map::new();
726 row.insert("errors".to_string(), serde_json::Value::Number(3.into()));
727 let rows = vec![row];
728 let result = check_numerical_thresholds(&rows, &[threshold], true).unwrap();
729 assert!(!result);
730 }
731
732 #[test]
733 fn test_when_changed_boolean_true() {
734 let yaml = r#"
735sql: "SELECT 1"
736when-changed: true
737send:
738 - id: test
739 subject: Test
740 template: Test
741"#;
742 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
743 assert!(matches!(alert.when_changed, WhenChanged::Boolean(true)));
744 }
745
746 #[test]
747 fn test_when_changed_boolean_false() {
748 let yaml = r#"
749sql: "SELECT 1"
750when-changed: false
751send:
752 - id: test
753 subject: Test
754 template: Test
755"#;
756 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
757 assert!(matches!(alert.when_changed, WhenChanged::Boolean(false)));
758 }
759
760 #[test]
761 fn test_when_changed_default() {
762 let yaml = r#"
763sql: "SELECT 1"
764send:
765 - id: test
766 subject: Test
767 template: Test
768"#;
769 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
770 assert!(matches!(alert.when_changed, WhenChanged::Boolean(false)));
771 }
772
773 #[test]
774 fn test_when_changed_except() {
775 let yaml = r#"
776sql: "SELECT 1"
777when-changed:
778 except: [created_at, updated_at]
779send:
780 - id: test
781 subject: Test
782 template: Test
783"#;
784 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
785 match &alert.when_changed {
786 WhenChanged::Detailed(config) => {
787 assert_eq!(config.except, vec!["created_at", "updated_at"]);
788 assert!(config.only.is_empty());
789 }
790 _ => panic!("Expected Detailed variant"),
791 }
792 }
793
794 #[test]
795 fn test_when_changed_only() {
796 let yaml = r#"
797sql: "SELECT 1"
798when-changed:
799 only: [error, message]
800send:
801 - id: test
802 subject: Test
803 template: Test
804"#;
805 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
806 match &alert.when_changed {
807 WhenChanged::Detailed(config) => {
808 assert_eq!(config.only, vec!["error", "message"]);
809 assert!(config.except.is_empty());
810 }
811 _ => panic!("Expected Detailed variant"),
812 }
813 }
814}