1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::router::WaterfallRouter;
4use crate::sms::SmsSender;
5use crate::template::Template;
6use crate::types::{EmailMessage, SendResult, SmsMessage};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub struct Chorus {
12 router: WaterfallRouter,
13 templates: HashMap<String, Template>,
14 default_from_email: Option<String>,
15 default_from_sms: Option<String>,
16}
17
18impl Chorus {
19 pub fn builder() -> ChorusBuilder {
21 ChorusBuilder::new()
22 }
23
24 pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
26 let msg = if msg.from.is_none() && self.default_from_sms.is_some() {
27 let mut m = msg.clone();
28 m.from = self.default_from_sms.clone();
29 std::borrow::Cow::Owned(m)
30 } else {
31 std::borrow::Cow::Borrowed(msg)
32 };
33 self.router.send_sms(&msg).await
34 }
35
36 pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
38 self.router.send_email(msg).await
39 }
40
41 pub async fn send_email_template(
43 &self,
44 to: &str,
45 template_slug: &str,
46 variables: &HashMap<String, String>,
47 ) -> Result<SendResult, ChorusError> {
48 let tmpl = self
49 .templates
50 .get(template_slug)
51 .ok_or_else(|| ChorusError::TemplateNotFound(template_slug.to_string()))?;
52
53 let rendered = tmpl.render(variables)?;
54
55 let msg = EmailMessage {
56 to: to.to_string(),
57 subject: rendered.subject,
58 html_body: rendered.html_body,
59 text_body: rendered.text_body,
60 from: self.default_from_email.clone(),
61 };
62
63 self.router.send_email(&msg).await
64 }
65
66 pub async fn send_otp(
68 &self,
69 recipient: &str,
70 code: &str,
71 app_name: &str,
72 ) -> Result<SendResult, ChorusError> {
73 self.router.send_otp(recipient, code, app_name).await
74 }
75}
76
77pub struct ChorusBuilder {
79 router: WaterfallRouter,
80 templates: HashMap<String, Template>,
81 default_from_email: Option<String>,
82 default_from_sms: Option<String>,
83}
84
85impl ChorusBuilder {
86 pub fn new() -> Self {
87 Self {
88 router: WaterfallRouter::new(),
89 templates: HashMap::new(),
90 default_from_email: None,
91 default_from_sms: None,
92 }
93 }
94
95 pub fn add_sms_provider(mut self, provider: Arc<dyn SmsSender>) -> Self {
97 self.router = self.router.add_sms(provider);
98 self
99 }
100
101 pub fn add_email_provider(mut self, provider: Arc<dyn EmailSender>) -> Self {
103 self.router = self.router.add_email(provider);
104 self
105 }
106
107 pub fn add_template(mut self, template: Template) -> Self {
109 self.templates.insert(template.slug.clone(), template);
110 self
111 }
112
113 pub fn default_from_email(mut self, from: String) -> Self {
115 self.default_from_email = Some(from);
116 self
117 }
118
119 pub fn default_from_sms(mut self, from: String) -> Self {
121 self.default_from_sms = Some(from);
122 self
123 }
124
125 pub fn build(self) -> Chorus {
127 Chorus {
128 router: self.router,
129 templates: self.templates,
130 default_from_email: self.default_from_email,
131 default_from_sms: self.default_from_sms,
132 }
133 }
134}
135
136impl Default for ChorusBuilder {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use crate::types::{Channel, DeliveryStatus};
146
147 struct TestEmail;
148 #[async_trait::async_trait]
149 impl EmailSender for TestEmail {
150 fn provider_name(&self) -> &str {
151 "test"
152 }
153 async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
154 Ok(SendResult {
155 message_id: "e1".into(),
156 provider: "test".into(),
157 channel: Channel::Email,
158 status: DeliveryStatus::Sent,
159 created_at: chrono::Utc::now(),
160 })
161 }
162 }
163
164 struct TestSms;
165 #[async_trait::async_trait]
166 impl SmsSender for TestSms {
167 fn provider_name(&self) -> &str {
168 "test"
169 }
170 async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
171 Ok(SendResult {
172 message_id: "s1".into(),
173 provider: "test".into(),
174 channel: Channel::Sms,
175 status: DeliveryStatus::Sent,
176 created_at: chrono::Utc::now(),
177 })
178 }
179 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
180 Ok(DeliveryStatus::Delivered)
181 }
182 }
183
184 #[tokio::test]
185 async fn chorus_send_email_template() {
186 let chorus = Chorus::builder()
187 .add_email_provider(Arc::new(TestEmail))
188 .add_template(Template {
189 slug: "otp".into(),
190 name: "OTP".into(),
191 subject: "Code: {{code}}".into(),
192 html_body: "<p>{{code}}</p>".into(),
193 text_body: "{{code}}".into(),
194 variables: vec!["code".into()],
195 })
196 .build();
197
198 let mut vars = HashMap::new();
199 vars.insert("code".into(), "123456".into());
200
201 let result = chorus
202 .send_email_template("user@test.com", "otp", &vars)
203 .await
204 .unwrap();
205 assert_eq!(result.channel, Channel::Email);
206 }
207
208 #[tokio::test]
209 async fn chorus_template_not_found() {
210 let chorus = Chorus::builder()
211 .add_email_provider(Arc::new(TestEmail))
212 .build();
213
214 let vars = HashMap::new();
215 let result = chorus
216 .send_email_template("user@test.com", "nonexistent", &vars)
217 .await;
218 assert!(matches!(result, Err(ChorusError::TemplateNotFound(_))));
219 }
220
221 #[tokio::test]
222 async fn chorus_send_otp_email() {
223 let chorus = Chorus::builder()
224 .add_email_provider(Arc::new(TestEmail))
225 .add_sms_provider(Arc::new(TestSms))
226 .build();
227
228 let result = chorus
229 .send_otp("user@test.com", "123456", "App")
230 .await
231 .unwrap();
232 assert_eq!(result.channel, Channel::Email);
233 }
234
235 #[tokio::test]
236 async fn chorus_send_otp_sms() {
237 let chorus = Chorus::builder()
238 .add_email_provider(Arc::new(TestEmail))
239 .add_sms_provider(Arc::new(TestSms))
240 .build();
241
242 let result = chorus
243 .send_otp("+66812345678", "123456", "App")
244 .await
245 .unwrap();
246 assert_eq!(result.channel, Channel::Sms);
247 }
248
249 struct CaptureSms {
250 captured_from: std::sync::Mutex<Option<Option<String>>>,
251 }
252 impl CaptureSms {
253 fn new() -> Self {
254 Self {
255 captured_from: std::sync::Mutex::new(None),
256 }
257 }
258 }
259 #[async_trait::async_trait]
260 impl SmsSender for CaptureSms {
261 fn provider_name(&self) -> &str {
262 "capture"
263 }
264 async fn send(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
265 *self.captured_from.lock().unwrap() = Some(msg.from.clone());
266 Ok(SendResult {
267 message_id: "c1".into(),
268 provider: "capture".into(),
269 channel: Channel::Sms,
270 status: DeliveryStatus::Sent,
271 created_at: chrono::Utc::now(),
272 })
273 }
274 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
275 Ok(DeliveryStatus::Delivered)
276 }
277 }
278
279 #[tokio::test]
280 async fn default_from_sms_applied_when_message_has_none() {
281 let capture = Arc::new(CaptureSms::new());
282 let chorus = Chorus::builder()
283 .add_sms_provider(capture.clone())
284 .default_from_sms("+66800000000".into())
285 .build();
286
287 let msg = SmsMessage {
288 to: "+66812345678".into(),
289 body: "Hi".into(),
290 from: None,
291 };
292 chorus.send_sms(&msg).await.unwrap();
293
294 let captured = capture.captured_from.lock().unwrap().clone().unwrap();
295 assert_eq!(captured, Some("+66800000000".to_string()));
296 }
297
298 #[tokio::test]
299 async fn default_from_sms_not_overridden_when_message_has_from() {
300 let capture = Arc::new(CaptureSms::new());
301 let chorus = Chorus::builder()
302 .add_sms_provider(capture.clone())
303 .default_from_sms("+66800000000".into())
304 .build();
305
306 let msg = SmsMessage {
307 to: "+66812345678".into(),
308 body: "Hi".into(),
309 from: Some("+66899999999".into()),
310 };
311 chorus.send_sms(&msg).await.unwrap();
312
313 let captured = capture.captured_from.lock().unwrap().clone().unwrap();
314 assert_eq!(captured, Some("+66899999999".to_string()));
315 }
316
317 #[test]
318 fn builder_default_creates_empty_builder() {
319 let builder = ChorusBuilder::default();
320 let chorus = builder.build();
321 assert!(chorus.templates.is_empty());
323 }
324
325 #[tokio::test]
326 async fn chorus_send_sms_without_providers_fails() {
327 let chorus = Chorus::builder().build();
328 let msg = SmsMessage {
329 to: "+66812345678".into(),
330 body: "Hi".into(),
331 from: None,
332 };
333 let result = chorus.send_sms(&msg).await;
334 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
335 }
336}