scouter_dispatch/dispatch/
dispatcher.rs

1use crate::dispatch::error::DispatchError;
2use scouter_types::{
3    AlertDispatchConfig, AlertDispatchType, DispatchAlertDescription, DispatchDriftConfig,
4    DriftArgs, OpsGenieDispatchConfig, SlackDispatchConfig,
5};
6use serde_json::{json, Value};
7use std::result::Result;
8use std::{collections::HashMap, env};
9use tracing::error;
10
11trait DispatchHelpers {
12    fn construct_alert_description<T: DispatchAlertDescription>(
13        &self,
14        feature_alerts: &T,
15    ) -> String;
16}
17pub trait Dispatch {
18    fn process_alerts<T: DispatchAlertDescription + std::marker::Sync>(
19        &self,
20        feature_alerts: &T,
21    ) -> impl std::future::Future<Output = Result<(), DispatchError>> + Send;
22}
23pub trait HttpAlertWrapper {
24    fn api_url(&self) -> &str;
25    fn header_auth_value(&self) -> &str;
26    fn construct_alert_body(&self, alert_description: &str) -> Value;
27}
28
29#[derive(Debug)]
30pub struct OpsGenieAlerter {
31    header_auth_value: String,
32    api_url: String,
33    team_name: String,
34    priority: String,
35    name: String,
36    space: String,
37    version: String,
38}
39
40impl OpsGenieAlerter {
41    /// Create a new OpsGenieAlerter
42    ///
43    /// # Arguments
44    ///
45    /// * `name` - Name of the model
46    /// * `space` - Space of the model
47    /// * `version` - Version of the model
48    /// * `dispatch_config` - OpsGenieAlerter dispatch configuration
49    ///
50    pub fn new(
51        name: &str,
52        space: &str,
53        version: &str,
54        dispatch_config: &OpsGenieDispatchConfig,
55    ) -> Result<Self, DispatchError> {
56        let api_key = env::var("OPSGENIE_API_KEY")
57            .map_err(|_| DispatchError::OpsGenieError("OPSGENIE_API_KEY is not set".to_string()))?;
58
59        let api_url = env::var("OPSGENIE_API_URL")
60            .map_err(|_| DispatchError::OpsGenieError("OPSGENIE_API_URL is not set".to_string()))?;
61
62        let team_name = dispatch_config.team.clone();
63        let priority = dispatch_config.priority.clone();
64
65        Ok(Self {
66            header_auth_value: format!("GenieKey {api_key}"),
67            api_url,
68            team_name,
69            name: name.to_string(),
70            space: space.to_string(),
71            version: version.to_string(),
72            priority,
73        })
74    }
75}
76
77impl HttpAlertWrapper for OpsGenieAlerter {
78    fn api_url(&self) -> &str {
79        &self.api_url
80    }
81
82    fn header_auth_value(&self) -> &str {
83        &self.header_auth_value
84    }
85
86    fn construct_alert_body(&self, alert_description: &str) -> Value {
87        let mut mapping: HashMap<&str, Value> = HashMap::new();
88        mapping.insert(
89            "message",
90            format!(
91                "Model drift detected for {}/{}/{}",
92                self.space, self.name, self.version
93            )
94            .into(),
95        );
96        mapping.insert("description", alert_description.to_string().into());
97        mapping.insert(
98            "responders",
99            json!([{"name": self.team_name, "type": "team"}]),
100        );
101        mapping.insert(
102            "visibleTo",
103            json!([{"name": self.team_name, "type": "team"}]),
104        );
105
106        mapping.insert("tags", json!(["Model Drift", "Scouter"]));
107        mapping.insert("priority", self.priority.clone().into());
108
109        json!(mapping)
110    }
111}
112impl DispatchHelpers for OpsGenieAlerter {
113    fn construct_alert_description<T: DispatchAlertDescription>(
114        &self,
115        feature_alerts: &T,
116    ) -> String {
117        feature_alerts.create_alert_description(AlertDispatchType::OpsGenie)
118    }
119}
120
121#[derive(Debug)]
122pub struct SlackAlerter {
123    header_auth_value: String,
124    api_url: String,
125    name: String,
126    space: String,
127    version: String,
128    channel: String,
129}
130
131impl SlackAlerter {
132    /// Create a new SlackAlerter
133    ///
134    /// # Arguments
135    ///
136    /// * `name` - Name of the model
137    /// * `space` - Space of the model
138    /// * `version` - Version of the model
139    /// * `dispatch_config` - slack dispatch configuration
140    ///
141    pub fn new(
142        name: &str,
143        space: &str,
144        version: &str,
145        dispatch_config: &SlackDispatchConfig,
146    ) -> Result<Self, DispatchError> {
147        let app_token = env::var("SLACK_APP_TOKEN")
148            .map_err(|_| DispatchError::SlackError("SLACK_APP_TOKEN not set".to_string()))?;
149
150        let api_url = env::var("SLACK_API_URL")
151            .map_err(|_| DispatchError::SlackError("SLACK_API_URL not set".to_string()))?;
152
153        let slack_channel = dispatch_config.channel.clone();
154
155        Ok(Self {
156            header_auth_value: format!("Bearer {app_token}"),
157            api_url: format!("{api_url}/chat.postMessage"),
158            name: name.to_string(),
159            space: space.to_string(),
160            version: version.to_string(),
161            channel: slack_channel,
162        })
163    }
164}
165
166impl HttpAlertWrapper for SlackAlerter {
167    fn api_url(&self) -> &str {
168        &self.api_url
169    }
170
171    fn header_auth_value(&self) -> &str {
172        &self.header_auth_value
173    }
174
175    fn construct_alert_body(&self, alert_description: &str) -> Value {
176        json!({
177            "channel": self.channel,
178            "blocks": [
179                {
180                    "type": "header",
181                    "text": {
182                      "type": "plain_text",
183                      "text": ":rotating_light: Drift Detected :rotating_light:",
184                      "emoji": true
185                    }
186                },
187                {
188                    "type": "section",
189                    "text": {
190                      "type": "mrkdwn",
191                      "text": format!("*Name*: {} *Space*: {} *Version*: {}", self.name, self.space, self.version),
192                    }
193                },
194                {
195                    "type": "section",
196                    "text": {
197                        "type": "mrkdwn",
198                        "text": alert_description
199                    },
200
201                }
202            ]
203        })
204    }
205}
206
207impl DispatchHelpers for SlackAlerter {
208    fn construct_alert_description<T: DispatchAlertDescription>(
209        &self,
210        feature_alerts: &T,
211    ) -> String {
212        feature_alerts.create_alert_description(AlertDispatchType::Slack)
213    }
214}
215
216#[derive(Debug)]
217pub struct HttpAlertDispatcher<T: HttpAlertWrapper> {
218    http_client: reqwest::Client,
219    alerter: T,
220}
221
222impl<T: HttpAlertWrapper> HttpAlertDispatcher<T> {
223    pub fn new(alerter: T) -> Self {
224        Self {
225            http_client: reqwest::Client::new(),
226            alerter,
227        }
228    }
229
230    async fn send_alerts(&self, body: Value) -> Result<(), DispatchError> {
231        let response = self
232            .http_client
233            .post(self.alerter.api_url())
234            .header("Authorization", self.alerter.header_auth_value())
235            .json(&body)
236            .send()
237            .await
238            .map_err(|e| DispatchError::HttpError(e.to_string()))?;
239
240        if response.status().is_success() {
241            Ok(())
242        } else {
243            let text = response
244                .text()
245                .await
246                .unwrap_or("Failed to parse response".to_string());
247            error!("Failed to send alert: {}. Continuing", text);
248            Ok(())
249        }
250    }
251}
252
253impl<T: HttpAlertWrapper + DispatchHelpers + std::marker::Sync> Dispatch
254    for HttpAlertDispatcher<T>
255{
256    async fn process_alerts<J: DispatchAlertDescription>(
257        &self,
258        feature_alerts: &J,
259    ) -> Result<(), DispatchError> {
260        let alert_description = self.alerter.construct_alert_description(feature_alerts);
261
262        let alert_body = self.alerter.construct_alert_body(&alert_description);
263
264        self.send_alerts(alert_body)
265            .await
266            .map_err(|e| DispatchError::HttpError(format!("Failed to send alerts: {e}")))?;
267
268        Ok(())
269    }
270}
271
272#[derive(Debug)]
273pub struct ConsoleAlertDispatcher {
274    name: String,
275    space: String,
276    version: String,
277}
278
279impl ConsoleAlertDispatcher {
280    pub fn new(name: &str, space: &str, version: &str) -> Self {
281        Self {
282            name: name.to_string(),
283            space: space.to_string(),
284            version: version.to_string(),
285        }
286    }
287}
288
289impl Dispatch for ConsoleAlertDispatcher {
290    async fn process_alerts<T: DispatchAlertDescription>(
291        &self,
292        feature_alerts: &T,
293    ) -> Result<(), DispatchError> {
294        let alert_description = self.construct_alert_description(feature_alerts);
295        if !alert_description.is_empty() {
296            let msg1 = "Drift detected for";
297            let msg2 = format!("{}/{}/{}!", self.space, self.name, self.version);
298            let mut body = format!("\n{msg1} {msg2} \n");
299            body.push_str(&alert_description);
300
301            println!("{body}");
302        }
303        Ok(())
304    }
305}
306
307impl DispatchHelpers for ConsoleAlertDispatcher {
308    fn construct_alert_description<T: DispatchAlertDescription>(
309        &self,
310        feature_alerts: &T,
311    ) -> String {
312        feature_alerts.create_alert_description(AlertDispatchType::Console)
313    }
314}
315
316#[derive(Debug)]
317pub enum AlertDispatcher {
318    Console(ConsoleAlertDispatcher),
319    OpsGenie(HttpAlertDispatcher<OpsGenieAlerter>),
320    Slack(HttpAlertDispatcher<SlackAlerter>),
321}
322
323impl AlertDispatcher {
324    // process alerts can be called asynchronously
325    pub async fn process_alerts<T: DispatchAlertDescription + std::marker::Sync>(
326        &self,
327        feature_alerts: &T,
328    ) -> Result<(), DispatchError> {
329        match self {
330            AlertDispatcher::Console(dispatcher) => dispatcher
331                .process_alerts(feature_alerts)
332                .await
333                .map_err(|e| DispatchError::AlertProcessError(e.to_string())),
334            AlertDispatcher::OpsGenie(dispatcher) => dispatcher
335                .process_alerts(feature_alerts)
336                .await
337                .map_err(|e| DispatchError::AlertProcessError(e.to_string())),
338            AlertDispatcher::Slack(dispatcher) => dispatcher
339                .process_alerts(feature_alerts)
340                .await
341                .map_err(|e| DispatchError::AlertProcessError(e.to_string())),
342        }
343    }
344
345    // create a new alert dispatcher based on the configuration
346    //
347    // # Arguments
348    //
349    // * `config` - DriftConfig (this is an enum wrapper for the different drift configurations)
350    pub fn new<T: DispatchDriftConfig>(config: &T) -> Result<Self, DispatchError> {
351        let args: DriftArgs = config.get_drift_args();
352
353        let result = match args.dispatch_config {
354            AlertDispatchConfig::Slack(config) => {
355                SlackAlerter::new(&args.name, &args.space, &args.version, &config)
356                    .map(|alerter| AlertDispatcher::Slack(HttpAlertDispatcher::new(alerter)))
357            }
358            AlertDispatchConfig::OpsGenie(config) => {
359                OpsGenieAlerter::new(&args.name, &args.space, &args.version, &config)
360                    .map(|alerter| AlertDispatcher::OpsGenie(HttpAlertDispatcher::new(alerter)))
361            }
362            AlertDispatchConfig::Console(_) => Ok(AlertDispatcher::Console(
363                ConsoleAlertDispatcher::new(&args.name, &args.space, &args.version),
364            )),
365        };
366
367        match result {
368            Ok(dispatcher) => Ok(dispatcher),
369            Err(e) => {
370                error!("Failed to create Alerter: {:?}", e);
371                Ok(AlertDispatcher::Console(ConsoleAlertDispatcher::new(
372                    &args.name,
373                    &args.space,
374                    &args.version,
375                )))
376            }
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use scouter_types::spc::{
385        AlertZone, SpcAlert, SpcAlertConfig, SpcAlertType, SpcDriftConfig, SpcFeatureAlert,
386        SpcFeatureAlerts,
387    };
388
389    use std::collections::HashMap;
390    use std::env;
391
392    fn test_features_map() -> HashMap<String, SpcFeatureAlert> {
393        let mut features: HashMap<String, SpcFeatureAlert> = HashMap::new();
394
395        features.insert(
396            "test_feature_1".to_string(),
397            SpcFeatureAlert {
398                feature: "test_feature_1".to_string(),
399                alerts: vec![SpcAlert {
400                    zone: AlertZone::Zone4,
401                    kind: SpcAlertType::OutOfBounds,
402                }]
403                .into_iter()
404                .collect(),
405            },
406        );
407        features.insert(
408            "test_feature_2".to_string(),
409            SpcFeatureAlert {
410                feature: "test_feature_2".to_string(),
411                alerts: vec![SpcAlert {
412                    zone: AlertZone::Zone1,
413                    kind: SpcAlertType::Consecutive,
414                }]
415                .into_iter()
416                .collect(),
417            },
418        );
419        features
420    }
421    #[test]
422    fn test_construct_opsgenie_alert_description() {
423        unsafe {
424            env::set_var("OPSGENIE_API_URL", "api_url");
425            env::set_var("OPSGENIE_API_KEY", "api_key");
426        }
427        let features = test_features_map();
428        let alerter = OpsGenieAlerter::new(
429            "name",
430            "space",
431            "1.0.0",
432            &OpsGenieDispatchConfig {
433                team: "test-team".to_string(),
434                priority: "P5".to_string(),
435            },
436        )
437        .unwrap();
438        let alert_description = alerter.construct_alert_description(&SpcFeatureAlerts {
439            features,
440            has_alerts: true,
441        });
442        let expected_alert_description = "Drift has been detected for the following features:\n    test_feature_2: \n        Kind: Consecutive\n        Zone: Zone 1\n    test_feature_1: \n        Kind: Out of bounds\n        Zone: Zone 4\n".to_string();
443        assert_eq!(&alert_description.len(), &expected_alert_description.len());
444
445        unsafe {
446            env::remove_var("OPSGENIE_API_URL");
447            env::remove_var("OPSGENIE_API_KEY");
448        }
449    }
450
451    #[test]
452    fn test_construct_opsgenie_alert_description_empty() {
453        unsafe {
454            env::set_var("OPSGENIE_API_URL", "api_url");
455            env::set_var("OPSGENIE_API_KEY", "api_key");
456        }
457        let features: HashMap<String, SpcFeatureAlert> = HashMap::new();
458        let alerter = OpsGenieAlerter::new(
459            "name",
460            "space",
461            "1.0.0",
462            &OpsGenieDispatchConfig {
463                team: "test-team".to_string(),
464                priority: "P5".to_string(),
465            },
466        )
467        .unwrap();
468        let alert_description = alerter.construct_alert_description(&SpcFeatureAlerts {
469            features,
470            has_alerts: true,
471        });
472        let expected_alert_description = "".to_string();
473        assert_eq!(alert_description, expected_alert_description);
474        unsafe {
475            env::remove_var("OPSGENIE_API_URL");
476            env::remove_var("OPSGENIE_API_KEY");
477        }
478    }
479
480    #[tokio::test]
481    async fn test_construct_opsgenie_alert_body() {
482        // set env variables
483        let download_server = mockito::Server::new_async().await;
484        let url = download_server.url();
485
486        // set env variables
487        unsafe {
488            env::set_var("OPSGENIE_API_URL", url);
489            env::set_var("OPSGENIE_API_KEY", "api_key");
490        }
491
492        let ops_genie_team = "test-team";
493
494        let expected_alert_body = json!(
495                {
496                    "message": "Model drift detected for test_repo/test_ml_model/1.0.0",
497                    "description": "Features have drifted",
498                    "responders":[
499                        {"name":ops_genie_team, "type":"team"}
500                    ],
501                    "visibleTo":[
502                        {"name":ops_genie_team, "type":"team"}
503                    ],
504                    "tags": ["Model Drift", "Scouter"],
505                    "priority": "P1"
506                }
507        );
508        let alerter = OpsGenieAlerter::new(
509            "test_ml_model",
510            "test_repo",
511            "1.0.0",
512            &OpsGenieDispatchConfig {
513                team: ops_genie_team.to_string(),
514                priority: "P1".to_string(),
515            },
516        )
517        .unwrap();
518        let alert_body = alerter.construct_alert_body("Features have drifted");
519        assert_eq!(alert_body, expected_alert_body);
520        unsafe {
521            env::remove_var("OPSGENIE_API_URL");
522            env::remove_var("OPSGENIE_API_KEY");
523        }
524    }
525
526    #[tokio::test]
527    async fn test_send_opsgenie_alerts() {
528        let mut download_server = mockito::Server::new_async().await;
529        let url = format!("{}/alerts", download_server.url());
530
531        // set env variables
532        unsafe {
533            env::set_var("OPSGENIE_API_URL", url);
534            env::set_var("OPSGENIE_API_KEY", "api_key");
535        }
536
537        let mock_get_path = download_server
538            .mock("Post", "/alerts")
539            .with_status(201)
540            .create();
541
542        let features = test_features_map();
543
544        let dispatcher = AlertDispatcher::OpsGenie(HttpAlertDispatcher::new(
545            OpsGenieAlerter::new(
546                "name",
547                "space",
548                "1.0.0",
549                &OpsGenieDispatchConfig {
550                    team: "test-team".to_string(),
551                    priority: "P5".to_string(),
552                },
553            )
554            .unwrap(),
555        ));
556        let _ = dispatcher
557            .process_alerts(&SpcFeatureAlerts {
558                features,
559                has_alerts: true,
560            })
561            .await;
562
563        mock_get_path.assert();
564
565        unsafe {
566            env::remove_var("OPSGENIE_API_URL");
567            env::remove_var("OPSGENIE_API_KEY");
568        }
569    }
570
571    #[tokio::test]
572    async fn test_send_console_alerts() {
573        let features = test_features_map();
574        let dispatcher =
575            AlertDispatcher::Console(ConsoleAlertDispatcher::new("name", "space", "1.0.0"));
576        let result = dispatcher
577            .process_alerts(&SpcFeatureAlerts {
578                features,
579                has_alerts: true,
580            })
581            .await;
582
583        assert!(result.is_ok());
584    }
585
586    #[tokio::test]
587    async fn test_send_slack_alerts() {
588        let mut download_server = mockito::Server::new_async().await;
589        let url = download_server.url();
590
591        // set env variables
592        unsafe {
593            env::set_var("SLACK_API_URL", url);
594            env::set_var("SLACK_APP_TOKEN", "bot_token");
595        }
596
597        let mock_get_path = download_server
598            .mock("Post", "/chat.postMessage")
599            .with_status(201)
600            .create();
601
602        let features = test_features_map();
603
604        let dispatcher = AlertDispatcher::Slack(HttpAlertDispatcher::new(
605            SlackAlerter::new(
606                "name",
607                "space",
608                "1.0.0",
609                &SlackDispatchConfig {
610                    channel: "test-channel".to_string(),
611                },
612            )
613            .unwrap(),
614        ));
615        let _ = dispatcher
616            .process_alerts(&SpcFeatureAlerts {
617                features,
618                has_alerts: true,
619            })
620            .await;
621
622        mock_get_path.assert();
623
624        unsafe {
625            env::remove_var("SLACK_API_URL");
626            env::remove_var("SLACK_APP_TOKEN");
627        }
628    }
629
630    #[tokio::test]
631    async fn test_construct_slack_alert_body() {
632        // set env variables
633        let download_server = mockito::Server::new_async().await;
634        let url = download_server.url();
635        let slack_channel = "test_channel";
636
637        unsafe {
638            env::set_var("SLACK_API_URL", url);
639            env::set_var("SLACK_APP_TOKEN", "bot_token");
640        }
641        let expected_alert_body = json!({
642            "channel": slack_channel,
643            "blocks": [
644                {
645                    "type": "header",
646                    "text": {
647                        "type": "plain_text",
648                        "text": ":rotating_light: Drift Detected :rotating_light:",
649                        "emoji": true
650                    }
651                },
652                {
653                    "type": "section",
654                    "text": {
655                      "type": "mrkdwn",
656                      "text": "*Name*: name *Space*: space *Version*: 1.0.0",
657                    }
658                },
659                {
660                    "type": "section",
661                    "text": {
662                        "type": "mrkdwn",
663                        "text": "*Features have drifted*"
664                    },
665                }
666            ]
667        });
668        let alerter = SlackAlerter::new(
669            "name",
670            "space",
671            "1.0.0",
672            &SlackDispatchConfig {
673                channel: slack_channel.to_string(),
674            },
675        )
676        .unwrap();
677        let alert_body = alerter.construct_alert_body("*Features have drifted*");
678        assert_eq!(alert_body, expected_alert_body);
679        unsafe {
680            env::remove_var("SLACK_API_URL");
681            env::remove_var("SLACK_APP_TOKEN");
682        }
683    }
684
685    #[test]
686    fn test_console_dispatcher_returned_when_env_vars_not_set_opsgenie() {
687        unsafe {
688            env::remove_var("OPSGENIE_API_KEY");
689        }
690        let alert_config = SpcAlertConfig {
691            dispatch_config: AlertDispatchConfig::OpsGenie(OpsGenieDispatchConfig {
692                team: "test-team".to_string(),
693                priority: "P5".to_string(),
694            }),
695            ..Default::default()
696        };
697
698        let config = SpcDriftConfig::new(
699            Some("name".to_string()),
700            Some("space".to_string()),
701            Some("1.0.0".to_string()),
702            None,
703            None,
704            Some(alert_config),
705            None,
706        )
707        .unwrap();
708        let dispatcher = AlertDispatcher::new(&config).unwrap();
709
710        assert!(
711            matches!(dispatcher, AlertDispatcher::Console(_)),
712            "Expected Console Dispatcher"
713        );
714    }
715
716    #[test]
717    fn test_console_dispatcher_returned_when_env_vars_not_set_slack() {
718        unsafe {
719            env::remove_var("SLACK_API_URL");
720            env::remove_var("SLACK_APP_TOKEN");
721        }
722
723        let alert_config = SpcAlertConfig {
724            dispatch_config: AlertDispatchConfig::Slack(SlackDispatchConfig {
725                channel: "test-channel".to_string(),
726            }),
727            ..Default::default()
728        };
729
730        let config = SpcDriftConfig::new(
731            Some("name".to_string()),
732            Some("space".to_string()),
733            Some("1.0.0".to_string()),
734            None,
735            None,
736            Some(alert_config),
737            None,
738        )
739        .unwrap();
740
741        let dispatcher = AlertDispatcher::new(&config).unwrap();
742        assert!(
743            matches!(dispatcher, AlertDispatcher::Console(_)),
744            "Expected Console Dispatcher"
745        );
746    }
747
748    #[test]
749    fn test_slack_dispatcher_returned_when_env_vars_set() {
750        unsafe {
751            env::set_var("SLACK_API_URL", "url");
752            env::set_var("SLACK_APP_TOKEN", "bot_token");
753        }
754        let alert_config = SpcAlertConfig {
755            dispatch_config: AlertDispatchConfig::Slack(SlackDispatchConfig {
756                channel: "test-channel".to_string(),
757            }),
758            ..Default::default()
759        };
760
761        let config = SpcDriftConfig::new(
762            Some("name".to_string()),
763            Some("space".to_string()),
764            Some("1.0.0".to_string()),
765            None,
766            None,
767            Some(alert_config),
768            None,
769        )
770        .unwrap();
771
772        let dispatcher = AlertDispatcher::new(&config).unwrap();
773
774        assert!(
775            matches!(dispatcher, AlertDispatcher::Slack(_)),
776            "Expected Slack Dispatcher"
777        );
778
779        unsafe {
780            env::remove_var("SLACK_API_URL");
781            env::remove_var("SLACK_APP_TOKEN");
782        }
783    }
784}