use crate::core::action::ProtocolAction;
use crate::core::padding::PaddingFactory;
use crate::core::state::State;
use crate::core::string_map::{StringMap, StringMapExt};
use crate::core::{Command, Frame};
use bytes::Bytes;
use std::sync::Arc;
use std::time::Duration;
pub struct Engine;
impl Engine {
pub fn on_session_start(state: &Arc<State>, is_client: bool, client_name: &str) -> std::io::Result<Vec<ProtocolAction>> {
if !is_client {
return Ok(Vec::new());
}
let mut settings = StringMap::new();
settings.insert("v".to_string(), "2".to_string());
settings.insert("client".to_string(), client_name.to_string());
settings.insert("padding-md5".to_string(), state.padding().md5().to_string());
Ok(vec![ProtocolAction::SendFrame(Frame::with_data(
Command::Settings,
0,
settings.to_bytes().into(),
))])
}
pub fn on_frame(state: &Arc<State>, is_client: bool, frame: &Frame) -> std::io::Result<Vec<ProtocolAction>> {
let mut actions = Vec::new();
match frame.cmd {
Command::Waste | Command::HeartResponse => {}
Command::Psh if !frame.data.is_empty() => {
actions.push(ProtocolAction::PushStreamData {
sid: frame.sid,
data: frame.data.clone(),
});
}
Command::Syn if !is_client => {
if !state.received_settings_from_client() {
actions.push(ProtocolAction::AlertAndFail {
message: "client did not send its settings".to_string(),
});
} else {
actions.push(ProtocolAction::EnsureIncomingStream { sid: frame.sid });
}
}
Command::Fin => {
actions.push(ProtocolAction::CloseLocalStream { sid: frame.sid });
}
Command::Settings if !is_client && !frame.data.is_empty() => {
let settings = StringMap::from_bytes(frame.data.as_ref());
state.mark_received_settings_from_client();
let padding = state.padding();
if settings.get("padding-md5").map(String::as_str) != Some(padding.md5()) {
actions.push(ProtocolAction::SendFrameSync(Frame::with_data(
Command::UpdatePaddingScheme,
0,
Bytes::copy_from_slice(padding.raw_scheme()),
)));
}
if let Some(version) = settings.get("v").and_then(|value| value.parse::<u8>().ok())
&& version >= 2
{
state.set_peer_version(version);
let mut server_settings = StringMap::new();
server_settings.insert("v".to_string(), "2".to_string());
actions.push(ProtocolAction::SendFrameSync(Frame::with_data(
Command::ServerSettings,
0,
server_settings.to_bytes().into(),
)));
}
}
Command::UpdatePaddingScheme if !frame.data.is_empty() && is_client => {
if let Some(factory) = PaddingFactory::new(frame.data.as_ref()) {
state.set_padding(factory);
}
}
Command::HeartRequest => {
actions.push(ProtocolAction::SendFrame(Frame::new(Command::HeartResponse, frame.sid)));
}
Command::ServerSettings if !frame.data.is_empty() && is_client => {
let settings = StringMap::from_bytes(frame.data.as_ref());
if let Some(version) = settings.get("v").and_then(|value| value.parse::<u8>().ok()) {
state.set_peer_version(version);
}
}
Command::SynAck => {
actions.push(ProtocolAction::CancelSynAckTimeout { sid: frame.sid });
if !frame.data.is_empty() {
actions.push(ProtocolAction::CloseRemoteStream {
sid: frame.sid,
message: String::from_utf8_lossy(frame.data.as_ref()).to_string(),
});
}
}
_ => log::warn!(
"Received unexpected frame: cmd={}, sid={}, data_len={}",
frame.cmd,
frame.sid,
frame.data.len()
),
}
Ok(actions)
}
pub fn on_open_stream(state: &Arc<State>, sid: u32) -> Vec<ProtocolAction> {
let mut actions = Vec::new();
if sid >= 2 && state.peer_version() >= 2 {
actions.push(ProtocolAction::ArmSynAckTimeout {
sid,
timeout: Duration::from_secs(3),
});
}
actions.push(ProtocolAction::SendFrameSync(Frame::new(Command::Syn, sid)));
actions.push(ProtocolAction::ReleaseWriteBuffering);
actions
}
}