use crate::envelope;
use crate::error::SoapError;
use reqwest::Client as HttpClient;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
pub trait SoapOperation {
type Request;
type Response;
const ACTION: &'static str;
const ENDPOINT: &'static str;
const BODY_ELEMENT: &'static str;
fn build_request_body(
&self,
request: &Self::Request,
) -> Result<(String, String), quick_xml::se::SeError>
where
Self::Request: Serialize,
{
let action = Self::ACTION.to_string();
let xml = quick_xml::se::to_string_with_root(Self::BODY_ELEMENT, request)?;
Ok((action, xml))
}
fn parse_response(&self, response_xml: &str) -> Result<Self::Response, SoapError>
where
Self::Response: DeserializeOwned,
{
if is_soap_fault(response_xml) {
let (code, message) = parse_soap_error(response_xml)?;
return Err(SoapError::SoapFault { code, message });
}
envelope::deserialize_response::<Self::Response>(response_xml)
.map_err(|e| SoapError::DeserializeResponse(Box::new(e)))
}
}
fn is_soap_fault(xml: &str) -> bool {
xml.contains("<soap:Fault") || xml.contains("<Fault xmlns=")
}
fn parse_soap_error(xml: &str) -> Result<(String, String), SoapError> {
#[derive(Debug, serde::Deserialize)]
struct Fault {
faultcode: Option<String>,
faultstring: Option<FaultString>,
}
#[derive(Debug, serde::Deserialize)]
struct FaultString {
#[serde(rename = "$text")]
value: String,
}
let fault: Fault = envelope::deserialize_response(xml)
.map_err(|e| SoapError::DeserializeResponse(Box::new(e)))?;
Ok((
fault.faultcode.unwrap_or_else(|| "unknown".to_string()),
fault
.faultstring
.map(|s| s.value)
.unwrap_or_else(|| "no details".to_string()),
))
}
pub struct SoapClient {
http: HttpClient,
endpoint: String,
default_headers: HashMap<String, String>,
}
impl SoapClient {
pub fn new(endpoint: impl Into<String>) -> Result<Self, SoapError> {
let endpoint = endpoint.into();
if !endpoint.starts_with("http://") && !endpoint.starts_with("https://") {
return Err(SoapError::http(reqwest::StatusCode::BAD_REQUEST));
}
Ok(Self {
http: HttpClient::new(),
endpoint,
default_headers: HashMap::new(),
})
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.default_headers.insert(key.into(), value.into());
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.default_headers.extend(headers);
self
}
pub async fn call<O>(
&self,
operation: &O,
request: &O::Request,
) -> Result<O::Response, SoapError>
where
O: SoapOperation + Send,
O::Request: Serialize + Sync,
O::Response: DeserializeOwned,
{
let (action, body_xml) = operation
.build_request_body(request)
.map_err(SoapError::serialize_request)?;
let xml_body = build_envelope(&action, &body_xml);
let mut request_builder = self
.http
.post(&self.endpoint)
.header("Content-Type", "text/xml; charset=utf-8");
for (key, value) in &self.default_headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
let response = request_builder.body(xml_body).send().await?;
if !response.status().is_success() {
let status = response.status();
return Err(SoapError::http(status));
}
let text = response.text().await?;
operation.parse_response(&text)
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
}
impl std::fmt::Debug for SoapClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("SoapClient");
ds.field("endpoint", &self.endpoint);
if self.default_headers.is_empty() {
ds.field("default_headers", &"<empty>");
} else {
let pairs: Vec<_> = self
.default_headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
ds.field("default_headers", &pairs);
}
ds.finish()
}
}
fn build_envelope(action: &str, body_xml: &str) -> String {
format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<soap:Header>
<Action soap:mustUnderstand="true" xmlns="http://schemas.xmlsoap.org/ws/2004/08/addressing">{action}</Action>
</soap:Header>
<soap:Body>
{body}
</soap:Body>
</soap:Envelope>"#,
body = body_xml,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_client_with_valid_url() {
let client = SoapClient::new("https://example.com/soap").unwrap();
assert_eq!(client.endpoint(), "https://example.com/soap");
}
#[test]
fn rejects_invalid_url() {
SoapClient::new("not-a-url").unwrap_err();
}
#[test]
fn builds_envelope_for_operation() {
let xml = build_envelope(
"GetTemperature",
"<req:GetTemperature><lat>40</lat></req:GetTemperature>",
);
assert!(xml.contains("<soap:Envelope"));
assert!(xml.contains("<Action"));
assert!(xml.contains(">GetTemperature</Action"));
assert!(xml.contains("<soap:Body"));
assert!(xml.contains("<req:GetTemperature>"));
}
#[test]
fn client_is_debuggable() {
let client = SoapClient::new("https://example.com")
.unwrap()
.with_header("X-Custom", "header");
let debug_str = format!("{client:?}");
assert!(debug_str.contains("SoapClient"));
}
#[test]
fn is_soap_fault_detects_fault() {
assert!(is_soap_fault("<soap:Fault>code</soap:Fault>"));
assert!(is_soap_fault("<Fault xmlns=\"...\">msg</Fault>"));
assert!(!is_soap_fault(
"<GetTempResponse><temp>72</temp></GetTempResponse>"
));
}
#[test]
fn parse_soap_error() {
let xml = r#"<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<soap:Body>
<soap:Fault xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<faultcode>Server</faultcode>
<faultstring>Invalid credentials</faultstring>
</soap:Fault>
</soap:Body>
</soap:Envelope>"#;
let (code, message) = super::parse_soap_error(xml).unwrap();
assert_eq!(code, "Server");
assert_eq!(message, "Invalid credentials");
}
#[test]
fn extract_body_from_soap_envelope() {
let xml = r#"<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<soap:Body>
<GetWeatherResponse>
<temperature>72</temperature>
</GetWeatherResponse>
</soap:Body>
</soap:Envelope>"#;
let body = envelope::extract_body(xml).unwrap();
assert!(body.contains("GetWeatherResponse"));
assert!(body.contains("72"));
}
}