use std::future::Future;
use std::net::SocketAddr;
use std::str::from_utf8;
use std::sync::{Arc, Weak};
use async_trait::async_trait;
use cfg_if::cfg_if;
use log::{debug, warn};
use crate::bytes::{ByteBuffer, DynamicByteBuffer, StaticByteBuffer};
pub use crate::flow::decoy::{DecoyFactory, decoy_factory, random_decoy_factory};
pub use crate::flow::probe::{ActiveProbeHandler, ProbeFactory, ProbeFlowSender, probe_factory};
use crate::settings::Settings;
use crate::settings::consts::DEFAULT_TYPHOON_ID_LENGTH;
pub use crate::tailer::{ClientConnectionHandler, ServerConnectionHandler};
use crate::tailer::{IdentityType, Tailer};
use crate::utils::random::{SupportRng, get_rng};
pub use crate::utils::sync::AsyncExecutor;
cfg_if! {
if #[cfg(feature = "tokio")] {
use tokio::spawn;
use tokio::runtime::Handle;
use tokio::task::block_in_place;
} else if #[cfg(feature = "async-std")] {
use async_io::block_on as async_io_block_on;
}
}
fn parse_version(bytes: &[u8]) -> (u64, u64, u64) {
let end = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
let s = from_utf8(&bytes[..end]).unwrap_or("").trim();
let base = s.split('-').next().unwrap_or(s);
let mut parts = base.split('.');
let major = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
let minor = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
let patch = parts.next().and_then(|s| s.parse().ok()).unwrap_or(0);
(major, minor, patch)
}
impl IdentityType for StaticByteBuffer {
fn from_bytes(bytes: &[u8]) -> Self {
assert_eq!(bytes.len(), DEFAULT_TYPHOON_ID_LENGTH, "invalid bytes identity length: expected {}, got {}", DEFAULT_TYPHOON_ID_LENGTH, bytes.len());
Self::from_slice(bytes)
}
fn to_bytes(&self) -> &[u8] {
self.slice()
}
fn length() -> usize {
DEFAULT_TYPHOON_ID_LENGTH
}
}
#[cfg(feature = "tokio")]
#[derive(Clone)]
pub struct TokioExecutor;
#[cfg(feature = "tokio")]
impl AsyncExecutor for TokioExecutor {
fn new() -> Self {
Self
}
fn spawn<F: Future<Output = ()> + Send + 'static>(&self, future: F) {
spawn(future);
}
fn block_on<F: Future<Output = ()>>(&self, future: F) {
block_in_place(|| Handle::current().block_on(future));
}
}
#[cfg(feature = "async-std")]
#[derive(Clone)]
pub struct AsyncStdExecutor {
executor: Arc<async_executor::Executor<'static>>,
}
#[cfg(feature = "async-std")]
impl AsyncExecutor for AsyncStdExecutor {
fn new() -> Self {
Self {
executor: Arc::new(async_executor::Executor::new()),
}
}
fn spawn<F: Future<Output = ()> + Send + 'static>(&self, future: F) {
self.executor.spawn(future).detach();
}
fn block_on<F: Future<Output = ()>>(&self, future: F) {
async_io_block_on(future);
}
}
#[cfg(feature = "async-std")]
impl From<Arc<async_executor::Executor<'static>>> for AsyncStdExecutor {
fn from(executor: Arc<async_executor::Executor<'static>>) -> Self {
Self {
executor,
}
}
}
cfg_if! {
if #[cfg(feature = "tokio")] {
pub type DefaultExecutor = TokioExecutor;
} else if #[cfg(feature = "async-std")] {
pub type DefaultExecutor = AsyncStdExecutor;
}
}
pub type DefaultSettings = Settings<DefaultExecutor>;
pub type DefaultTailer = Tailer<StaticByteBuffer>;
pub struct DefaultServerConnectionHandler;
impl ServerConnectionHandler<StaticByteBuffer> for DefaultServerConnectionHandler {
fn generate(&self, _initial_data: &[u8]) -> StaticByteBuffer {
StaticByteBuffer::from_slice(get_rng().random_byte_buffer::<DEFAULT_TYPHOON_ID_LENGTH>().slice())
}
fn initial_data(&self, _identity: &StaticByteBuffer) -> StaticByteBuffer {
StaticByteBuffer::from_slice(&[])
}
fn verify_version(&self, version_bytes: &[u8]) -> bool {
let (cli_major, cli_minor, cli_patch) = parse_version(version_bytes);
let (srv_major, srv_minor, srv_patch) = parse_version(env!("CARGO_PKG_VERSION").as_bytes());
if cli_major != srv_major {
warn!("client version major mismatch (client={cli_major}.{cli_minor}.{cli_patch}, server={srv_major}.{srv_minor}.{srv_patch}), rejecting handshake");
false
} else if cli_minor != srv_minor {
warn!("client version minor mismatch (client={cli_major}.{cli_minor}.{cli_patch}, server={srv_major}.{srv_minor}.{srv_patch})");
true
} else if cli_patch != srv_patch {
debug!("client version patch mismatch (client={cli_major}.{cli_minor}.{cli_patch}, server={srv_major}.{srv_minor}.{srv_patch})");
true
} else {
true
}
}
}
#[derive(Default)]
pub struct NoopProbeHandler;
#[async_trait]
impl<AE: AsyncExecutor + 'static> ActiveProbeHandler<AE> for NoopProbeHandler {
async fn start(&mut self, _: Weak<dyn ProbeFlowSender>, _: Arc<Settings<AE>>) {}
async fn process(&mut self, _: DynamicByteBuffer, _: Option<SocketAddr>) {}
}
pub struct DefaultClientConnectionHandler;
impl ClientConnectionHandler for DefaultClientConnectionHandler {
fn initial_data(&self) -> StaticByteBuffer {
StaticByteBuffer::from_slice(&[])
}
fn version(&self, length: usize) -> StaticByteBuffer {
let ver = env!("CARGO_PKG_VERSION").as_bytes();
let copy_len = ver.len().min(length);
let mut buf = vec![0u8; length];
buf[..copy_len].copy_from_slice(&ver[..copy_len]);
StaticByteBuffer::from(buf)
}
}