use std::{borrow::Cow, io::{Cursor, Write}, marker::PhantomData, mem::offset_of};
use byteorder::{BigEndian, WriteBytesExt};
use crc32fast::Hasher;
use crate::{common, error::{HaProxErr, HaProxRes}, return_error};
use super::{protocol::{self, HdrV2Command, PP2TlvClient, PP2Tlvs, ProxyTransportFam, ProxyV2Addr}, protocol_raw, PP2TlvDump, PP2TlvUniqId, ProxyV2OpCode};
#[derive(Clone, Debug)]
pub struct HdrV2OpLocal;
#[derive(Clone, Debug)]
pub struct HdrV2OpProxy;
impl ProxyV2OpCode for HdrV2OpLocal
{
const OPCODE: u8 = HdrV2Command::LOCAL as u8;
}
impl ProxyV2OpCode for HdrV2OpProxy
{
const OPCODE: u8 = HdrV2Command::PROXY as u8;
}
#[derive(Debug)]
pub struct TlvSubTypeSsl<'s>
{
start: u64,
hdr: TlType<'s>,
constraints: &'static [std::ops::RangeInclusive<u8>]
}
impl<'s> TlvSubTypeSsl<'s>
{
fn new(mut main_tp: TlType<'s>, constraints: &'static [std::ops::RangeInclusive<u8>],
client: PP2TlvClient, verify: u32) -> HaProxRes<Self>
{
let start = main_tp.hdr.buffer.position();
main_tp.add_tlv(PP2Tlvs::TypeSsl{client, verify}, None)?;
return Ok(
Self
{
start: start,
hdr: main_tp,
constraints: constraints
}
);
}
#[inline]
fn done_int(&mut self) -> HaProxRes<()>
{
let cur_pos = self.hdr.hdr.buffer.position();
self.hdr.hdr.buffer.set_position(self.start + offset_of!(protocol_raw::PP2Tlv, length_hi) as u64);
self.hdr.hdr.buffer.write_u16::<BigEndian>((cur_pos - self.start - 3) as u16).map_err(common::map_io_err)?;
self.hdr.hdr.buffer.set_position(cur_pos);
return Ok(());
}
#[inline]
pub
fn done(mut self) -> HaProxRes<TlType<'s>>
{
self.done_int()?;
return Ok(self.hdr);
}
pub
fn add_ssl_sub_version(&mut self, ver: impl Into<String>) -> HaProxRes<()>
{
return self.hdr.add_tlv(PP2Tlvs::TypeSubtypeSslVersion(Cow::Owned(ver.into())), Some(self.constraints));
}
pub
fn add_ssl_sub_cn(&mut self, cn: impl Into<String>) -> HaProxRes<()>
{
return self.hdr.add_tlv(PP2Tlvs::TypeSubtypeSslCn(Cow::Owned(cn.into())), Some(self.constraints));
}
pub
fn add_ssl_sub_cipher(&mut self, ver: impl Into<String>) -> HaProxRes<()>
{
return self.hdr.add_tlv(PP2Tlvs::TypeSubtypeSslCipher(Cow::Owned(ver.into())), Some(self.constraints));
}
pub
fn add_ssl_sub_sigalg(&mut self, ver: impl Into<String>) -> HaProxRes<()>
{
return self.hdr.add_tlv(PP2Tlvs::TypeSubtypeSslSigAlg(Cow::Owned(ver.into())), Some(self.constraints));
}
pub
fn add_ssl_sub_keyalg(&mut self, ver: impl Into<String>) -> HaProxRes<()>
{
return self.hdr.add_tlv(PP2Tlvs::TypeSubtypeSslKeyAlg(Cow::Owned(ver.into())), Some(self.constraints));
}
pub
fn add_ssl_sub_netns(&mut self, ver: impl Into<String>) -> HaProxRes<()>
{
return self.hdr.add_tlv(PP2Tlvs::TypeNetNs(ver.into()), Some(self.constraints));
}
}
#[derive(Debug)]
pub struct TlType<'s>
{
hdr: &'s mut ProxyHdrV2<HdrV2OpProxy>,
constraints: &'static [std::ops::RangeInclusive<u8>]
}
impl<'s> TlType<'s>
{
fn new(hdr: &'s mut ProxyHdrV2<HdrV2OpProxy>, constraints: &'static [std::ops::RangeInclusive<u8>]) -> Self
{
return
Self
{
hdr,
constraints
};
}
pub
fn add_alpn<'a>(&mut self, alpns: impl Iterator<Item = &'a [u8]>) -> HaProxRes<()>
{
return self.add_tlv(PP2Tlvs::TypeAlpn( alpns.map(|v| v.to_vec()).collect()), None);
}
pub
fn add_noop(&mut self) -> HaProxRes<()>
{
return self.add_tlv(PP2Tlvs::TypeNoop, None);
}
pub
fn add_netns(&mut self, ns: impl Into<String>) -> HaProxRes<()>
{
return self.add_tlv(PP2Tlvs::TypeNetNs(ns.into()), None);
}
pub
fn add_crc32(&mut self) -> HaProxRes<()>
{
return self.add_tlv(PP2Tlvs::TypeCrc32c(0), None);
}
pub
fn add_uniq_id<ID: PP2TlvUniqId>(&mut self, au: ID) -> HaProxRes<()>
{
let uniq_id = au.into_bytes();
return self.add_tlv(PP2Tlvs::TypeUniqId(uniq_id), None);
}
pub
fn add_authority(&mut self, authority: impl Into<String>) -> HaProxRes<()>
{
return self.add_tlv(PP2Tlvs::TypeAuthority(authority.into()), None);
}
pub
fn add_ssl(self, client: PP2TlvClient, verify: u32) -> HaProxRes<TlvSubTypeSsl<'s>>
{
return TlvSubTypeSsl::new(self, PP2Tlvs::TLV_TYPE_SSL_SUB_RANGE, client, verify);
}
pub
fn add_tlv<TLV: PP2TlvDump>(&mut self, tlv: TLV,
opt_constr: Option<&'static [std::ops::RangeInclusive<u8>]>) -> HaProxRes<()>
{
let tlv_id: u8 = tlv.get_type();
let constr = opt_constr.unwrap_or(self.constraints);
if constr.iter().any(|idr| idr.contains(&tlv_id)) == false
{
return_error!(ArgumentEinval, "TLV: {} is subtype of other type or type!", tlv);
}
if tlv_id == PP2Tlvs::TypeCrc32c(0).into()
{
if self.hdr.crc_tlv_offset != 0
{
return_error!(ArgumentEinval, "diplicate PLT CRC!");
}
self.hdr.crc_tlv_offset = self.hdr.buffer.position();
}
self.hdr.buffer.write_u8(tlv_id).map_err(common::map_io_err)?;
let size_pos = self.hdr.buffer.position();
self.hdr.buffer.write_u16::<BigEndian>(0).map_err(common::map_io_err)?;
tlv.dump(&mut self.hdr.buffer)?;
let cur_pos = self.hdr.buffer.position();
self.hdr.buffer.set_position(size_pos);
self.hdr.buffer.write_u16::<BigEndian>((cur_pos - size_pos - 2) as u16).map_err(common::map_io_err)?;
self.hdr.buffer.set_position(cur_pos);
return Ok(());
}
}
#[derive(Clone, Debug)]
pub struct ProxyHdrV2<OPC: ProxyV2OpCode>
{
buffer: Cursor<Vec<u8>>,
crc_tlv_offset: u64,
_p: PhantomData<OPC>,
}
impl<OPC: ProxyV2OpCode> ProxyHdrV2<OPC>
{
pub const HDR_MSG_LEN_OFFSET: u64 = offset_of!(protocol_raw::ProxyHdrV2, len) as u64;
}
impl ProxyHdrV2<HdrV2OpLocal>
{
pub
fn new() -> Vec<u8>
{
return protocol_raw::MSG_HEADER_LOCAL_V2.to_vec();
}
}
impl ProxyHdrV2<HdrV2OpProxy>
{
pub
fn new(transport: ProxyTransportFam, address: ProxyV2Addr) -> HaProxRes<Self>
{
let buf: Vec<u8> =
Vec::with_capacity(size_of::<protocol_raw::ProxyHdrV2>());
let mut cur = Cursor::new(buf);
cur.write_all(protocol_raw::HEADER_MAGIC_V2).map_err(common::map_io_err)?;
cur.write_u8(0x20 | HdrV2OpProxy::OPCODE).map_err(common::map_io_err)?;
cur.write_u8(((address.as_addr_family() as u8) << 4) | transport as u8).map_err(common::map_io_err)?;
cur.write_u16::<BigEndian>(address.get_len()).map_err(common::map_io_err)?;
address.write(&mut cur)?;
return Ok(
Self
{
buffer: cur,
crc_tlv_offset: 0,
_p: PhantomData
}
);
}
pub
fn set_plts<'s>(&'s mut self) -> TlType<'s>
{
return TlType::new(self, PP2Tlvs::TLV_TYPE_MAIN_RANGES);
}
fn finalize(&mut self) -> HaProxRes<()>
{
let last_off = self.buffer.position();
self.buffer.set_position(Self::HDR_MSG_LEN_OFFSET);
let tlv_len = last_off - size_of::<protocol_raw::ProxyHdrV2>() as u64;
self.buffer.write_u16::<BigEndian>(tlv_len as u16).map_err(common::map_io_err)?;
if self.crc_tlv_offset > 0
{
let mut hasher = Hasher::new();
hasher.update(self.buffer.get_ref());
self.buffer.set_position(self.crc_tlv_offset + protocol::TLV_HEADER_LEN as u64);
self.buffer.write_u32::<BigEndian>(hasher.finalize()).map_err(common::map_io_err)?;
}
return Ok(());
}
}
impl TryFrom<ProxyHdrV2<HdrV2OpProxy>> for Vec<u8>
{
type Error = HaProxErr;
fn try_from(mut value: ProxyHdrV2<HdrV2OpProxy>) -> Result<Self, Self::Error>
{
value.finalize()?;
return Ok(value.buffer.into_inner());
}
}
impl PP2TlvDump for PP2Tlvs
{
fn get_type(&self) -> u8
{
return self.into();
}
fn dump(&self, cur: &mut Cursor<Vec<u8>>) -> HaProxRes<()>
{
match self
{
PP2Tlvs::TypeAlpn(items) =>
{
for alpn in items.iter() {
cur.write_u16::<BigEndian>(alpn.len() as u16).map_err(common::map_io_err)?;
cur.write_all(alpn).map_err(common::map_io_err)?;
}
},
PP2Tlvs::TypeAuthority(auth) =>
{
cur.write_all(auth.as_bytes()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeCrc32c(crc) =>
{
cur.write_u32::<BigEndian>(*crc).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeNoop =>
{
},
PP2Tlvs::TypeUniqId(items) =>
{
cur.write_all(items.as_slice()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeSsl{client, verify} =>
{
cur.write_u8(client.bits()).map_err(common::map_io_err)?;
cur.write_u32::<BigEndian>(*verify).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeSubtypeSslVersion(v) =>
{
cur.write_all(v.as_bytes()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeSubtypeSslCn(cn) =>
{
cur.write_all(cn.as_bytes()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeSubtypeSslCipher(c) =>
{
cur.write_all(c.as_bytes()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeSubtypeSslSigAlg(sa) =>
{
cur.write_all(sa.as_bytes()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeSubtypeSslKeyAlg(ka) =>
{
cur.write_all(ka.as_bytes()).map_err(common::map_io_err)?;
},
PP2Tlvs::TypeNetNs(ns) =>
{
cur.write_all(ns.as_bytes()).map_err(common::map_io_err)?;
},
}
return Ok(());
}
}
#[cfg(test)]
mod tests
{
use std::{fmt, io::Cursor};
use byteorder::{BigEndian, WriteBytesExt};
use crate::{common::map_io_err, protocol::{protocol::{PP2TlvClient, PP2Tlvs}, PP2TlvDump}, HaProxRes, ProxyTransportFam, ProxyV2Addr};
use super::{HdrV2OpProxy, ProxyHdrV2};
#[test]
fn test_comp0()
{
let addr = ProxyV2Addr::try_from(("127.0.0.1:39754", "127.0.0.67:11883")).unwrap();
let mut comp =
ProxyHdrV2::<HdrV2OpProxy>::new(ProxyTransportFam::STREAM, addr).unwrap();
let plts = comp.set_plts();
let mut ssl = plts.add_ssl(PP2TlvClient::PP2_CLIENT_SSL, 0).unwrap();
ssl.add_ssl_sub_version("TLSv1.2").unwrap();
ssl.done().unwrap();
let pkt: Vec<u8> = comp.try_into().unwrap();
let ctrl =
b"\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a\x21\x11\x00\x1e\
\x7f\x00\x00\x01\x7f\x00\x00\x43\x9b\x4a\x2e\x6b\x20\x00\x0f\x01\
\x00\x00\x00\x00\x21\x00\x07\x54\x4c\x53\x76\x31\x2e\x32";
assert_eq!(pkt.as_slice(), ctrl.as_slice());
}
#[test]
fn test_comp1()
{
#[derive(Clone, Debug)]
pub enum ProxyV2Dummy2
{
SomeTlvName(u32, u32),
}
impl fmt::Display for ProxyV2Dummy2
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "DUMMY external reader")
}
}
impl PP2TlvDump for ProxyV2Dummy2
{
fn get_type(&self) -> u8
{
let Self::SomeTlvName(..) = self else { panic!("wrong") };
return 0xE0;
}
fn dump(&self, cur: &mut Cursor<Vec<u8>>) -> HaProxRes<()>
{
match self
{
Self::SomeTlvName(arg0, arg1) =>
{
cur.write_u32::<BigEndian>(*arg0).map_err(map_io_err)?;
cur.write_u32::<BigEndian>(*arg1).map_err(map_io_err)?;
}
}
return Ok(());
}
}
let addr = ProxyV2Addr::try_from(("127.0.0.1:39754", "127.0.0.67:11883")).unwrap();
let mut comp =
ProxyHdrV2::<HdrV2OpProxy>::new(ProxyTransportFam::STREAM, addr).unwrap();
let plts = comp.set_plts();
let mut ssl = plts.add_ssl(PP2TlvClient::PP2_CLIENT_SSL, 0).unwrap();
ssl.add_ssl_sub_version("TLSv1.2").unwrap();
let mut plts = ssl.done().unwrap();
let cust_plt = ProxyV2Dummy2::SomeTlvName(0x01020304, 0x05060708);
plts.add_tlv(cust_plt, Some(&[0xE0..=0xE0])).unwrap();
drop(plts);
let pkt: Vec<u8> = comp.try_into().unwrap();
let ctrl =
b"\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a\x21\x11\x00\x29\
\x7f\x00\x00\x01\x7f\x00\x00\x43\x9b\x4a\x2e\x6b\x20\x00\x0f\x01\
\x00\x00\x00\x00\x21\x00\x07\x54\x4c\x53\x76\x31\x2e\x32\xE0\x00\
\x08\x01\x02\x03\x04\x05\x06\x07\x08";
assert_eq!(pkt.as_slice(), ctrl.as_slice());
}
#[test]
fn test_alpns()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeAlpn(vec![b"test".as_slice().to_vec()]);
alpn.dump(&mut cur).unwrap();
let reference = b"\x00\x04test";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_authority()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeAuthority("tset".into());
alpn.dump(&mut cur).unwrap();
let reference = b"tset";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_crc32()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeCrc32c(0xABCDEF01);
alpn.dump(&mut cur).unwrap();
let reference = b"\xab\xcd\xef\x01";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_netns()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeNetNs("tstt".into());
alpn.dump(&mut cur).unwrap();
let reference = b"tstt";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_noop()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeNoop;
alpn.dump(&mut cur).unwrap();
let reference = b"";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_ssl()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeSsl{ client: PP2TlvClient::PP2_CLIENT_SSL, verify: 0x00003210};
alpn.dump(&mut cur).unwrap();
let reference = b"\x01\x00\x00\x32\x10";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_uniqid()
{
const ID: &'static [u8] = b"ABCD12345678901234567890";
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let alpn = PP2Tlvs::TypeUniqId(ID.into());
alpn.dump(&mut cur).unwrap();
let reference = ID;
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference);
}
#[test]
fn test_ssl_cipher()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let ciph = PP2Tlvs::TypeSubtypeSslCipher("ECDHE-RSA-AES128-GCM-SHA256".into());
ciph.dump(&mut cur).unwrap();
let reference = b"ECDHE-RSA-AES128-GCM-SHA256";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_ssl_cn()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let cn = PP2Tlvs::TypeSubtypeSslCn("example.com".into());
cn.dump(&mut cur).unwrap();
let reference = b"example.com";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_ssl_keyalg()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let keyalg = PP2Tlvs::TypeSubtypeSslKeyAlg("RSA2048".into());
keyalg.dump(&mut cur).unwrap();
let reference = b"RSA2048";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_ssl_sigalg()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let sigalg = PP2Tlvs::TypeSubtypeSslSigAlg("SHA256".into());
sigalg.dump(&mut cur).unwrap();
let reference = b"SHA256";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
#[test]
fn test_ssl_version()
{
let mut cur = Cursor::new(Vec::<u8>::with_capacity(64));
let vers = PP2Tlvs::TypeSubtypeSslVersion("TLSv1_3".into());
vers.dump(&mut cur).unwrap();
let reference = b"TLSv1_3";
let generated = cur.into_inner();
assert_eq!(generated.as_slice(), reference.as_slice());
}
}