use crate::envelope::{self, SoapVersion};
use crate::error::SoapError;
use reqwest::Client as HttpClient;
use serde::{de::DeserializeOwned, Serialize};
use std::collections::HashMap;
pub trait SoapOperation {
type Request;
type Response;
const ACTION: &'static str;
const ENDPOINT: &'static str;
const BODY_ELEMENT: &'static str;
const VERSION: SoapVersion = SoapVersion::V11;
fn build_request_body(
&self,
request: &Self::Request,
) -> Result<(String, String), quick_xml::se::SeError>
where
Self::Request: Serialize,
{
let action = Self::ACTION.to_string();
let xml = quick_xml::se::to_string_with_root(Self::BODY_ELEMENT, request)?;
Ok((action, xml))
}
fn parse_response(&self, response_xml: &str) -> Result<Self::Response, SoapError>
where
Self::Response: DeserializeOwned,
{
if envelope::is_soap_fault(response_xml) {
let (code, message) = envelope::parse_soap_fault(response_xml)
.map_err(|e| SoapError::DeserializeResponse(Box::new(e)))?;
return Err(SoapError::SoapFault { code, message });
}
envelope::deserialize_response::<Self::Response>(response_xml)
.map_err(|e| SoapError::DeserializeResponse(Box::new(e)))
}
}
pub struct SoapClient {
http: HttpClient,
endpoint: String,
default_headers: HashMap<String, String>,
}
impl SoapClient {
pub fn new(endpoint: impl Into<String>) -> Result<Self, SoapError> {
let endpoint = endpoint.into();
reqwest::Url::parse(&endpoint)
.map_err(|_| SoapError::http_status(reqwest::StatusCode::BAD_REQUEST))?;
Ok(Self {
http: HttpClient::new(),
endpoint,
default_headers: HashMap::new(),
})
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.default_headers.insert(key.into(), value.into());
self
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.default_headers.extend(headers);
self
}
pub async fn call<O>(
&self,
operation: &O,
request: &O::Request,
) -> Result<O::Response, SoapError>
where
O: SoapOperation + Send,
O::Request: Serialize + Sync,
O::Response: DeserializeOwned,
{
let (action, body_xml) = operation
.build_request_body(request)
.map_err(SoapError::serialize_request)?;
let xml_body = envelope::build_envelope(O::VERSION, &action, &body_xml);
let content_type = match O::VERSION {
SoapVersion::V11 => "text/xml; charset=utf-8".to_string(),
SoapVersion::V12 => {
format!("application/soap+xml; charset=utf-8; action=\"{action}\"")
}
};
let mut request_builder = self
.http
.post(&self.endpoint)
.header("Content-Type", content_type);
if O::VERSION == SoapVersion::V11 {
request_builder = request_builder.header("SOAPAction", format!("\"{action}\""));
}
for (key, value) in &self.default_headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
let response = request_builder.body(xml_body).send().await?;
let status = response.status();
let text = response.text().await?;
if envelope::is_soap_fault(&text) {
if let Ok((code, message)) = envelope::parse_soap_fault(&text) {
return Err(SoapError::SoapFault { code, message });
}
}
if !status.is_success() {
return Err(SoapError::http_status(status));
}
operation.parse_response(&text)
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
}
impl std::fmt::Debug for SoapClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut ds = f.debug_struct("SoapClient");
ds.field("endpoint", &self.endpoint);
if self.default_headers.is_empty() {
ds.field("default_headers", &"<empty>");
} else {
let pairs: Vec<_> = self
.default_headers
.iter()
.map(|(k, v)| (k.as_str(), v.as_str()))
.collect();
ds.field("default_headers", &pairs);
}
ds.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_client_with_valid_url() {
let client = SoapClient::new("https://example.com/soap").unwrap();
assert_eq!(client.endpoint(), "https://example.com/soap");
}
#[test]
fn rejects_invalid_url() {
SoapClient::new("not-a-url").unwrap_err();
}
#[test]
fn builds_envelope_for_operation() {
let xml = envelope::build_envelope(
SoapVersion::V11,
"GetTemperature",
"<req:GetTemperature><lat>40</lat></req:GetTemperature>",
);
assert!(xml.contains("<soap:Envelope"));
assert!(xml.contains("<Action"));
assert!(xml.contains(">GetTemperature</Action>"));
assert!(xml.contains("<soap:Body"));
assert!(xml.contains("<req:GetTemperature>"));
}
#[test]
fn client_is_debuggable() {
let client = SoapClient::new("https://example.com")
.unwrap()
.with_header("X-Custom", "header");
let debug_str = format!("{client:?}");
assert!(debug_str.contains("SoapClient"));
}
#[test]
fn is_soap_fault_detects_fault() {
assert!(envelope::is_soap_fault("<soap:Fault>code</soap:Fault>"));
assert!(envelope::is_soap_fault("<env:Fault>code</env:Fault>"));
assert!(envelope::is_soap_fault("<Fault xmlns=\"...\">msg</Fault>"));
assert!(!envelope::is_soap_fault(
"<GetTempResponse><temp>72</temp></GetTempResponse>"
));
}
#[test]
fn parse_soap_error() {
let xml = r#"<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<soap:Body>
<soap:Fault xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<faultcode>Server</faultcode>
<faultstring>Invalid credentials</faultstring>
</soap:Fault>
</soap:Body>
</soap:Envelope>"#;
let (code, message) = envelope::parse_soap_fault(xml).unwrap();
assert_eq!(code, "Server");
assert_eq!(message, "Invalid credentials");
}
#[test]
fn extract_body_from_soap_envelope() {
let xml = r#"<soap:Envelope xmlns:soap="http://schemas.xmlsoap.org/soap/envelope/">
<soap:Body>
<GetWeatherResponse>
<temperature>72</temperature>
</GetWeatherResponse>
</soap:Body>
</soap:Envelope>"#;
let body = envelope::extract_body(xml).unwrap();
assert!(body.contains("GetWeatherResponse"));
assert!(body.contains("72"));
}
}