Skip to main content

chorus_core/
router.rs

1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::sms::SmsSender;
4use crate::types::{Channel, EmailMessage, SendResult, SmsMessage};
5use std::sync::Arc;
6
7/// A step in the waterfall routing chain.
8pub struct RouteStep {
9    pub channel: Channel,
10    sender: RouteSender,
11}
12
13enum RouteSender {
14    Sms(Arc<dyn SmsSender>),
15    Email(Arc<dyn EmailSender>),
16}
17
18/// Waterfall router: tries each provider in order, falls back to the next on failure.
19///
20/// Optimizes cost by trying cheaper channels (email) before expensive ones (SMS).
21/// For OTP delivery, automatically detects whether the recipient is an email address
22/// or phone number and routes accordingly.
23pub struct WaterfallRouter {
24    steps: Vec<RouteStep>,
25}
26
27impl WaterfallRouter {
28    /// Creates an empty router with no providers.
29    pub fn new() -> Self {
30        Self { steps: Vec::new() }
31    }
32
33    /// Adds an SMS provider to the routing chain. Providers are tried in insertion order.
34    pub fn add_sms(mut self, provider: Arc<dyn SmsSender>) -> Self {
35        self.steps.push(RouteStep {
36            channel: Channel::Sms,
37            sender: RouteSender::Sms(provider),
38        });
39        self
40    }
41
42    /// Adds an email provider to the routing chain. Providers are tried in insertion order.
43    pub fn add_email(mut self, provider: Arc<dyn EmailSender>) -> Self {
44        self.steps.push(RouteStep {
45            channel: Channel::Email,
46            sender: RouteSender::Email(provider),
47        });
48        self
49    }
50
51    /// Sends an OTP via waterfall routing. Routes to email if recipient contains `@`, otherwise SMS.
52    pub async fn send_otp(
53        &self,
54        recipient: &str,
55        code: &str,
56        app_name: &str,
57    ) -> Result<SendResult, ChorusError> {
58        let mut errors = Vec::new();
59
60        for step in &self.steps {
61            let result = match &step.sender {
62                RouteSender::Email(sender) => {
63                    if !recipient.contains('@') {
64                        continue;
65                    }
66                    let msg = EmailMessage {
67                        to: recipient.to_string(),
68                        subject: format!("Your {} verification code", app_name),
69                        html_body: format!(
70                            "<p>Your verification code is: <strong>{}</strong>. It expires in 5 minutes.</p>",
71                            code
72                        ),
73                        text_body: format!(
74                            "Your verification code is: {}. It expires in 5 minutes.",
75                            code
76                        ),
77                        from: None,
78                    };
79                    sender.send(&msg).await
80                }
81                RouteSender::Sms(sender) => {
82                    if recipient.contains('@') {
83                        continue;
84                    }
85                    let msg = SmsMessage {
86                        to: recipient.to_string(),
87                        body: format!("Your {} code: {} (expires in 5 min)", app_name, code),
88                        from: None,
89                    };
90                    sender.send(&msg).await
91                }
92            };
93
94            match result {
95                Ok(send_result) => {
96                    tracing::info!(
97                        provider = %send_result.provider,
98                        channel = %send_result.channel,
99                        "Message sent successfully via waterfall"
100                    );
101                    return Ok(send_result);
102                }
103                Err(e) => {
104                    tracing::warn!(
105                        channel = %step.channel,
106                        error = %e,
107                        "Waterfall step failed, trying next"
108                    );
109                    errors.push(e);
110                }
111            }
112        }
113
114        Err(ChorusError::AllProvidersFailed)
115    }
116
117    /// Sends an SMS directly, trying each SMS provider in order until one succeeds.
118    pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
119        for step in &self.steps {
120            if let RouteSender::Sms(sender) = &step.sender {
121                match sender.send(msg).await {
122                    Ok(result) => return Ok(result),
123                    Err(e) => {
124                        tracing::warn!(provider = sender.provider_name(), error = %e, "SMS provider failed, trying next");
125                        continue;
126                    }
127                }
128            }
129        }
130        Err(ChorusError::AllProvidersFailed)
131    }
132
133    /// Sends an email directly, trying each email provider in order until one succeeds.
134    pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
135        for step in &self.steps {
136            if let RouteSender::Email(sender) = &step.sender {
137                match sender.send(msg).await {
138                    Ok(result) => return Ok(result),
139                    Err(e) => {
140                        tracing::warn!(provider = sender.provider_name(), error = %e, "Email provider failed, trying next");
141                        continue;
142                    }
143                }
144            }
145        }
146        Err(ChorusError::AllProvidersFailed)
147    }
148}
149
150impl Default for WaterfallRouter {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::types::DeliveryStatus;
160
161    struct SuccessSms;
162    #[async_trait::async_trait]
163    impl SmsSender for SuccessSms {
164        fn provider_name(&self) -> &str {
165            "test-sms"
166        }
167        async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
168            Ok(SendResult {
169                message_id: "sms-1".to_string(),
170                provider: "test-sms".to_string(),
171                channel: Channel::Sms,
172                status: DeliveryStatus::Sent,
173                created_at: chrono::Utc::now(),
174            })
175        }
176        async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
177            Ok(DeliveryStatus::Delivered)
178        }
179    }
180
181    struct FailSms;
182    #[async_trait::async_trait]
183    impl SmsSender for FailSms {
184        fn provider_name(&self) -> &str {
185            "fail-sms"
186        }
187        async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
188            Err(ChorusError::Provider {
189                provider: "fail-sms".into(),
190                message: "timeout".into(),
191            })
192        }
193        async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
194            Ok(DeliveryStatus::Failed {
195                reason: "timeout".into(),
196            })
197        }
198    }
199
200    struct SuccessEmail;
201    #[async_trait::async_trait]
202    impl EmailSender for SuccessEmail {
203        fn provider_name(&self) -> &str {
204            "test-email"
205        }
206        async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
207            Ok(SendResult {
208                message_id: "email-1".to_string(),
209                provider: "test-email".to_string(),
210                channel: Channel::Email,
211                status: DeliveryStatus::Sent,
212                created_at: chrono::Utc::now(),
213            })
214        }
215    }
216
217    #[tokio::test]
218    async fn waterfall_sends_email_for_email_recipient() {
219        let router = WaterfallRouter::new()
220            .add_email(Arc::new(SuccessEmail))
221            .add_sms(Arc::new(SuccessSms));
222
223        let result = router
224            .send_otp("user@test.com", "123456", "TestApp")
225            .await
226            .unwrap();
227        assert_eq!(result.channel, Channel::Email);
228        assert_eq!(result.provider, "test-email");
229    }
230
231    #[tokio::test]
232    async fn waterfall_sends_sms_for_phone_recipient() {
233        let router = WaterfallRouter::new()
234            .add_email(Arc::new(SuccessEmail))
235            .add_sms(Arc::new(SuccessSms));
236
237        let result = router
238            .send_otp("+66812345678", "123456", "TestApp")
239            .await
240            .unwrap();
241        assert_eq!(result.channel, Channel::Sms);
242        assert_eq!(result.provider, "test-sms");
243    }
244
245    #[tokio::test]
246    async fn waterfall_fallback_on_failure() {
247        let router = WaterfallRouter::new()
248            .add_sms(Arc::new(FailSms))
249            .add_sms(Arc::new(SuccessSms));
250
251        let result = router
252            .send_otp("+66812345678", "123456", "TestApp")
253            .await
254            .unwrap();
255        assert_eq!(result.provider, "test-sms");
256    }
257
258    #[tokio::test]
259    async fn waterfall_all_fail_returns_error() {
260        let router = WaterfallRouter::new().add_sms(Arc::new(FailSms));
261
262        let result = router.send_otp("+66812345678", "123456", "TestApp").await;
263        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
264    }
265
266    #[tokio::test]
267    async fn waterfall_empty_router_returns_error() {
268        let router = WaterfallRouter::new();
269        let result = router.send_otp("user@test.com", "123456", "TestApp").await;
270        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
271    }
272
273    #[tokio::test]
274    async fn send_sms_directly() {
275        let router = WaterfallRouter::new()
276            .add_email(Arc::new(SuccessEmail))
277            .add_sms(Arc::new(SuccessSms));
278
279        let msg = SmsMessage {
280            to: "+66812345678".into(),
281            body: "Hi".into(),
282            from: None,
283        };
284        let result = router.send_sms(&msg).await.unwrap();
285        assert_eq!(result.channel, Channel::Sms);
286    }
287
288    #[tokio::test]
289    async fn send_email_directly() {
290        let router = WaterfallRouter::new()
291            .add_email(Arc::new(SuccessEmail))
292            .add_sms(Arc::new(SuccessSms));
293
294        let msg = EmailMessage {
295            to: "user@test.com".into(),
296            subject: "Hi".into(),
297            html_body: "<p>Hi</p>".into(),
298            text_body: "Hi".into(),
299            from: None,
300        };
301        let result = router.send_email(&msg).await.unwrap();
302        assert_eq!(result.channel, Channel::Email);
303    }
304
305    #[tokio::test]
306    async fn send_sms_no_sms_providers_returns_error() {
307        let router = WaterfallRouter::new().add_email(Arc::new(SuccessEmail));
308
309        let msg = SmsMessage {
310            to: "+66812345678".into(),
311            body: "Hi".into(),
312            from: None,
313        };
314        let result = router.send_sms(&msg).await;
315        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
316    }
317
318    #[tokio::test]
319    async fn send_email_no_email_providers_returns_error() {
320        let router = WaterfallRouter::new().add_sms(Arc::new(SuccessSms));
321
322        let msg = EmailMessage {
323            to: "user@test.com".into(),
324            subject: "Hi".into(),
325            html_body: "<p>Hi</p>".into(),
326            text_body: "Hi".into(),
327            from: None,
328        };
329        let result = router.send_email(&msg).await;
330        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
331    }
332
333    #[tokio::test]
334    async fn send_sms_failover_across_providers() {
335        let router = WaterfallRouter::new()
336            .add_sms(Arc::new(FailSms))
337            .add_sms(Arc::new(SuccessSms));
338
339        let msg = SmsMessage {
340            to: "+66812345678".into(),
341            body: "Hi".into(),
342            from: None,
343        };
344        let result = router.send_sms(&msg).await.unwrap();
345        assert_eq!(result.provider, "test-sms");
346    }
347
348    #[tokio::test]
349    async fn send_sms_all_fail_returns_error() {
350        let router = WaterfallRouter::new()
351            .add_sms(Arc::new(FailSms))
352            .add_sms(Arc::new(FailSms));
353
354        let msg = SmsMessage {
355            to: "+66812345678".into(),
356            body: "Hi".into(),
357            from: None,
358        };
359        let result = router.send_sms(&msg).await;
360        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
361    }
362}