use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use uuid::Uuid;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
pub enum SmsError {
#[error("http error: {0}")]
Http(String),
#[error("authentication error: {0}")]
Auth(String),
#[error("invalid request: {0}")]
Invalid(String),
#[error("provider error: {0}")]
Provider(String),
#[error("unexpected: {0}")]
Unexpected(String),
}
#[derive(Debug, thiserror::Error)]
pub enum WebhookError {
#[error("provider not found: {0}")]
ProviderNotFound(String),
#[error("signature verification failed: {0}")]
VerificationFailed(String),
#[error("parsing failed: {0}")]
ParseError(String),
#[error("SMS processing error: {0}")]
SmsError(#[from] SmsError),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HttpStatus {
Ok = 200,
BadRequest = 400,
Unauthorized = 401,
NotFound = 404,
InternalServerError = 500,
}
impl HttpStatus {
pub fn as_u16(self) -> u16 {
self as u16
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SendRequest<'a> {
pub to: &'a str,
pub from: &'a str,
pub text: &'a str,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct OwnedSendRequest {
pub to: String,
pub from: String,
pub text: String,
}
impl OwnedSendRequest {
pub fn new(
to: impl Into<String>,
from: impl Into<String>,
text: impl Into<String>,
) -> Self {
Self {
to: to.into(),
from: from.into(),
text: text.into(),
}
}
pub fn as_ref(&self) -> SendRequest<'_> {
SendRequest {
to: &self.to,
from: &self.from,
text: &self.text,
}
}
}
impl<'a> From<SendRequest<'a>> for OwnedSendRequest {
fn from(req: SendRequest<'a>) -> Self {
Self {
to: req.to.to_owned(),
from: req.from.to_owned(),
text: req.text.to_owned(),
}
}
}
impl<'a> From<&'a OwnedSendRequest> for SendRequest<'a> {
fn from(req: &'a OwnedSendRequest) -> Self {
req.as_ref()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SendResponse {
pub id: String,
pub provider: &'static str,
pub raw: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct InboundMessage {
pub id: Option<String>,
pub from: String,
pub to: String,
pub text: String,
pub timestamp: Option<OffsetDateTime>,
pub provider: &'static str,
pub raw: serde_json::Value,
}
#[derive(Debug, Clone)]
pub struct WebhookResult {
pub message: InboundMessage,
pub status: u16,
}
#[derive(Debug, Clone)]
pub struct WebhookResponse {
pub status: HttpStatus,
pub body: String,
pub content_type: String,
}
impl WebhookResponse {
pub fn success(message: InboundMessage) -> Self {
Self {
status: HttpStatus::Ok,
body: serde_json::to_string(&message).unwrap_or_else(|_| "{}".to_string()),
content_type: "application/json".to_string(),
}
}
pub fn error(status: HttpStatus, message: &str) -> Self {
Self {
status,
body: format!(r#"{{"error": "{}"}}"#, message.replace('"', r#"\""#)),
content_type: "application/json".to_string(),
}
}
}
#[async_trait]
pub trait SmsClient: Send + Sync {
async fn send(&self, req: SendRequest<'_>) -> Result<SendResponse, SmsError>;
}
pub fn fallback_id() -> String {
Uuid::new_v4().to_string()
}
pub type Headers = Vec<(String, String)>;
#[async_trait]
pub trait InboundWebhook: Send + Sync {
fn provider(&self) -> &'static str;
fn parse_inbound(&self, headers: &Headers, body: &[u8]) -> Result<InboundMessage, SmsError>;
fn verify(&self, _headers: &Headers, _body: &[u8]) -> Result<(), SmsError> {
Ok(())
}
}
#[derive(Default, Clone)]
pub struct InboundRegistry {
map: Arc<HashMap<&'static str, Arc<dyn InboundWebhook>>>,
}
impl InboundRegistry {
pub fn new() -> Self {
Self {
map: Arc::new(HashMap::new()),
}
}
pub fn with(mut self, hook: Arc<dyn InboundWebhook>) -> Self {
let mut m = (*self.map).clone();
m.insert(hook.provider(), hook);
self.map = Arc::new(m);
self
}
pub fn get(&self, provider: &str) -> Option<Arc<dyn InboundWebhook>> {
self.map.get(provider).cloned()
}
}
#[derive(Clone)]
pub struct SmsRouter {
providers: Arc<HashMap<String, Arc<dyn SmsClient>>>,
default: Option<String>,
}
impl SmsRouter {
pub fn new() -> Self {
Self {
providers: Arc::new(HashMap::new()),
default: None,
}
}
pub fn with(mut self, name: impl Into<String>, client: impl SmsClient + 'static) -> Self {
let name = name.into();
let mut m = (*self.providers).clone();
let first = m.is_empty();
m.insert(name.clone(), Arc::new(client));
self.providers = Arc::new(m);
if first {
self.default = Some(name);
}
self
}
pub fn with_arc(mut self, name: impl Into<String>, client: Arc<dyn SmsClient>) -> Self {
let name = name.into();
let mut m = (*self.providers).clone();
let first = m.is_empty();
m.insert(name.clone(), client);
self.providers = Arc::new(m);
if first {
self.default = Some(name);
}
self
}
pub fn default_provider(mut self, name: impl Into<String>) -> Self {
self.default = Some(name.into());
self
}
pub async fn send_via(
&self,
provider: &str,
req: SendRequest<'_>,
) -> Result<SendResponse, SmsError> {
let client = self
.providers
.get(provider)
.ok_or_else(|| SmsError::Invalid(format!("unknown provider: {}", provider)))?;
client.send(req).await
}
pub fn has_provider(&self, name: &str) -> bool {
self.providers.contains_key(name)
}
pub fn default_provider_name(&self) -> Option<&str> {
self.default.as_deref()
}
}
impl Default for SmsRouter {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SmsClient for SmsRouter {
async fn send(&self, req: SendRequest<'_>) -> Result<SendResponse, SmsError> {
let name = self
.default
.as_deref()
.ok_or_else(|| SmsError::Invalid("no default provider configured".into()))?;
self.send_via(name, req).await
}
}
pub struct FallbackClient {
providers: Vec<Arc<dyn SmsClient>>,
}
impl FallbackClient {
pub fn new(providers: Vec<Arc<dyn SmsClient>>) -> Self {
assert!(!providers.is_empty(), "FallbackClient requires at least one provider");
Self { providers }
}
pub fn from_clients(clients: Vec<Box<dyn SmsClient>>) -> Self {
let providers = clients.into_iter().map(Arc::from).collect();
Self { providers }
}
pub fn len(&self) -> usize {
self.providers.len()
}
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
}
#[async_trait]
impl SmsClient for FallbackClient {
async fn send(&self, req: SendRequest<'_>) -> Result<SendResponse, SmsError> {
let mut errors: Vec<String> = Vec::new();
for provider in &self.providers {
match provider.send(req.clone()).await {
Ok(resp) => return Ok(resp),
Err(e) => {
errors.push(e.to_string());
}
}
}
Err(SmsError::Provider(format!(
"all {} providers failed: [{}]",
self.providers.len(),
errors.join("; ")
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn owned_send_request_new() {
let req = OwnedSendRequest::new("+14155551234", "+10005551234", "Hello");
assert_eq!(req.to, "+14155551234");
assert_eq!(req.from, "+10005551234");
assert_eq!(req.text, "Hello");
}
#[test]
fn owned_send_request_from_string_values() {
let to = String::from("+14155551234");
let from = String::from("+10005551234");
let text = String::from("Hello");
let req = OwnedSendRequest::new(to, from, text);
assert_eq!(req.to, "+14155551234");
}
#[test]
fn owned_send_request_as_ref_roundtrip() {
let owned = OwnedSendRequest::new("+1", "+2", "hi");
let borrowed = owned.as_ref();
assert_eq!(borrowed.to, "+1");
assert_eq!(borrowed.from, "+2");
assert_eq!(borrowed.text, "hi");
}
#[test]
fn owned_send_request_from_send_request() {
let borrowed = SendRequest {
to: "+1",
from: "+2",
text: "msg",
};
let owned: OwnedSendRequest = borrowed.into();
assert_eq!(owned.to, "+1");
assert_eq!(owned.text, "msg");
}
#[test]
fn send_request_from_owned_ref() {
let owned = OwnedSendRequest::new("+1", "+2", "hi");
let borrowed: SendRequest<'_> = (&owned).into();
assert_eq!(borrowed.to, "+1");
}
#[test]
fn owned_send_request_serde_roundtrip() {
let req = OwnedSendRequest::new("+1", "+2", "test");
let json = serde_json::to_string(&req).unwrap();
let deser: OwnedSendRequest = serde_json::from_str(&json).unwrap();
assert_eq!(req, deser);
}
#[test]
fn http_status_values() {
assert_eq!(HttpStatus::Ok.as_u16(), 200);
assert_eq!(HttpStatus::BadRequest.as_u16(), 400);
assert_eq!(HttpStatus::Unauthorized.as_u16(), 401);
assert_eq!(HttpStatus::NotFound.as_u16(), 404);
assert_eq!(HttpStatus::InternalServerError.as_u16(), 500);
}
#[test]
fn webhook_response_success_serializes_message() {
let msg = InboundMessage {
id: Some("msg-1".into()),
from: "+1111".into(),
to: "+2222".into(),
text: "hi".into(),
timestamp: None,
provider: "test",
raw: serde_json::json!({}),
};
let resp = WebhookResponse::success(msg);
assert_eq!(resp.status, HttpStatus::Ok);
assert!(resp.body.contains("msg-1"));
assert_eq!(resp.content_type, "application/json");
}
#[test]
fn webhook_response_error_escapes_quotes() {
let resp = WebhookResponse::error(HttpStatus::BadRequest, r#"bad "input""#);
assert!(resp.body.contains(r#"bad \"input\""#));
}
#[test]
fn inbound_registry_get_returns_none_for_unknown() {
let reg = InboundRegistry::new();
assert!(reg.get("nonexistent").is_none());
}
#[test]
fn sms_error_display() {
let e = SmsError::Http("timeout".into());
assert_eq!(e.to_string(), "http error: timeout");
let e = SmsError::Auth("bad token".into());
assert_eq!(e.to_string(), "authentication error: bad token");
}
#[test]
fn webhook_error_from_sms_error() {
let sms_err = SmsError::Provider("oops".into());
let wh_err: WebhookError = sms_err.into();
assert!(wh_err.to_string().contains("oops"));
}
#[test]
fn fallback_id_is_valid_uuid() {
let id = fallback_id();
assert!(uuid::Uuid::parse_str(&id).is_ok());
}
struct MockClient {
provider_name: &'static str,
}
#[async_trait]
impl SmsClient for MockClient {
async fn send(&self, _req: SendRequest<'_>) -> Result<SendResponse, SmsError> {
Ok(SendResponse {
id: "mock-id".into(),
provider: self.provider_name,
raw: serde_json::json!({"mock": true}),
})
}
}
struct FailingClient {
message: String,
}
#[async_trait]
impl SmsClient for FailingClient {
async fn send(&self, _req: SendRequest<'_>) -> Result<SendResponse, SmsError> {
Err(SmsError::Provider(self.message.clone()))
}
}
fn test_request() -> SendRequest<'static> {
SendRequest {
to: "+14155551234",
from: "+10005551234",
text: "test",
}
}
#[tokio::test]
async fn router_send_via_dispatches_correctly() {
let router = SmsRouter::new()
.with("alpha", MockClient { provider_name: "alpha" })
.with("beta", MockClient { provider_name: "beta" });
let resp = router.send_via("beta", test_request()).await.unwrap();
assert_eq!(resp.provider, "beta");
}
#[tokio::test]
async fn router_send_via_unknown_provider_errors() {
let router = SmsRouter::new()
.with("alpha", MockClient { provider_name: "alpha" });
let err = router.send_via("nope", test_request()).await.unwrap_err();
assert!(err.to_string().contains("unknown provider"));
}
#[tokio::test]
async fn router_default_is_first_registered() {
let router = SmsRouter::new()
.with("first", MockClient { provider_name: "first" })
.with("second", MockClient { provider_name: "second" });
assert_eq!(router.default_provider_name(), Some("first"));
let resp = router.send(test_request()).await.unwrap();
assert_eq!(resp.provider, "first");
}
#[tokio::test]
async fn router_explicit_default_override() {
let router = SmsRouter::new()
.with("first", MockClient { provider_name: "first" })
.with("second", MockClient { provider_name: "second" })
.default_provider("second");
let resp = router.send(test_request()).await.unwrap();
assert_eq!(resp.provider, "second");
}
#[tokio::test]
async fn router_no_default_errors() {
let router = SmsRouter::new();
let err = router.send(test_request()).await.unwrap_err();
assert!(err.to_string().contains("no default provider"));
}
#[test]
fn router_has_provider() {
let router = SmsRouter::new()
.with("plivo", MockClient { provider_name: "plivo" });
assert!(router.has_provider("plivo"));
assert!(!router.has_provider("twilio"));
}
#[tokio::test]
async fn fallback_returns_first_success() {
let client = FallbackClient::new(vec![
Arc::new(MockClient { provider_name: "primary" }),
Arc::new(MockClient { provider_name: "backup" }),
]);
let resp = client.send(test_request()).await.unwrap();
assert_eq!(resp.provider, "primary");
}
#[tokio::test]
async fn fallback_skips_failing_provider() {
let client = FallbackClient::new(vec![
Arc::new(FailingClient { message: "down".into() }),
Arc::new(MockClient { provider_name: "backup" }),
]);
let resp = client.send(test_request()).await.unwrap();
assert_eq!(resp.provider, "backup");
}
#[tokio::test]
async fn fallback_all_fail_returns_summary() {
let client = FallbackClient::new(vec![
Arc::new(FailingClient { message: "err-a".into() }),
Arc::new(FailingClient { message: "err-b".into() }),
]);
let err = client.send(test_request()).await.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("all 2 providers failed"));
assert!(msg.contains("err-a"));
assert!(msg.contains("err-b"));
}
#[test]
fn fallback_len() {
let client = FallbackClient::new(vec![
Arc::new(MockClient { provider_name: "a" }),
Arc::new(MockClient { provider_name: "b" }),
]);
assert_eq!(client.len(), 2);
assert!(!client.is_empty());
}
#[test]
#[should_panic(expected = "at least one provider")]
fn fallback_empty_panics() {
FallbackClient::new(vec![]);
}
}