use std::sync::Arc;
use std::time::Duration;
use reqwest::header::CONTENT_TYPE;
use secrecy::ExposeSecret;
use tokio::sync::Mutex;
use tracing::{debug, trace};
use zeroize::Zeroizing;
use crate::auth::AuthTransport;
use crate::auth::basic::BasicAuth;
use crate::auth::certificate::CertificateAuth;
#[cfg(feature = "credssp")]
use crate::auth::credssp::CredSspAuth;
use crate::auth::kerberos::KerberosAuth;
use crate::auth::ntlm::{self as ntlm_auth, NtlmAuth};
use crate::config::{AuthMethod, EncryptionMode, WinrmConfig, WinrmCredentials};
use crate::error::WinrmError;
use crate::ntlm::NtlmSession;
use crate::soap;
use crate::tls::CertHandle;
struct NtlmSessionCache {
host: String,
session: NtlmSession,
}
pub(crate) struct HttpTransport {
http: reqwest::Client,
config: WinrmConfig,
credentials: WinrmCredentials,
cert_handle: Option<CertHandle>,
ntlm_cache: Mutex<Option<NtlmSessionCache>>,
}
impl HttpTransport {
#[tracing::instrument(level = "debug", skip(credentials))]
pub(crate) fn new(
config: WinrmConfig,
credentials: WinrmCredentials,
) -> Result<Self, WinrmError> {
let mut builder = reqwest::Client::builder()
.danger_accept_invalid_certs(config.accept_invalid_certs)
.connect_timeout(Duration::from_secs(config.connect_timeout_secs))
.timeout(Duration::from_secs(config.operation_timeout_secs + 10))
.http1_only()
.tcp_keepalive(Duration::from_secs(60))
.pool_max_idle_per_host(10)
.pool_idle_timeout(Duration::from_secs(90))
.user_agent(
config
.user_agent
.as_deref()
.unwrap_or(concat!("winrm-rs/", env!("CARGO_PKG_VERSION"))),
);
if matches!(config.auth_method, AuthMethod::Certificate) {
let cert_path = config.client_cert_pem.as_deref().ok_or_else(|| {
WinrmError::AuthFailed("Certificate auth requires client_cert_pem".into())
})?;
let key_path = config.client_key_pem.as_deref().ok_or_else(|| {
WinrmError::AuthFailed("Certificate auth requires client_key_pem".into())
})?;
let cert_pem = std::fs::read(cert_path).map_err(|e| {
WinrmError::AuthFailed(format!("failed to read client cert {cert_path}: {e}"))
})?;
let key_pem = std::fs::read(key_path).map_err(|e| {
WinrmError::AuthFailed(format!("failed to read client key {key_path}: {e}"))
})?;
let mut combined = cert_pem;
combined.extend_from_slice(b"\n");
combined.extend_from_slice(&key_pem);
let identity = reqwest::Identity::from_pem(&combined).map_err(WinrmError::Http)?;
builder = builder.identity(identity);
}
if let Some(ref proxy_url) = config.proxy {
let proxy = reqwest::Proxy::all(proxy_url).map_err(WinrmError::Http)?;
builder = builder.proxy(proxy);
}
if matches!(config.auth_method, AuthMethod::Basic) && !config.use_tls {
tracing::warn!(
"Basic auth over HTTP transmits credentials in cleartext — use HTTPS in production"
);
}
if matches!(config.auth_method, AuthMethod::CredSsp) && !config.use_tls {
return Err(WinrmError::AuthFailed(
"CredSSP requires HTTPS (set use_tls = true)".into(),
));
}
let cert_handle = if config.use_tls {
let _ = rustls::crypto::ring::default_provider().install_default();
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let inner_verifier: Arc<dyn rustls::client::danger::ServerCertVerifier> =
if config.accept_invalid_certs {
Arc::new(crate::tls::NoVerifier)
} else {
rustls::client::WebPkiServerVerifier::builder(Arc::new(root_store))
.build()
.map_err(|e| WinrmError::AuthFailed(format!("TLS verifier error: {e}")))?
};
let capturing_verifier = crate::tls::CertCapturingVerifier::new(inner_verifier);
let handle = capturing_verifier.cert_handle();
let tls_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(capturing_verifier))
.with_no_client_auth();
builder = builder.tls_backend_preconfigured(tls_config);
Some(handle)
} else {
None
};
let http = builder.build().map_err(WinrmError::Http)?;
Ok(Self {
http,
config,
credentials,
cert_handle,
ntlm_cache: Mutex::new(None),
})
}
pub(crate) fn endpoint(&self, host: &str) -> String {
let scheme = if self.config.use_tls { "https" } else { "http" };
format!("{scheme}://{host}:{}/wsman", self.config.port)
}
pub(crate) fn config(&self) -> &WinrmConfig {
&self.config
}
#[tracing::instrument(level = "debug", skip(self, body))]
async fn send_soap(&self, host: &str, body: String) -> Result<String, WinrmError> {
let url = self.endpoint(host);
debug!(url = %url, "sending WinRM SOAP request");
trace!(body = %body, "SOAP request body");
let response_text: String = match &self.config.auth_method {
AuthMethod::Basic => {
let auth = BasicAuth::new(
&self.credentials.username,
self.credentials.password.expose_secret(),
);
auth.send_authenticated(&self.http, &url, body).await?
}
AuthMethod::Ntlm => {
let encrypt = match self.config.encryption {
EncryptionMode::Always => true,
EncryptionMode::Never => false,
EncryptionMode::Auto => !self.config.use_tls,
};
if encrypt {
self.send_ntlm_sealed(host, &url, body).await?
} else {
let auth = NtlmAuth {
username: self.credentials.username.clone(),
password: Zeroizing::new(
self.credentials.password.expose_secret().to_string(),
),
domain: self.credentials.domain.clone(),
cert_handle: self.cert_handle.clone(),
};
auth.send_authenticated(&self.http, &url, body).await?
}
}
AuthMethod::Kerberos => {
let host_part = host.split(':').next().unwrap_or(host);
let auth = KerberosAuth {
service_principal: format!("HTTP/{host_part}"),
};
auth.send_authenticated(&self.http, &url, body).await?
}
AuthMethod::Certificate => {
let auth = CertificateAuth;
auth.send_authenticated(&self.http, &url, body).await?
}
#[cfg(feature = "credssp")]
AuthMethod::CredSsp => {
let auth = CredSspAuth {
username: self.credentials.username.clone(),
password: Zeroizing::new(self.credentials.password.expose_secret().to_string()),
domain: self.credentials.domain.clone(),
cert_handle: self.cert_handle.clone(),
};
auth.send_authenticated(&self.http, &url, body).await?
}
#[cfg(not(feature = "credssp"))]
AuthMethod::CredSsp => {
return Err(WinrmError::AuthFailed(
"CredSSP authentication requires the `credssp` cargo feature".into(),
));
}
};
trace!(response = %response_text, "SOAP response body");
soap::check_soap_fault(&response_text).map_err(WinrmError::Soap)?;
Ok(response_text)
}
async fn send_ntlm_sealed(
&self,
host: &str,
url: &str,
body: String,
) -> Result<String, WinrmError> {
{
let mut cache = self.ntlm_cache.lock().await;
if let Some(ref mut c) = *cache {
if c.host == host {
match self.send_sealed_body(&mut c.session, url, &body).await {
Ok(resp) => return Ok(resp),
Err(WinrmError::AuthFailed(_)) => {
debug!("NTLM sealed session expired, re-authenticating");
*cache = None;
}
Err(e) => return Err(e),
}
} else {
*cache = None;
}
}
}
let auth = NtlmAuth {
username: self.credentials.username.clone(),
password: Zeroizing::new(self.credentials.password.expose_secret().to_string()),
domain: self.credentials.domain.clone(),
cert_handle: self.cert_handle.clone(),
};
let (response, session_key) = auth.handshake_and_send(&self.http, url, &body).await?;
let session = NtlmSession::from_auth(&session_key);
*self.ntlm_cache.lock().await = Some(NtlmSessionCache {
host: host.to_string(),
session,
});
Ok(response)
}
async fn send_sealed_body(
&self,
session: &mut NtlmSession,
url: &str,
body: &str,
) -> Result<String, WinrmError> {
let (content_type, sealed) = ntlm_auth::seal_body(session, body);
let resp = self
.http
.post(url)
.header(CONTENT_TYPE, &content_type)
.body(sealed)
.send()
.await
.map_err(WinrmError::Http)?;
if resp.status().as_u16() == 401 {
return Err(WinrmError::AuthFailed("NTLM session expired".into()));
}
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
if status.as_u16() == 500
&& let Err(soap_err) = crate::soap::parser::check_soap_fault(&text)
{
return Err(WinrmError::Soap(soap_err));
}
return Err(WinrmError::AuthFailed(format!("HTTP {status}: {text}")));
}
let resp_bytes = resp.bytes().await.map_err(WinrmError::Http)?;
ntlm_auth::unseal_body(session, &resp_bytes)
}
pub(crate) async fn send_soap_with_retry(
&self,
host: &str,
body: String,
) -> Result<String, WinrmError> {
let max = self.config.max_retries;
for attempt in 0..=max {
match self.send_soap(host, body.clone()).await {
Ok(r) => return Ok(r),
Err(WinrmError::Http(_)) if attempt < max => {
let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt));
tracing::warn!(
attempt = attempt + 1,
max_retries = max,
delay_ms = delay.as_millis() as u64,
"retrying after transient HTTP error"
);
tokio::time::sleep(delay).await;
}
Err(e) => return Err(e),
}
}
unreachable!()
}
pub(crate) async fn send_soap_raw(
&self,
host: &str,
body: String,
) -> Result<String, WinrmError> {
self.send_soap_with_retry(host, body).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, ResponseTemplate};
fn basic_transport(port: u16) -> HttpTransport {
let config = WinrmConfig {
auth_method: AuthMethod::Basic,
port,
use_tls: false,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
HttpTransport::new(config, creds).unwrap()
}
#[test]
fn endpoint_http() {
let transport = basic_transport(5985);
assert_eq!(transport.endpoint("10.0.0.1"), "http://10.0.0.1:5985/wsman");
}
#[test]
fn endpoint_https() {
let config = WinrmConfig {
use_tls: true,
port: 5986,
accept_invalid_certs: true,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
let transport = HttpTransport::new(config, creds).unwrap();
assert_eq!(
transport.endpoint("win.local"),
"https://win.local:5986/wsman"
);
}
#[test]
fn config_accessor() {
let transport = basic_transport(5985);
assert_eq!(transport.config().port, 5985);
assert!(matches!(transport.config().auth_method, AuthMethod::Basic));
}
#[test]
fn credssp_without_tls_returns_error() {
let config = WinrmConfig {
auth_method: AuthMethod::CredSsp,
use_tls: false,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
match HttpTransport::new(config, creds) {
Err(WinrmError::AuthFailed(_)) => {} Err(e) => panic!("expected AuthFailed, got: {e}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn certificate_auth_without_cert_path_returns_error() {
let config = WinrmConfig {
auth_method: AuthMethod::Certificate,
client_cert_pem: None,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
match HttpTransport::new(config, creds) {
Err(WinrmError::AuthFailed(_)) => {}
Err(e) => panic!("expected AuthFailed, got: {e}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn certificate_auth_without_key_path_returns_error() {
let config = WinrmConfig {
auth_method: AuthMethod::Certificate,
client_cert_pem: Some("/tmp/cert.pem".into()),
client_key_pem: None,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
match HttpTransport::new(config, creds) {
Err(WinrmError::AuthFailed(_)) => {}
Err(e) => panic!("expected AuthFailed, got: {e}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[test]
fn certificate_auth_nonexistent_cert_file_returns_error() {
let config = WinrmConfig {
auth_method: AuthMethod::Certificate,
client_cert_pem: Some("/nonexistent/cert.pem".into()),
client_key_pem: Some("/nonexistent/key.pem".into()),
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
match HttpTransport::new(config, creds) {
Err(WinrmError::AuthFailed(_)) => {}
Err(e) => panic!("expected AuthFailed, got: {e}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[tokio::test]
async fn send_basic_success() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string("<ok/>"))
.mount(&server)
.await;
let addr = server.address();
let transport = basic_transport(addr.port());
let result = transport
.send_soap_with_retry(&addr.ip().to_string(), "<soap/>".into())
.await;
assert_eq!(result.unwrap(), "<ok/>");
}
#[tokio::test]
async fn send_basic_soap_fault_returns_soap_error() {
let server = MockServer::start().await;
let fault = r#"<s:Envelope xmlns:s="http://www.w3.org/2003/05/soap-envelope">
<s:Body><s:Fault><s:Code><s:Value>s:Receiver</s:Value></s:Code>
<s:Reason><s:Text>boom</s:Text></s:Reason></s:Fault></s:Body></s:Envelope>"#;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string(fault))
.mount(&server)
.await;
let addr = server.address();
let transport = basic_transport(addr.port());
let err = transport
.send_soap_with_retry(&addr.ip().to_string(), "<soap/>".into())
.await
.unwrap_err();
assert!(
matches!(err, WinrmError::Soap(_)),
"expected Soap error, got: {err}"
);
}
#[tokio::test]
async fn send_basic_auth_failed_returns_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(403).set_body_string("Forbidden"))
.mount(&server)
.await;
let addr = server.address();
let transport = basic_transport(addr.port());
let err = transport
.send_soap_with_retry(&addr.ip().to_string(), "<soap/>".into())
.await
.unwrap_err();
assert!(
matches!(err, WinrmError::AuthFailed(_)),
"expected AuthFailed, got: {err}"
);
}
#[tokio::test]
async fn retry_on_transient_http_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_string("<ok/>"))
.expect(1)
.mount(&server)
.await;
let addr = server.address();
let config = WinrmConfig {
auth_method: AuthMethod::Basic,
port: addr.port(),
use_tls: false,
max_retries: 3,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
let transport = HttpTransport::new(config, creds).unwrap();
let result = transport
.send_soap_with_retry(&addr.ip().to_string(), "<soap/>".into())
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn no_retry_on_auth_failure() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(403).set_body_string("Forbidden"))
.expect(1) .mount(&server)
.await;
let addr = server.address();
let config = WinrmConfig {
auth_method: AuthMethod::Basic,
port: addr.port(),
use_tls: false,
max_retries: 3,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
let transport = HttpTransport::new(config, creds).unwrap();
let _ = transport
.send_soap_with_retry(&addr.ip().to_string(), "<soap/>".into())
.await;
}
#[test]
fn proxy_config_applied() {
let config = WinrmConfig {
auth_method: AuthMethod::Basic,
use_tls: false,
proxy: Some("http://proxy:8080".into()),
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
assert!(HttpTransport::new(config, creds).is_ok());
}
#[test]
fn tls_with_accept_invalid_certs() {
let config = WinrmConfig {
use_tls: true,
accept_invalid_certs: true,
port: 5986,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
let transport = HttpTransport::new(config, creds).unwrap();
assert!(transport.cert_handle.is_some());
}
#[test]
fn tls_with_valid_certs() {
let config = WinrmConfig {
use_tls: true,
accept_invalid_certs: false,
port: 5986,
..Default::default()
};
let creds = WinrmCredentials::new("admin", "pass", "");
let transport = HttpTransport::new(config, creds).unwrap();
assert!(transport.cert_handle.is_some());
}
#[test]
fn no_tls_has_no_cert_handle() {
let transport = basic_transport(5985);
assert!(transport.cert_handle.is_none());
}
}