use s2n_tls_sys::*;
use crate::{
callbacks::with_context,
config,
connection::Connection,
enums::CallbackResult,
error::{Error, ErrorType, Fallible, Pollable},
};
use std::task::Poll::{self, Pending, Ready};
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum RenegotiateResponse {
Ignore,
Reject,
Accept,
Schedule,
}
impl From<RenegotiateResponse> for s2n_renegotiate_response::Type {
fn from(input: RenegotiateResponse) -> s2n_renegotiate_response::Type {
match input {
RenegotiateResponse::Ignore => s2n_renegotiate_response::RENEGOTIATE_IGNORE,
RenegotiateResponse::Reject => s2n_renegotiate_response::RENEGOTIATE_REJECT,
RenegotiateResponse::Accept => s2n_renegotiate_response::RENEGOTIATE_ACCEPT,
RenegotiateResponse::Schedule => s2n_renegotiate_response::RENEGOTIATE_ACCEPT,
}
}
}
pub trait RenegotiateCallback: 'static + Send + Sync {
fn on_renegotiate_request(&self, connection: &mut Connection) -> Option<RenegotiateResponse>;
fn on_renegotiate_wipe(&self, _connection: &mut Connection) -> Result<(), Error> {
Ok(())
}
}
impl RenegotiateCallback for RenegotiateResponse {
fn on_renegotiate_request(&self, _conn: &mut Connection) -> Option<RenegotiateResponse> {
Some(*self)
}
}
#[derive(Clone, Debug, Default, PartialEq)]
pub(crate) struct RenegotiateState {
needs_handshake: bool,
needs_wipe: bool,
send_pending: bool,
}
impl Connection {
fn schedule_renegotiate(&mut self) {
let state = self.renegotiate_state_mut();
if !state.needs_handshake {
state.needs_handshake = true;
state.needs_wipe = true;
}
}
fn is_renegotiating(&self) -> bool {
self.renegotiate_state().needs_handshake
}
pub fn wipe_for_renegotiate(&mut self) -> Result<(), Error> {
if self.renegotiate_state().send_pending {
return Err(Error::bindings(
ErrorType::UsageError,
"RenegotiateError",
"Unexpected buffered send data during renegotiate",
));
}
let renegotiate_state = self.renegotiate_state().clone();
let waker = self.waker().cloned();
let server_name = self.server_name().map(|name| name.to_owned());
self.wipe_method(|conn| unsafe { s2n_renegotiate_wipe(conn.as_ptr()).into_result() })?;
*self.renegotiate_state_mut() = renegotiate_state;
self.set_waker(waker.as_ref())?;
if let Some(server_name) = server_name {
self.set_server_name(&server_name)?;
}
if let Some(config) = self.config() {
if let Some(callback) = config.context().renegotiate.as_ref() {
callback.on_renegotiate_wipe(self)?;
}
}
self.renegotiate_state_mut().needs_wipe = false;
Ok(())
}
fn poll_renegotiate_raw(
&mut self,
buf_ptr: *mut libc::c_void,
buf_len: isize,
) -> (Poll<Result<(), Error>>, usize) {
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
let mut read: isize = 0;
let r = self.poll_negotiate_method(|conn| {
unsafe {
s2n_renegotiate(
conn.as_ptr(),
buf_ptr as *mut u8,
buf_len,
&mut read,
&mut blocked,
)
}
.into_poll()
});
if let Ready(Ok(())) = r {
self.renegotiate_state_mut().needs_handshake = false;
}
(r, read.try_into().unwrap())
}
pub fn poll_renegotiate(&mut self, buf: &mut [u8]) -> (Poll<Result<(), Error>>, usize) {
let buf_len: isize = buf.len().try_into().unwrap_or(0);
let buf_ptr = buf.as_ptr() as *mut ::libc::c_void;
self.poll_renegotiate_raw(buf_ptr, buf_len)
}
pub fn poll_send(&mut self, buf: &[u8]) -> Poll<Result<usize, Error>> {
if self.is_renegotiating() {
return Ready(Err(Error::bindings(
ErrorType::Blocked,
"RenegotiateError",
"Cannot send application data while renegotiating",
)));
}
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
let buf_len: isize = buf.len().try_into().map_err(|_| Error::INVALID_INPUT)?;
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let result = unsafe { s2n_send(self.as_ptr(), buf_ptr, buf_len, &mut blocked) }.into_poll();
self.renegotiate_state_mut().send_pending = result.is_pending();
result
}
pub(crate) fn poll_recv_raw(
&mut self,
buf_ptr: *mut libc::c_void,
buf_len: isize,
) -> Poll<Result<usize, Error>> {
if !self.is_renegotiating() {
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
let result =
unsafe { s2n_recv(self.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() };
return if self.is_renegotiating() && result.is_pending() {
self.poll_recv_raw(buf_ptr, buf_len)
} else {
result
};
}
if self.peek_len() > 0 {
let buf_len = std::cmp::min(self.peek_len() as isize, buf_len);
let mut blocked = s2n_blocked_status::NOT_BLOCKED;
return unsafe { s2n_recv(self.as_ptr(), buf_ptr, buf_len, &mut blocked).into_poll() };
}
if self.renegotiate_state().needs_wipe {
self.wipe_for_renegotiate()?;
}
match self.poll_renegotiate_raw(buf_ptr, buf_len) {
(Ready(Err(err)), _) => Ready(Err(err)),
(Ready(Ok(())), 0) => self.poll_recv_raw(buf_ptr, buf_len),
(Pending, 0) => Pending,
(_, bytes) => Ready(Ok(bytes)),
}
}
}
impl config::Builder {
pub fn set_renegotiate_callback<T: 'static + RenegotiateCallback>(
&mut self,
handler: T,
) -> Result<&mut Self, Error> {
unsafe extern "C" fn renegotiate_cb(
conn_ptr: *mut s2n_connection,
_context: *mut libc::c_void,
response: *mut s2n_renegotiate_response::Type,
) -> libc::c_int {
with_context(conn_ptr, |conn, context| {
let callback = context.renegotiate.as_ref();
if let Some(callback) = callback {
if let Some(result) = callback.on_renegotiate_request(conn) {
if result == RenegotiateResponse::Schedule {
conn.schedule_renegotiate();
}
*response = result.into();
return CallbackResult::Success.into();
}
}
CallbackResult::Failure.into()
})
}
let handler = Box::new(handler);
let context = unsafe {
self.config.context_mut()
};
context.renegotiate = Some(handler);
unsafe {
s2n_config_set_renegotiate_request_cb(
self.as_mut_ptr(),
Some(renegotiate_cb),
std::ptr::null_mut(),
)
.into_result()?;
}
Ok(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
callbacks::{
ConnectionFuture, ConnectionFutureResult, PrivateKeyCallback, PrivateKeyOperation,
},
config::ConnectionInitializer,
error::{ErrorSource, ErrorType},
testing::{CertKeyPair, InsecureAcceptAllCertificatesHandler, TestPair, TestPairIO},
};
use foreign_types::ForeignTypeRef;
use futures_test::task::new_count_waker;
use openssl::ssl::{
ErrorCode, Ssl, SslContext, SslFiletype, SslMethod, SslStream, SslVerifyMode, SslVersion,
};
use std::{
error::Error,
io::{Read, Write},
pin::Pin,
task::Poll::{Pending, Ready},
};
const RENEG_ERR_MARKER: &str = "renegotiat";
extern "C" {
fn SSL_renegotiate(s: *mut openssl_sys::SSL) -> libc::size_t;
fn SSL_renegotiate_pending(s: *mut openssl_sys::SSL) -> libc::size_t;
fn SSL_in_init(s: *mut openssl_sys::SSL) -> libc::size_t;
}
fn unwrap_poll<T>(
poll: Poll<Result<T, crate::error::Error>>,
) -> Result<T, crate::error::Error> {
if let Ready(value) = poll {
return value;
}
panic!("Poll not Ready");
}
#[derive(Debug)]
struct ServerTestStream(TestPairIO);
impl Read for ServerTestStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let result = self.0.client_tx_stream.borrow_mut().read(buf);
if let Ok(0) = result {
Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"blocking",
))
} else {
result
}
}
}
impl Write for ServerTestStream {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.0.server_tx_stream.borrow_mut().write(buf)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
self.0.server_tx_stream.borrow_mut().flush()
}
}
struct RenegotiateTestPair {
client: Connection,
server: SslStream<ServerTestStream>,
}
impl RenegotiateTestPair {
fn from(mut builder: config::Builder) -> Result<Self, Box<dyn Error>> {
let certs = CertKeyPair::from_path(
"permutations/rsae_pkcs_4096_sha384/",
"server-chain",
"server-key",
"ca-cert",
);
builder.load_pem(certs.cert(), certs.key())?;
builder.trust_pem(certs.cert())?;
builder.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?;
let config = builder.build()?;
let s2n_pair = TestPair::from_config(&config);
let client = s2n_pair.client;
let mut ctx_builder = SslContext::builder(SslMethod::tls_server())?;
ctx_builder.set_max_proto_version(Some(SslVersion::TLS1_2))?;
ctx_builder.set_min_proto_version(Some(SslVersion::TLS1_2))?;
ctx_builder.set_certificate_chain_file(certs.cert_path())?;
ctx_builder.set_private_key_file(certs.key_path(), SslFiletype::PEM)?;
ctx_builder.set_ca_file(certs.ca_path())?;
ctx_builder.set_verify(SslVerifyMode::PEER);
let openssl_ctx = ctx_builder.build();
let openssl_ssl = Ssl::new(&openssl_ctx)?;
let server_stream = ServerTestStream(s2n_pair.io);
let server = SslStream::new(openssl_ssl, server_stream)?;
Ok(Self { client, server })
}
fn poll_openssl_negotiate(
server: &mut SslStream<ServerTestStream>,
) -> Poll<Result<(), Box<dyn Error>>> {
match server.accept() {
Ok(_) => Ready(Ok(())),
Err(err) if err.code() == ErrorCode::WANT_READ => Pending,
Err(err) => Ready(Err(err.into())),
}
}
fn handshake(&mut self) -> Result<(), Box<dyn Error>> {
loop {
match (
self.client.poll_negotiate(),
Self::poll_openssl_negotiate(&mut self.server),
) {
(Poll::Ready(Ok(_)), Poll::Ready(Ok(_))) => return Ok(()),
(_, Poll::Ready(Err(e))) => return Err(e),
(Poll::Ready(Err(e)), _) => return Err(Box::new(e)),
_ => continue,
}
}
}
fn send_renegotiate_request(&mut self) -> Result<(), crate::error::Error> {
let openssl_ptr = self.server.ssl().as_ptr();
let requested = unsafe { SSL_renegotiate_pending(openssl_ptr) };
assert_eq!(requested, 0, "Renegotiation should not be pending");
unsafe { SSL_renegotiate(openssl_ptr) };
let requested = unsafe { SSL_renegotiate_pending(openssl_ptr) };
assert_eq!(requested, 1, "Renegotiation should be pending");
assert_eq!(
self.server
.write(&[0; 0])
.expect("Failed to write hello request"),
0
);
Ok(())
}
fn send_and_receive(&mut self) -> Result<(), Box<dyn Error>> {
let to_send = [0; 1];
let mut recv_buffer = [0; 1];
self.server.write_all(&to_send)?;
unwrap_poll(self.client.poll_recv(&mut recv_buffer))?;
unwrap_poll(self.client.poll_send(&to_send))?;
self.server.read_exact(&mut recv_buffer)?;
Ok(())
}
fn openssl_is_handshaking(&self) -> bool {
(unsafe { SSL_in_init(self.server.ssl().as_ptr()) } == 1)
}
fn assert_renegotiate(&mut self) -> Result<(), Box<dyn Error>> {
const APP_DATA: &[u8] = b"Renegotiation complete";
let mut buffer = [0; APP_DATA.len()];
for _ in 0..20 {
let client_read_poll = self.client.poll_recv(&mut buffer);
match client_read_poll {
Pending => {
assert!(self.client.is_renegotiating(), "s2n-tls not renegotiating");
}
Ready(Ok(bytes_read)) => {
assert_eq!(bytes_read, APP_DATA.len());
assert_eq!(&buffer, APP_DATA);
break;
}
Ready(err) => err.map(|_| ())?,
};
if !self.openssl_is_handshaking() {
let _ = self.server.read(&mut [0; 0]);
} else {
let server_write_result = self.server.write(APP_DATA);
println!(
"openssl result: {:?}, state: {:?}",
server_write_result,
self.server.ssl().state_string_long()
);
match server_write_result {
Ok(bytes_written) => assert_eq!(bytes_written, APP_DATA.len()),
Err(_) => {
assert!(self.openssl_is_handshaking(), "openssl not renegotiating");
}
}
}
}
assert!(
!self.client.is_renegotiating(),
"s2n-tls renegotiation not complete"
);
assert!(
!self.openssl_is_handshaking(),
"openssl renegotiation not complete"
);
Ok(())
}
}
#[test]
fn ignore_callback() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Ignore)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("Server sends request");
pair.send_and_receive().expect("Application data");
assert!(!pair.client.is_renegotiating(), "Unexpected renegotiation");
Ok(())
}
#[test]
fn accept_callback() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Accept)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("Server sends request");
pair.send_and_receive().expect("Application data");
assert!(!pair.client.is_renegotiating(), "Unexpected renegotiation");
Ok(())
}
#[test]
fn error_callback() -> Result<(), Box<dyn Error>> {
struct ErrorRenegotiateCallback {}
impl RenegotiateCallback for ErrorRenegotiateCallback {
fn on_renegotiate_request(&self, _: &mut Connection) -> Option<RenegotiateResponse> {
None
}
}
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(ErrorRenegotiateCallback {})?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("Server sends request");
let error = unwrap_poll(pair.client.poll_recv(&mut [0; 1])).unwrap_err();
assert_eq!(error.name(), "S2N_ERR_CANCELLED");
Ok(())
}
#[test]
fn reject_callback() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Reject)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("Server sends request");
let openssl_error = pair.send_and_receive().unwrap_err();
assert!(openssl_error.to_string().contains("no renegotiation"));
Ok(())
}
#[test]
fn scheduled_renegotiate_basic() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_and_receive()
.expect("Application data before renegotiate");
pair.send_renegotiate_request()
.expect("Server sends request");
pair.assert_renegotiate().expect("Renegotiate");
pair.send_and_receive()
.expect("Application data after renegotiate");
Ok(())
}
#[test]
fn scheduled_renegotiate_repeatedly() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
for _ in 0..10 {
pair.send_and_receive()
.expect("Application data before renegotiate");
pair.send_renegotiate_request()
.expect("Server sends request");
pair.assert_renegotiate().expect("Renegotiate");
pair.send_and_receive()
.expect("Application data after renegotiate");
}
Ok(())
}
#[test]
fn scheduled_renegotiate_with_immediate_app_data() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
let server_data = b"server_data";
pair.send_renegotiate_request()
.expect("server hello request");
pair.server
.write_all(server_data)
.expect("server app data after hello request");
let mut buffer = [0; 100];
let read = unwrap_poll(pair.client.poll_recv(&mut buffer))?;
assert_eq!(read, server_data.len());
assert_eq!(&buffer[0..read], server_data);
assert!(pair.client.is_renegotiating());
pair.assert_renegotiate().expect("Renegotiate");
Ok(())
}
#[test]
fn scheduled_renegotiate_with_delayed_app_data() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("server hello request");
let mut buffer = [0; 100];
let poll = pair.client.poll_recv(&mut buffer);
assert!(poll.is_pending());
assert!(pair.client.is_renegotiating());
let server_data = b"server_data";
pair.server
.write_all(server_data)
.expect("server app data after hello request");
let mut buffer = [0; 100];
let read = unwrap_poll(pair.client.poll_recv(&mut buffer))?;
assert_eq!(read, server_data.len());
assert_eq!(&buffer[0..read], server_data);
assert!(pair.client.is_renegotiating());
pair.assert_renegotiate().expect("Renegotiate");
Ok(())
}
#[test]
fn scheduled_renegotiate_without_final_app_data() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("server hello request");
assert!(pair.client.poll_recv(&mut [0; 1]).is_pending());
assert!(pair.client.is_renegotiating());
loop {
let _ = pair.server.read(&mut [0; 0]);
assert!(pair.client.poll_recv(&mut [0; 1]).is_pending());
if !pair.client.is_renegotiating() {
break;
}
}
pair.send_and_receive()
.expect("Application data after renegotiate");
Ok(())
}
#[test]
fn scheduled_renegotiate_with_buffered_recv() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("Server sends request");
let server_data = b"server_data";
assert_eq!(
pair.server.write(server_data).expect("server app data"),
server_data.len()
);
let mut buffer = [0; 100];
let read = unwrap_poll(pair.client.poll_recv(&mut buffer[..1]))
.expect("Read first byte of server data");
assert_eq!(read, 1);
assert_eq!(buffer[0], server_data[0]);
assert!(pair.client.is_renegotiating());
let read = unwrap_poll(pair.client.poll_recv(&mut buffer[1..]))
.expect("Drain buffered receive data");
assert_eq!(read, server_data.len() - 1);
assert_eq!(&buffer[..server_data.len()], server_data);
assert!(pair.client.is_renegotiating());
pair.assert_renegotiate().expect("Renegotiate");
Ok(())
}
#[test]
fn scheduled_renegotiate_with_buffered_send() -> Result<(), Box<dyn Error>> {
unsafe extern "C" fn blocking_send_cb(
_: *mut libc::c_void,
_: *const u8,
_: u32,
) -> libc::c_int {
errno::set_errno(errno::Errno(libc::EWOULDBLOCK));
-1
}
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
let client_data = b"client data";
pair.client.set_send_callback(Some(blocking_send_cb))?;
assert!(pair.client.poll_send(client_data).is_pending());
assert!(pair.client.renegotiate_state().send_pending);
pair.send_renegotiate_request()
.expect("Server sends request");
let error = unwrap_poll(pair.client.poll_recv(&mut [0; 1])).unwrap_err();
assert_eq!(error.kind(), ErrorType::UsageError);
assert!(error.message().contains(RENEG_ERR_MARKER));
assert!(error.message().contains("buffered send data"));
assert!(pair.client.is_renegotiating());
Ok(())
}
#[test]
fn scheduled_renegotiate_with_poll_send() -> Result<(), Box<dyn Error>> {
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
let mut pair = RenegotiateTestPair::from(builder)?;
pair.handshake().expect("Initial handshake");
pair.send_renegotiate_request()
.expect("server HELLO_REQUEST");
assert!(pair.client.poll_recv(&mut [0; 1]).is_pending());
assert!(pair.client.is_renegotiating());
let error = unwrap_poll(pair.client.poll_send(&[0; 1])).unwrap_err();
assert_eq!(error.kind(), ErrorType::Blocked);
assert!(error.message().contains(RENEG_ERR_MARKER));
assert!(error.message().contains("send application data"));
assert!(pair.client.is_renegotiating());
Ok(())
}
#[test]
fn scheduled_renegotiate_with_async_callback() -> Result<(), Box<dyn Error>> {
struct TestAsyncCallback {
count: usize,
op: Option<PrivateKeyOperation>,
}
impl PrivateKeyCallback for TestAsyncCallback {
fn handle_operation(
&self,
_: &mut Connection,
operation: PrivateKeyOperation,
) -> ConnectionFutureResult {
Ok(Some(Box::pin(TestAsyncCallback {
count: self.count,
op: Some(operation),
})))
}
}
impl ConnectionFuture for TestAsyncCallback {
fn poll(
self: Pin<&mut Self>,
conn: &mut Connection,
ctx: &mut core::task::Context,
) -> Poll<Result<(), crate::error::Error>> {
ctx.waker().wake_by_ref();
let this = self.get_mut();
if this.count > 1 {
this.count -= 1;
Pending
} else {
let op = this.op.take().unwrap();
let opt_ptr = op.as_ptr();
let chain_ptr = conn.selected_cert().unwrap().as_ptr();
unsafe {
let key_ptr = s2n_cert_chain_and_key_get_private_key(chain_ptr as *mut _)
.into_result()?
.as_ptr();
s2n_async_pkey_op_perform(opt_ptr, key_ptr).into_result()?;
s2n_async_pkey_op_apply(opt_ptr, conn.as_ptr()).into_result()?;
}
Ready(Ok(()))
}
}
}
let count_per_handshake = 10;
let async_callback = TestAsyncCallback {
count: count_per_handshake,
op: None,
};
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
builder.set_private_key_callback(async_callback)?;
let mut pair = RenegotiateTestPair::from(builder)?;
let (waker, wake_count) = new_count_waker();
pair.client.set_waker(Some(&waker))?;
pair.handshake().expect("Initial handshake");
assert_eq!(wake_count, count_per_handshake);
pair.send_renegotiate_request()
.expect("Server sends request");
pair.assert_renegotiate()?;
assert_eq!(wake_count, count_per_handshake * 2);
Ok(())
}
#[test]
fn scheduled_renegotiate_with_async_init() -> Result<(), Box<dyn Error>> {
#[derive(Clone)]
struct TestInitializer {
count: usize,
context: String,
}
impl ConnectionInitializer for TestInitializer {
fn initialize_connection(
&self,
_: &mut crate::connection::Connection,
) -> ConnectionFutureResult {
Ok(Some(Box::pin(self.clone())))
}
}
impl ConnectionFuture for TestInitializer {
fn poll(
self: Pin<&mut Self>,
conn: &mut Connection,
ctx: &mut core::task::Context,
) -> Poll<Result<(), crate::error::Error>> {
ctx.waker().wake_by_ref();
let this = self.get_mut();
assert!(conn.application_context::<String>().is_none());
if this.count > 1 {
this.count -= 1;
Pending
} else {
conn.set_application_context(this.context.clone());
Ready(Ok(()))
}
}
}
let count_per_handshake = 10;
let expected_context = "helloworld".to_owned();
let initializer = TestInitializer {
count: count_per_handshake,
context: expected_context.clone(),
};
let mut builder = config::Builder::new();
builder.set_renegotiate_callback(RenegotiateResponse::Schedule)?;
builder.set_connection_initializer(initializer)?;
let mut pair = RenegotiateTestPair::from(builder)?;
let (waker, wake_count) = new_count_waker();
pair.client.set_waker(Some(&waker))?;
pair.handshake().expect("Initial handshake");
assert_eq!(wake_count, count_per_handshake);
pair.send_renegotiate_request()
.expect("Server sends request");
pair.assert_renegotiate()?;
assert_eq!(wake_count, count_per_handshake * 2);
let context: Option<&String> = pair.client.application_context();
assert_eq!(Some(&expected_context), context);
Ok(())
}
#[test]
fn wipe_for_renegotiate_failure() -> Result<(), Box<dyn Error>> {
let mut connection = Connection::new_server();
let error = connection.wipe_for_renegotiate().unwrap_err();
assert_eq!(error.source(), ErrorSource::Library);
assert_eq!(error.name(), "S2N_ERR_NO_RENEGOTIATION");
Ok(())
}
}