1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::router::WaterfallRouter;
4use crate::sms::SmsSender;
5use crate::template::Template;
6use crate::types::{EmailMessage, SendResult, SmsMessage};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub struct Chorus {
12 router: WaterfallRouter,
13 templates: HashMap<String, Template>,
14 default_from_email: Option<String>,
15 default_from_sms: Option<String>,
16}
17
18impl Chorus {
19 pub fn builder() -> ChorusBuilder {
20 ChorusBuilder::new()
21 }
22
23 pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
24 let msg = if msg.from.is_none() && self.default_from_sms.is_some() {
25 let mut m = msg.clone();
26 m.from = self.default_from_sms.clone();
27 std::borrow::Cow::Owned(m)
28 } else {
29 std::borrow::Cow::Borrowed(msg)
30 };
31 self.router.send_sms(&msg).await
32 }
33
34 pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
35 self.router.send_email(msg).await
36 }
37
38 pub async fn send_email_template(
39 &self,
40 to: &str,
41 template_slug: &str,
42 variables: &HashMap<String, String>,
43 ) -> Result<SendResult, ChorusError> {
44 let tmpl = self
45 .templates
46 .get(template_slug)
47 .ok_or_else(|| ChorusError::TemplateNotFound(template_slug.to_string()))?;
48
49 let rendered = tmpl.render(variables)?;
50
51 let msg = EmailMessage {
52 to: to.to_string(),
53 subject: rendered.subject,
54 html_body: rendered.html_body,
55 text_body: rendered.text_body,
56 from: self.default_from_email.clone(),
57 };
58
59 self.router.send_email(&msg).await
60 }
61
62 pub async fn send_otp(
63 &self,
64 recipient: &str,
65 code: &str,
66 app_name: &str,
67 ) -> Result<SendResult, ChorusError> {
68 self.router.send_otp(recipient, code, app_name).await
69 }
70}
71
72pub struct ChorusBuilder {
73 router: WaterfallRouter,
74 templates: HashMap<String, Template>,
75 default_from_email: Option<String>,
76 default_from_sms: Option<String>,
77}
78
79impl ChorusBuilder {
80 pub fn new() -> Self {
81 Self {
82 router: WaterfallRouter::new(),
83 templates: HashMap::new(),
84 default_from_email: None,
85 default_from_sms: None,
86 }
87 }
88
89 pub fn add_sms_provider(mut self, provider: Arc<dyn SmsSender>) -> Self {
90 self.router = self.router.add_sms(provider);
91 self
92 }
93
94 pub fn add_email_provider(mut self, provider: Arc<dyn EmailSender>) -> Self {
95 self.router = self.router.add_email(provider);
96 self
97 }
98
99 pub fn add_template(mut self, template: Template) -> Self {
100 self.templates.insert(template.slug.clone(), template);
101 self
102 }
103
104 pub fn default_from_email(mut self, from: String) -> Self {
105 self.default_from_email = Some(from);
106 self
107 }
108
109 pub fn default_from_sms(mut self, from: String) -> Self {
110 self.default_from_sms = Some(from);
111 self
112 }
113
114 pub fn build(self) -> Chorus {
115 Chorus {
116 router: self.router,
117 templates: self.templates,
118 default_from_email: self.default_from_email,
119 default_from_sms: self.default_from_sms,
120 }
121 }
122}
123
124impl Default for ChorusBuilder {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::types::{Channel, DeliveryStatus};
134
135 struct TestEmail;
136 #[async_trait::async_trait]
137 impl EmailSender for TestEmail {
138 fn provider_name(&self) -> &str {
139 "test"
140 }
141 async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
142 Ok(SendResult {
143 message_id: "e1".into(),
144 provider: "test".into(),
145 channel: Channel::Email,
146 status: DeliveryStatus::Sent,
147 created_at: chrono::Utc::now(),
148 })
149 }
150 }
151
152 struct TestSms;
153 #[async_trait::async_trait]
154 impl SmsSender for TestSms {
155 fn provider_name(&self) -> &str {
156 "test"
157 }
158 async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
159 Ok(SendResult {
160 message_id: "s1".into(),
161 provider: "test".into(),
162 channel: Channel::Sms,
163 status: DeliveryStatus::Sent,
164 created_at: chrono::Utc::now(),
165 })
166 }
167 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
168 Ok(DeliveryStatus::Delivered)
169 }
170 }
171
172 #[tokio::test]
173 async fn chorus_send_email_template() {
174 let chorus = Chorus::builder()
175 .add_email_provider(Arc::new(TestEmail))
176 .add_template(Template {
177 slug: "otp".into(),
178 name: "OTP".into(),
179 subject: "Code: {{code}}".into(),
180 html_body: "<p>{{code}}</p>".into(),
181 text_body: "{{code}}".into(),
182 variables: vec!["code".into()],
183 })
184 .build();
185
186 let mut vars = HashMap::new();
187 vars.insert("code".into(), "123456".into());
188
189 let result = chorus
190 .send_email_template("user@test.com", "otp", &vars)
191 .await
192 .unwrap();
193 assert_eq!(result.channel, Channel::Email);
194 }
195
196 #[tokio::test]
197 async fn chorus_template_not_found() {
198 let chorus = Chorus::builder()
199 .add_email_provider(Arc::new(TestEmail))
200 .build();
201
202 let vars = HashMap::new();
203 let result = chorus
204 .send_email_template("user@test.com", "nonexistent", &vars)
205 .await;
206 assert!(matches!(result, Err(ChorusError::TemplateNotFound(_))));
207 }
208
209 #[tokio::test]
210 async fn chorus_send_otp_email() {
211 let chorus = Chorus::builder()
212 .add_email_provider(Arc::new(TestEmail))
213 .add_sms_provider(Arc::new(TestSms))
214 .build();
215
216 let result = chorus
217 .send_otp("user@test.com", "123456", "App")
218 .await
219 .unwrap();
220 assert_eq!(result.channel, Channel::Email);
221 }
222
223 #[tokio::test]
224 async fn chorus_send_otp_sms() {
225 let chorus = Chorus::builder()
226 .add_email_provider(Arc::new(TestEmail))
227 .add_sms_provider(Arc::new(TestSms))
228 .build();
229
230 let result = chorus
231 .send_otp("+66812345678", "123456", "App")
232 .await
233 .unwrap();
234 assert_eq!(result.channel, Channel::Sms);
235 }
236
237 struct CaptureSms {
238 captured_from: std::sync::Mutex<Option<Option<String>>>,
239 }
240 impl CaptureSms {
241 fn new() -> Self {
242 Self {
243 captured_from: std::sync::Mutex::new(None),
244 }
245 }
246 }
247 #[async_trait::async_trait]
248 impl SmsSender for CaptureSms {
249 fn provider_name(&self) -> &str {
250 "capture"
251 }
252 async fn send(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
253 *self.captured_from.lock().unwrap() = Some(msg.from.clone());
254 Ok(SendResult {
255 message_id: "c1".into(),
256 provider: "capture".into(),
257 channel: Channel::Sms,
258 status: DeliveryStatus::Sent,
259 created_at: chrono::Utc::now(),
260 })
261 }
262 async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
263 Ok(DeliveryStatus::Delivered)
264 }
265 }
266
267 #[tokio::test]
268 async fn default_from_sms_applied_when_message_has_none() {
269 let capture = Arc::new(CaptureSms::new());
270 let chorus = Chorus::builder()
271 .add_sms_provider(capture.clone())
272 .default_from_sms("+66800000000".into())
273 .build();
274
275 let msg = SmsMessage {
276 to: "+66812345678".into(),
277 body: "Hi".into(),
278 from: None,
279 };
280 chorus.send_sms(&msg).await.unwrap();
281
282 let captured = capture.captured_from.lock().unwrap().clone().unwrap();
283 assert_eq!(captured, Some("+66800000000".to_string()));
284 }
285
286 #[tokio::test]
287 async fn default_from_sms_not_overridden_when_message_has_from() {
288 let capture = Arc::new(CaptureSms::new());
289 let chorus = Chorus::builder()
290 .add_sms_provider(capture.clone())
291 .default_from_sms("+66800000000".into())
292 .build();
293
294 let msg = SmsMessage {
295 to: "+66812345678".into(),
296 body: "Hi".into(),
297 from: Some("+66899999999".into()),
298 };
299 chorus.send_sms(&msg).await.unwrap();
300
301 let captured = capture.captured_from.lock().unwrap().clone().unwrap();
302 assert_eq!(captured, Some("+66899999999".to_string()));
303 }
304
305 #[test]
306 fn builder_default_creates_empty_builder() {
307 let builder = ChorusBuilder::default();
308 let chorus = builder.build();
309 assert!(chorus.templates.is_empty());
311 }
312
313 #[tokio::test]
314 async fn chorus_send_sms_without_providers_fails() {
315 let chorus = Chorus::builder().build();
316 let msg = SmsMessage {
317 to: "+66812345678".into(),
318 body: "Hi".into(),
319 from: None,
320 };
321 let result = chorus.send_sms(&msg).await;
322 assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
323 }
324}