use crate::wallet::wire::{WalletCall, WalletWire};
use crate::Error;
use reqwest::Client;
pub struct HttpWalletWire {
client: Client,
base_url: String,
originator: Option<String>,
}
impl HttpWalletWire {
pub fn new(originator: Option<String>, base_url: Option<String>) -> Self {
Self {
client: Client::new(),
base_url: base_url.unwrap_or_else(|| super::DEFAULT_WIRE_URL.to_string()),
originator,
}
}
pub fn with_client(
client: Client,
originator: Option<String>,
base_url: Option<String>,
) -> Self {
Self {
client,
base_url: base_url.unwrap_or_else(|| super::DEFAULT_WIRE_URL.to_string()),
originator,
}
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn originator(&self) -> Option<&str> {
self.originator.as_deref()
}
}
#[async_trait::async_trait]
impl WalletWire for HttpWalletWire {
async fn transmit_to_wallet(&self, message: &[u8]) -> Result<Vec<u8>, Error> {
if message.is_empty() {
return Err(Error::WalletError("empty message".to_string()));
}
let call_code = message[0];
let call = WalletCall::try_from(call_code)?;
let call_name = call.method_name();
if message.len() < 2 {
return Err(Error::WalletError(
"message too short for originator length".to_string(),
));
}
let originator_len = message[1] as usize;
if message.len() < 2 + originator_len {
return Err(Error::WalletError(
"message too short for originator".to_string(),
));
}
let originator = if originator_len > 0 {
String::from_utf8(message[2..2 + originator_len].to_vec())
.map_err(|_| Error::WalletError("invalid originator UTF-8".to_string()))?
} else {
String::new()
};
let payload = &message[2 + originator_len..];
let url = format!("{}/{}", self.base_url, call_name);
let mut request = self
.client
.post(&url)
.header("Content-Type", "application/octet-stream")
.body(payload.to_vec());
let origin = if !originator.is_empty() {
Some(to_origin_header(&originator))
} else {
self.originator.as_ref().map(|o| to_origin_header(o))
};
if let Some(origin) = origin {
request = request.header("Origin", origin);
}
let response = request
.send()
.await
.map_err(|e| Error::WalletError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Error::WalletError(format!(
"HTTP {} for {}: {}",
status, call_name, body
)));
}
let bytes = response
.bytes()
.await
.map_err(|e| Error::WalletError(format!("failed to read response body: {}", e)))?;
Ok(bytes.to_vec())
}
}
fn to_origin_header(originator: &str) -> String {
if originator.starts_with("http://") || originator.starts_with("https://") {
originator.to_string()
} else {
format!("http://{}", originator)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_origin_header() {
assert_eq!(to_origin_header("example.com"), "http://example.com");
assert_eq!(to_origin_header("http://example.com"), "http://example.com");
assert_eq!(
to_origin_header("https://example.com"),
"https://example.com"
);
}
#[test]
fn test_default_url() {
let wire = HttpWalletWire::new(None, None);
assert_eq!(wire.base_url(), "http://localhost:3301");
}
#[test]
fn test_custom_url() {
let wire = HttpWalletWire::new(None, Some("https://wallet.example.com".into()));
assert_eq!(wire.base_url(), "https://wallet.example.com");
}
#[test]
fn test_originator() {
let wire = HttpWalletWire::new(Some("myapp.example.com".into()), None);
assert_eq!(wire.originator(), Some("myapp.example.com"));
let wire = HttpWalletWire::new(None, None);
assert_eq!(wire.originator(), None);
}
}