#![allow(dead_code)]
#[cfg(feature = "credssp")]
use std::sync::Arc;
#[cfg(feature = "credssp")]
use base64::Engine;
#[cfg(feature = "credssp")]
use base64::engine::general_purpose::STANDARD as B64;
#[cfg(feature = "credssp")]
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
#[cfg(feature = "credssp")]
use crate::asn1;
use crate::auth::AuthTransport;
#[cfg(feature = "credssp")]
use crate::error::CredSspError;
use crate::error::WinrmError;
#[cfg(feature = "credssp")]
use crate::ntlm;
#[cfg(feature = "credssp")]
use crate::ntlm::NtlmSession;
use crate::tls::CertHandle;
#[cfg(feature = "credssp")]
const CREDSSP_VERSION: u32 = 6;
#[cfg(feature = "credssp")]
fn extract_credssp_token(headers: &http::HeaderMap) -> Option<String> {
let all: String = headers
.get_all("WWW-Authenticate")
.iter()
.filter_map(|v| v.to_str().ok())
.collect::<Vec<_>>()
.join(", ");
let upper = all.to_uppercase();
let pos = upper.find("CREDSSP ")?;
let after = &all[pos + "CREDSSP ".len()..];
let token: String = after
.chars()
.take_while(|c| !c.is_whitespace() && *c != ',')
.collect();
if token.is_empty() { None } else { Some(token) }
}
#[cfg(feature = "credssp")]
fn advertises_credssp(headers: &http::HeaderMap) -> bool {
headers
.get_all("WWW-Authenticate")
.iter()
.filter_map(|v| v.to_str().ok())
.any(|v| v.to_uppercase().contains("CREDSSP"))
}
#[cfg(feature = "credssp")]
struct OpenSslMemTls {
ssl: openssl::ssl::SslStream<MemBio>,
}
#[cfg(feature = "credssp")]
struct MemBio {
incoming: std::collections::VecDeque<u8>,
outgoing: Vec<u8>,
}
#[cfg(feature = "credssp")]
impl std::io::Read for MemBio {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.incoming.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"no data",
));
}
let n = std::cmp::min(buf.len(), self.incoming.len());
for slot in buf.iter_mut().take(n) {
*slot = self.incoming.pop_front().unwrap();
}
Ok(n)
}
}
#[cfg(feature = "credssp")]
impl std::io::Write for MemBio {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.outgoing.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(feature = "credssp")]
impl OpenSslMemTls {
fn handshake_step(&mut self) -> Result<bool, WinrmError> {
use openssl::ssl::ErrorCode;
match self.ssl.do_handshake() {
Ok(()) => Ok(true),
Err(e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
Ok(false)
}
Err(e) => Err(WinrmError::AuthFailed(format!("inner TLS handshake: {e}"))),
}
}
fn drain_outgoing(&mut self) -> Vec<u8> {
std::mem::take(&mut self.ssl.get_mut().outgoing)
}
fn feed_incoming(&mut self, data: &[u8]) {
self.ssl.get_mut().incoming.extend(data.iter().copied());
}
fn write_plaintext(&mut self, data: &[u8]) -> Result<(), WinrmError> {
use std::io::Write;
self.ssl
.write_all(data)
.map_err(|e| WinrmError::AuthFailed(format!("inner TLS write: {e}")))?;
self.ssl
.flush()
.map_err(|e| WinrmError::AuthFailed(format!("inner TLS flush: {e}")))?;
Ok(())
}
fn read_plaintext(&mut self) -> Result<Vec<u8>, WinrmError> {
use openssl::ssl::ErrorCode;
let mut out = Vec::new();
let mut buf = [0u8; 4096];
loop {
match self.ssl.ssl_read(&mut buf) {
Ok(0) => break,
Ok(n) => out.extend_from_slice(&buf[..n]),
Err(e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
break;
}
Err(e) => {
return Err(WinrmError::AuthFailed(format!("inner TLS read: {e}")));
}
}
}
Ok(out)
}
fn peer_cert_der(&self) -> Result<Vec<u8>, WinrmError> {
let cert = self
.ssl
.ssl()
.peer_certificate()
.ok_or_else(|| WinrmError::AuthFailed("CredSSP: no peer cert".into()))?;
cert.to_der()
.map_err(|e| WinrmError::AuthFailed(format!("cert DER: {e}")))
}
}
#[cfg(feature = "credssp")]
fn build_inner_openssl_tls() -> Result<OpenSslMemTls, WinrmError> {
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
let mut builder = SslConnector::builder(SslMethod::tls_client())
.map_err(|e| WinrmError::AuthFailed(format!("SslConnector: {e}")))?;
builder.set_verify(SslVerifyMode::NONE);
builder
.set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2))
.map_err(|e| WinrmError::AuthFailed(format!("min ver: {e}")))?;
builder
.set_max_proto_version(Some(openssl::ssl::SslVersion::TLS1_2))
.map_err(|e| WinrmError::AuthFailed(format!("max ver: {e}")))?;
let connector = builder.build();
let mem_bio = MemBio {
incoming: std::collections::VecDeque::new(),
outgoing: Vec::new(),
};
let cfg = connector
.configure()
.map_err(|e| WinrmError::AuthFailed(format!("ssl configure: {e}")))?
.verify_hostname(false)
.use_server_name_indication(false);
let mut mhss = cfg
.into_ssl("credssp")
.map_err(|e| WinrmError::AuthFailed(format!("into_ssl: {e}")))?;
mhss.set_connect_state();
let ssl = openssl::ssl::SslStream::new(mhss, mem_bio)
.map_err(|e| WinrmError::AuthFailed(format!("SslStream::new: {e}")))?;
Ok(OpenSslMemTls { ssl })
}
#[cfg(feature = "credssp")]
struct CredSspConnection {
stream: tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
host: String,
port: u16,
path: String,
}
#[cfg(feature = "credssp")]
impl CredSspConnection {
async fn connect(host: &str, port: u16, path: &str) -> Result<Self, WinrmError> {
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
let _ = rustls::crypto::ring::default_provider().install_default();
let mut outer_config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(crate::tls::NoVerifier))
.with_no_client_auth();
if std::env::var("SSLKEYLOGFILE").is_ok() {
outer_config.key_log = Arc::new(rustls::KeyLogFile::new());
}
let connector = TlsConnector::from(Arc::new(outer_config));
let server_name = rustls::pki_types::ServerName::try_from(host.to_string())
.map_err(|_| WinrmError::AuthFailed(format!("CredSSP: invalid SNI: {host}")))?;
let tcp = TcpStream::connect((host, port))
.await
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP TCP connect: {e}")))?;
tcp.set_nodelay(true).ok();
let stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP outer TLS: {e}")))?;
Ok(Self {
stream,
host: host.to_string(),
port,
path: path.to_string(),
})
}
async fn post(
&mut self,
auth_header: Option<&str>,
body: &str,
) -> Result<(u16, std::collections::HashMap<String, String>, Vec<u8>), WinrmError> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut req = String::with_capacity(512 + body.len() + auth_header.map_or(0, |a| a.len()));
use std::fmt::Write as _;
write!(req, "POST {} HTTP/1.1\r\n", self.path).unwrap();
write!(req, "Host: {}:{}\r\n", self.host, self.port).unwrap();
write!(req, "User-Agent: Python WinRM client\r\n").unwrap();
write!(req, "Accept-Encoding: gzip, deflate, zstd\r\n").unwrap();
write!(req, "Accept: */*\r\n").unwrap();
write!(req, "Connection: Keep-Alive\r\n").unwrap();
if let Some(av) = auth_header {
write!(req, "Authorization: {av}\r\n").unwrap();
}
write!(req, "Content-Type: application/soap+xml;charset=UTF-8\r\n").unwrap();
write!(req, "Content-Length: {}\r\n", body.len()).unwrap();
req.push_str("\r\n");
tracing::trace!(target: "winrm_rs::credssp", head_len = req.len(), "POST {}", self.path);
let mut wire = Vec::with_capacity(req.len() + body.len());
wire.extend_from_slice(req.as_bytes());
wire.extend_from_slice(body.as_bytes());
self.stream
.write_all(&wire)
.await
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP write: {e}")))?;
self.stream
.flush()
.await
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP flush: {e}")))?;
let mut buf = Vec::with_capacity(8192);
let head_end = loop {
let mut chunk = [0u8; 4096];
let n = self
.stream
.read(&mut chunk)
.await
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP read: {e}")))?;
if n == 0 {
return Err(WinrmError::AuthFailed(
"CredSSP: server closed connection mid-response".into(),
));
}
buf.extend_from_slice(&chunk[..n]);
if let Some(pos) = find_double_crlf(&buf) {
break pos + 4;
}
if buf.len() > 1024 * 1024 {
return Err(WinrmError::AuthFailed(
"CredSSP: response head too large".into(),
));
}
};
let head_str = std::str::from_utf8(&buf[..head_end])
.map_err(|_| WinrmError::AuthFailed("CredSSP: non-UTF8 response head".into()))?;
let mut lines = head_str.split("\r\n");
let status_line = lines.next().unwrap_or("");
let status: u16 = status_line
.split_whitespace()
.nth(1)
.and_then(|s| s.parse().ok())
.ok_or_else(|| {
WinrmError::AuthFailed(format!("CredSSP: bad status line: {status_line}"))
})?;
let mut headers: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
for line in lines {
if line.is_empty() {
break;
}
if let Some((k, v)) = line.split_once(':') {
let key = k.trim().to_ascii_lowercase();
let val = v.trim().to_string();
headers
.entry(key)
.and_modify(|e| {
e.push_str(", ");
e.push_str(&val);
})
.or_insert(val);
}
}
let content_length: usize = headers
.get("content-length")
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let mut body_bytes: Vec<u8> = buf[head_end..].to_vec();
while body_bytes.len() < content_length {
let mut chunk = [0u8; 4096];
let n = self
.stream
.read(&mut chunk)
.await
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP read body: {e}")))?;
if n == 0 {
break;
}
body_bytes.extend_from_slice(&chunk[..n]);
}
body_bytes.truncate(content_length);
tracing::trace!(target: "winrm_rs::credssp", %status, body = body_bytes.len(), "credssp response");
Ok((status, headers, body_bytes))
}
}
#[cfg(feature = "credssp")]
fn find_double_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n")
}
#[cfg(feature = "credssp")]
fn header_get<'a>(
headers: &'a std::collections::HashMap<String, String>,
name: &str,
) -> Option<&'a str> {
headers.get(&name.to_ascii_lowercase()).map(|s| s.as_str())
}
#[cfg(feature = "credssp")]
fn extract_credssp_token_str(www_auth: &str) -> Option<String> {
let upper = www_auth.to_uppercase();
let pos = upper.find("CREDSSP ")?;
let after = &www_auth[pos + "CREDSSP ".len()..];
let token: String = after
.chars()
.take_while(|c| !c.is_whitespace() && *c != ',')
.collect();
if token.is_empty() { None } else { Some(token) }
}
#[cfg(feature = "credssp")]
fn parse_url(url: &str) -> Result<(String, u16, String), WinrmError> {
let after_scheme = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))
.ok_or_else(|| WinrmError::AuthFailed(format!("CredSSP: bad URL {url}")))?;
let (authority, path) = match after_scheme.find('/') {
Some(i) => (&after_scheme[..i], &after_scheme[i..]),
None => (after_scheme, "/"),
};
let (host, port) = match authority.rfind(':') {
Some(i) => {
let p: u16 = authority[i + 1..]
.parse()
.map_err(|_| WinrmError::AuthFailed(format!("CredSSP: bad port in {url}")))?;
(authority[..i].to_string(), p)
}
None => (authority.to_string(), 5986),
};
Ok((host, port, path.to_string()))
}
pub(crate) struct CredSspAuth {
pub(crate) username: String,
pub(crate) password: Zeroizing<String>,
pub(crate) domain: String,
pub(crate) cert_handle: Option<CertHandle>,
}
#[cfg(feature = "credssp")]
impl AuthTransport for CredSspAuth {
async fn send_authenticated(
&self,
_http: &reqwest::Client,
url: &str,
body: String,
) -> Result<String, WinrmError> {
let (host, port, path) = parse_url(url)?;
let mut conn = CredSspConnection::connect(&host, port, &path).await?;
let body_for_auth: &str = &body;
let encode_auth = |bytes: &[u8]| format!("CredSSP {}", B64.encode(bytes));
let _ = &header_get;
let mut inner_tls = build_inner_openssl_tls()?;
inner_tls.handshake_step()?;
let mut round = 0;
loop {
round += 1;
if round > 10 {
return Err(WinrmError::AuthFailed(
"CredSSP: TLS handshake did not complete in 10 rounds".into(),
));
}
let outgoing = inner_tls.drain_outgoing();
if outgoing.is_empty() {
if inner_tls.handshake_step()? {
break;
}
continue;
}
let (status, headers, _) = conn
.post(Some(&encode_auth(&outgoing)), body_for_auth)
.await?;
if status != 401 {
return Err(WinrmError::AuthFailed(format!(
"CredSSP: TLS handshake round {round}: expected 401, got {status}"
)));
}
let server_token = header_get(&headers, "www-authenticate")
.and_then(extract_credssp_token_str)
.ok_or_else(|| {
WinrmError::AuthFailed(format!(
"CredSSP: TLS handshake round {round}: no CredSSP token in response"
))
})?;
let server_bytes = B64.decode(server_token.trim_ascii()).map_err(|e| {
WinrmError::AuthFailed(format!("CredSSP: bad base64 in handshake: {e}"))
})?;
inner_tls.feed_incoming(&server_bytes);
if inner_tls.handshake_step()? {
break;
}
}
let inner_cert_der = inner_tls.peer_cert_der()?;
let inner_cert_der: &[u8] = &inner_cert_der;
let subject_public_key = {
use x509_cert::der::Decode;
let cert = x509_cert::Certificate::from_der(inner_cert_der)
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP cert parse: {e}")))?;
cert.tbs_certificate
.subject_public_key_info
.subject_public_key
.raw_bytes()
.to_vec()
};
let type1 = ntlm::create_negotiate_message_credssp();
let spnego_init = asn1::encode_spnego_init(&type1);
let ts_req1 =
asn1::encode_ts_request(CREDSSP_VERSION, Some(&spnego_init), None, None, None);
inner_tls.write_plaintext(&ts_req1)?;
let outgoing = Ok::<_, WinrmError>(inner_tls.drain_outgoing())?;
let (status, headers, _) = conn
.post(Some(&encode_auth(&outgoing)), body_for_auth)
.await?;
if status != 401 {
return Err(WinrmError::AuthFailed(format!(
"CredSSP: NTLM negotiate: expected 401, got {status}"
)));
}
let server_token = header_get(&headers, "www-authenticate")
.and_then(extract_credssp_token_str)
.ok_or_else(|| WinrmError::AuthFailed("CredSSP: NTLM nego: no CredSSP token".into()))?;
let server_bytes = B64
.decode(server_token.trim_ascii())
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP: bad b64 NTLM nego: {e}")))?;
{
inner_tls.feed_incoming(&server_bytes);
Ok::<_, WinrmError>(())
}?;
let plaintext = inner_tls.read_plaintext()?;
let ts_resp = asn1::decode_ts_request(&plaintext).map_err(WinrmError::CredSsp)?;
let negotiated_version = std::cmp::min(ts_resp.version, CREDSSP_VERSION);
eprintln!(
"[CREDSSP] server CredSSP version: {} (negotiated: {})",
ts_resp.version, negotiated_version
);
if let Some(code) = ts_resp.error_code {
return Err(WinrmError::CredSsp(CredSspError::ServerError(code)));
}
let spnego_resp = ts_resp
.nego_token
.ok_or_else(|| WinrmError::AuthFailed("CredSSP: no negoToken from server".into()))?;
let type2 = asn1::decode_spnego_token(&spnego_resp).map_err(WinrmError::CredSsp)?;
let challenge = ntlm::parse_challenge(&type2).map_err(WinrmError::Ntlm)?;
let domain = self.domain.clone();
let host = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))
.unwrap_or(url);
let host_part = host.split('/').next().unwrap_or(host);
let host_only = host_part.split(':').next().unwrap_or(host_part);
let spn = format!("HTTP/{host_only}");
let (type3, session_key) = ntlm::create_authenticate_message_credssp(
&challenge,
&self.username,
&self.password,
&domain,
&spn,
&type1,
&type2,
);
let mut ntlm_session = NtlmSession::from_auth(&session_key);
let mech_list_mic = ntlm_session.sign(asn1::MECH_TYPE_LIST_NTLM);
let nonce: [u8; 32] = std::env::var("CREDSSP_FIXED_NONCE")
.ok()
.and_then(|s| {
let bytes = (0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).ok())
.collect::<Option<Vec<_>>>()?;
bytes.try_into().ok()
})
.unwrap_or_else(rand::random);
let client_hash = {
let mut hasher = Sha256::new();
hasher.update(b"CredSSP Client-To-Server Binding Hash\0");
hasher.update(nonce);
hasher.update(&subject_public_key);
hasher.finalize().to_vec()
};
let encrypted_pub_key_auth = ntlm_session.seal(&client_hash);
let spnego_authenticate = asn1::encode_spnego_response(&type3, Some(&mech_list_mic));
let ts_req3 = asn1::encode_ts_request(
CREDSSP_VERSION,
Some(&spnego_authenticate),
Some(&encrypted_pub_key_auth),
None,
Some(&nonce),
);
if std::env::var("CREDSSP_DUMP").is_ok() {
let h = |b: &[u8]| b.iter().map(|x| format!("{:02x}", x)).collect::<String>();
eprintln!(
"[CREDSSP_DUMP] ts_req3 ({}B): {}",
ts_req3.len(),
h(&ts_req3)
);
eprintln!("[CREDSSP_DUMP] nonce: {}", h(&nonce));
eprintln!(
"[CREDSSP_DUMP] subject_public_key: {}",
h(&subject_public_key)
);
eprintln!("[CREDSSP_DUMP] client_hash: {}", h(&client_hash));
eprintln!(
"[CREDSSP_DUMP] sealed_pub_key_auth: {}",
h(&encrypted_pub_key_auth)
);
}
inner_tls.write_plaintext(&ts_req3)?;
let outgoing = Ok::<_, WinrmError>(inner_tls.drain_outgoing())?;
let (status, headers, _) = conn
.post(Some(&encode_auth(&outgoing)), body_for_auth)
.await?;
if status != 401 {
return Err(WinrmError::AuthFailed(format!(
"CredSSP: NTLM authenticate: expected 401, got {status}"
)));
}
let server_token = header_get(&headers, "www-authenticate")
.and_then(extract_credssp_token_str)
.ok_or_else(|| WinrmError::AuthFailed("CredSSP: NTLM auth: no CredSSP token".into()))?;
let server_bytes = B64
.decode(server_token.trim_ascii())
.map_err(|e| WinrmError::AuthFailed(format!("CredSSP: bad b64 auth: {e}")))?;
{
inner_tls.feed_incoming(&server_bytes);
Ok::<_, WinrmError>(())
}?;
let plaintext = inner_tls.read_plaintext()?;
let ts_resp = asn1::decode_ts_request(&plaintext).map_err(WinrmError::CredSsp)?;
if let Some(code) = ts_resp.error_code {
return Err(WinrmError::CredSsp(CredSspError::ServerError(code)));
}
let server_pub_key_auth = ts_resp
.pub_key_auth
.ok_or_else(|| WinrmError::AuthFailed("CredSSP: no pubKeyAuth from server".into()))?;
let decrypted_server_hash = ntlm_session
.unseal(&server_pub_key_auth)
.map_err(WinrmError::Ntlm)?;
let expected_server_hash = {
let mut hasher = Sha256::new();
hasher.update(b"CredSSP Server-To-Client Binding Hash\0");
hasher.update(nonce);
hasher.update(&subject_public_key);
hasher.finalize().to_vec()
};
if decrypted_server_hash != expected_server_hash {
return Err(WinrmError::CredSsp(CredSspError::PublicKeyMismatch));
}
let ts_creds = asn1::encode_ts_credentials(&domain, &self.username, &self.password);
let encrypted_creds = ntlm_session.seal(&ts_creds);
let ts_req5 =
asn1::encode_ts_request(CREDSSP_VERSION, None, None, Some(&encrypted_creds), None);
inner_tls.write_plaintext(&ts_req5)?;
let outgoing = Ok::<_, WinrmError>(inner_tls.drain_outgoing())?;
let (status, _headers, resp_body) = conn.post(Some(&encode_auth(&outgoing)), &body).await?;
if status == 401 {
return Err(WinrmError::AuthFailed(
"CredSSP: credentials rejected after delegation".into(),
));
}
if !(200..300).contains(&status) {
let body_text = String::from_utf8_lossy(&resp_body).into_owned();
return Err(WinrmError::AuthFailed(format!(
"CredSSP HTTP {status}: {body_text}"
)));
}
Ok(String::from_utf8_lossy(&resp_body).into_owned())
}
}
#[cfg(not(feature = "credssp"))]
impl AuthTransport for CredSspAuth {
async fn send_authenticated(
&self,
_http: &reqwest::Client,
_url: &str,
_body: String,
) -> Result<String, WinrmError> {
Err(WinrmError::AuthFailed(
"CredSSP authentication requires the 'credssp' feature. \
Enable it with: cargo add winrm-rs --features credssp"
.into(),
))
}
}
#[cfg(all(test, feature = "credssp"))]
mod tests {
use super::*;
use reqwest::header::{HeaderMap, HeaderValue};
fn make_headers(values: &[&str]) -> HeaderMap {
let mut h = HeaderMap::new();
for v in values {
h.append("WWW-Authenticate", HeaderValue::from_str(v).unwrap());
}
h
}
#[test]
fn extract_credssp_token_simple() {
let h = make_headers(&["CredSSP YWJjZGVm"]);
assert_eq!(extract_credssp_token(&h), Some("YWJjZGVm".to_string()));
}
#[test]
fn extract_credssp_token_case_insensitive() {
let h = make_headers(&["credssp YWJjZGVm"]);
assert_eq!(extract_credssp_token(&h), Some("YWJjZGVm".to_string()));
let h = make_headers(&["CREDSSP YWJjZGVm"]);
assert_eq!(extract_credssp_token(&h), Some("YWJjZGVm".to_string()));
}
#[test]
fn extract_credssp_token_multiple_schemes_one_header() {
let h = make_headers(&["Negotiate, CredSSP TOKEN, Basic"]);
assert_eq!(extract_credssp_token(&h), Some("TOKEN".to_string()));
}
#[test]
fn extract_credssp_token_multiple_headers() {
let h = make_headers(&["Negotiate", "Basic realm=\"WSMAN\"", "CredSSP MYTOKEN"]);
assert_eq!(extract_credssp_token(&h), Some("MYTOKEN".to_string()));
}
#[test]
fn extract_credssp_token_no_credssp() {
let h = make_headers(&["Negotiate", "Basic realm=\"WSMAN\""]);
assert_eq!(extract_credssp_token(&h), None);
}
#[test]
fn extract_credssp_token_advertised_no_token() {
let h = make_headers(&["CredSSP"]);
assert_eq!(extract_credssp_token(&h), None);
}
#[test]
fn extract_credssp_token_empty_headers() {
let h = HeaderMap::new();
assert_eq!(extract_credssp_token(&h), None);
}
#[test]
fn advertises_credssp_detects_scheme() {
let h = make_headers(&["Negotiate", "Basic", "CredSSP"]);
assert!(advertises_credssp(&h));
}
#[test]
fn advertises_credssp_case_insensitive() {
let h = make_headers(&["credssp"]);
assert!(advertises_credssp(&h));
}
#[test]
fn advertises_credssp_returns_false_when_absent() {
let h = make_headers(&["Negotiate", "Basic"]);
assert!(!advertises_credssp(&h));
}
#[test]
fn advertises_credssp_works_with_token() {
let h = make_headers(&["CredSSP YWJjZGVm"]);
assert!(advertises_credssp(&h));
}
#[test]
fn parse_url_https_with_port_and_path() {
let (h, p, path) = parse_url("https://host.example:5986/wsman").unwrap();
assert_eq!(h, "host.example");
assert_eq!(p, 5986);
assert_eq!(path, "/wsman");
}
#[test]
fn parse_url_http_scheme_also_accepted() {
let (h, p, path) = parse_url("http://h:80/x").unwrap();
assert_eq!((h.as_str(), p, path.as_str()), ("h", 80, "/x"));
}
#[test]
fn parse_url_defaults_port_5986_when_missing() {
let (h, p, path) = parse_url("https://host/wsman").unwrap();
assert_eq!(h, "host");
assert_eq!(p, 5986);
assert_eq!(path, "/wsman");
}
#[test]
fn parse_url_defaults_path_when_missing() {
let (h, p, path) = parse_url("https://host:5986").unwrap();
assert_eq!((h.as_str(), p, path.as_str()), ("host", 5986, "/"));
}
#[test]
fn parse_url_rejects_bad_scheme() {
assert!(parse_url("ftp://host/path").is_err());
}
#[test]
fn parse_url_rejects_bad_port() {
assert!(parse_url("https://host:notaport/x").is_err());
}
#[test]
fn extract_credssp_token_str_simple() {
assert_eq!(
extract_credssp_token_str("CredSSP YWJjZGVm"),
Some("YWJjZGVm".to_string())
);
}
#[test]
fn extract_credssp_token_str_case_insensitive_and_mixed() {
assert_eq!(
extract_credssp_token_str("Negotiate, credssp ABCDEF, Basic"),
Some("ABCDEF".to_string())
);
}
#[test]
fn extract_credssp_token_str_returns_none_when_no_token() {
assert_eq!(extract_credssp_token_str("Negotiate, Basic"), None);
assert_eq!(extract_credssp_token_str("CredSSP "), None);
}
#[test]
fn find_double_crlf_locates_separator() {
let buf = b"GET / HTTP/1.1\r\nHost: x\r\n\r\nbody";
let pos = find_double_crlf(buf).unwrap();
assert_eq!(&buf[pos..pos + 4], b"\r\n\r\n");
}
#[test]
fn find_double_crlf_returns_none_when_absent() {
assert!(find_double_crlf(b"no separator here").is_none());
assert!(find_double_crlf(b"a\r\nb").is_none());
}
#[test]
fn header_get_is_case_insensitive() {
let mut h = std::collections::HashMap::new();
h.insert("content-length".to_string(), "42".to_string());
assert_eq!(header_get(&h, "Content-Length"), Some("42"));
assert_eq!(header_get(&h, "CONTENT-LENGTH"), Some("42"));
assert_eq!(header_get(&h, "missing"), None);
}
#[test]
fn membio_write_appends_to_outgoing_and_read_drains_incoming() {
use std::io::{Read, Write};
let mut bio = MemBio {
incoming: std::collections::VecDeque::new(),
outgoing: Vec::new(),
};
assert_eq!(bio.write(b"hello").unwrap(), 5);
bio.flush().unwrap();
assert_eq!(bio.outgoing, b"hello");
let mut buf = [0u8; 4];
let err = bio.read(&mut buf).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::WouldBlock);
bio.incoming.extend(b"abcdef".iter().copied());
let n = bio.read(&mut buf).unwrap();
assert_eq!(n, 4);
assert_eq!(&buf, b"abcd");
assert_eq!(bio.incoming.len(), 2);
}
#[test]
fn openssl_inner_tls_peer_cert_before_handshake_errors() {
let tls = build_inner_openssl_tls().expect("build inner tls");
assert!(tls.peer_cert_der().is_err());
}
#[test]
fn openssl_inner_tls_feed_incoming_then_drain_clears_outgoing() {
let mut tls = build_inner_openssl_tls().expect("build inner tls");
let _ = tls.handshake_step();
let first = tls.drain_outgoing();
assert!(!first.is_empty());
assert!(tls.drain_outgoing().is_empty());
tls.feed_incoming(&[0u8; 8]);
}
#[test]
fn openssl_inner_tls_handshake_starts() {
let mut tls = build_inner_openssl_tls().expect("build inner tls");
let _ = tls.handshake_step();
let bytes = tls.drain_outgoing();
assert!(!bytes.is_empty(), "should produce ClientHello bytes");
assert_eq!(bytes[0], 0x16);
}
}