Skip to main content

chorus_core/
router.rs

1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::sms::SmsSender;
4use crate::types::{Channel, EmailMessage, SendResult, SmsMessage};
5use std::sync::Arc;
6
7/// A step in the waterfall routing chain.
8pub struct RouteStep {
9    pub channel: Channel,
10    sender: RouteSender,
11}
12
13enum RouteSender {
14    Sms(Arc<dyn SmsSender>),
15    Email(Arc<dyn EmailSender>),
16}
17
18/// Waterfall router: tries each step in order, falls back to next on failure.
19/// Optimizes cost by trying cheaper channels (email) before expensive ones (SMS).
20pub struct WaterfallRouter {
21    steps: Vec<RouteStep>,
22}
23
24impl WaterfallRouter {
25    pub fn new() -> Self {
26        Self { steps: Vec::new() }
27    }
28
29    pub fn add_sms(mut self, provider: Arc<dyn SmsSender>) -> Self {
30        self.steps.push(RouteStep {
31            channel: Channel::Sms,
32            sender: RouteSender::Sms(provider),
33        });
34        self
35    }
36
37    pub fn add_email(mut self, provider: Arc<dyn EmailSender>) -> Self {
38        self.steps.push(RouteStep {
39            channel: Channel::Email,
40            sender: RouteSender::Email(provider),
41        });
42        self
43    }
44
45    /// Send a message through the waterfall chain.
46    /// For OTP: recipient can be email or phone — tries each step in order.
47    pub async fn send_otp(
48        &self,
49        recipient: &str,
50        code: &str,
51        app_name: &str,
52    ) -> Result<SendResult, ChorusError> {
53        let mut errors = Vec::new();
54
55        for step in &self.steps {
56            let result = match &step.sender {
57                RouteSender::Email(sender) => {
58                    if !recipient.contains('@') {
59                        continue;
60                    }
61                    let msg = EmailMessage {
62                        to: recipient.to_string(),
63                        subject: format!("Your {} verification code", app_name),
64                        html_body: format!(
65                            "<p>Your verification code is: <strong>{}</strong>. It expires in 5 minutes.</p>",
66                            code
67                        ),
68                        text_body: format!(
69                            "Your verification code is: {}. It expires in 5 minutes.",
70                            code
71                        ),
72                        from: None,
73                    };
74                    sender.send(&msg).await
75                }
76                RouteSender::Sms(sender) => {
77                    if recipient.contains('@') {
78                        continue;
79                    }
80                    let msg = SmsMessage {
81                        to: recipient.to_string(),
82                        body: format!("Your {} code: {} (expires in 5 min)", app_name, code),
83                        from: None,
84                    };
85                    sender.send(&msg).await
86                }
87            };
88
89            match result {
90                Ok(send_result) => {
91                    tracing::info!(
92                        provider = %send_result.provider,
93                        channel = %send_result.channel,
94                        "Message sent successfully via waterfall"
95                    );
96                    return Ok(send_result);
97                }
98                Err(e) => {
99                    tracing::warn!(
100                        channel = %step.channel,
101                        error = %e,
102                        "Waterfall step failed, trying next"
103                    );
104                    errors.push(e);
105                }
106            }
107        }
108
109        Err(ChorusError::AllProvidersFailed)
110    }
111
112    /// Send SMS directly (bypass waterfall).
113    pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
114        for step in &self.steps {
115            if let RouteSender::Sms(sender) = &step.sender {
116                match sender.send(msg).await {
117                    Ok(result) => return Ok(result),
118                    Err(e) => {
119                        tracing::warn!(provider = sender.provider_name(), error = %e, "SMS provider failed, trying next");
120                        continue;
121                    }
122                }
123            }
124        }
125        Err(ChorusError::AllProvidersFailed)
126    }
127
128    /// Send email directly (bypass waterfall).
129    pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
130        for step in &self.steps {
131            if let RouteSender::Email(sender) = &step.sender {
132                match sender.send(msg).await {
133                    Ok(result) => return Ok(result),
134                    Err(e) => {
135                        tracing::warn!(provider = sender.provider_name(), error = %e, "Email provider failed, trying next");
136                        continue;
137                    }
138                }
139            }
140        }
141        Err(ChorusError::AllProvidersFailed)
142    }
143}
144
145impl Default for WaterfallRouter {
146    fn default() -> Self {
147        Self::new()
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::types::DeliveryStatus;
155
156    struct SuccessSms;
157    #[async_trait::async_trait]
158    impl SmsSender for SuccessSms {
159        fn provider_name(&self) -> &str {
160            "test-sms"
161        }
162        async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
163            Ok(SendResult {
164                message_id: "sms-1".to_string(),
165                provider: "test-sms".to_string(),
166                channel: Channel::Sms,
167                status: DeliveryStatus::Sent,
168                created_at: chrono::Utc::now(),
169            })
170        }
171        async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
172            Ok(DeliveryStatus::Delivered)
173        }
174    }
175
176    struct FailSms;
177    #[async_trait::async_trait]
178    impl SmsSender for FailSms {
179        fn provider_name(&self) -> &str {
180            "fail-sms"
181        }
182        async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
183            Err(ChorusError::Provider {
184                provider: "fail-sms".into(),
185                message: "timeout".into(),
186            })
187        }
188        async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
189            Ok(DeliveryStatus::Failed {
190                reason: "timeout".into(),
191            })
192        }
193    }
194
195    struct SuccessEmail;
196    #[async_trait::async_trait]
197    impl EmailSender for SuccessEmail {
198        fn provider_name(&self) -> &str {
199            "test-email"
200        }
201        async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
202            Ok(SendResult {
203                message_id: "email-1".to_string(),
204                provider: "test-email".to_string(),
205                channel: Channel::Email,
206                status: DeliveryStatus::Sent,
207                created_at: chrono::Utc::now(),
208            })
209        }
210    }
211
212    #[tokio::test]
213    async fn waterfall_sends_email_for_email_recipient() {
214        let router = WaterfallRouter::new()
215            .add_email(Arc::new(SuccessEmail))
216            .add_sms(Arc::new(SuccessSms));
217
218        let result = router
219            .send_otp("user@test.com", "123456", "TestApp")
220            .await
221            .unwrap();
222        assert_eq!(result.channel, Channel::Email);
223        assert_eq!(result.provider, "test-email");
224    }
225
226    #[tokio::test]
227    async fn waterfall_sends_sms_for_phone_recipient() {
228        let router = WaterfallRouter::new()
229            .add_email(Arc::new(SuccessEmail))
230            .add_sms(Arc::new(SuccessSms));
231
232        let result = router
233            .send_otp("+66812345678", "123456", "TestApp")
234            .await
235            .unwrap();
236        assert_eq!(result.channel, Channel::Sms);
237        assert_eq!(result.provider, "test-sms");
238    }
239
240    #[tokio::test]
241    async fn waterfall_fallback_on_failure() {
242        let router = WaterfallRouter::new()
243            .add_sms(Arc::new(FailSms))
244            .add_sms(Arc::new(SuccessSms));
245
246        let result = router
247            .send_otp("+66812345678", "123456", "TestApp")
248            .await
249            .unwrap();
250        assert_eq!(result.provider, "test-sms");
251    }
252
253    #[tokio::test]
254    async fn waterfall_all_fail_returns_error() {
255        let router = WaterfallRouter::new().add_sms(Arc::new(FailSms));
256
257        let result = router.send_otp("+66812345678", "123456", "TestApp").await;
258        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
259    }
260
261    #[tokio::test]
262    async fn waterfall_empty_router_returns_error() {
263        let router = WaterfallRouter::new();
264        let result = router.send_otp("user@test.com", "123456", "TestApp").await;
265        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
266    }
267
268    #[tokio::test]
269    async fn send_sms_directly() {
270        let router = WaterfallRouter::new()
271            .add_email(Arc::new(SuccessEmail))
272            .add_sms(Arc::new(SuccessSms));
273
274        let msg = SmsMessage {
275            to: "+66812345678".into(),
276            body: "Hi".into(),
277            from: None,
278        };
279        let result = router.send_sms(&msg).await.unwrap();
280        assert_eq!(result.channel, Channel::Sms);
281    }
282
283    #[tokio::test]
284    async fn send_email_directly() {
285        let router = WaterfallRouter::new()
286            .add_email(Arc::new(SuccessEmail))
287            .add_sms(Arc::new(SuccessSms));
288
289        let msg = EmailMessage {
290            to: "user@test.com".into(),
291            subject: "Hi".into(),
292            html_body: "<p>Hi</p>".into(),
293            text_body: "Hi".into(),
294            from: None,
295        };
296        let result = router.send_email(&msg).await.unwrap();
297        assert_eq!(result.channel, Channel::Email);
298    }
299
300    #[tokio::test]
301    async fn send_sms_no_sms_providers_returns_error() {
302        let router = WaterfallRouter::new().add_email(Arc::new(SuccessEmail));
303
304        let msg = SmsMessage {
305            to: "+66812345678".into(),
306            body: "Hi".into(),
307            from: None,
308        };
309        let result = router.send_sms(&msg).await;
310        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
311    }
312
313    #[tokio::test]
314    async fn send_email_no_email_providers_returns_error() {
315        let router = WaterfallRouter::new().add_sms(Arc::new(SuccessSms));
316
317        let msg = EmailMessage {
318            to: "user@test.com".into(),
319            subject: "Hi".into(),
320            html_body: "<p>Hi</p>".into(),
321            text_body: "Hi".into(),
322            from: None,
323        };
324        let result = router.send_email(&msg).await;
325        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
326    }
327
328    #[tokio::test]
329    async fn send_sms_failover_across_providers() {
330        let router = WaterfallRouter::new()
331            .add_sms(Arc::new(FailSms))
332            .add_sms(Arc::new(SuccessSms));
333
334        let msg = SmsMessage {
335            to: "+66812345678".into(),
336            body: "Hi".into(),
337            from: None,
338        };
339        let result = router.send_sms(&msg).await.unwrap();
340        assert_eq!(result.provider, "test-sms");
341    }
342
343    #[tokio::test]
344    async fn send_sms_all_fail_returns_error() {
345        let router = WaterfallRouter::new()
346            .add_sms(Arc::new(FailSms))
347            .add_sms(Arc::new(FailSms));
348
349        let msg = SmsMessage {
350            to: "+66812345678".into(),
351            body: "Hi".into(),
352            from: None,
353        };
354        let result = router.send_sms(&msg).await;
355        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
356    }
357}