1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::sms::SmsSender;
4use crate::types::{Channel, EmailMessage, SendResult, SmsMessage};
5use std::sync::Arc;
6
7pub struct RouteStep {
9 pub channel: Channel,
10 sender: RouteSender,
11}
12
13enum RouteSender {
14 Sms(Arc<dyn SmsSender>),
15 Email(Arc<dyn EmailSender>),
16}
17
18pub struct WaterfallRouter {
21 steps: Vec<RouteStep>,
22}
23
24impl WaterfallRouter {
25 pub fn new() -> Self {
26 Self { steps: Vec::new() }
27 }
28
29 pub fn add_sms(mut self, provider: Arc<dyn SmsSender>) -> Self {
30 self.steps.push(RouteStep {
31 channel: Channel::Sms,
32 sender: RouteSender::Sms(provider),
33 });
34 self
35 }
36
37 pub fn add_email(mut self, provider: Arc<dyn EmailSender>) -> Self {
38 self.steps.push(RouteStep {
39 channel: Channel::Email,
40 sender: RouteSender::Email(provider),
41 });
42 self
43 }
44
45 pub async fn send_otp(
48 &self,
49 recipient: &str,
50 code: &str,
51 app_name: &str,
52 ) -> Result<SendResult, ChorusError> {
53 let mut errors = Vec::new();
54
55 for step in &self.steps {
56 let result = match &step.sender {
57 RouteSender::Email(sender) => {
58 if !recipient.contains('@') {
59 continue;
60 }
61 let msg = EmailMessage {
62 to: recipient.to_string(),
63 subject: format!("Your {} verification code", app_name),
64 html_body: format!(
65 "<p>Your verification code is: <strong>{}</strong>. It expires in 5 minutes.</p>",
66 code
67 ),
68 text_body: format!(
69 "Your verification code is: {}. It expires in 5 minutes.",
70 code
71 ),
72 from: None,
73 };
74 sender.send(&msg).await
75 }
76 RouteSender::Sms(sender) => {
77 if recipient.contains('@') {
78 continue;
79 }
80 let msg = SmsMessage {
81 to: recipient.to_string(),
82 body: format!("Your {} code: {} (expires in 5 min)", app_name, code),
83 from: None,
84 };
85 sender.send(&msg).await
86 }
87 };
88
89 match result {
90 Ok(send_result) => {
91 tracing::info!(
92 provider = %send_result.provider,
93 channel = %send_result.channel,
94 "Message sent successfully via waterfall"
95 );
96 return Ok(send_result);
97 }
98 Err(e) => {
99 tracing::warn!(
100 channel = %step.channel,
101 error = %e,
102 "Waterfall step failed, trying next"
103 );
104 errors.push(e);
105 }
106 }
107 }
108
109 Err(ChorusError::AllProvidersFailed)
110 }
111
112 pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
114 for step in &self.steps {
115 if let RouteSender::Sms(sender) = &step.sender {
116 match sender.send(msg).await {
117 Ok(result) => return Ok(result),
118 Err(e) => {
119 tracing::warn!(provider = sender.provider_name(), error = %e, "SMS provider failed, trying next");
120 continue;
121 }
122 }
123 }
124 }
125 Err(ChorusError::AllProvidersFailed)
126 }
127
128 pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
130 for step in &self.steps {
131 if let RouteSender::Email(sender) = &step.sender {
132 match sender.send(msg).await {
133 Ok(result) => return Ok(result),
134 Err(e) => {
135 tracing::warn!(provider = sender.provider_name(), error = %e, "Email provider failed, trying next");
136 continue;
137 }
138 }
139 }
140 }
141 Err(ChorusError::AllProvidersFailed)
142 }
143}
144
145impl Default for WaterfallRouter {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::types::DeliveryStatus;
155
156 struct SuccessSms;
157 #[async_trait::async_trait]
158 impl SmsSender for SuccessSms {
159 fn provider_name(&self) -> &str {
160 "test-sms"
161 }
162 async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
163 Ok(SendResult {
164 message_id: "sms-1".to_string(),
165 provider: "test-sms".to_string(),
166 channel: Channel::Sms,
167 status: DeliveryStatus::Sent,
168 created_at: chrono::Utc::now(),
169 })
170 }
171 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
172 Ok(DeliveryStatus::Delivered)
173 }
174 }
175
176 struct FailSms;
177 #[async_trait::async_trait]
178 impl SmsSender for FailSms {
179 fn provider_name(&self) -> &str {
180 "fail-sms"
181 }
182 async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
183 Err(ChorusError::Provider {
184 provider: "fail-sms".into(),
185 message: "timeout".into(),
186 })
187 }
188 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
189 Ok(DeliveryStatus::Failed {
190 reason: "timeout".into(),
191 })
192 }
193 }
194
195 struct SuccessEmail;
196 #[async_trait::async_trait]
197 impl EmailSender for SuccessEmail {
198 fn provider_name(&self) -> &str {
199 "test-email"
200 }
201 async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
202 Ok(SendResult {
203 message_id: "email-1".to_string(),
204 provider: "test-email".to_string(),
205 channel: Channel::Email,
206 status: DeliveryStatus::Sent,
207 created_at: chrono::Utc::now(),
208 })
209 }
210 }
211
212 #[tokio::test]
213 async fn waterfall_sends_email_for_email_recipient() {
214 let router = WaterfallRouter::new()
215 .add_email(Arc::new(SuccessEmail))
216 .add_sms(Arc::new(SuccessSms));
217
218 let result = router
219 .send_otp("user@test.com", "123456", "TestApp")
220 .await
221 .unwrap();
222 assert_eq!(result.channel, Channel::Email);
223 assert_eq!(result.provider, "test-email");
224 }
225
226 #[tokio::test]
227 async fn waterfall_sends_sms_for_phone_recipient() {
228 let router = WaterfallRouter::new()
229 .add_email(Arc::new(SuccessEmail))
230 .add_sms(Arc::new(SuccessSms));
231
232 let result = router
233 .send_otp("+66812345678", "123456", "TestApp")
234 .await
235 .unwrap();
236 assert_eq!(result.channel, Channel::Sms);
237 assert_eq!(result.provider, "test-sms");
238 }
239
240 #[tokio::test]
241 async fn waterfall_fallback_on_failure() {
242 let router = WaterfallRouter::new()
243 .add_sms(Arc::new(FailSms))
244 .add_sms(Arc::new(SuccessSms));
245
246 let result = router
247 .send_otp("+66812345678", "123456", "TestApp")
248 .await
249 .unwrap();
250 assert_eq!(result.provider, "test-sms");
251 }
252
253 #[tokio::test]
254 async fn waterfall_all_fail_returns_error() {
255 let router = WaterfallRouter::new().add_sms(Arc::new(FailSms));
256
257 let result = router.send_otp("+66812345678", "123456", "TestApp").await;
258 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
259 }
260
261 #[tokio::test]
262 async fn waterfall_empty_router_returns_error() {
263 let router = WaterfallRouter::new();
264 let result = router.send_otp("user@test.com", "123456", "TestApp").await;
265 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
266 }
267
268 #[tokio::test]
269 async fn send_sms_directly() {
270 let router = WaterfallRouter::new()
271 .add_email(Arc::new(SuccessEmail))
272 .add_sms(Arc::new(SuccessSms));
273
274 let msg = SmsMessage {
275 to: "+66812345678".into(),
276 body: "Hi".into(),
277 from: None,
278 };
279 let result = router.send_sms(&msg).await.unwrap();
280 assert_eq!(result.channel, Channel::Sms);
281 }
282
283 #[tokio::test]
284 async fn send_email_directly() {
285 let router = WaterfallRouter::new()
286 .add_email(Arc::new(SuccessEmail))
287 .add_sms(Arc::new(SuccessSms));
288
289 let msg = EmailMessage {
290 to: "user@test.com".into(),
291 subject: "Hi".into(),
292 html_body: "<p>Hi</p>".into(),
293 text_body: "Hi".into(),
294 from: None,
295 };
296 let result = router.send_email(&msg).await.unwrap();
297 assert_eq!(result.channel, Channel::Email);
298 }
299
300 #[tokio::test]
301 async fn send_sms_no_sms_providers_returns_error() {
302 let router = WaterfallRouter::new().add_email(Arc::new(SuccessEmail));
303
304 let msg = SmsMessage {
305 to: "+66812345678".into(),
306 body: "Hi".into(),
307 from: None,
308 };
309 let result = router.send_sms(&msg).await;
310 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
311 }
312
313 #[tokio::test]
314 async fn send_email_no_email_providers_returns_error() {
315 let router = WaterfallRouter::new().add_sms(Arc::new(SuccessSms));
316
317 let msg = EmailMessage {
318 to: "user@test.com".into(),
319 subject: "Hi".into(),
320 html_body: "<p>Hi</p>".into(),
321 text_body: "Hi".into(),
322 from: None,
323 };
324 let result = router.send_email(&msg).await;
325 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
326 }
327
328 #[tokio::test]
329 async fn send_sms_failover_across_providers() {
330 let router = WaterfallRouter::new()
331 .add_sms(Arc::new(FailSms))
332 .add_sms(Arc::new(SuccessSms));
333
334 let msg = SmsMessage {
335 to: "+66812345678".into(),
336 body: "Hi".into(),
337 from: None,
338 };
339 let result = router.send_sms(&msg).await.unwrap();
340 assert_eq!(result.provider, "test-sms");
341 }
342
343 #[tokio::test]
344 async fn send_sms_all_fail_returns_error() {
345 let router = WaterfallRouter::new()
346 .add_sms(Arc::new(FailSms))
347 .add_sms(Arc::new(FailSms));
348
349 let msg = SmsMessage {
350 to: "+66812345678".into(),
351 body: "Hi".into(),
352 from: None,
353 };
354 let result = router.send_sms(&msg).await;
355 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
356 }
357}