Skip to main content

chorus_core/
client.rs

1use crate::email::EmailSender;
2use crate::error::ChorusError;
3use crate::router::WaterfallRouter;
4use crate::sms::SmsSender;
5use crate::template::Template;
6use crate::types::{EmailMessage, SendResult, SmsMessage};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10/// The main Chorus client — high-level API for sending messages.
11pub struct Chorus {
12    router: WaterfallRouter,
13    templates: HashMap<String, Template>,
14    default_from_email: Option<String>,
15    default_from_sms: Option<String>,
16}
17
18impl Chorus {
19    pub fn builder() -> ChorusBuilder {
20        ChorusBuilder::new()
21    }
22
23    pub async fn send_sms(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
24        let msg = if msg.from.is_none() && self.default_from_sms.is_some() {
25            let mut m = msg.clone();
26            m.from = self.default_from_sms.clone();
27            std::borrow::Cow::Owned(m)
28        } else {
29            std::borrow::Cow::Borrowed(msg)
30        };
31        self.router.send_sms(&msg).await
32    }
33
34    pub async fn send_email(&self, msg: &EmailMessage) -> Result<SendResult, ChorusError> {
35        self.router.send_email(msg).await
36    }
37
38    pub async fn send_email_template(
39        &self,
40        to: &str,
41        template_slug: &str,
42        variables: &HashMap<String, String>,
43    ) -> Result<SendResult, ChorusError> {
44        let tmpl = self
45            .templates
46            .get(template_slug)
47            .ok_or_else(|| ChorusError::TemplateNotFound(template_slug.to_string()))?;
48
49        let rendered = tmpl.render(variables)?;
50
51        let msg = EmailMessage {
52            to: to.to_string(),
53            subject: rendered.subject,
54            html_body: rendered.html_body,
55            text_body: rendered.text_body,
56            from: self.default_from_email.clone(),
57        };
58
59        self.router.send_email(&msg).await
60    }
61
62    pub async fn send_otp(
63        &self,
64        recipient: &str,
65        code: &str,
66        app_name: &str,
67    ) -> Result<SendResult, ChorusError> {
68        self.router.send_otp(recipient, code, app_name).await
69    }
70}
71
72pub struct ChorusBuilder {
73    router: WaterfallRouter,
74    templates: HashMap<String, Template>,
75    default_from_email: Option<String>,
76    default_from_sms: Option<String>,
77}
78
79impl ChorusBuilder {
80    pub fn new() -> Self {
81        Self {
82            router: WaterfallRouter::new(),
83            templates: HashMap::new(),
84            default_from_email: None,
85            default_from_sms: None,
86        }
87    }
88
89    pub fn add_sms_provider(mut self, provider: Arc<dyn SmsSender>) -> Self {
90        self.router = self.router.add_sms(provider);
91        self
92    }
93
94    pub fn add_email_provider(mut self, provider: Arc<dyn EmailSender>) -> Self {
95        self.router = self.router.add_email(provider);
96        self
97    }
98
99    pub fn add_template(mut self, template: Template) -> Self {
100        self.templates.insert(template.slug.clone(), template);
101        self
102    }
103
104    pub fn default_from_email(mut self, from: String) -> Self {
105        self.default_from_email = Some(from);
106        self
107    }
108
109    pub fn default_from_sms(mut self, from: String) -> Self {
110        self.default_from_sms = Some(from);
111        self
112    }
113
114    pub fn build(self) -> Chorus {
115        Chorus {
116            router: self.router,
117            templates: self.templates,
118            default_from_email: self.default_from_email,
119            default_from_sms: self.default_from_sms,
120        }
121    }
122}
123
124impl Default for ChorusBuilder {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::types::{Channel, DeliveryStatus};
134
135    struct TestEmail;
136    #[async_trait::async_trait]
137    impl EmailSender for TestEmail {
138        fn provider_name(&self) -> &str {
139            "test"
140        }
141        async fn send(&self, _msg: &EmailMessage) -> Result<SendResult, ChorusError> {
142            Ok(SendResult {
143                message_id: "e1".into(),
144                provider: "test".into(),
145                channel: Channel::Email,
146                status: DeliveryStatus::Sent,
147                created_at: chrono::Utc::now(),
148            })
149        }
150    }
151
152    struct TestSms;
153    #[async_trait::async_trait]
154    impl SmsSender for TestSms {
155        fn provider_name(&self) -> &str {
156            "test"
157        }
158        async fn send(&self, _msg: &SmsMessage) -> Result<SendResult, ChorusError> {
159            Ok(SendResult {
160                message_id: "s1".into(),
161                provider: "test".into(),
162                channel: Channel::Sms,
163                status: DeliveryStatus::Sent,
164                created_at: chrono::Utc::now(),
165            })
166        }
167        async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
168            Ok(DeliveryStatus::Delivered)
169        }
170    }
171
172    #[tokio::test]
173    async fn chorus_send_email_template() {
174        let chorus = Chorus::builder()
175            .add_email_provider(Arc::new(TestEmail))
176            .add_template(Template {
177                slug: "otp".into(),
178                name: "OTP".into(),
179                subject: "Code: {{code}}".into(),
180                html_body: "<p>{{code}}</p>".into(),
181                text_body: "{{code}}".into(),
182                variables: vec!["code".into()],
183            })
184            .build();
185
186        let mut vars = HashMap::new();
187        vars.insert("code".into(), "123456".into());
188
189        let result = chorus
190            .send_email_template("user@test.com", "otp", &vars)
191            .await
192            .unwrap();
193        assert_eq!(result.channel, Channel::Email);
194    }
195
196    #[tokio::test]
197    async fn chorus_template_not_found() {
198        let chorus = Chorus::builder()
199            .add_email_provider(Arc::new(TestEmail))
200            .build();
201
202        let vars = HashMap::new();
203        let result = chorus
204            .send_email_template("user@test.com", "nonexistent", &vars)
205            .await;
206        assert!(matches!(result, Err(ChorusError::TemplateNotFound(_))));
207    }
208
209    #[tokio::test]
210    async fn chorus_send_otp_email() {
211        let chorus = Chorus::builder()
212            .add_email_provider(Arc::new(TestEmail))
213            .add_sms_provider(Arc::new(TestSms))
214            .build();
215
216        let result = chorus
217            .send_otp("user@test.com", "123456", "App")
218            .await
219            .unwrap();
220        assert_eq!(result.channel, Channel::Email);
221    }
222
223    #[tokio::test]
224    async fn chorus_send_otp_sms() {
225        let chorus = Chorus::builder()
226            .add_email_provider(Arc::new(TestEmail))
227            .add_sms_provider(Arc::new(TestSms))
228            .build();
229
230        let result = chorus
231            .send_otp("+66812345678", "123456", "App")
232            .await
233            .unwrap();
234        assert_eq!(result.channel, Channel::Sms);
235    }
236
237    struct CaptureSms {
238        captured_from: std::sync::Mutex<Option<Option<String>>>,
239    }
240    impl CaptureSms {
241        fn new() -> Self {
242            Self {
243                captured_from: std::sync::Mutex::new(None),
244            }
245        }
246    }
247    #[async_trait::async_trait]
248    impl SmsSender for CaptureSms {
249        fn provider_name(&self) -> &str {
250            "capture"
251        }
252        async fn send(&self, msg: &SmsMessage) -> Result<SendResult, ChorusError> {
253            *self.captured_from.lock().unwrap() = Some(msg.from.clone());
254            Ok(SendResult {
255                message_id: "c1".into(),
256                provider: "capture".into(),
257                channel: Channel::Sms,
258                status: DeliveryStatus::Sent,
259                created_at: chrono::Utc::now(),
260            })
261        }
262        async fn check_status(&self, _id: &str) -> Result<DeliveryStatus, ChorusError> {
263            Ok(DeliveryStatus::Delivered)
264        }
265    }
266
267    #[tokio::test]
268    async fn default_from_sms_applied_when_message_has_none() {
269        let capture = Arc::new(CaptureSms::new());
270        let chorus = Chorus::builder()
271            .add_sms_provider(capture.clone())
272            .default_from_sms("+66800000000".into())
273            .build();
274
275        let msg = SmsMessage {
276            to: "+66812345678".into(),
277            body: "Hi".into(),
278            from: None,
279        };
280        chorus.send_sms(&msg).await.unwrap();
281
282        let captured = capture.captured_from.lock().unwrap().clone().unwrap();
283        assert_eq!(captured, Some("+66800000000".to_string()));
284    }
285
286    #[tokio::test]
287    async fn default_from_sms_not_overridden_when_message_has_from() {
288        let capture = Arc::new(CaptureSms::new());
289        let chorus = Chorus::builder()
290            .add_sms_provider(capture.clone())
291            .default_from_sms("+66800000000".into())
292            .build();
293
294        let msg = SmsMessage {
295            to: "+66812345678".into(),
296            body: "Hi".into(),
297            from: Some("+66899999999".into()),
298        };
299        chorus.send_sms(&msg).await.unwrap();
300
301        let captured = capture.captured_from.lock().unwrap().clone().unwrap();
302        assert_eq!(captured, Some("+66899999999".to_string()));
303    }
304
305    #[test]
306    fn builder_default_creates_empty_builder() {
307        let builder = ChorusBuilder::default();
308        let chorus = builder.build();
309        // Should build without panic — empty but valid
310        assert!(chorus.templates.is_empty());
311    }
312
313    #[tokio::test]
314    async fn chorus_send_sms_without_providers_fails() {
315        let chorus = Chorus::builder().build();
316        let msg = SmsMessage {
317            to: "+66812345678".into(),
318            body: "Hi".into(),
319            from: None,
320        };
321        let result = chorus.send_sms(&msg).await;
322        assert!(matches!(result, Err(ChorusError::AllProvidersFailed)));
323    }
324}