use async_trait::async_trait;
#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
use std::collections::HashMap;
use std::fmt::{self, Debug};
use std::sync::Arc;
use crate::error::{Error, Result};
use crate::storage::{
models::{DeliveryStatus, DeliveryType},
Storage,
};
#[async_trait]
pub trait PlainMessageSender: Send + Sync + Debug {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()>;
}
pub struct NodePlainMessageSender {
send_callback: Arc<dyn Fn(String, Vec<String>) -> Result<()> + Send + Sync>,
}
impl Debug for NodePlainMessageSender {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NodePlainMessageSender")
.field("send_callback", &"<function>")
.finish()
}
}
impl NodePlainMessageSender {
pub fn new<F>(callback: F) -> Self
where
F: Fn(String, Vec<String>) -> Result<()> + Send + Sync + 'static,
{
Self {
send_callback: Arc::new(callback),
}
}
}
#[async_trait]
impl PlainMessageSender for NodePlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
(self.send_callback)(packed_message, recipient_dids)
.map_err(|e| Error::Dispatch(format!("Failed to send message: {}", e)))
}
}
#[derive(Debug)]
pub struct HttpPlainMessageSender {
base_url: String,
#[cfg(feature = "reqwest")]
client: reqwest::Client,
#[allow(dead_code)] timeout_ms: u64,
max_retries: u32,
}
impl HttpPlainMessageSender {
pub fn new(base_url: String) -> Self {
Self::with_options(base_url, 30000, 3) }
pub fn with_options(base_url: String, timeout_ms: u64, max_retries: u32) -> Self {
#[cfg(feature = "reqwest")]
{
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_millis(timeout_ms))
.user_agent("TAP-Node/0.1")
.build()
.unwrap_or_default();
Self {
base_url,
client,
timeout_ms,
max_retries,
}
}
#[cfg(not(feature = "reqwest"))]
{
Self {
base_url,
timeout_ms,
max_retries,
}
}
}
pub fn get_endpoint_url(&self, recipient_did: &str) -> String {
let encoded_did = self.url_encode(recipient_did);
format!(
"{}/api/messages/{}",
self.base_url.trim_end_matches('/'),
encoded_did
)
}
fn url_encode(&self, text: &str) -> String {
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
utf8_percent_encode(text, NON_ALPHANUMERIC).to_string()
}
}
#[derive(Debug)]
pub struct WebSocketPlainMessageSender {
base_url: String,
#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
connections: std::sync::Mutex<HashMap<String, tokio::sync::mpsc::Sender<String>>>,
#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
task_handles: std::sync::Mutex<HashMap<String, tokio::task::JoinHandle<()>>>,
}
impl WebSocketPlainMessageSender {
pub fn new(base_url: String) -> Self {
Self::with_options(base_url)
}
pub fn with_options(base_url: String) -> Self {
#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
{
Self {
base_url,
connections: std::sync::Mutex::new(HashMap::new()),
task_handles: std::sync::Mutex::new(HashMap::new()),
}
}
#[cfg(not(all(not(target_arch = "wasm32"), feature = "websocket")))]
{
Self { base_url }
}
}
fn get_endpoint_url(&self, recipient_did: &str) -> String {
let ws_base_url = if self.base_url.starts_with("https://") {
self.base_url.replace("https://", "wss://")
} else if self.base_url.starts_with("http://") {
self.base_url.replace("http://", "ws://")
} else {
self.base_url.clone()
};
let encoded_did = self.url_encode(recipient_did);
format!(
"{}/ws/messages/{}",
ws_base_url.trim_end_matches('/'),
encoded_did
)
}
fn url_encode(&self, text: &str) -> String {
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
utf8_percent_encode(text, NON_ALPHANUMERIC).to_string()
}
#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
async fn ensure_connection(
&self,
recipient: &str,
) -> Result<tokio::sync::mpsc::Sender<String>> {
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::protocol::Message;
{
let connections = self.connections.lock().unwrap();
if let Some(connection) = connections.get(recipient) {
return Ok(connection.clone());
}
}
let endpoint = self.get_endpoint_url(recipient);
log::info!(
"Creating new WebSocket connection to {} at {}",
recipient,
endpoint
);
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(100);
let (ws_stream, _) = match tokio::time::timeout(
std::time::Duration::from_millis(30000),
connect_async(&endpoint),
)
.await
{
Ok(Ok(stream)) => stream,
Ok(Err(e)) => {
return Err(Error::Dispatch(format!(
"Failed to connect to WebSocket endpoint {}: {}",
endpoint, e
)));
}
Err(_) => {
return Err(Error::Dispatch(format!(
"Connection to WebSocket endpoint {} timed out",
endpoint
)));
}
};
log::debug!("WebSocket connection established to {}", recipient);
let (mut write, mut read) = ws_stream.split();
let recipient_clone = recipient.to_string();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
Some(message) = rx.recv() => {
log::debug!("Sending message to {} via WebSocket", recipient_clone);
if let Err(e) = write.send(Message::Text(message)).await {
log::error!("Failed to send WebSocket message to {}: {}", recipient_clone, e);
}
}
result = read.next() => {
match result {
Some(Ok(message)) => {
if let Message::Text(text) = message {
log::debug!("Received WebSocket message from {}: {}", recipient_clone, text);
}
}
Some(Err(e)) => {
log::error!("WebSocket error from {}: {}", recipient_clone, e);
break;
}
None => {
log::info!("WebSocket connection to {} closed", recipient_clone);
break;
}
}
}
}
}
log::info!("WebSocket connection to {} terminated", recipient_clone);
});
{
let mut connections = self.connections.lock().unwrap();
connections.insert(recipient.to_string(), tx.clone());
}
{
let mut task_handles = self.task_handles.lock().unwrap();
task_handles.insert(recipient.to_string(), handle);
}
Ok(tx)
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "websocket"))]
#[async_trait]
impl PlainMessageSender for WebSocketPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
if recipient_dids.is_empty() {
return Err(Error::Dispatch("No recipients specified".to_string()));
}
let mut failures = Vec::new();
for recipient in &recipient_dids {
log::info!("Sending message to {} via WebSocket", recipient);
match self.ensure_connection(recipient).await {
Ok(sender) => {
if let Err(e) = sender.send(packed_message.clone()).await {
let err_msg = format!("Failed to send message to WebSocket task: {}", e);
log::error!("{}", err_msg);
failures.push((recipient.clone(), err_msg));
}
}
Err(e) => {
let err_msg = format!("Failed to establish WebSocket connection: {}", e);
log::error!("{}", err_msg);
failures.push((recipient.clone(), err_msg));
}
}
}
if !failures.is_empty() {
let failure_messages = failures
.iter()
.map(|(did, err)| format!("{}: {}", did, err))
.collect::<Vec<_>>()
.join("; ");
return Err(Error::Dispatch(format!(
"Failed to send message to some recipients via WebSocket: {}",
failure_messages
)));
}
Ok(())
}
}
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
#[async_trait(?Send)]
impl PlainMessageSender for WebSocketPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use web_sys::{MessageEvent, WebSocket};
if recipient_dids.is_empty() {
return Err(Error::Dispatch("No recipients specified".to_string()));
}
let mut failures = Vec::new();
let window = web_sys::window().ok_or_else(|| {
Error::Dispatch("Could not get window object in WASM environment".to_string())
})?;
for recipient in &recipient_dids {
let endpoint = self.get_endpoint_url(recipient);
log::info!(
"Sending message to {} via WebSocket at {} (WASM)",
recipient,
endpoint
);
let (resolve, reject) = js_sys::Promise::new_resolver();
let promise_resolver = resolve.clone();
let promise_rejecter = reject.clone();
let ws = match WebSocket::new(&endpoint) {
Ok(ws) => ws,
Err(err) => {
let err_msg = format!("Failed to create WebSocket: {:?}", err);
log::error!("{}", err_msg);
failures.push((recipient.clone(), err_msg));
continue;
}
};
let onopen_callback = Closure::once(Box::new(move |_: web_sys::Event| {
promise_resolver.resolve(&JsValue::from(true));
}) as Box<dyn FnOnce(web_sys::Event)>);
let onerror_callback = Closure::once(Box::new(move |e: web_sys::Event| {
let err_msg = format!("WebSocket error: {:?}", e);
promise_rejecter.reject(&JsValue::from_str(&err_msg));
}) as Box<dyn FnOnce(web_sys::Event)>);
let message_clone = packed_message.clone();
let onmessage_callback = Closure::wrap(Box::new(move |e: MessageEvent| {
if let Ok(txt) = e.data().dyn_into::<js_sys::JsString>() {
log::debug!("Received message: {}", String::from(txt));
}
}) as Box<dyn FnMut(MessageEvent)>);
ws.set_onopen(Some(onopen_callback.as_ref().unchecked_ref()));
ws.set_onerror(Some(onerror_callback.as_ref().unchecked_ref()));
ws.set_onmessage(Some(onmessage_callback.as_ref().unchecked_ref()));
match JsFuture::from(js_sys::Promise::race(&js_sys::Array::of2(
&js_sys::Promise::resolve(&promise_resolver),
&js_sys::Promise::new(&mut |resolve, _| {
let timeout_closure = Closure::once_into_js(move || {
resolve.call0(&JsValue::NULL).unwrap();
});
window
.set_timeout_with_callback_and_timeout_and_arguments_0(
timeout_closure.as_ref().unchecked_ref(),
30000, )
.unwrap();
}),
)))
.await
{
Ok(_) => {
if let Err(err) = ws.send_with_str(&message_clone) {
let err_msg = format!("Failed to send WebSocket message: {:?}", err);
log::error!("{}", err_msg);
failures.push((recipient.clone(), err_msg));
}
}
Err(err) => {
let err_msg = format!("WebSocket connection failed: {:?}", err);
log::error!("{}", err_msg);
failures.push((recipient.clone(), err_msg));
}
}
onopen_callback.forget();
onerror_callback.forget();
onmessage_callback.forget();
}
if !failures.is_empty() {
let failure_messages = failures
.iter()
.map(|(did, err)| format!("{}: {}", did, err))
.collect::<Vec<_>>()
.join("; ");
return Err(Error::Dispatch(format!(
"Failed to send message to some recipients via WebSocket: {}",
failure_messages
)));
}
Ok(())
}
}
#[cfg(not(any(
all(not(target_arch = "wasm32"), feature = "websocket"),
all(target_arch = "wasm32", feature = "wasm")
)))]
#[async_trait]
impl PlainMessageSender for WebSocketPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
for recipient in &recipient_dids {
let endpoint = self.get_endpoint_url(recipient);
log::info!(
"Would send message to {} via WebSocket at {} (WebSocket not available)",
recipient,
endpoint
);
log::debug!("PlainMessage content: {}", packed_message);
}
log::warn!("WebSocket sender is running without WebSocket features enabled. No actual WebSocket connections will be made.");
Ok(())
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))]
#[async_trait]
impl PlainMessageSender for HttpPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
if recipient_dids.is_empty() {
return Err(Error::Dispatch("No recipients specified".to_string()));
}
let mut failures = Vec::new();
for recipient in &recipient_dids {
let endpoint = self.get_endpoint_url(recipient);
log::info!("Sending message to {} via HTTP at {}", recipient, endpoint);
let mut attempt = 0;
let mut success = false;
let mut last_error = None;
while attempt < self.max_retries && !success {
attempt += 1;
if attempt > 1 {
let backoff_ms = 100 * (2_u64.pow(attempt - 1));
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
}
match self
.client
.post(&endpoint)
.header("Content-Type", "application/didcomm-message+json")
.body(packed_message.clone())
.send()
.await
{
Ok(response) => {
if response.status().is_success() {
log::debug!("Successfully sent message to {}", recipient);
success = true;
} else {
let status = response.status();
let body = response.text().await.unwrap_or_default();
log::warn!(
"Failed to send message to {} (attempt {}/{}): HTTP {} - {}",
recipient,
attempt,
self.max_retries,
status,
body
);
last_error = Some(format!("HTTP error: {} - {}", status, body));
if status.as_u16() == 404 || status.as_u16() == 400 {
break; }
}
}
Err(err) => {
log::warn!(
"Failed to send message to {} (attempt {}/{}): {}",
recipient,
attempt,
self.max_retries,
err
);
last_error = Some(format!("Request error: {}", err));
}
}
}
if !success {
failures.push((
recipient.clone(),
last_error.unwrap_or_else(|| "Unknown error".to_string()),
));
}
}
if !failures.is_empty() {
let failure_messages = failures
.iter()
.map(|(did, err)| format!("{}: {}", did, err))
.collect::<Vec<_>>()
.join("; ");
return Err(Error::Dispatch(format!(
"Failed to send message to some recipients: {}",
failure_messages
)));
}
Ok(())
}
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "reqwest")))]
#[async_trait]
impl PlainMessageSender for HttpPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
for recipient in &recipient_dids {
let endpoint = self.get_endpoint_url(recipient);
log::info!(
"Would send message to {} via HTTP at {} (reqwest not available)",
recipient,
endpoint
);
log::debug!("PlainMessage content: {}", packed_message);
}
log::warn!("HTTP sender is running without reqwest feature enabled. No actual HTTP requests will be made.");
Ok(())
}
}
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
#[async_trait(?Send)]
impl PlainMessageSender for HttpPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
use web_sys::{Request, RequestInit, RequestMode, Response};
if recipient_dids.is_empty() {
return Err(Error::Dispatch("No recipients specified".to_string()));
}
let mut failures = Vec::new();
let window = web_sys::window().ok_or_else(|| {
Error::Dispatch("Could not get window object in WASM environment".to_string())
})?;
for recipient in &recipient_dids {
let endpoint = self.get_endpoint_url(recipient);
log::info!(
"Sending message to {} via HTTP at {} (WASM)",
recipient,
endpoint
);
let mut attempt = 0;
let mut success = false;
let mut last_error = None;
while attempt < self.max_retries && !success {
attempt += 1;
if attempt > 1 {
let backoff_ms = 100 * (2_u64.pow(attempt - 1));
let promise = js_sys::Promise::new(&mut |resolve, _| {
let closure = Closure::once_into_js(move || {
resolve.call0(&JsValue::NULL).unwrap();
});
window
.set_timeout_with_callback_and_timeout_and_arguments_0(
closure.as_ref().unchecked_ref(),
backoff_ms as i32,
)
.unwrap();
});
let _ = JsFuture::from(promise).await;
}
let mut opts = RequestInit::new();
opts.method("POST");
opts.mode(RequestMode::Cors);
opts.body(Some(&JsValue::from_str(&packed_message)));
let request = match Request::new_with_str_and_init(&endpoint, &opts) {
Ok(req) => req,
Err(err) => {
let err_msg = format!("Failed to create request: {:?}", err);
log::warn!("{}", err_msg);
last_error = Some(err_msg);
continue;
}
};
if let Err(err) = request
.headers()
.set("Content-Type", "application/didcomm-message+json")
{
let err_msg = format!("Failed to set headers: {:?}", err);
log::warn!("{}", err_msg);
last_error = Some(err_msg);
continue;
}
let resp_promise = window.fetch_with_request(&request);
let resp_jsvalue = match JsFuture::from(resp_promise).await {
Ok(val) => val,
Err(err) => {
let err_msg = format!("Fetch error: {:?}", err);
log::warn!(
"Failed to send message to {} (attempt {}/{}): {}",
recipient,
attempt,
self.max_retries,
err_msg
);
last_error = Some(err_msg);
continue;
}
};
let response: Response = match resp_jsvalue.dyn_into() {
Ok(resp) => resp,
Err(err) => {
let err_msg = format!("Failed to convert response: {:?}", err);
log::warn!("{}", err_msg);
last_error = Some(err_msg);
continue;
}
};
if response.ok() {
log::debug!("Successfully sent message to {}", recipient);
success = true;
} else {
let status = response.status();
let body_promise = response.text();
let body = match JsFuture::from(body_promise).await {
Ok(text_jsval) => text_jsval.as_string().unwrap_or_default(),
Err(_) => String::from("[Could not read response body]"),
};
let err_msg = format!("HTTP error: {} - {}", status, body);
log::warn!(
"Failed to send message to {} (attempt {}/{}): {}",
recipient,
attempt,
self.max_retries,
err_msg
);
last_error = Some(err_msg);
if status == 404 || status == 400 {
break; }
}
}
if !success {
failures.push((
recipient.clone(),
last_error.unwrap_or_else(|| "Unknown error".to_string()),
));
}
}
if !failures.is_empty() {
let failure_messages = failures
.iter()
.map(|(did, err)| format!("{}: {}", did, err))
.collect::<Vec<_>>()
.join("; ");
return Err(Error::Dispatch(format!(
"Failed to send message to some recipients: {}",
failure_messages
)));
}
Ok(())
}
}
#[cfg(all(target_arch = "wasm32", not(feature = "wasm")))]
#[async_trait(?Send)]
impl PlainMessageSender for HttpPlainMessageSender {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
for recipient in &recipient_dids {
let endpoint = self.get_endpoint_url(recipient);
log::info!(
"Would send message to {} via HTTP at {} (WASM without web-sys)",
recipient,
endpoint
);
log::debug!("PlainMessage content: {}", packed_message);
}
log::warn!("HTTP sender is running in WASM without the web-sys feature enabled. No actual HTTP requests will be made.");
Ok(())
}
}
#[derive(Debug)]
pub struct HttpPlainMessageSenderWithTracking {
http_sender: HttpPlainMessageSender,
storage: Arc<Storage>,
}
impl HttpPlainMessageSenderWithTracking {
pub fn new(base_url: String, storage: Arc<Storage>) -> Self {
Self {
http_sender: HttpPlainMessageSender::new(base_url),
storage,
}
}
pub fn with_options(
base_url: String,
timeout_ms: u64,
max_retries: u32,
storage: Arc<Storage>,
) -> Self {
Self {
http_sender: HttpPlainMessageSender::with_options(base_url, timeout_ms, max_retries),
storage,
}
}
}
#[async_trait]
impl PlainMessageSender for HttpPlainMessageSenderWithTracking {
async fn send(&self, packed_message: String, recipient_dids: Vec<String>) -> Result<()> {
if recipient_dids.is_empty() {
return Err(Error::Dispatch("No recipients specified".to_string()));
}
let message_id = format!("msg_{}", uuid::Uuid::new_v4());
let mut delivery_ids = Vec::new();
for recipient in &recipient_dids {
let delivery_url = Some(self.http_sender.get_endpoint_url(recipient));
match self
.storage
.create_delivery(
&message_id,
&packed_message,
recipient,
delivery_url.as_deref(),
DeliveryType::Https,
)
.await
{
Ok(delivery_id) => {
delivery_ids.push((recipient.clone(), delivery_id));
log::debug!(
"Created delivery record {} for message {} to {}",
delivery_id,
message_id,
recipient
);
}
Err(e) => {
log::error!("Failed to create delivery record for {}: {}", recipient, e);
delivery_ids.push((recipient.clone(), -1)); }
}
}
let delivery_result = self
.http_sender
.send(packed_message, recipient_dids.clone())
.await;
for (_recipient, delivery_id) in delivery_ids {
if delivery_id == -1 {
continue; }
match &delivery_result {
Ok(_) => {
if let Err(e) = self
.storage
.update_delivery_status(
delivery_id,
DeliveryStatus::Success,
Some(200), None,
)
.await
{
log::error!(
"Failed to update delivery record {} to success: {}",
delivery_id,
e
);
} else {
log::debug!("Updated delivery record {} to success", delivery_id);
}
}
Err(e) => {
let error_msg = e.to_string();
let http_status_code = if error_msg.contains("HTTP error: ") {
error_msg
.split("HTTP error: ")
.nth(1)
.and_then(|s| s.split(' ').next())
.and_then(|s| s.parse::<i32>().ok())
} else {
None
};
if let Err(e) = self
.storage
.update_delivery_status(
delivery_id,
DeliveryStatus::Failed,
http_status_code,
Some(&error_msg),
)
.await
{
log::error!(
"Failed to update delivery record {} to failed: {}",
delivery_id,
e
);
} else {
log::debug!(
"Updated delivery record {} to failed with error: {}",
delivery_id,
error_msg
);
}
if let Err(e) = self
.storage
.increment_delivery_retry_count(delivery_id)
.await
{
log::error!(
"Failed to increment retry count for delivery record {}: {}",
delivery_id,
e
);
}
}
}
}
delivery_result
}
}