1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::sms::SmsSender;
4use crate::types::{Channel, EmailMessage, SendResult, SmsMessage};
5use std::sync::Arc;
6
7pub struct RouteStep {
9 pub channel: Channel,
10 sender: RouteSender,
11}
12
13enum RouteSender {
14 Sms(Arc<dyn SmsSender>),
15 Email(Arc<dyn EmailSender>),
16}
17
18pub struct WaterfallRouter {
24 steps: Vec<RouteStep>,
25}
26
27impl WaterfallRouter {
28 pub fn new() -> Self {
30 Self { steps: Vec::new() }
31 }
32
33 pub fn add_sms(mut self, provider: Arc<dyn SmsSender>) -> Self {
35 self.steps.push(RouteStep {
36 channel: Channel::Sms,
37 sender: RouteSender::Sms(provider),
38 });
39 self
40 }
41
42 pub fn add_email(mut self, provider: Arc<dyn EmailSender>) -> Self {
44 self.steps.push(RouteStep {
45 channel: Channel::Email,
46 sender: RouteSender::Email(provider),
47 });
48 self
49 }
50
51 pub async fn send_otp(
53 &self,
54 recipient: &str,
55 code: &str,
56 app_name: &str,
57 ) -> Result<SendResult, ChorusError> {
58 let mut errors = Vec::new();
59
60 for step in &self.steps {
61 let result = match &step.sender {
62 RouteSender::Email(sender) => {
63 if !recipient.contains('@') {
64 continue;
65 }
66 let msg = EmailMessage {
67 to: recipient.to_string(),
68 subject: format!("Your {} verification code", app_name),
69 html_body: format!(
70 "<p>Your verification code is: <strong>{}</strong>. It expires in 5 minutes.</p>",
71 code
72 ),
73 text_body: format!(
74 "Your verification code is: {}. It expires in 5 minutes.",
75 code
76 ),
77 from: None,
78 };
79 sender.send(&msg).await
80 }
81 RouteSender::Sms(sender) => {
82 if recipient.contains('@') {
83 continue;
84 }
85 let msg = SmsMessage {
86 to: recipient.to_string(),
87 body: format!("Your {} code: {} (expires in 5 min)", app_name, code),
88 from: None,
89 };
90 sender.send(&msg).await
91 }
92 };
93
94 match result {
95 Ok(send_result) => {
96 tracing::info!(
97 provider = %send_result.provider,
98 channel = %send_result.channel,
99 "Message sent successfully via waterfall"
100 );
101 return Ok(send_result);
102 }
103 Err(e) => {
104 tracing::warn!(
105 channel = %step.channel,
106 error = %e,
107 "Waterfall step failed, trying next"
108 );
109 errors.push(e);
110 }
111 }
112 }
113
114 Err(ChorusError::AllProvidersFailed)
115 }
116
117 pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
119 for step in &self.steps {
120 if let RouteSender::Sms(sender) = &step.sender {
121 match sender.send(msg).await {
122 Ok(result) => return Ok(result),
123 Err(e) => {
124 tracing::warn!(provider = sender.provider_name(), error = %e, "SMS provider failed, trying next");
125 continue;
126 }
127 }
128 }
129 }
130 Err(ChorusError::AllProvidersFailed)
131 }
132
133 pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
135 for step in &self.steps {
136 if let RouteSender::Email(sender) = &step.sender {
137 match sender.send(msg).await {
138 Ok(result) => return Ok(result),
139 Err(e) => {
140 tracing::warn!(provider = sender.provider_name(), error = %e, "Email provider failed, trying next");
141 continue;
142 }
143 }
144 }
145 }
146 Err(ChorusError::AllProvidersFailed)
147 }
148}
149
150impl Default for WaterfallRouter {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::types::DeliveryStatus;
160
161 struct SuccessSms;
162 #[async_trait::async_trait]
163 impl SmsSender for SuccessSms {
164 fn provider_name(&self) -> &str {
165 "test-sms"
166 }
167 async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
168 Ok(SendResult {
169 message_id: "sms-1".to_string(),
170 provider: "test-sms".to_string(),
171 channel: Channel::Sms,
172 status: DeliveryStatus::Sent,
173 created_at: chrono::Utc::now(),
174 })
175 }
176 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
177 Ok(DeliveryStatus::Delivered)
178 }
179 }
180
181 struct FailSms;
182 #[async_trait::async_trait]
183 impl SmsSender for FailSms {
184 fn provider_name(&self) -> &str {
185 "fail-sms"
186 }
187 async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
188 Err(ChorusError::Provider {
189 provider: "fail-sms".into(),
190 message: "timeout".into(),
191 })
192 }
193 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
194 Ok(DeliveryStatus::Failed {
195 reason: "timeout".into(),
196 })
197 }
198 }
199
200 struct SuccessEmail;
201 #[async_trait::async_trait]
202 impl EmailSender for SuccessEmail {
203 fn provider_name(&self) -> &str {
204 "test-email"
205 }
206 async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
207 Ok(SendResult {
208 message_id: "email-1".to_string(),
209 provider: "test-email".to_string(),
210 channel: Channel::Email,
211 status: DeliveryStatus::Sent,
212 created_at: chrono::Utc::now(),
213 })
214 }
215 }
216
217 #[tokio::test]
218 async fn waterfall_sends_email_for_email_recipient() {
219 let router = WaterfallRouter::new()
220 .add_email(Arc::new(SuccessEmail))
221 .add_sms(Arc::new(SuccessSms));
222
223 let result = router
224 .send_otp("user@test.com", "123456", "TestApp")
225 .await
226 .unwrap();
227 assert_eq!(result.channel, Channel::Email);
228 assert_eq!(result.provider, "test-email");
229 }
230
231 #[tokio::test]
232 async fn waterfall_sends_sms_for_phone_recipient() {
233 let router = WaterfallRouter::new()
234 .add_email(Arc::new(SuccessEmail))
235 .add_sms(Arc::new(SuccessSms));
236
237 let result = router
238 .send_otp("+66812345678", "123456", "TestApp")
239 .await
240 .unwrap();
241 assert_eq!(result.channel, Channel::Sms);
242 assert_eq!(result.provider, "test-sms");
243 }
244
245 #[tokio::test]
246 async fn waterfall_fallback_on_failure() {
247 let router = WaterfallRouter::new()
248 .add_sms(Arc::new(FailSms))
249 .add_sms(Arc::new(SuccessSms));
250
251 let result = router
252 .send_otp("+66812345678", "123456", "TestApp")
253 .await
254 .unwrap();
255 assert_eq!(result.provider, "test-sms");
256 }
257
258 #[tokio::test]
259 async fn waterfall_all_fail_returns_error() {
260 let router = WaterfallRouter::new().add_sms(Arc::new(FailSms));
261
262 let result = router.send_otp("+66812345678", "123456", "TestApp").await;
263 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
264 }
265
266 #[tokio::test]
267 async fn waterfall_empty_router_returns_error() {
268 let router = WaterfallRouter::new();
269 let result = router.send_otp("user@test.com", "123456", "TestApp").await;
270 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
271 }
272
273 #[tokio::test]
274 async fn send_sms_directly() {
275 let router = WaterfallRouter::new()
276 .add_email(Arc::new(SuccessEmail))
277 .add_sms(Arc::new(SuccessSms));
278
279 let msg = SmsMessage {
280 to: "+66812345678".into(),
281 body: "Hi".into(),
282 from: None,
283 };
284 let result = router.send_sms(&msg).await.unwrap();
285 assert_eq!(result.channel, Channel::Sms);
286 }
287
288 #[tokio::test]
289 async fn send_email_directly() {
290 let router = WaterfallRouter::new()
291 .add_email(Arc::new(SuccessEmail))
292 .add_sms(Arc::new(SuccessSms));
293
294 let msg = EmailMessage {
295 to: "user@test.com".into(),
296 subject: "Hi".into(),
297 html_body: "<p>Hi</p>".into(),
298 text_body: "Hi".into(),
299 from: None,
300 };
301 let result = router.send_email(&msg).await.unwrap();
302 assert_eq!(result.channel, Channel::Email);
303 }
304
305 #[tokio::test]
306 async fn send_sms_no_sms_providers_returns_error() {
307 let router = WaterfallRouter::new().add_email(Arc::new(SuccessEmail));
308
309 let msg = SmsMessage {
310 to: "+66812345678".into(),
311 body: "Hi".into(),
312 from: None,
313 };
314 let result = router.send_sms(&msg).await;
315 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
316 }
317
318 #[tokio::test]
319 async fn send_email_no_email_providers_returns_error() {
320 let router = WaterfallRouter::new().add_sms(Arc::new(SuccessSms));
321
322 let msg = EmailMessage {
323 to: "user@test.com".into(),
324 subject: "Hi".into(),
325 html_body: "<p>Hi</p>".into(),
326 text_body: "Hi".into(),
327 from: None,
328 };
329 let result = router.send_email(&msg).await;
330 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
331 }
332
333 #[tokio::test]
334 async fn send_sms_failover_across_providers() {
335 let router = WaterfallRouter::new()
336 .add_sms(Arc::new(FailSms))
337 .add_sms(Arc::new(SuccessSms));
338
339 let msg = SmsMessage {
340 to: "+66812345678".into(),
341 body: "Hi".into(),
342 from: None,
343 };
344 let result = router.send_sms(&msg).await.unwrap();
345 assert_eq!(result.provider, "test-sms");
346 }
347
348 #[tokio::test]
349 async fn send_sms_all_fail_returns_error() {
350 let router = WaterfallRouter::new()
351 .add_sms(Arc::new(FailSms))
352 .add_sms(Arc::new(FailSms));
353
354 let msg = SmsMessage {
355 to: "+66812345678".into(),
356 body: "Hi".into(),
357 from: None,
358 };
359 let result = router.send_sms(&msg).await;
360 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
361 }
362}