#[macro_use]
extern crate log;
use std::cmp;
use std::mem;
use std::time;
use std::collections::hash_map;
use std::collections::HashMap;
pub const VERSION_DRAFT17: u32 = 0xff00_0011;
pub const MAX_CONN_ID_LEN: usize = 18;
const CLIENT_INITIAL_MIN_LEN: usize = 1200;
const PAYLOAD_MIN_LEN: usize = 4;
const DRAINING_TIMEOUT: time::Duration = time::Duration::from_millis(200);
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, PartialEq)]
pub enum Error {
Done,
BufferTooShort,
UnknownVersion,
InvalidFrame,
InvalidPacket,
InvalidState,
InvalidStreamState,
InvalidTransportParam,
CryptoFail,
TlsFail,
FlowControl,
StreamLimit,
}
impl Error {
pub fn to_wire(&self) -> u16 {
match self {
Error::Done => 0x0,
Error::InvalidFrame => 0x7,
Error::InvalidStreamState => 0x5,
Error::InvalidTransportParam => 0x8,
Error::CryptoFail => 0x100,
Error::TlsFail => 0x100,
Error::FlowControl => 0x3,
Error::StreamLimit => 0x4,
_ => 0xa,
}
}
fn to_c(&self) -> libc::ssize_t {
match self {
Error::Done => -1,
Error::BufferTooShort => -2,
Error::UnknownVersion => -3,
Error::InvalidFrame => -4,
Error::InvalidPacket => -5,
Error::InvalidState => -6,
Error::InvalidStreamState => -7,
Error::InvalidTransportParam => -8,
Error::CryptoFail => -9,
Error::TlsFail => -10,
Error::FlowControl => -11,
Error::StreamLimit => -12,
}
}
}
pub struct Config {
local_transport_params: TransportParams,
version: u32,
tls_ctx: tls::Context,
application_protos: Vec<Vec<u8>>,
}
impl Config {
#[allow(clippy::new_ret_no_self)]
pub fn new(version: u32) -> Result<Config> {
let tls_ctx = tls::Context::new().map_err(|_| Error::TlsFail)?;
Ok(Config {
local_transport_params: TransportParams::default(),
version,
tls_ctx,
application_protos: Vec::new(),
})
}
pub fn load_cert_chain_from_pem_file(&mut self, file: &str) -> Result<()> {
self.tls_ctx.use_certificate_chain_file(file)
.map_err(|_| Error::TlsFail)
}
pub fn load_priv_key_from_pem_file(&mut self, file: &str) -> Result<()> {
self.tls_ctx.use_privkey_file(file)
.map_err(|_| Error::TlsFail)
}
pub fn verify_peer(&mut self, verify: bool) {
self.tls_ctx.set_verify(verify);
}
pub fn log_keys(&mut self) {
self.tls_ctx.enable_keylog();
}
pub fn set_application_protos(&mut self, protos: &[&[u8]]) ->Result<()> {
self.application_protos = protos.iter().map(|p| p.to_vec()).collect();
self.tls_ctx.set_alpn(&self.application_protos)
.map_err(|_| Error::TlsFail)
}
pub fn set_idle_timeout(&mut self, v: u64) {
self.local_transport_params.idle_timeout = v;
}
pub fn set_stateless_reset_token(&mut self, v: &[u8; 16]) {
self.local_transport_params.stateless_reset_token = Some(v.to_vec());
}
pub fn set_max_packet_size(&mut self, v: u64) {
self.local_transport_params.max_packet_size = v;
}
pub fn set_initial_max_data(&mut self, v: u64) {
self.local_transport_params.initial_max_data = v;
}
pub fn set_initial_max_stream_data_bidi_local(&mut self, v: u64) {
self.local_transport_params.initial_max_stream_data_bidi_local = v;
}
pub fn set_initial_max_stream_data_bidi_remote(&mut self, v: u64) {
self.local_transport_params.initial_max_stream_data_bidi_remote = v;
}
pub fn set_initial_max_stream_data_uni(&mut self, v: u64) {
self.local_transport_params.initial_max_stream_data_uni = v;
}
pub fn set_initial_max_streams_bidi(&mut self, v: u64) {
self.local_transport_params.initial_max_streams_bidi = v;
}
pub fn set_initial_max_streams_uni(&mut self, v: u64) {
self.local_transport_params.initial_max_streams_uni = v;
}
pub fn set_ack_delay_exponent(&mut self, v: u64) {
self.local_transport_params.ack_delay_exponent = v;
}
pub fn set_max_ack_delay(&mut self, v: u64) {
self.local_transport_params.max_ack_delay = v;
}
pub fn set_disable_migration(&mut self, v: bool) {
self.local_transport_params.disable_migration = v;
}
}
pub struct Connection {
version: u32,
dcid: Vec<u8>,
scid: Vec<u8>,
trace_id: String,
initial: packet::PktNumSpace,
handshake: packet::PktNumSpace,
application: packet::PktNumSpace,
peer_transport_params: TransportParams,
local_transport_params: TransportParams,
tls_state: tls::Handshake,
recovery: recovery::Recovery,
application_protos: Vec<Vec<u8>>,
sent_count: usize,
lost_count: usize,
rx_data: usize,
max_rx_data: usize,
new_max_rx_data: usize,
tx_data: usize,
max_tx_data: usize,
streams: HashMap<u64, stream::Stream>,
local_max_streams_bidi: usize,
local_max_streams_uni: usize,
peer_max_streams_bidi: usize,
peer_max_streams_uni: usize,
odcid: Option<Vec<u8>>,
token: Option<Vec<u8>>,
error: Option<u16>,
app_error: Option<u16>,
app_reason: Vec<u8>,
challenge: Option<Vec<u8>>,
idle_timer: Option<time::Instant>,
draining_timer: Option<time::Instant>,
is_server: bool,
derived_initial_secrets: bool,
did_version_negotiation: bool,
did_retry: bool,
got_peer_conn_id: bool,
handshake_completed: bool,
draining: bool,
closed: bool,
}
pub fn accept(scid: &[u8], odcid: Option<&[u8]>, config: &mut Config) -> Result<Box<Connection>> {
let conn = Connection::new(scid, odcid, config, true)?;
Ok(conn)
}
pub fn connect(server_name: Option<&str>, scid: &[u8], config: &mut Config)
-> Result<Box<Connection>> {
let conn = Connection::new(scid, None, config, false)?;
if server_name.is_some() {
conn.tls_state.set_host_name(server_name.unwrap())
.map_err(|_| Error::TlsFail)?;
}
Ok(conn)
}
pub fn negotiate_version(scid: &[u8], dcid: &[u8], out: &mut [u8]) -> Result<usize> {
packet::negotiate_version(scid, dcid, out)
}
pub fn retry(scid: &[u8], dcid: &[u8], new_scid: &[u8], token: &[u8], out: &mut [u8]) -> Result<usize> {
packet::retry(scid, dcid, new_scid, token, out)
}
impl Connection {
#[allow(clippy::new_ret_no_self)]
fn new(scid: &[u8], odcid: Option<&[u8]>, config: &mut Config,
is_server: bool) -> Result<Box<Connection>> {
let tls = config.tls_ctx.new_handshake().map_err(|_| Error::TlsFail)?;
Connection::with_tls(scid, odcid, config, tls, is_server)
}
#[doc(hidden)]
pub fn with_tls(scid: &[u8], odcid: Option<&[u8]>, config: &mut Config,
tls: tls::Handshake, is_server: bool) -> Result<Box<Connection>> {
let max_rx_data = config.local_transport_params.initial_max_data;
let scid_as_hex: Vec<String> = scid.iter()
.map(|b| format!("{:02x}", b))
.collect();
let mut conn = Box::new(Connection {
version: config.version,
dcid: Vec::new(),
scid: scid.to_vec(),
trace_id: scid_as_hex.join(""),
initial: packet::PktNumSpace::new(crypto::Level::Initial),
handshake: packet::PktNumSpace::new(crypto::Level::Handshake),
application: packet::PktNumSpace::new(crypto::Level::Application),
peer_transport_params: TransportParams::default(),
local_transport_params: config.local_transport_params.clone(),
tls_state: tls,
recovery: recovery::Recovery::default(),
application_protos: config.application_protos.clone(),
sent_count: 0,
lost_count: 0,
rx_data: 0,
max_rx_data: max_rx_data as usize,
new_max_rx_data: max_rx_data as usize,
tx_data: 0,
max_tx_data: 0,
streams: HashMap::new(),
local_max_streams_bidi:
config.local_transport_params.initial_max_streams_bidi as usize,
local_max_streams_uni:
config.local_transport_params.initial_max_streams_uni as usize,
peer_max_streams_bidi: 0,
peer_max_streams_uni: 0,
odcid: None,
token: None,
error: None,
app_error: None,
app_reason: Vec::new(),
challenge: None,
idle_timer: None,
draining_timer: None,
is_server,
derived_initial_secrets: false,
did_version_negotiation: false,
did_retry: false,
got_peer_conn_id: false,
handshake_completed: false,
draining: false,
closed: false,
});
if let Some(odcid) = odcid {
conn.local_transport_params.original_connection_id =
Some(odcid.to_vec());
}
conn.tls_state.init(&conn).map_err(|_| Error::TlsFail)?;
if !is_server {
let mut dcid: [u8; 16] = [0; 16];
rand::rand_bytes(&mut dcid[..]);
let (aead_open, aead_seal) =
crypto::derive_initial_key_material(&dcid, conn.is_server)?;
conn.dcid.extend_from_slice(&dcid);
conn.initial.crypto_open = Some(aead_open);
conn.initial.crypto_seal = Some(aead_seal);
conn.derived_initial_secrets = true;
}
Ok(conn)
}
pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
let len = buf.len();
let mut done = 0;
let mut left = len;
while left > 0 {
let read = self.recv_single(&mut buf[len - left..len])?;
done += read;
left -= read;
}
Ok(done)
}
fn recv_single(&mut self, buf: &mut [u8]) -> Result<usize> {
let now = time::Instant::now();
if buf.is_empty() {
return Err(Error::BufferTooShort);
}
if self.draining {
return Err(Error::Done);
}
self.do_handshake()?;
let is_closing = self.error.is_some() || self.app_error.is_some();
if is_closing {
return Err(Error::Done);
}
let mut b = octets::Octets::with_slice(buf);
let mut hdr = Header::from_bytes(&mut b, self.scid.len())?;
if hdr.ty == packet::Type::VersionNegotiation {
if self.is_server {
return Err(Error::Done);
}
if self.did_version_negotiation {
return Err(Error::Done);
}
if hdr.dcid != self.scid {
return Err(Error::Done);
}
if hdr.scid != self.dcid {
return Err(Error::Done);
}
trace!("{} rx pkt {:?}", self.trace_id, hdr);
let versions = match hdr.versions {
Some(ref v) => v,
None => return Err(Error::InvalidPacket),
};
let mut new_version = 0;
for v in versions.iter() {
if *v == VERSION_DRAFT17 {
new_version = *v;
}
}
if new_version == 0 {
return Err(Error::UnknownVersion);
}
self.version = new_version;
self.did_version_negotiation = true;
self.got_peer_conn_id = false;
self.recovery.drop_unacked_data(&mut self.initial.flight);
self.initial.clear();
self.tls_state.clear()
.map_err(|_| Error::TlsFail)?;
return Err(Error::Done);
}
if hdr.ty == packet::Type::Retry {
if self.is_server {
return Err(Error::Done);
}
if self.did_retry {
return Err(Error::Done);
}
if hdr.odcid.as_ref() != Some(&self.dcid) {
return Err(Error::Done);
}
trace!("{} rx pkt {:?}", self.trace_id, hdr);
self.token = hdr.token;
self.did_retry = true;
self.odcid = Some(self.dcid.clone());
self.dcid.resize(hdr.scid.len(), 0);
self.dcid.copy_from_slice(&hdr.scid);
let (aead_open, aead_seal) =
crypto::derive_initial_key_material(&hdr.scid, self.is_server)?;
self.initial.crypto_open = Some(aead_open);
self.initial.crypto_seal = Some(aead_seal);
self.got_peer_conn_id = false;
self.recovery.drop_unacked_data(&mut self.initial.flight);
self.initial.clear();
self.tls_state.clear()
.map_err(|_| Error::TlsFail)?;
return Err(Error::Done);
}
if hdr.ty != packet::Type::Application && hdr.version != self.version {
return Err(Error::UnknownVersion);
}
let payload_len = if hdr.ty == packet::Type::Application {
b.cap()
} else {
b.get_varint()? as usize
};
if b.cap() < payload_len {
return Err(Error::BufferTooShort);
}
if !self.is_server && !self.got_peer_conn_id {
self.dcid.resize(hdr.scid.len(), 0);
self.dcid.copy_from_slice(&hdr.scid);
self.got_peer_conn_id = true;
}
if !self.derived_initial_secrets {
let (aead_open, aead_seal) =
crypto::derive_initial_key_material(&hdr.dcid, self.is_server)?;
self.initial.crypto_open = Some(aead_open);
self.initial.crypto_seal = Some(aead_seal);
self.derived_initial_secrets = true;
self.dcid.extend_from_slice(&hdr.scid);
self.got_peer_conn_id = true;
}
let space = match hdr.ty {
packet::Type::Initial => &mut self.initial,
packet::Type::Handshake => &mut self.handshake,
packet::Type::Application => &mut self.application,
_ => return Err(Error::InvalidPacket),
};
let aead = match space.crypto_open {
Some(ref v) => v,
None => {
trace!("{} dropped undecryptable packet type={:?} len={}",
self.trace_id, hdr.ty, payload_len);
return Ok(b.off() + payload_len)
},
};
packet::decrypt_hdr(&mut b, &mut hdr, &aead)?;
let pn = packet::decode_pkt_num(space.largest_rx_pkt_num,
hdr.pkt_num, hdr.pkt_num_len);
trace!("{} rx pkt {:?} len={} pn={}", self.trace_id, hdr,
payload_len, pn);
let mut payload = packet::decrypt_pkt(&mut b, pn, hdr.pkt_num_len,
payload_len, &aead)?;
if space.recv_pkt_num.contains(pn) {
trace!("{} ignored duplicate packet {}", self.trace_id, pn);
return Err(Error::Done);
}
let mut do_ack = false;
while payload.cap() > 0 {
let frame = frame::Frame::from_bytes(&mut payload, hdr.ty)?;
trace!("{} rx frm {:?}", self.trace_id, frame);
match frame {
frame::Frame::Padding { .. } => (),
frame::Frame::Ping => {
do_ack = true;
},
frame::Frame::ACK { ranges, ack_delay } => {
let ack_delay =
ack_delay * 2_u64.pow(self.peer_transport_params
.ack_delay_exponent as u32);
self.recovery.on_ack_received(&ranges, ack_delay,
&mut space.flight,
now, &self.trace_id);
},
frame::Frame::StopSending { stream_id, .. } => {
if !stream::is_local(stream_id, self.is_server) &&
!stream::is_bidi(stream_id) {
return Err(Error::InvalidPacket);
}
do_ack = true;
},
frame::Frame::Crypto { data } => {
space.crypto_stream.recv_push(data)?;
if space.crypto_stream.readable() {
let buf = space.crypto_stream.recv_pop(std::usize::MAX)?;
let level = space.crypto_level;
self.tls_state.provide_data(level, &buf)
.map_err(|_| Error::TlsFail)?;
}
do_ack = true;
},
frame::Frame::NewToken { .. } => {
do_ack = true;
},
frame::Frame::Stream { stream_id, data } => {
if !stream::is_bidi(stream_id) &&
stream::is_local(stream_id, self.is_server) {
return Err(Error::InvalidStreamState);
}
let max_rx_data =
self.local_transport_params
.initial_max_stream_data_bidi_remote as usize;
let max_tx_data =
self.peer_transport_params
.initial_max_stream_data_bidi_local as usize;
let stream = match self.streams.entry(stream_id) {
hash_map::Entry::Vacant(v) => {
if stream::is_local(stream_id, self.is_server) {
return Err(Error::InvalidStreamState);
}
if stream::is_bidi(stream_id) {
self.local_max_streams_bidi
.checked_sub(1)
.ok_or(Error::StreamLimit)?;
} else {
self.local_max_streams_uni
.checked_sub(1)
.ok_or(Error::StreamLimit)?;
}
let s = stream::Stream::new(max_rx_data, max_tx_data);
v.insert(s)
},
hash_map::Entry::Occupied(v) => v.into_mut(),
};
self.rx_data += data.len();
if self.rx_data > self.max_rx_data {
return Err(Error::FlowControl);
}
stream.recv_push(data)?;
do_ack = true;
},
frame::Frame::MaxData { max } => {
self.max_tx_data = cmp::max(self.max_tx_data,
max as usize);
do_ack = true;
},
frame::Frame::MaxStreamData { stream_id, max } => {
let max_rx_data =
self.local_transport_params
.initial_max_stream_data_bidi_remote as usize;
let max_tx_data =
self.peer_transport_params
.initial_max_stream_data_bidi_local as usize;
let stream = match self.streams.entry(stream_id) {
hash_map::Entry::Vacant(v) => {
if stream::is_local(stream_id, self.is_server) {
return Err(Error::InvalidStreamState);
}
if stream::is_bidi(stream_id) {
self.local_max_streams_bidi
.checked_sub(1)
.ok_or(Error::StreamLimit)?;
} else {
self.local_max_streams_uni
.checked_sub(1)
.ok_or(Error::StreamLimit)?;
}
let s = stream::Stream::new(max_rx_data, max_tx_data);
v.insert(s)
},
hash_map::Entry::Occupied(v) => v.into_mut(),
};
stream.send_max_data(max as usize);
do_ack = true;
},
frame::Frame::MaxStreamsBidi { max } => {
self.peer_max_streams_bidi =
cmp::max(self.peer_max_streams_bidi, max as usize);
do_ack = true;
},
frame::Frame::MaxStreamsUni { max } => {
self.peer_max_streams_uni =
cmp::max(self.peer_max_streams_uni, max as usize);
do_ack = true;
},
frame::Frame::NewConnectionId { .. } => {
do_ack = true;
},
frame::Frame::RetireConnectionId { .. } => {
do_ack = true;
},
frame::Frame::PathChallenge { data } => {
self.challenge = Some(data);
do_ack = true;
},
frame::Frame::PathResponse { .. } => {
do_ack = true;
},
frame::Frame::ConnectionClose { .. } => {
self.draining = true;
self.draining_timer = Some(now + DRAINING_TIMEOUT);
},
frame::Frame::ApplicationClose { .. } => {
self.draining = true;
self.draining_timer = Some(now + DRAINING_TIMEOUT);
},
}
}
for acked in space.flight.acked.drain(..) {
match acked {
frame::Frame::ACK { ranges, .. } => {
let largest_acked = ranges.largest().unwrap();
space.recv_pkt_need_ack.remove_until(largest_acked);
},
frame::Frame::Ping => (),
_ => (),
}
}
if space.recv_pkt_need_ack.largest() < Some(pn) {
space.largest_rx_pkt_time = now;
}
space.recv_pkt_num.insert(pn);
space.recv_pkt_need_ack.push_item(pn);
space.do_ack = cmp::max(space.do_ack, do_ack);
space.largest_rx_pkt_num = cmp::max(space.largest_rx_pkt_num, pn);
self.idle_timer =
Some(now + time::Duration::from_secs(
self.local_transport_params.idle_timeout));
let read = b.off() + aead.alg().tag_len();
if self.is_server && hdr.ty == packet::Type::Handshake {
self.drop_initial_state();
}
Ok(read)
}
pub fn send(&mut self, out: &mut [u8]) -> Result<usize> {
let now = time::Instant::now();
if out.is_empty() {
return Err(Error::BufferTooShort);
}
if self.draining {
return Err(Error::Done);
}
let is_closing = self.error.is_some() || self.app_error.is_some();
if !is_closing {
self.do_handshake()?;
}
let max_pkt_len = self.peer_transport_params.max_packet_size as usize;
let avail = cmp::min(max_pkt_len, out.len());
let mut b = octets::Octets::with_slice(&mut out[..avail]);
let pkt_type = self.select_egress_pkt_type()?;
let space = match pkt_type {
packet::Type::Initial => &mut self.initial,
packet::Type::Handshake => &mut self.handshake,
packet::Type::Application => &mut self.application,
_ => unreachable!(),
};
for lost in space.flight.lost.drain(..) {
match lost {
frame::Frame::Crypto { data } => {
space.crypto_stream.send_push_front(data)?;
},
frame::Frame::Stream { stream_id, data } => {
let stream = match self.streams.get_mut(&stream_id) {
Some(v) => v,
None => continue,
};
self.tx_data -= data.len();
stream.send_push_front(data)?;
},
frame::Frame::ACK { .. } => {
space.do_ack = true;
},
_ => (),
}
}
self.lost_count += space.flight.lost_count;
space.flight.lost_count = 0;
let mut left = cmp::min(self.recovery.cwnd(), b.cap());
let pn = space.next_pkt_num;
let pn_len = packet::pkt_num_len(pn)?;
let hdr = Header {
ty: pkt_type,
version: self.version,
dcid: self.dcid.clone(),
scid: self.scid.clone(),
pkt_num: 0,
pkt_num_len: pn_len,
odcid: None,
token: self.token.clone(),
versions: None,
key_phase: false,
};
hdr.to_bytes(&mut b)?;
if left < b.off() + 4 + pn_len + space.overhead() {
return Err(Error::Done);
}
left -= b.off() + 4 + pn_len + space.overhead();
let mut frames: Vec<frame::Frame> = Vec::new();
let mut ack_eliciting = false;
let mut is_crypto = false;
let mut payload_len = 0;
if space.do_ack {
let ack_delay = space.largest_rx_pkt_time.elapsed();
let ack_delay = ack_delay.as_secs() * 1_000_000 +
u64::from(ack_delay.subsec_micros());
let ack_delay =
ack_delay / 2_u64.pow(self.local_transport_params
.ack_delay_exponent as u32);
let frame = frame::Frame::ACK {
ack_delay,
ranges: space.recv_pkt_need_ack.clone(),
};
if frame.wire_len() <= left {
space.do_ack = false;
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
}
}
if pkt_type == packet::Type::Application && !is_closing
&& (self.new_max_rx_data != self.max_rx_data &&
self.new_max_rx_data / 2 > self.max_rx_data - self.rx_data)
{
let frame = frame::Frame::MaxData {
max: self.new_max_rx_data as u64,
};
if frame.wire_len() <= left {
self.max_rx_data = self.new_max_rx_data;
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
ack_eliciting = true;
}
}
if pkt_type == packet::Type::Application && !is_closing {
for (id, stream) in self.streams.iter_mut()
.filter(|(_, s)| s.more_credit()) {
let frame = frame::Frame::MaxStreamData {
stream_id: *id,
max: stream.recv_update_max_data() as u64,
};
if frame.wire_len() > left {
break;
}
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
ack_eliciting = true;
}
}
if self.recovery.probes > 0 && left >= 1 {
let frame = frame::Frame::Ping;
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
self.recovery.probes -= 1;
ack_eliciting = true;
}
if let Some(err) = self.error {
let frame = frame::Frame::ConnectionClose {
error_code: err,
frame_type: 0,
reason: Vec::new(),
};
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
self.draining = true;
self.draining_timer = Some(now + DRAINING_TIMEOUT);
}
if let Some(err) = self.app_error {
let frame = frame::Frame::ApplicationClose {
error_code: err,
reason: self.app_reason.clone(),
};
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
self.draining = true;
self.draining_timer = Some(now + DRAINING_TIMEOUT);
}
if let Some(ref challenge) = self.challenge {
let frame = frame::Frame::PathResponse {
data: challenge.clone(),
};
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
self.challenge = None;
}
if space.crypto_stream.writable() && !is_closing {
let crypto_len = left - frame::MAX_CRYPTO_OVERHEAD;
let crypto_buf = space.crypto_stream.send_pop(crypto_len)?;
let frame = frame::Frame::Crypto {
data: crypto_buf,
};
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
ack_eliciting = true;
is_crypto = true;
}
if pkt_type == packet::Type::Application && !is_closing
&& self.max_tx_data > self.tx_data
&& left > frame::MAX_STREAM_OVERHEAD
{
for (id, stream) in self.streams.iter_mut()
.filter(|(_, s)| s.writable()) {
let stream_len = cmp::min(left - frame::MAX_STREAM_OVERHEAD,
self.max_tx_data - self.tx_data);
let stream_buf = stream.send_pop(stream_len)?;
if stream_buf.is_empty() {
continue;
}
self.tx_data += stream_buf.len();
let frame = frame::Frame::Stream {
stream_id: *id,
data: stream_buf,
};
payload_len += frame.wire_len();
left -= frame.wire_len();
frames.push(frame);
ack_eliciting = true;
break;
}
}
if frames.is_empty() {
return Err(Error::Done);
}
if !self.is_server && pkt_type == packet::Type::Initial {
let pkt_len = pn_len + payload_len + space.overhead();
let frame = frame::Frame::Padding {
len: cmp::min(CLIENT_INITIAL_MIN_LEN - pkt_len, left),
};
payload_len += frame.wire_len();
frames.push(frame);
}
if payload_len < PAYLOAD_MIN_LEN {
let frame = frame::Frame::Padding {
len: PAYLOAD_MIN_LEN - payload_len,
};
payload_len += frame.wire_len();
frames.push(frame);
}
payload_len += space.overhead();
if pkt_type != packet::Type::Application {
let len = pn_len + payload_len;
b.put_varint(len as u64)?;
}
packet::encode_pkt_num(pn, &mut b)?;
let payload_offset = b.off();
trace!("{} tx pkt {:?} len={} pn={}", self.trace_id, hdr,
payload_len, pn);
for frame in &frames {
trace!("{} tx frm {:?}", self.trace_id, frame);
frame.to_bytes(&mut b)?;
}
let aead = match space.crypto_seal {
Some(ref v) => v,
None => return Err(Error::InvalidState),
};
let written = packet::encrypt_pkt(&mut b, pn, pn_len, payload_len,
payload_offset, aead)?;
let sent_pkt = recovery::Sent::new(pn, frames, written, ack_eliciting,
is_crypto, now);
self.recovery.on_packet_sent(sent_pkt, &mut space.flight, now,
&self.trace_id);
space.next_pkt_num += 1;
self.sent_count += 1;
if !self.is_server && hdr.ty == packet::Type::Handshake {
self.drop_initial_state();
}
Ok(written)
}
pub fn stream_recv(&mut self, stream_id: u64, max_len: usize) -> Result<RangeBuf> {
let stream = match self.streams.get_mut(&stream_id) {
Some(v) => v,
None => return Err(Error::InvalidStreamState),
};
if !stream.readable() {
return Err(Error::Done);
}
let buf = stream.recv_pop(max_len)?;
self.new_max_rx_data = self.max_rx_data + buf.len();
Ok(buf)
}
pub fn stream_send(&mut self, stream_id: u64, buf: &[u8], fin: bool)
-> Result<usize> {
if !stream::is_bidi(stream_id) &&
!stream::is_local(stream_id, self.is_server) {
return Err(Error::InvalidStreamState);
}
let max_rx_data = self.local_transport_params
.initial_max_stream_data_bidi_local as usize;
let max_tx_data = self.peer_transport_params
.initial_max_stream_data_bidi_remote as usize;
let stream = match self.streams.entry(stream_id) {
hash_map::Entry::Vacant(v) => {
if !stream::is_local(stream_id, self.is_server) {
return Err(Error::InvalidStreamState);
}
if stream::is_bidi(stream_id) {
self.peer_max_streams_bidi.checked_sub(1)
.ok_or(Error::StreamLimit)?;
} else {
self.peer_max_streams_uni.checked_sub(1)
.ok_or(Error::StreamLimit)?;
}
let s = stream::Stream::new(max_rx_data, max_tx_data);
v.insert(s)
},
hash_map::Entry::Occupied(v) => v.into_mut(),
};
stream.send_push(buf, fin)?;
Ok(buf.len())
}
pub fn readable(&mut self) -> Readable {
stream::Readable::new(&self.streams)
}
pub fn timeout(&self) -> Option<std::time::Duration> {
if self.closed {
return None;
}
let timeout = if self.draining {
self.draining_timer
} else if self.recovery.loss_detection_timer().is_some() {
self.recovery.loss_detection_timer()
} else if self.idle_timer.is_some() {
self.idle_timer
} else {
None
};
if let Some(timeout) = timeout {
let now = time::Instant::now();
if timeout <= now {
return Some(std::time::Duration::new(0, 0));
}
return Some(timeout.duration_since(now));
}
None
}
pub fn on_timeout(&mut self) {
let now = time::Instant::now();
if self.draining {
if self.draining_timer.is_some() &&
self.draining_timer.unwrap() <= now {
trace!("{} draining timeout expired", self.trace_id);
self.closed = true;
}
return;
}
if self.idle_timer.is_some() && self.idle_timer.unwrap() <= now {
trace!("{} idle timeout expired", self.trace_id);
self.closed = true;
return;
}
if self.recovery.loss_detection_timer().is_some() &&
self.recovery.loss_detection_timer().unwrap() <= now {
trace!("{} loss detection timeout expired", self.trace_id);
self.recovery.on_loss_detection_timer(&mut self.initial.flight,
&mut self.handshake.flight,
&mut self.application.flight,
now, &self.trace_id);
return;
}
}
pub fn close(&mut self, app: bool, err: u16, reason: &[u8]) -> Result<()> {
if self.draining {
return Err(Error::Done);
}
if self.error.is_some() || self.app_error.is_some() {
return Err(Error::Done);
}
if app {
self.app_error = Some(err);
self.app_reason.extend_from_slice(reason);
} else {
self.error = Some(err);
}
Ok(())
}
pub fn trace_id(&self) -> &str {
&self.trace_id
}
pub fn application_proto(&self) -> &[u8] {
self.tls_state.get_alpn_protocol()
}
pub fn is_established(&self) -> bool {
self.handshake_completed
}
pub fn is_resumed(&self) -> bool {
self.tls_state.is_resumed()
}
pub fn is_closed(&self) -> bool {
self.closed
}
pub fn stats(&self) -> Stats {
Stats {
sent: self.sent_count,
lost: self.lost_count,
rtt: self.recovery.rtt(),
}
}
fn do_handshake(&mut self) -> Result<()> {
if !self.handshake_completed {
match self.tls_state.do_handshake() {
Ok(_) => {
self.handshake_completed = true;
let mut raw_params =
self.tls_state.get_quic_transport_params().to_vec();
let peer_params = TransportParams::decode(&mut raw_params,
self.version,
self.is_server)?;
if peer_params.original_connection_id != self.odcid {
return Err(Error::InvalidTransportParam);
}
self.max_tx_data = peer_params.initial_max_data as usize;
self.peer_max_streams_bidi =
peer_params.initial_max_streams_bidi as usize;
self.peer_max_streams_uni =
peer_params.initial_max_streams_uni as usize;
self.recovery.max_ack_delay =
time::Duration::from_millis(peer_params.max_ack_delay);
self.peer_transport_params = peer_params;
trace!("{} connection established: cipher={:?} proto={:?} resumed={}",
&self.trace_id,
self.tls_state.cipher(),
std::str::from_utf8(self.application_proto()),
self.is_resumed());
},
Err(tls::Error::TlsFail) => {
if self.error.is_none() {
return Err(Error::TlsFail);
}
},
Err(tls::Error::SyscallFail) => return Err(Error::TlsFail),
Err(_) => (),
}
}
Ok(())
}
fn select_egress_pkt_type(&self) -> Result<Type> {
let ty =
if self.error.is_some() || self.recovery.probes > 0 {
match self.tls_state.get_write_level() {
crypto::Level::Initial => Type::Initial,
crypto::Level::ZeroRTT => unreachable!(),
crypto::Level::Handshake => Type::Handshake,
crypto::Level::Application => Type::Application,
}
} else if self.initial.ready() {
Type::Initial
} else if self.handshake.ready() {
Type::Handshake
} else if self.handshake_completed &&
(self.application.ready() ||
self.streams.values().any(|s| s.writable()) ||
self.streams.values().any(|s| s.more_credit())) {
Type::Application
} else {
return Err(Error::Done);
};
Ok(ty)
}
fn drop_initial_state(&mut self) {
if self.initial.crypto_open.is_none() {
return;
}
self.recovery.drop_unacked_data(&mut self.initial.flight);
self.initial.crypto_open = None;
self.initial.crypto_seal = None;
self.initial.clear();
trace!("{} dropped initial state", self.trace_id);
}
}
#[derive(Clone)]
pub struct Stats {
pub sent: usize,
pub lost: usize,
pub rtt: time::Duration,
}
impl std::fmt::Debug for Stats {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "sent={} lost={} rtt={:?}", self.sent, self.lost, self.rtt)
}
}
#[derive(Clone, Debug, PartialEq)]
struct TransportParams {
pub original_connection_id: Option<Vec<u8>>,
pub idle_timeout: u64,
pub stateless_reset_token: Option<Vec<u8>>,
pub max_packet_size: u64,
pub initial_max_data: u64,
pub initial_max_stream_data_bidi_local: u64,
pub initial_max_stream_data_bidi_remote: u64,
pub initial_max_stream_data_uni: u64,
pub initial_max_streams_bidi: u64,
pub initial_max_streams_uni: u64,
pub ack_delay_exponent: u64,
pub max_ack_delay: u64,
pub disable_migration: bool,
}
impl Default for TransportParams {
fn default() -> TransportParams {
TransportParams {
original_connection_id: None,
idle_timeout: 0,
stateless_reset_token: None,
max_packet_size: 65527,
initial_max_data: 0,
initial_max_stream_data_bidi_local: 0,
initial_max_stream_data_bidi_remote: 0,
initial_max_stream_data_uni: 0,
initial_max_streams_bidi: 0,
initial_max_streams_uni: 0,
ack_delay_exponent: 3,
max_ack_delay: 25,
disable_migration: false,
}
}
}
impl TransportParams {
fn decode(buf: &mut [u8], _version: u32, is_server: bool)
-> Result<TransportParams> {
let mut b = octets::Octets::with_slice(buf);
let _tp_version = b.get_u32()?;
if !is_server {
b.get_bytes_with_u8_length()?;
}
let mut tp = TransportParams::default();
let mut params = b.get_bytes_with_u16_length()?;
while params.cap() > 0 {
let id = params.get_u16()?;
let mut val = params.get_bytes_with_u16_length()?;
match id {
0x0000 => {
if is_server {
return Err(Error::InvalidTransportParam);
}
tp.original_connection_id = Some(val.to_vec());
},
0x0001 => {
tp.idle_timeout = val.get_varint()?;
},
0x0002 => {
if is_server {
return Err(Error::InvalidTransportParam);
}
tp.stateless_reset_token = Some(val.get_bytes(16)?.to_vec());
},
0x0003 => {
tp.max_packet_size = val.get_varint()?;
},
0x0004 => {
tp.initial_max_data = val.get_varint()?;
},
0x0005 => {
tp.initial_max_stream_data_bidi_local = val.get_varint()?;
},
0x0006 => {
tp.initial_max_stream_data_bidi_remote = val.get_varint()?;
},
0x0007 => {
tp.initial_max_stream_data_uni = val.get_varint()?;
},
0x0008 => {
tp.initial_max_streams_bidi = val.get_varint()?;
},
0x0009 => {
tp.initial_max_streams_uni = val.get_varint()?;
},
0x000a => {
tp.ack_delay_exponent = val.get_varint()?;
},
0x000b => {
tp.max_ack_delay = val.get_varint()?;
},
0x000c => {
tp.disable_migration = true;
},
0x000d => {
if is_server {
return Err(Error::InvalidTransportParam);
}
},
_ => (),
}
}
Ok(tp)
}
fn encode<'a>(tp: &TransportParams, version: u32, is_server: bool,
out: &'a mut [u8]) -> Result<&'a mut [u8]> {
let mut params: [u8; 128] = [0; 128];
let params_len = {
let mut b = octets::Octets::with_slice(&mut params);
if is_server {
if let Some(ref odcid) = tp.original_connection_id {
b.put_u16(0x0000)?;
b.put_u16(odcid.len() as u16)?;
b.put_bytes(&odcid)?;
}
};
if tp.idle_timeout != 0 {
b.put_u16(0x0001)?;
b.put_u16(octets::varint_len(tp.idle_timeout) as u16)?;
b.put_varint(tp.idle_timeout)?;
}
if let Some(ref token) = tp.stateless_reset_token {
if is_server {
b.put_u16(0x0002)?;
b.put_u16(token.len() as u16)?;
b.put_bytes(&token)?;
}
}
if tp.max_packet_size != 0 {
b.put_u16(0x0003)?;
b.put_u16(octets::varint_len(tp.max_packet_size) as u16)?;
b.put_varint(tp.max_packet_size)?;
}
if tp.initial_max_data != 0 {
b.put_u16(0x0004)?;
b.put_u16(octets::varint_len(tp.initial_max_data) as u16)?;
b.put_varint(tp.initial_max_data)?;
}
if tp.initial_max_stream_data_bidi_local != 0 {
b.put_u16(0x0005)?;
b.put_u16(octets::varint_len(tp.initial_max_stream_data_bidi_local) as u16)?;
b.put_varint(tp.initial_max_stream_data_bidi_local)?;
}
if tp.initial_max_stream_data_bidi_remote != 0 {
b.put_u16(0x0006)?;
b.put_u16(octets::varint_len(tp.initial_max_stream_data_bidi_remote) as u16)?;
b.put_varint(tp.initial_max_stream_data_bidi_remote)?;
}
if tp.initial_max_stream_data_uni != 0 {
b.put_u16(0x0007)?;
b.put_u16(octets::varint_len(tp.initial_max_stream_data_uni) as u16)?;
b.put_varint(tp.initial_max_stream_data_uni)?;
}
if tp.initial_max_streams_bidi != 0 {
b.put_u16(0x0008)?;
b.put_u16(octets::varint_len(tp.initial_max_streams_bidi) as u16)?;
b.put_varint(tp.initial_max_streams_bidi)?;
}
if tp.initial_max_streams_uni != 0 {
b.put_u16(0x0009)?;
b.put_u16(octets::varint_len(tp.initial_max_streams_uni) as u16)?;
b.put_varint(tp.initial_max_streams_uni)?;
}
if tp.ack_delay_exponent != 0 {
b.put_u16(0x000a)?;
b.put_u16(octets::varint_len(tp.ack_delay_exponent) as u16)?;
b.put_varint(tp.ack_delay_exponent)?;
}
if tp.max_ack_delay != 0 {
b.put_u16(0x000b)?;
b.put_u16(octets::varint_len(tp.max_ack_delay) as u16)?;
b.put_varint(tp.max_ack_delay)?;
}
if tp.disable_migration {
b.put_u16(0x000c)?;
b.put_u16(0)?;
}
b.off()
};
let out_len = {
let mut b = octets::Octets::with_slice(out);
b.put_u32(version)?;
if is_server {
b.put_u8(mem::size_of::<u32>() as u8)?;
b.put_u32(version)?;
};
b.put_u16(params_len as u16)?;
b.put_bytes(¶ms[..params_len])?;
b.off()
};
Ok(&mut out[..out_len])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transport_params() {
let tp = TransportParams {
original_connection_id: None,
idle_timeout: 30,
stateless_reset_token: Some(vec![0xba; 16]),
max_packet_size: 23_421,
initial_max_data: 424_645_563,
initial_max_stream_data_bidi_local: 154_323_123,
initial_max_stream_data_bidi_remote: 6_587_456,
initial_max_stream_data_uni: 2_461_234,
initial_max_streams_bidi: 12_231,
initial_max_streams_uni: 18_473,
ack_delay_exponent: 123,
max_ack_delay: 1234,
disable_migration: true,
};
let mut raw_params: [u8; 256] = [42; 256];
let mut raw_params = TransportParams::encode(&tp, VERSION_DRAFT17, true,
&mut raw_params).unwrap();
assert_eq!(raw_params.len(), 106);
let new_tp = TransportParams::decode(&mut raw_params, VERSION_DRAFT17,
false).unwrap();
assert_eq!(new_tp, tp);
}
fn create_conn(is_server: bool) -> Box<Connection> {
let mut scid: [u8; 16] = [0; 16];
rand::rand_bytes(&mut scid[..]);
let mut config = Config::new(VERSION_DRAFT17).unwrap();
config.load_cert_chain_from_pem_file("examples/cert.crt").unwrap();
config.load_priv_key_from_pem_file("examples/cert.key").unwrap();
config.verify_peer(false);
Connection::new(&scid, None, &mut config, is_server).unwrap()
}
fn recv_send(conn: &mut Connection, buf: &mut [u8], len: usize) -> usize {
let mut left = len;
while left > 0 {
let read = conn.recv(&mut buf[len - left..len]).unwrap();
left -= read;
}
let mut off = 0;
while off < buf.len() {
let write = match conn.send(&mut buf[off..]) {
Ok(v) => v,
Err(Error::Done) => { break; },
Err(e) => panic!("SEND FAILED: {:?}", e),
};
off += write;
}
off
}
#[test]
fn self_handshake() {
let mut buf = [0; 65535];
let mut cln = create_conn(false);
let mut srv = create_conn(true);
let mut len = cln.send(&mut buf).unwrap();
while !cln.is_established() && !srv.is_established() {
len = recv_send(&mut srv, &mut buf, len);
len = recv_send(&mut cln, &mut buf, len);
}
assert!(true);
}
}
pub use crate::stream::RangeBuf;
pub use crate::stream::Readable;
pub use crate::packet::Header;
pub use crate::packet::Type;
mod crypto;
mod ffi;
mod frame;
mod octets;
mod packet;
mod rand;
mod ranges;
mod recovery;
mod stream;
mod tls;