1use std::{
2 collections::HashMap, io::Write, ops::ControlFlow, path::PathBuf, process::Stdio, sync::Arc,
3 time::Duration,
4};
5
6use chrono::{DateTime, Utc};
7use miette::{Context as _, IntoDiagnostic, Result, miette};
8use tera::Context as TeraCtx;
9use tokio::io::AsyncReadExt as _;
10use tokio_postgres::types::ToSql;
11use tracing::{debug, error, info, instrument, warn};
12
13use crate::{
14 EmailConfig, LogError, events::EventType, targets::ExternalTarget, templates::build_context,
15};
16
17fn enabled() -> bool {
18 true
19}
20
21fn default_interval() -> String {
22 "1 minute".to_string()
23}
24
25#[derive(serde::Deserialize, facet::Facet, Debug, Clone)]
26#[facet(rename_all = "kebab-case")]
27#[serde(rename_all = "kebab-case")]
28pub struct NumericalThreshold {
29 pub field: String,
30 pub alert_at: f64,
31 pub clear_at: Option<f64>,
32}
33
34#[derive(serde::Deserialize, Debug, Clone)]
35#[serde(rename_all = "kebab-case")]
36#[serde(untagged)]
37pub enum WhenChanged {
38 Boolean(bool),
39 Detailed(WhenChangedConfig),
40}
41
42impl Default for WhenChanged {
43 fn default() -> Self {
44 WhenChanged::Boolean(false)
45 }
46}
47
48#[derive(serde::Deserialize, facet::Facet, Debug, Clone)]
49#[facet(rename_all = "kebab-case")]
50#[serde(rename_all = "kebab-case")]
51pub struct WhenChangedConfig {
52 #[serde(default)]
53 pub except: Vec<String>,
54 #[serde(default)]
55 pub only: Vec<String>,
56}
57
58#[derive(serde::Deserialize, Debug, Clone)]
59#[serde(rename_all = "kebab-case")]
60#[serde(untagged)]
61pub enum AlwaysSend {
62 Boolean(bool),
63 Timed(AlwaysSendConfig),
64}
65
66impl Default for AlwaysSend {
67 fn default() -> Self {
68 AlwaysSend::Boolean(false)
69 }
70}
71
72#[derive(serde::Deserialize, Debug, Clone)]
73#[serde(rename_all = "kebab-case")]
74pub struct AlwaysSendConfig {
75 pub after: String,
76 #[serde(skip)]
77 pub after_duration: Duration,
78}
79
80#[derive(serde::Deserialize, Debug, Default, Clone)]
81#[serde(rename_all = "kebab-case")]
82pub struct AlertDefinition {
83 #[serde(default, skip)]
84 pub file: PathBuf,
85
86 #[serde(default = "enabled")]
87 pub enabled: bool,
88
89 #[serde(default = "default_interval")]
90 pub interval: String,
91
92 #[serde(skip)]
93 pub interval_duration: Duration,
94
95 #[serde(default)]
96 pub always_send: AlwaysSend,
97
98 #[serde(default)]
99 pub when_changed: WhenChanged,
100
101 #[serde(default)]
102 pub send: Vec<crate::targets::SendTarget>,
103
104 #[serde(flatten)]
105 pub source: TicketSource,
106}
107
108#[derive(serde::Deserialize, Debug, Default, Clone)]
109#[serde(untagged, deny_unknown_fields)]
110pub enum TicketSource {
111 Sql {
112 sql: String,
113 #[serde(default)]
114 numerical: Vec<NumericalThreshold>,
115 },
116 Shell {
117 shell: String,
118 run: String,
119 },
120 Event {
121 event: EventType,
122 },
123
124 #[default]
125 None,
126}
127
128impl AlertDefinition {
129 pub fn normalise(
130 mut self,
131 external_targets: &HashMap<String, Vec<ExternalTarget>>,
132 ) -> Result<(Self, Vec<crate::targets::ResolvedTarget>)> {
133 self.interval_duration = parse_interval(&self.interval)
135 .wrap_err_with(|| format!("failed to parse interval: {}", self.interval))?;
136
137 if let AlwaysSend::Timed(ref mut config) = self.always_send {
139 config.after_duration = parse_interval(&config.after)
140 .wrap_err_with(|| format!("failed to parse always-send after: {}", config.after))?;
141 }
142
143 for (idx, target) in self.send.iter().enumerate() {
146 crate::templates::load_templates(target.subject(), target.template()).wrap_err_with(
147 || {
148 format!(
149 "validating templates for send target #{} (id: {})",
150 idx + 1,
151 target.id()
152 )
153 },
154 )?;
155 }
156
157 let resolved = self
158 .send
159 .iter()
160 .flat_map(|target| {
161 let resolved_targets = target.resolve_external(external_targets);
162 if resolved_targets.is_empty() {
163 error!(
164 file=?self.file,
165 id = %target.id(),
166 available_targets=?external_targets.keys().collect::<Vec<_>>(),
167 "external target not found"
168 );
169 }
170 resolved_targets
171 })
172 .collect();
173
174 self.send.clear(); Ok((self, resolved))
176 }
177
178 #[instrument(skip(self, pool, not_before, context))]
179 pub async fn read_sources(
180 &self,
181 pool: &bestool_postgres::pool::PgPool,
182 not_before: DateTime<Utc>,
183 context: &mut TeraCtx,
184 was_triggered: bool,
185 ) -> Result<ControlFlow<(), ()>> {
186 match &self.source {
187 TicketSource::None => {
188 debug!(?self.file, "no source, skipping");
189 return Ok(ControlFlow::Break(()));
190 }
191 TicketSource::Event { .. } => {
192 debug!(?self.file, "event source, skipping normal execution");
194 return Ok(ControlFlow::Break(()));
195 }
196 TicketSource::Sql { sql, numerical } => {
197 let client = pool
198 .get()
199 .await
200 .map_err(|e| miette!("getting connection from pool: {e}"))?;
201 let statement = client.prepare(sql).await.into_diagnostic()?;
202
203 let interval = bestool_postgres::pg_interval::Interval(self.interval_duration);
204 let all_params: Vec<&(dyn ToSql + Sync)> = vec![¬_before, &interval];
205
206 let rows = client
207 .query(&statement, &all_params[..statement.params().len()])
208 .await
209 .into_diagnostic()
210 .wrap_err("querying database")?;
211
212 if rows.is_empty() {
213 debug!(?self.file, "no rows returned, skipping");
214 return Ok(ControlFlow::Break(()));
215 }
216
217 let context_rows = rows_to_value_map(&rows);
218
219 if !numerical.is_empty() {
221 let triggered =
222 check_numerical_thresholds(&context_rows, numerical, was_triggered)?;
223 if !triggered {
224 debug!(?self.file, "numerical thresholds not met, skipping");
225 return Ok(ControlFlow::Break(()));
226 }
227 }
228
229 info!(?self.file, rows=%rows.len(), "alert triggered");
230 context.insert("rows", &context_rows);
231 }
232 TicketSource::Shell { shell, run } => {
233 let mut script = tempfile::Builder::new().tempfile().into_diagnostic()?;
234 write!(script.as_file_mut(), "{run}").into_diagnostic()?;
235
236 let mut shell = tokio::process::Command::new(shell)
237 .arg(script.path())
238 .stdin(Stdio::null())
239 .stdout(Stdio::piped())
240 .spawn()
241 .into_diagnostic()?;
242
243 let mut output = Vec::new();
244 let mut stdout = shell
245 .stdout
246 .take()
247 .ok_or_else(|| miette!("getting the child stdout handle"))?;
248 let output_future =
249 futures::future::try_join(shell.wait(), stdout.read_to_end(&mut output));
250
251 let Ok(res) = tokio::time::timeout(self.interval_duration, output_future).await
252 else {
253 warn!(?self.file, "the script timed out, skipping");
254 shell.kill().await.into_diagnostic()?;
255 return Ok(ControlFlow::Break(()));
256 };
257
258 let (status, output_size) = res.into_diagnostic().wrap_err("running the shell")?;
259
260 if status.success() {
261 debug!(?self.file, "the script succeeded, skipping");
262 return Ok(ControlFlow::Break(()));
263 }
264 info!(?self.file, ?status, ?output_size, "alert triggered");
265
266 context.insert("output", &String::from_utf8_lossy(&output));
267 }
268 }
269 Ok(ControlFlow::Continue(()))
270 }
271
272 pub async fn execute(
273 &self,
274 ctx: Arc<InternalContext>,
275 email: Option<&EmailConfig>,
276 dry_run: bool,
277 resolved_targets: &[crate::targets::ResolvedTarget],
278 ) -> Result<()> {
279 info!(?self.file, "executing alert");
280
281 let now = chrono::Utc::now();
282 let not_before = now - self.interval_duration;
283 info!(?now, ?not_before, interval=?self.interval_duration, "date range for alert");
284
285 let mut tera_ctx = build_context(self, now);
286 if self
287 .read_sources(&ctx.pg_pool, not_before, &mut tera_ctx, false)
288 .await?
289 .is_break()
290 {
291 return Ok(());
292 }
293
294 for target in resolved_targets {
295 if let Err(err) = target.send(self, &mut tera_ctx, email, dry_run).await {
296 error!("sending: {}", LogError(&err));
297 }
298 }
299
300 Ok(())
301 }
302}
303
304#[derive(Debug, Clone)]
305pub struct InternalContext {
306 pub pg_pool: bestool_postgres::pool::PgPool,
307}
308
309fn rows_to_value_map(
310 rows: &[tokio_postgres::Row],
311) -> Vec<serde_json::Map<String, serde_json::Value>> {
312 rows.iter()
313 .map(|row| {
314 let mut map = serde_json::Map::new();
315 for (idx, column) in row.columns().iter().enumerate() {
316 let value = bestool_postgres::stringify::postgres_to_json_value(row, idx);
317 map.insert(column.name().to_string(), value);
318 }
319 map
320 })
321 .collect()
322}
323
324fn check_numerical_thresholds(
325 rows: &[serde_json::Map<String, serde_json::Value>],
326 thresholds: &[NumericalThreshold],
327 was_triggered: bool,
328) -> Result<bool> {
329 for threshold in thresholds {
330 for row in rows {
331 let value = match row.get(&threshold.field) {
332 Some(serde_json::Value::Number(n)) => n
333 .as_f64()
334 .ok_or_else(|| miette!("field '{}' is not a valid number", threshold.field))?,
335 Some(_) => {
336 return Err(miette!(
337 "field '{}' exists but is not a number",
338 threshold.field
339 ));
340 }
341 None => {
342 return Err(miette!(
343 "field '{}' not found in query results",
344 threshold.field
345 ));
346 }
347 };
348
349 let is_inverted = threshold
351 .clear_at
352 .is_some_and(|clear| clear > threshold.alert_at);
353
354 if was_triggered {
355 if let Some(clear_at) = threshold.clear_at {
357 let should_clear = if is_inverted {
358 value >= clear_at
360 } else {
361 value <= clear_at
363 };
364
365 if should_clear {
366 continue;
368 } else {
369 return Ok(true);
371 }
372 } else {
373 let still_triggered = if is_inverted {
375 value <= threshold.alert_at
376 } else {
377 value >= threshold.alert_at
378 };
379
380 if still_triggered {
381 return Ok(true);
382 }
383 }
384 } else {
385 let should_trigger = if is_inverted {
387 value <= threshold.alert_at
389 } else {
390 value >= threshold.alert_at
392 };
393
394 if should_trigger {
395 return Ok(true);
396 }
397 }
398 }
399 }
400
401 Ok(false)
402}
403
404fn parse_interval(s: &str) -> Result<Duration> {
405 let s = s.trim();
406
407 if let Ok(secs) = s.parse::<u64>() {
409 return Ok(Duration::from_secs(secs));
410 }
411
412 let parts: Vec<&str> = s.split_whitespace().collect();
414 if parts.len() != 2 {
415 return Err(miette!(
416 "interval must be in format '<number> <unit>' or just '<seconds>'"
417 ));
418 }
419
420 let value: u64 = parts[0]
421 .parse()
422 .into_diagnostic()
423 .wrap_err("interval value must be a number")?;
424 let unit = parts[1].to_lowercase();
425
426 let duration = match unit.as_str() {
427 "second" | "seconds" | "s" | "sec" | "secs" => Duration::from_secs(value),
428 "minute" | "minutes" | "m" | "min" | "mins" => Duration::from_secs(value * 60),
429 "hour" | "hours" | "h" | "hr" | "hrs" => Duration::from_secs(value * 3600),
430 "day" | "days" | "d" => Duration::from_secs(value * 86400),
431 _ => {
432 return Err(miette!(
433 "unknown interval unit: {}, expected: seconds, minutes, hours, or days",
434 unit
435 ));
436 }
437 };
438
439 Ok(duration)
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_alert_with_event_source() {
448 let yaml = r#"
449event: source-error
450send:
451 - id: test-target
452 subject: Test
453 template: Test template
454"#;
455 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
456 assert!(matches!(alert.source, TicketSource::Event { .. }));
457 if let TicketSource::Event { event } = alert.source {
458 assert_eq!(event, EventType::SourceError);
459 }
460 }
461
462 #[test]
463 fn test_parse_interval() {
464 assert_eq!(parse_interval("60").unwrap(), Duration::from_secs(60));
465 assert_eq!(parse_interval("1 minute").unwrap(), Duration::from_secs(60));
466 assert_eq!(
467 parse_interval("5 minutes").unwrap(),
468 Duration::from_secs(300)
469 );
470 assert_eq!(
471 parse_interval("2 hours").unwrap(),
472 Duration::from_secs(7200)
473 );
474 assert_eq!(parse_interval("1 day").unwrap(), Duration::from_secs(86400));
475 assert_eq!(
476 parse_interval("30 seconds").unwrap(),
477 Duration::from_secs(30)
478 );
479 }
480
481 #[test]
482 fn test_default_interval() {
483 let yaml = r#"
484sql: "SELECT 1"
485send:
486 - id: test
487 subject: Test
488 template: Test
489"#;
490 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
491 assert_eq!(alert.interval, "1 minute");
492 }
493
494 #[test]
495 fn test_default_always_send() {
496 let yaml = r#"
497sql: "SELECT 1"
498send:
499 - id: test
500 subject: Test
501 template: Test
502"#;
503 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
504 assert!(matches!(alert.always_send, AlwaysSend::Boolean(false)));
505 }
506
507 #[test]
508 fn test_always_send_true() {
509 let yaml = r#"
510always-send: true
511sql: "SELECT 1"
512send:
513 - id: test
514 subject: Test
515 template: Test
516"#;
517 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
518 assert!(matches!(alert.always_send, AlwaysSend::Boolean(true)));
519 }
520
521 #[test]
522 fn test_always_send_timed() {
523 let yaml = r#"
524always-send:
525 after: 8h
526sql: "SELECT 1"
527send:
528 - id: test
529 subject: Test
530 template: Test
531"#;
532 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
533 match alert.always_send {
534 AlwaysSend::Timed(config) => {
535 assert_eq!(config.after, "8h");
536 }
537 _ => panic!("Expected AlwaysSend::Timed"),
538 }
539 }
540
541 #[test]
542 fn test_always_send_timed_normalised() {
543 let yaml = r#"
544always-send:
545 after: 8 hours
546sql: "SELECT 1"
547send:
548 - id: test
549 subject: Test
550 template: Test
551"#;
552 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
553 let external_targets = std::collections::HashMap::new();
554 let (normalised, _) = alert.normalise(&external_targets).unwrap();
555
556 match normalised.always_send {
557 AlwaysSend::Timed(config) => {
558 assert_eq!(config.after, "8 hours");
559 assert_eq!(
560 config.after_duration,
561 std::time::Duration::from_secs(8 * 3600)
562 );
563 }
564 _ => panic!("Expected AlwaysSend::Timed"),
565 }
566 }
567
568 #[test]
569 fn test_numerical_threshold_normal() {
570 let yaml = r#"
571sql: "SELECT 1"
572numerical:
573 - field: count
574 alert-at: 100
575 clear-at: 50
576send:
577 - id: test
578 subject: Test
579 template: Test
580"#;
581 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
582 if let TicketSource::Sql { numerical, .. } = &alert.source {
583 assert_eq!(numerical.len(), 1);
584 assert_eq!(numerical[0].field, "count");
585 assert_eq!(numerical[0].alert_at, 100.0);
586 assert_eq!(numerical[0].clear_at, Some(50.0));
587 } else {
588 panic!("Expected Sql source");
589 }
590 }
591
592 #[test]
593 fn test_numerical_threshold_inverted() {
594 let yaml = r#"
595sql: "SELECT 1"
596numerical:
597 - field: free_space
598 alert-at: 10
599 clear-at: 50
600send:
601 - id: test
602 subject: Test
603 template: Test
604"#;
605 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
606 if let TicketSource::Sql { numerical, .. } = &alert.source {
607 assert_eq!(numerical.len(), 1);
608 assert_eq!(numerical[0].field, "free_space");
609 assert_eq!(numerical[0].alert_at, 10.0);
610 assert_eq!(numerical[0].clear_at, Some(50.0));
611 } else {
612 panic!("Expected Sql source");
613 }
614 }
615
616 #[test]
617 fn test_numerical_threshold_no_clear() {
618 let yaml = r#"
619sql: "SELECT 1"
620numerical:
621 - field: errors
622 alert-at: 5
623send:
624 - id: test
625 subject: Test
626 template: Test
627"#;
628 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
629 if let TicketSource::Sql { numerical, .. } = &alert.source {
630 assert_eq!(numerical.len(), 1);
631 assert_eq!(numerical[0].field, "errors");
632 assert_eq!(numerical[0].alert_at, 5.0);
633 assert_eq!(numerical[0].clear_at, None);
634 } else {
635 panic!("Expected Sql source");
636 }
637 }
638
639 #[test]
640 fn test_check_numerical_thresholds_normal_trigger() {
641 let mut row = serde_json::Map::new();
642 row.insert("count".to_string(), serde_json::Value::Number(150.into()));
643 let rows = vec![row];
644
645 let threshold = NumericalThreshold {
646 field: "count".to_string(),
647 alert_at: 100.0,
648 clear_at: Some(50.0),
649 };
650
651 let result =
653 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), false).unwrap();
654 assert!(result);
655
656 let result =
658 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), true).unwrap();
659 assert!(result);
660
661 let mut row = serde_json::Map::new();
663 row.insert("count".to_string(), serde_json::Value::Number(30.into()));
664 let rows = vec![row];
665 let result = check_numerical_thresholds(&rows, &[threshold], true).unwrap();
666 assert!(!result);
667 }
668
669 #[test]
670 fn test_check_numerical_thresholds_inverted_trigger() {
671 let mut row = serde_json::Map::new();
672 row.insert(
673 "free_space".to_string(),
674 serde_json::Value::Number(5.into()),
675 );
676 let rows = vec![row];
677
678 let threshold = NumericalThreshold {
679 field: "free_space".to_string(),
680 alert_at: 10.0,
681 clear_at: Some(50.0), };
683
684 let result =
686 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), false).unwrap();
687 assert!(result);
688
689 let result =
691 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), true).unwrap();
692 assert!(result);
693
694 let mut row = serde_json::Map::new();
696 row.insert(
697 "free_space".to_string(),
698 serde_json::Value::Number(60.into()),
699 );
700 let rows = vec![row];
701 let result = check_numerical_thresholds(&rows, &[threshold], true).unwrap();
702 assert!(!result);
703 }
704
705 #[test]
706 fn test_check_numerical_thresholds_no_clear_at() {
707 let threshold = NumericalThreshold {
708 field: "errors".to_string(),
709 alert_at: 5.0,
710 clear_at: None,
711 };
712
713 let mut row = serde_json::Map::new();
715 row.insert("errors".to_string(), serde_json::Value::Number(10.into()));
716 let rows = vec![row];
717 let result =
718 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), false).unwrap();
719 assert!(result);
720
721 let result =
723 check_numerical_thresholds(&rows, std::slice::from_ref(&threshold), true).unwrap();
724 assert!(result);
725
726 let mut row = serde_json::Map::new();
728 row.insert("errors".to_string(), serde_json::Value::Number(3.into()));
729 let rows = vec![row];
730 let result = check_numerical_thresholds(&rows, &[threshold], true).unwrap();
731 assert!(!result);
732 }
733
734 #[test]
735 fn test_when_changed_boolean_true() {
736 let yaml = r#"
737sql: "SELECT 1"
738when-changed: true
739send:
740 - id: test
741 subject: Test
742 template: Test
743"#;
744 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
745 assert!(matches!(alert.when_changed, WhenChanged::Boolean(true)));
746 }
747
748 #[test]
749 fn test_when_changed_boolean_false() {
750 let yaml = r#"
751sql: "SELECT 1"
752when-changed: false
753send:
754 - id: test
755 subject: Test
756 template: Test
757"#;
758 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
759 assert!(matches!(alert.when_changed, WhenChanged::Boolean(false)));
760 }
761
762 #[test]
763 fn test_when_changed_default() {
764 let yaml = r#"
765sql: "SELECT 1"
766send:
767 - id: test
768 subject: Test
769 template: Test
770"#;
771 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
772 assert!(matches!(alert.when_changed, WhenChanged::Boolean(false)));
773 }
774
775 #[test]
776 fn test_when_changed_except() {
777 let yaml = r#"
778sql: "SELECT 1"
779when-changed:
780 except: [created_at, updated_at]
781send:
782 - id: test
783 subject: Test
784 template: Test
785"#;
786 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
787 match &alert.when_changed {
788 WhenChanged::Detailed(config) => {
789 assert_eq!(config.except, vec!["created_at", "updated_at"]);
790 assert!(config.only.is_empty());
791 }
792 _ => panic!("Expected Detailed variant"),
793 }
794 }
795
796 #[test]
797 fn test_when_changed_only() {
798 let yaml = r#"
799sql: "SELECT 1"
800when-changed:
801 only: [error, message]
802send:
803 - id: test
804 subject: Test
805 template: Test
806"#;
807 let alert: AlertDefinition = serde_yaml::from_str(yaml).unwrap();
808 match &alert.when_changed {
809 WhenChanged::Detailed(config) => {
810 assert_eq!(config.only, vec!["error", "message"]);
811 assert!(config.except.is_empty());
812 }
813 _ => panic!("Expected Detailed variant"),
814 }
815 }
816}