use std::fmt;
use std::sync::{Arc, Mutex};
use rustls::client::{ClientSessionStore, Resumption};
use rustls::client::{Tls12ClientSessionValue, Tls13ClientSessionValue};
use rustls::NamedGroup;
use rustls_pki_types::ServerName;
pub struct TlsResponse {
pub response: reqwest::Response,
pub group: Option<String>,
pub cipher: Option<String>,
}
#[derive(Default, Debug)]
struct Captured {
group: Option<String>,
cipher: Option<String>,
}
struct CapturingSessionStore {
active: Arc<Mutex<Option<Arc<Mutex<Captured>>>>>,
}
impl fmt::Debug for CapturingSessionStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CapturingSessionStore")
}
}
impl ClientSessionStore for CapturingSessionStore {
fn set_kx_hint(&self, _server_name: ServerName<'static>, group: NamedGroup) {
let slot = self.active.lock().unwrap().clone();
if let Some(captured) = slot {
captured.lock().unwrap().group = Some(format!("{:?}", group));
}
}
fn kx_hint(&self, _server_name: &ServerName<'_>) -> Option<NamedGroup> {
None
}
fn set_tls12_session(&self, _server_name: ServerName<'static>, _value: Tls12ClientSessionValue) {}
fn tls12_session(&self, _server_name: &ServerName<'_>) -> Option<Tls12ClientSessionValue> {
None
}
fn remove_tls12_session(&self, _server_name: &ServerName<'static>) {}
fn insert_tls13_ticket(&self, _server_name: ServerName<'static>, value: Tls13ClientSessionValue) {
let slot = self.active.lock().unwrap().clone();
if let Some(captured) = slot {
let mut c = captured.lock().unwrap();
if c.cipher.is_none() {
c.cipher = Some(format!("{:?}", value.suite().common.suite));
}
}
}
fn take_tls13_ticket(&self, _server_name: &ServerName<'static>) -> Option<Tls13ClientSessionValue> {
None
}
}
pub struct TlsAwareClient {
client: reqwest::Client,
active_capture: Arc<Mutex<Option<Arc<Mutex<Captured>>>>>,
}
impl TlsAwareClient {
pub fn new() -> Self {
let active_capture: Arc<Mutex<Option<Arc<Mutex<Captured>>>>> =
Arc::new(Mutex::new(None));
let session_store = Arc::new(CapturingSessionStore {
active: active_capture.clone(),
});
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
tls_config.resumption = Resumption::store(session_store);
let client = reqwest::Client::builder()
.use_preconfigured_tls(tls_config)
.build()
.expect("failed to build reqwest client");
Self { client, active_capture }
}
pub async fn execute(&self, request: reqwest::Request) -> Result<TlsResponse, reqwest::Error> {
let captured = Arc::new(Mutex::new(Captured::default()));
{
let mut active = self.active_capture.lock().unwrap();
*active = Some(captured.clone());
}
let response = self.client.execute(request).await?;
{
let mut active = self.active_capture.lock().unwrap();
*active = None;
}
let state = captured.lock().unwrap();
Ok(TlsResponse {
response,
group: state.group.clone(),
cipher: state.cipher.clone(),
})
}
}