use crate::error::WSError;
#[cfg(not(target_os = "wasi"))]
use crate::signature::keyless::cert_pinning::{create_pinned_rustls_config, PinningConfig};
#[cfg(not(target_os = "wasi"))]
use rustls::{ClientConfig, ClientConnection, StreamOwned};
#[cfg(not(target_os = "wasi"))]
use rustls_pki_types::ServerName;
#[cfg(not(target_os = "wasi"))]
use std::convert::TryInto;
#[cfg(not(target_os = "wasi"))]
use std::fmt;
#[cfg(not(target_os = "wasi"))]
use std::sync::Arc;
#[cfg(not(target_os = "wasi"))]
use ureq::unversioned::transport::{
Buffers, ConnectionDetails, Connector, Either, LazyBuffers, NextTimeout, TcpConnector,
Transport, TransportAdapter,
};
#[cfg(not(target_os = "wasi"))]
pub struct PinnedRustlsConnector {
config: Arc<ClientConfig>,
}
#[cfg(not(target_os = "wasi"))]
impl PinnedRustlsConnector {
pub fn new(pinning: PinningConfig) -> Result<Self, WSError> {
let config = create_pinned_rustls_config(pinning)?;
log::info!("Created PinnedRustlsConnector with certificate pinning");
Ok(Self { config })
}
}
#[cfg(not(target_os = "wasi"))]
impl fmt::Debug for PinnedRustlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PinnedRustlsConnector")
.field("config", &"ClientConfig with PinnedCertVerifier")
.finish()
}
}
#[cfg(not(target_os = "wasi"))]
pub struct PinnedRustlsTransport {
buffers: LazyBuffers,
stream: StreamOwned<ClientConnection, TransportAdapter<Box<dyn Transport>>>,
}
#[cfg(not(target_os = "wasi"))]
impl PinnedRustlsTransport {
pub fn new(
buffers: LazyBuffers,
stream: StreamOwned<ClientConnection, TransportAdapter<Box<dyn Transport>>>,
) -> Self {
Self { buffers, stream }
}
}
#[cfg(not(target_os = "wasi"))]
impl fmt::Debug for PinnedRustlsTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PinnedRustlsTransport").finish()
}
}
#[cfg(not(target_os = "wasi"))]
impl<In: Transport> Connector<In> for PinnedRustlsConnector {
type Out = Either<In, PinnedRustlsTransport>;
fn connect(
&self,
details: &ConnectionDetails,
chained: Option<In>,
) -> Result<Option<Self::Out>, ureq::Error> {
let Some(transport) = chained else {
panic!("PinnedRustlsConnector requires a chained transport");
};
if !details.needs_tls() || transport.is_tls() {
log::trace!("PinnedRustlsConnector: Skip (not HTTPS or already TLS)");
return Ok(Some(Either::A(transport)));
}
log::trace!("PinnedRustlsConnector: Wrapping connection in pinned TLS");
let name_borrowed: ServerName<'_> = details
.uri
.authority()
.expect("uri authority for tls")
.host()
.try_into()
.map_err(|e| {
log::debug!("PinnedRustlsConnector: invalid dns name: {}", e);
ureq::Error::Tls("Invalid DNS name for TLS")
})?;
let name = name_borrowed.to_owned();
let conn = ClientConnection::new(self.config.clone(), name)?;
let stream = StreamOwned {
conn,
sock: TransportAdapter::new(transport.boxed()),
};
let buffers = LazyBuffers::new(
details.config.input_buffer_size(),
details.config.output_buffer_size(),
);
let transport = PinnedRustlsTransport { buffers, stream };
log::debug!("PinnedRustlsConnector: Wrapped TLS with certificate pinning");
Ok(Some(Either::B(transport)))
}
}
#[cfg(not(target_os = "wasi"))]
impl Transport for PinnedRustlsTransport {
fn buffers(&mut self) -> &mut dyn Buffers {
&mut self.buffers
}
fn transmit_output(&mut self, amount: usize, timeout: NextTimeout) -> Result<(), ureq::Error> {
use std::io::Write;
self.stream.sock.set_timeout(timeout);
let output = self.buffers.output();
self.stream.write_all(&output[..amount])?;
self.stream.flush()?;
Ok(())
}
fn await_input(&mut self, timeout: NextTimeout) -> Result<bool, ureq::Error> {
use std::io::Read;
self.stream.sock.set_timeout(timeout);
let input = self.buffers.input_append_buf();
let amount = self.stream.read(input)?;
self.buffers.input_appended(amount);
Ok(amount > 0)
}
fn is_open(&mut self) -> bool {
!self.stream.conn.is_handshaking()
}
fn is_tls(&self) -> bool {
true
}
}
#[cfg(not(target_os = "wasi"))]
pub fn create_pinned_agent(pinning: PinningConfig) -> Result<ureq::Agent, WSError> {
use ureq::unversioned::resolver::DefaultResolver;
let connector = ()
.chain(TcpConnector::default())
.chain(PinnedRustlsConnector::new(pinning)?);
let config = ureq::config::Config::builder()
.http_status_as_error(false)
.build();
let resolver = DefaultResolver::default();
let agent = ureq::Agent::with_parts(config, connector, resolver);
log::info!("Created HTTP agent with certificate pinning enabled");
Ok(agent)
}
#[cfg(not(target_os = "wasi"))]
pub fn create_standard_agent() -> ureq::Agent {
ureq::Agent::config_builder()
.http_status_as_error(false)
.build()
.into()
}
#[cfg(not(target_os = "wasi"))]
pub fn create_agent_with_optional_pinning(
pinning: Option<PinningConfig>,
) -> Result<ureq::Agent, WSError> {
match pinning {
Some(config) if config.is_enabled() => create_pinned_agent(config),
_ => {
if std::env::var("WSC_REQUIRE_CERT_PINNING").unwrap_or_default() == "1" {
return Err(WSError::CertificatePinningError(
"Certificate pinning required (WSC_REQUIRE_CERT_PINNING=1) but no pins configured".to_string(),
));
}
log::debug!("Certificate pinning disabled, using standard TLS");
Ok(create_standard_agent())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_standard_agent() {
let agent = create_standard_agent();
assert!(format!("{:?}", agent).contains("Agent"));
}
#[test]
fn test_create_pinned_agent() {
let pins = vec!["a".repeat(64), "b".repeat(64)];
let config = PinningConfig::custom(pins, "test-service".to_string());
let result = create_pinned_agent(config);
assert!(result.is_ok());
}
#[test]
fn test_create_agent_with_optional_pinning_none() {
let result = create_agent_with_optional_pinning(None);
assert!(result.is_ok());
}
#[test]
fn test_create_agent_with_optional_pinning_some() {
let pins = vec!["a".repeat(64)];
let config = PinningConfig::custom(pins, "test".to_string());
let result = create_agent_with_optional_pinning(Some(config));
assert!(result.is_ok());
}
#[test]
fn test_create_agent_with_empty_pinning_config() {
let config = PinningConfig::custom(vec![], "test".to_string());
let result = create_agent_with_optional_pinning(Some(config));
assert!(result.is_ok());
}
#[test]
fn test_pinned_connector_creation() {
let pins = vec!["a".repeat(64)];
let config = PinningConfig::custom(pins, "test".to_string());
let connector = PinnedRustlsConnector::new(config);
assert!(connector.is_ok());
let connector = connector.unwrap();
assert!(format!("{:?}", connector).contains("PinnedRustlsConnector"));
}
}