1use crate::dispatch::error::DispatchError;
2use scouter_types::{
3 AlertDispatchConfig, AlertDispatchType, DispatchAlertDescription, DispatchDriftConfig,
4 DriftArgs, OpsGenieDispatchConfig, SlackDispatchConfig,
5};
6use serde_json::{json, Value};
7use std::result::Result;
8use std::{collections::HashMap, env};
9use tracing::error;
10
11trait DispatchHelpers {
12 fn construct_alert_description<T: DispatchAlertDescription>(
13 &self,
14 feature_alerts: &T,
15 ) -> String;
16}
17pub trait Dispatch {
18 fn process_alerts<T: DispatchAlertDescription + std::marker::Sync>(
19 &self,
20 feature_alerts: &T,
21 ) -> impl std::future::Future<Output = Result<(), DispatchError>> + Send;
22}
23pub trait HttpAlertWrapper {
24 fn api_url(&self) -> &str;
25 fn header_auth_value(&self) -> &str;
26 fn construct_alert_body(&self, alert_description: &str) -> Value;
27}
28
29#[derive(Debug)]
30pub struct OpsGenieAlerter {
31 header_auth_value: String,
32 api_url: String,
33 team_name: String,
34 priority: String,
35 name: String,
36 space: String,
37 version: String,
38}
39
40impl OpsGenieAlerter {
41 pub fn new(
51 name: &str,
52 space: &str,
53 version: &str,
54 dispatch_config: &OpsGenieDispatchConfig,
55 ) -> Result<Self, DispatchError> {
56 let api_key = env::var("OPSGENIE_API_KEY")
57 .map_err(|_| DispatchError::OpsGenieError("OPSGENIE_API_KEY is not set".to_string()))?;
58
59 let api_url = env::var("OPSGENIE_API_URL")
60 .map_err(|_| DispatchError::OpsGenieError("OPSGENIE_API_URL is not set".to_string()))?;
61
62 let team_name = dispatch_config.team.clone();
63 let priority = dispatch_config.priority.clone();
64
65 Ok(Self {
66 header_auth_value: format!("GenieKey {}", api_key),
67 api_url,
68 team_name,
69 name: name.to_string(),
70 space: space.to_string(),
71 version: version.to_string(),
72 priority,
73 })
74 }
75}
76
77impl HttpAlertWrapper for OpsGenieAlerter {
78 fn api_url(&self) -> &str {
79 &self.api_url
80 }
81
82 fn header_auth_value(&self) -> &str {
83 &self.header_auth_value
84 }
85
86 fn construct_alert_body(&self, alert_description: &str) -> Value {
87 let mut mapping: HashMap<&str, Value> = HashMap::new();
88 mapping.insert(
89 "message",
90 format!(
91 "Model drift detected for {}/{}/{}",
92 self.space, self.name, self.version
93 )
94 .into(),
95 );
96 mapping.insert("description", alert_description.to_string().into());
97 mapping.insert(
98 "responders",
99 json!([{"name": self.team_name, "type": "team"}]),
100 );
101 mapping.insert(
102 "visibleTo",
103 json!([{"name": self.team_name, "type": "team"}]),
104 );
105
106 mapping.insert("tags", json!(["Model Drift", "Scouter"]));
107 mapping.insert("priority", self.priority.clone().into());
108
109 json!(mapping)
110 }
111}
112impl DispatchHelpers for OpsGenieAlerter {
113 fn construct_alert_description<T: DispatchAlertDescription>(
114 &self,
115 feature_alerts: &T,
116 ) -> String {
117 feature_alerts.create_alert_description(AlertDispatchType::OpsGenie)
118 }
119}
120
121#[derive(Debug)]
122pub struct SlackAlerter {
123 header_auth_value: String,
124 api_url: String,
125 name: String,
126 space: String,
127 version: String,
128 channel: String,
129}
130
131impl SlackAlerter {
132 pub fn new(
142 name: &str,
143 space: &str,
144 version: &str,
145 dispatch_config: &SlackDispatchConfig,
146 ) -> Result<Self, DispatchError> {
147 let app_token = env::var("SLACK_APP_TOKEN")
148 .map_err(|_| DispatchError::SlackError("SLACK_APP_TOKEN not set".to_string()))?;
149
150 let api_url = env::var("SLACK_API_URL")
151 .map_err(|_| DispatchError::SlackError("SLACK_API_URL not set".to_string()))?;
152
153 let slack_channel = dispatch_config.channel.clone();
154
155 Ok(Self {
156 header_auth_value: format!("Bearer {}", app_token),
157 api_url: format!("{}/chat.postMessage", api_url),
158 name: name.to_string(),
159 space: space.to_string(),
160 version: version.to_string(),
161 channel: slack_channel,
162 })
163 }
164}
165
166impl HttpAlertWrapper for SlackAlerter {
167 fn api_url(&self) -> &str {
168 &self.api_url
169 }
170
171 fn header_auth_value(&self) -> &str {
172 &self.header_auth_value
173 }
174
175 fn construct_alert_body(&self, alert_description: &str) -> Value {
176 json!({
177 "channel": self.channel,
178 "blocks": [
179 {
180 "type": "header",
181 "text": {
182 "type": "plain_text",
183 "text": ":rotating_light: Drift Detected :rotating_light:",
184 "emoji": true
185 }
186 },
187 {
188 "type": "section",
189 "text": {
190 "type": "mrkdwn",
191 "text": format!("*Name*: {} *Space*: {} *Version*: {}", self.name, self.space, self.version),
192 }
193 },
194 {
195 "type": "section",
196 "text": {
197 "type": "mrkdwn",
198 "text": alert_description
199 },
200
201 }
202 ]
203 })
204 }
205}
206
207impl DispatchHelpers for SlackAlerter {
208 fn construct_alert_description<T: DispatchAlertDescription>(
209 &self,
210 feature_alerts: &T,
211 ) -> String {
212 feature_alerts.create_alert_description(AlertDispatchType::Slack)
213 }
214}
215
216#[derive(Debug)]
217pub struct HttpAlertDispatcher<T: HttpAlertWrapper> {
218 http_client: reqwest::Client,
219 alerter: T,
220}
221
222impl<T: HttpAlertWrapper> HttpAlertDispatcher<T> {
223 pub fn new(alerter: T) -> Self {
224 Self {
225 http_client: reqwest::Client::new(),
226 alerter,
227 }
228 }
229
230 async fn send_alerts(&self, body: Value) -> Result<(), DispatchError> {
231 let response = self
232 .http_client
233 .post(self.alerter.api_url())
234 .header("Authorization", self.alerter.header_auth_value())
235 .json(&body)
236 .send()
237 .await
238 .map_err(|e| DispatchError::HttpError(e.to_string()))?;
239
240 if response.status().is_success() {
241 Ok(())
242 } else {
243 let text = response
244 .text()
245 .await
246 .unwrap_or("Failed to parse response".to_string());
247 error!("Failed to send alert: {}. Continuing", text);
248 Ok(())
249 }
250 }
251}
252
253impl<T: HttpAlertWrapper + DispatchHelpers + std::marker::Sync> Dispatch
254 for HttpAlertDispatcher<T>
255{
256 async fn process_alerts<J: DispatchAlertDescription>(
257 &self,
258 feature_alerts: &J,
259 ) -> Result<(), DispatchError> {
260 let alert_description = self.alerter.construct_alert_description(feature_alerts);
261
262 let alert_body = self.alerter.construct_alert_body(&alert_description);
263
264 self.send_alerts(alert_body)
265 .await
266 .map_err(|e| DispatchError::HttpError(format!("Failed to send alerts: {}", e)))?;
267
268 Ok(())
269 }
270}
271
272#[derive(Debug)]
273pub struct ConsoleAlertDispatcher {
274 name: String,
275 space: String,
276 version: String,
277}
278
279impl ConsoleAlertDispatcher {
280 pub fn new(name: &str, space: &str, version: &str) -> Self {
281 Self {
282 name: name.to_string(),
283 space: space.to_string(),
284 version: version.to_string(),
285 }
286 }
287}
288
289impl Dispatch for ConsoleAlertDispatcher {
290 async fn process_alerts<T: DispatchAlertDescription>(
291 &self,
292 feature_alerts: &T,
293 ) -> Result<(), DispatchError> {
294 let alert_description = self.construct_alert_description(feature_alerts);
295 if !alert_description.is_empty() {
296 let msg1 = "Drift detected for";
297 let msg2 = format!("{}/{}/{}!", self.space, self.name, self.version);
298 let mut body = format!("\n{} {} \n", msg1, msg2);
299 body.push_str(&alert_description);
300
301 println!("{}", body);
302 }
303 Ok(())
304 }
305}
306
307impl DispatchHelpers for ConsoleAlertDispatcher {
308 fn construct_alert_description<T: DispatchAlertDescription>(
309 &self,
310 feature_alerts: &T,
311 ) -> String {
312 feature_alerts.create_alert_description(AlertDispatchType::Console)
313 }
314}
315
316#[derive(Debug)]
317pub enum AlertDispatcher {
318 Console(ConsoleAlertDispatcher),
319 OpsGenie(HttpAlertDispatcher<OpsGenieAlerter>),
320 Slack(HttpAlertDispatcher<SlackAlerter>),
321}
322
323impl AlertDispatcher {
324 pub async fn process_alerts<T: DispatchAlertDescription + std::marker::Sync>(
326 &self,
327 feature_alerts: &T,
328 ) -> Result<(), DispatchError> {
329 match self {
330 AlertDispatcher::Console(dispatcher) => dispatcher
331 .process_alerts(feature_alerts)
332 .await
333 .map_err(|e| DispatchError::AlertProcessError(e.to_string())),
334 AlertDispatcher::OpsGenie(dispatcher) => dispatcher
335 .process_alerts(feature_alerts)
336 .await
337 .map_err(|e| DispatchError::AlertProcessError(e.to_string())),
338 AlertDispatcher::Slack(dispatcher) => dispatcher
339 .process_alerts(feature_alerts)
340 .await
341 .map_err(|e| DispatchError::AlertProcessError(e.to_string())),
342 }
343 }
344
345 pub fn new<T: DispatchDriftConfig>(config: &T) -> Result<Self, DispatchError> {
351 let args: DriftArgs = config.get_drift_args();
352
353 let result = match args.dispatch_config {
354 AlertDispatchConfig::Slack(config) => {
355 SlackAlerter::new(&args.name, &args.space, &args.version, &config)
356 .map(|alerter| AlertDispatcher::Slack(HttpAlertDispatcher::new(alerter)))
357 }
358 AlertDispatchConfig::OpsGenie(config) => {
359 OpsGenieAlerter::new(&args.name, &args.space, &args.version, &config)
360 .map(|alerter| AlertDispatcher::OpsGenie(HttpAlertDispatcher::new(alerter)))
361 }
362 AlertDispatchConfig::Console(_) => Ok(AlertDispatcher::Console(
363 ConsoleAlertDispatcher::new(&args.name, &args.space, &args.version),
364 )),
365 };
366
367 match result {
368 Ok(dispatcher) => Ok(dispatcher),
369 Err(e) => {
370 error!("Failed to create Alerter: {:?}", e);
371 Ok(AlertDispatcher::Console(ConsoleAlertDispatcher::new(
372 &args.name,
373 &args.space,
374 &args.version,
375 )))
376 }
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use scouter_types::spc::{
385 AlertZone, SpcAlert, SpcAlertConfig, SpcAlertType, SpcDriftConfig, SpcFeatureAlert,
386 SpcFeatureAlerts,
387 };
388
389 use std::collections::HashMap;
390 use std::env;
391
392 fn test_features_map() -> HashMap<String, SpcFeatureAlert> {
393 let mut features: HashMap<String, SpcFeatureAlert> = HashMap::new();
394
395 features.insert(
396 "test_feature_1".to_string(),
397 SpcFeatureAlert {
398 feature: "test_feature_1".to_string(),
399 alerts: vec![SpcAlert {
400 zone: AlertZone::Zone4,
401 kind: SpcAlertType::OutOfBounds,
402 }]
403 .into_iter()
404 .collect(),
405 },
406 );
407 features.insert(
408 "test_feature_2".to_string(),
409 SpcFeatureAlert {
410 feature: "test_feature_2".to_string(),
411 alerts: vec![SpcAlert {
412 zone: AlertZone::Zone1,
413 kind: SpcAlertType::Consecutive,
414 }]
415 .into_iter()
416 .collect(),
417 },
418 );
419 features
420 }
421 #[test]
422 fn test_construct_opsgenie_alert_description() {
423 unsafe {
424 env::set_var("OPSGENIE_API_URL", "api_url");
425 env::set_var("OPSGENIE_API_KEY", "api_key");
426 }
427 let features = test_features_map();
428 let alerter = OpsGenieAlerter::new(
429 "name",
430 "space",
431 "1.0.0",
432 &OpsGenieDispatchConfig {
433 team: "test-team".to_string(),
434 priority: "P5".to_string(),
435 },
436 )
437 .unwrap();
438 let alert_description = alerter.construct_alert_description(&SpcFeatureAlerts {
439 features,
440 has_alerts: true,
441 });
442 let expected_alert_description = "Drift has been detected for the following features:\n test_feature_2: \n Kind: Consecutive\n Zone: Zone 1\n test_feature_1: \n Kind: Out of bounds\n Zone: Zone 4\n".to_string();
443 assert_eq!(&alert_description.len(), &expected_alert_description.len());
444
445 unsafe {
446 env::remove_var("OPSGENIE_API_URL");
447 env::remove_var("OPSGENIE_API_KEY");
448 }
449 }
450
451 #[test]
452 fn test_construct_opsgenie_alert_description_empty() {
453 unsafe {
454 env::set_var("OPSGENIE_API_URL", "api_url");
455 env::set_var("OPSGENIE_API_KEY", "api_key");
456 }
457 let features: HashMap<String, SpcFeatureAlert> = HashMap::new();
458 let alerter = OpsGenieAlerter::new(
459 "name",
460 "space",
461 "1.0.0",
462 &OpsGenieDispatchConfig {
463 team: "test-team".to_string(),
464 priority: "P5".to_string(),
465 },
466 )
467 .unwrap();
468 let alert_description = alerter.construct_alert_description(&SpcFeatureAlerts {
469 features,
470 has_alerts: true,
471 });
472 let expected_alert_description = "".to_string();
473 assert_eq!(alert_description, expected_alert_description);
474 unsafe {
475 env::remove_var("OPSGENIE_API_URL");
476 env::remove_var("OPSGENIE_API_KEY");
477 }
478 }
479
480 #[tokio::test]
481 async fn test_construct_opsgenie_alert_body() {
482 let download_server = mockito::Server::new_async().await;
484 let url = download_server.url();
485
486 unsafe {
488 env::set_var("OPSGENIE_API_URL", url);
489 env::set_var("OPSGENIE_API_KEY", "api_key");
490 }
491
492 let ops_genie_team = "test-team";
493
494 let expected_alert_body = json!(
495 {
496 "message": "Model drift detected for test_repo/test_ml_model/1.0.0",
497 "description": "Features have drifted",
498 "responders":[
499 {"name":ops_genie_team, "type":"team"}
500 ],
501 "visibleTo":[
502 {"name":ops_genie_team, "type":"team"}
503 ],
504 "tags": ["Model Drift", "Scouter"],
505 "priority": "P1"
506 }
507 );
508 let alerter = OpsGenieAlerter::new(
509 "test_ml_model",
510 "test_repo",
511 "1.0.0",
512 &OpsGenieDispatchConfig {
513 team: ops_genie_team.to_string(),
514 priority: "P1".to_string(),
515 },
516 )
517 .unwrap();
518 let alert_body = alerter.construct_alert_body("Features have drifted");
519 assert_eq!(alert_body, expected_alert_body);
520 unsafe {
521 env::remove_var("OPSGENIE_API_URL");
522 env::remove_var("OPSGENIE_API_KEY");
523 }
524 }
525
526 #[tokio::test]
527 async fn test_send_opsgenie_alerts() {
528 let mut download_server = mockito::Server::new_async().await;
529 let url = format!("{}/alerts", download_server.url());
530
531 unsafe {
533 env::set_var("OPSGENIE_API_URL", url);
534 env::set_var("OPSGENIE_API_KEY", "api_key");
535 }
536
537 let mock_get_path = download_server
538 .mock("Post", "/alerts")
539 .with_status(201)
540 .create();
541
542 let features = test_features_map();
543
544 let dispatcher = AlertDispatcher::OpsGenie(HttpAlertDispatcher::new(
545 OpsGenieAlerter::new(
546 "name",
547 "space",
548 "1.0.0",
549 &OpsGenieDispatchConfig {
550 team: "test-team".to_string(),
551 priority: "P5".to_string(),
552 },
553 )
554 .unwrap(),
555 ));
556 let _ = dispatcher
557 .process_alerts(&SpcFeatureAlerts {
558 features,
559 has_alerts: true,
560 })
561 .await;
562
563 mock_get_path.assert();
564
565 unsafe {
566 env::remove_var("OPSGENIE_API_URL");
567 env::remove_var("OPSGENIE_API_KEY");
568 }
569 }
570
571 #[tokio::test]
572 async fn test_send_console_alerts() {
573 let features = test_features_map();
574 let dispatcher =
575 AlertDispatcher::Console(ConsoleAlertDispatcher::new("name", "space", "1.0.0"));
576 let result = dispatcher
577 .process_alerts(&SpcFeatureAlerts {
578 features,
579 has_alerts: true,
580 })
581 .await;
582
583 assert!(result.is_ok());
584 }
585
586 #[tokio::test]
587 async fn test_send_slack_alerts() {
588 let mut download_server = mockito::Server::new_async().await;
589 let url = download_server.url();
590
591 unsafe {
593 env::set_var("SLACK_API_URL", url);
594 env::set_var("SLACK_APP_TOKEN", "bot_token");
595 }
596
597 let mock_get_path = download_server
598 .mock("Post", "/chat.postMessage")
599 .with_status(201)
600 .create();
601
602 let features = test_features_map();
603
604 let dispatcher = AlertDispatcher::Slack(HttpAlertDispatcher::new(
605 SlackAlerter::new(
606 "name",
607 "space",
608 "1.0.0",
609 &SlackDispatchConfig {
610 channel: "test-channel".to_string(),
611 },
612 )
613 .unwrap(),
614 ));
615 let _ = dispatcher
616 .process_alerts(&SpcFeatureAlerts {
617 features,
618 has_alerts: true,
619 })
620 .await;
621
622 mock_get_path.assert();
623
624 unsafe {
625 env::remove_var("SLACK_API_URL");
626 env::remove_var("SLACK_APP_TOKEN");
627 }
628 }
629
630 #[tokio::test]
631 async fn test_construct_slack_alert_body() {
632 let download_server = mockito::Server::new_async().await;
634 let url = download_server.url();
635 let slack_channel = "test_channel";
636
637 unsafe {
638 env::set_var("SLACK_API_URL", url);
639 env::set_var("SLACK_APP_TOKEN", "bot_token");
640 }
641 let expected_alert_body = json!({
642 "channel": slack_channel,
643 "blocks": [
644 {
645 "type": "header",
646 "text": {
647 "type": "plain_text",
648 "text": ":rotating_light: Drift Detected :rotating_light:",
649 "emoji": true
650 }
651 },
652 {
653 "type": "section",
654 "text": {
655 "type": "mrkdwn",
656 "text": "*Name*: name *Space*: space *Version*: 1.0.0",
657 }
658 },
659 {
660 "type": "section",
661 "text": {
662 "type": "mrkdwn",
663 "text": "*Features have drifted*"
664 },
665 }
666 ]
667 });
668 let alerter = SlackAlerter::new(
669 "name",
670 "space",
671 "1.0.0",
672 &SlackDispatchConfig {
673 channel: slack_channel.to_string(),
674 },
675 )
676 .unwrap();
677 let alert_body = alerter.construct_alert_body("*Features have drifted*");
678 assert_eq!(alert_body, expected_alert_body);
679 unsafe {
680 env::remove_var("SLACK_API_URL");
681 env::remove_var("SLACK_APP_TOKEN");
682 }
683 }
684
685 #[test]
686 fn test_console_dispatcher_returned_when_env_vars_not_set_opsgenie() {
687 unsafe {
688 env::remove_var("OPSGENIE_API_KEY");
689 }
690 let alert_config = SpcAlertConfig {
691 dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
692 team: "test-team".to_string(),
693 priority: "P5".to_string(),
694 }),
695 ..Default::default()
696 };
697
698 let config = SpcDriftConfig::new(
699 Some("name".to_string()),
700 Some("space".to_string()),
701 Some("1.0.0".to_string()),
702 None,
703 None,
704 Some(alert_config),
705 None,
706 )
707 .unwrap();
708 let dispatcher = AlertDispatcher::new(&config).unwrap();
709
710 assert!(
711 matches!(dispatcher, AlertDispatcher::Console(_)),
712 "Expected Console Dispatcher"
713 );
714 }
715
716 #[test]
717 fn test_console_dispatcher_returned_when_env_vars_not_set_slack() {
718 unsafe {
719 env::remove_var("SLACK_API_URL");
720 env::remove_var("SLACK_APP_TOKEN");
721 }
722
723 let alert_config = SpcAlertConfig {
724 dispatch_config: AlertDispatchConfig::Slack(SlackDispatchConfig {
725 channel: "test-channel".to_string(),
726 }),
727 ..Default::default()
728 };
729
730 let config = SpcDriftConfig::new(
731 Some("name".to_string()),
732 Some("space".to_string()),
733 Some("1.0.0".to_string()),
734 None,
735 None,
736 Some(alert_config),
737 None,
738 )
739 .unwrap();
740
741 let dispatcher = AlertDispatcher::new(&config).unwrap();
742 assert!(
743 matches!(dispatcher, AlertDispatcher::Console(_)),
744 "Expected Console Dispatcher"
745 );
746 }
747
748 #[test]
749 fn test_slack_dispatcher_returned_when_env_vars_set() {
750 unsafe {
751 env::set_var("SLACK_API_URL", "url");
752 env::set_var("SLACK_APP_TOKEN", "bot_token");
753 }
754 let alert_config = SpcAlertConfig {
755 dispatch_config: AlertDispatchConfig::Slack(SlackDispatchConfig {
756 channel: "test-channel".to_string(),
757 }),
758 ..Default::default()
759 };
760
761 let config = SpcDriftConfig::new(
762 Some("name".to_string()),
763 Some("space".to_string()),
764 Some("1.0.0".to_string()),
765 None,
766 None,
767 Some(alert_config),
768 None,
769 )
770 .unwrap();
771
772 let dispatcher = AlertDispatcher::new(&config).unwrap();
773
774 assert!(
775 matches!(dispatcher, AlertDispatcher::Slack(_)),
776 "Expected Slack Dispatcher"
777 );
778
779 unsafe {
780 env::remove_var("SLACK_API_URL");
781 env::remove_var("SLACK_APP_TOKEN");
782 }
783 }
784}