use super::websocket::{WebSocketTransaction, ALPROTO_WEBSOCKET};
use crate::core::{STREAM_TOCLIENT, STREAM_TOSERVER};
use crate::detect::uint::{
detect_parse_uint, detect_parse_uint_enum, DetectUintData, DetectUintMode, SCDetectU32Free,
SCDetectU32Match, SCDetectU32Parse, SCDetectU8Free, SCDetectU8Match,
};
use crate::detect::{helper_keyword_register_sticky_buffer, SigTableElmtStickyBuffer};
use crate::websocket::parser::WebSocketOpcode;
use suricata_sys::sys::{
DetectEngineCtx, DetectEngineThreadCtx, Flow, SCDetectBufferSetActiveList,
SCDetectHelperBufferMpmRegister, SCDetectHelperBufferRegister, SCDetectHelperKeywordRegister,
SCDetectSignatureSetAppProto, SCSigMatchAppendSMToList, SCSigTableAppLiteElmt, SigMatchCtx,
Signature,
};
use nom7::branch::alt;
use nom7::bytes::complete::{is_a, tag};
use nom7::combinator::{opt, value};
use nom7::multi::many1;
use nom7::IResult;
use std::ffi::CStr;
use std::os::raw::{c_int, c_void};
unsafe extern "C" fn websocket_parse_opcode(
ustr: *const std::os::raw::c_char,
) -> *mut DetectUintData<u8> {
let ft_name: &CStr = CStr::from_ptr(ustr); if let Ok(s) = ft_name.to_str() {
if let Some(ctx) = detect_parse_uint_enum::<u8, WebSocketOpcode>(s) {
let boxed = Box::new(ctx);
return Box::into_raw(boxed) as *mut _;
}
}
return std::ptr::null_mut();
}
struct WebSocketFlag {
neg: bool,
value: u8,
}
fn parse_flag_list_item(s: &str) -> IResult<&str, WebSocketFlag> {
let (s, _) = opt(is_a(" "))(s)?;
let (s, neg) = opt(tag("!"))(s)?;
let neg = neg.is_some();
let (s, value) = alt((value(0x80, tag("fin")), value(0x40, tag("comp"))))(s)?;
let (s, _) = opt(is_a(" ,"))(s)?;
Ok((s, WebSocketFlag { neg, value }))
}
fn parse_flag_list(s: &str) -> IResult<&str, Vec<WebSocketFlag>> {
return many1(parse_flag_list_item)(s);
}
fn parse_flags(s: &str) -> Option<DetectUintData<u8>> {
if let Ok((_, ctx)) = detect_parse_uint::<u8>(s) {
return Some(ctx);
}
if let Ok((_, l)) = parse_flag_list(s) {
let mut arg1 = 0;
let mut arg2 = 0;
for elem in l.iter() {
if elem.value & arg1 != 0 {
SCLogWarning!("Repeated bitflag for websocket.flags");
return None;
}
arg1 |= elem.value;
if !elem.neg {
arg2 |= elem.value;
}
}
let ctx = DetectUintData::<u8> {
arg1,
arg2,
mode: DetectUintMode::DetectUintModeBitmask,
};
return Some(ctx);
}
return None;
}
unsafe extern "C" fn websocket_parse_flags(
ustr: *const std::os::raw::c_char,
) -> *mut DetectUintData<u8> {
let ft_name: &CStr = CStr::from_ptr(ustr); if let Ok(s) = ft_name.to_str() {
if let Some(ctx) = parse_flags(s) {
let boxed = Box::new(ctx);
return Box::into_raw(boxed) as *mut _;
}
}
return std::ptr::null_mut();
}
static mut G_WEBSOCKET_OPCODE_KW_ID: u16 = 0;
static mut G_WEBSOCKET_OPCODE_BUFFER_ID: c_int = 0;
static mut G_WEBSOCKET_MASK_KW_ID: u16 = 0;
static mut G_WEBSOCKET_MASK_BUFFER_ID: c_int = 0;
static mut G_WEBSOCKET_FLAGS_KW_ID: u16 = 0;
static mut G_WEBSOCKET_FLAGS_BUFFER_ID: c_int = 0;
static mut G_WEBSOCKET_PAYLOAD_BUFFER_ID: c_int = 0;
unsafe extern "C" fn websocket_detect_opcode_setup(
de: *mut DetectEngineCtx, s: *mut Signature, raw: *const libc::c_char,
) -> c_int {
if SCDetectSignatureSetAppProto(s, ALPROTO_WEBSOCKET) != 0 {
return -1;
}
let ctx = websocket_parse_opcode(raw) as *mut c_void;
if ctx.is_null() {
return -1;
}
if SCSigMatchAppendSMToList(
de,
s,
G_WEBSOCKET_OPCODE_KW_ID,
ctx as *mut SigMatchCtx,
G_WEBSOCKET_OPCODE_BUFFER_ID,
)
.is_null()
{
websocket_detect_opcode_free(std::ptr::null_mut(), ctx);
return -1;
}
return 0;
}
unsafe extern "C" fn websocket_detect_opcode_match(
_de: *mut DetectEngineThreadCtx, _f: *mut Flow, _flags: u8, _state: *mut c_void,
tx: *mut c_void, _sig: *const Signature, ctx: *const SigMatchCtx,
) -> c_int {
let tx = cast_pointer!(tx, WebSocketTransaction);
let ctx = cast_pointer!(ctx, DetectUintData<u8>);
return SCDetectU8Match(tx.pdu.opcode, ctx);
}
unsafe extern "C" fn websocket_detect_opcode_free(_de: *mut DetectEngineCtx, ctx: *mut c_void) {
let ctx = cast_pointer!(ctx, DetectUintData<u8>);
SCDetectU8Free(ctx);
}
unsafe extern "C" fn websocket_detect_mask_setup(
de: *mut DetectEngineCtx, s: *mut Signature, raw: *const libc::c_char,
) -> c_int {
if SCDetectSignatureSetAppProto(s, ALPROTO_WEBSOCKET) != 0 {
return -1;
}
let ctx = SCDetectU32Parse(raw) as *mut c_void;
if ctx.is_null() {
return -1;
}
if SCSigMatchAppendSMToList(
de,
s,
G_WEBSOCKET_MASK_KW_ID,
ctx as *mut SigMatchCtx,
G_WEBSOCKET_MASK_BUFFER_ID,
)
.is_null()
{
websocket_detect_mask_free(std::ptr::null_mut(), ctx);
return -1;
}
return 0;
}
unsafe extern "C" fn websocket_detect_mask_match(
_de: *mut DetectEngineThreadCtx, _f: *mut Flow, _flags: u8, _state: *mut c_void,
tx: *mut c_void, _sig: *const Signature, ctx: *const SigMatchCtx,
) -> c_int {
let tx = cast_pointer!(tx, WebSocketTransaction);
let ctx = cast_pointer!(ctx, DetectUintData<u32>);
if let Some(xorkey) = tx.pdu.mask {
return SCDetectU32Match(xorkey, ctx);
}
return 0;
}
unsafe extern "C" fn websocket_detect_mask_free(_de: *mut DetectEngineCtx, ctx: *mut c_void) {
let ctx = cast_pointer!(ctx, DetectUintData<u32>);
SCDetectU32Free(ctx);
}
unsafe extern "C" fn websocket_detect_flags_setup(
de: *mut DetectEngineCtx, s: *mut Signature, raw: *const libc::c_char,
) -> c_int {
if SCDetectSignatureSetAppProto(s, ALPROTO_WEBSOCKET) != 0 {
return -1;
}
let ctx = websocket_parse_flags(raw) as *mut c_void;
if ctx.is_null() {
return -1;
}
if SCSigMatchAppendSMToList(
de,
s,
G_WEBSOCKET_FLAGS_KW_ID,
ctx as *mut SigMatchCtx,
G_WEBSOCKET_FLAGS_BUFFER_ID,
)
.is_null()
{
websocket_detect_flags_free(std::ptr::null_mut(), ctx);
return -1;
}
return 0;
}
unsafe extern "C" fn websocket_detect_flags_match(
_de: *mut DetectEngineThreadCtx, _f: *mut Flow, _flags: u8, _state: *mut c_void,
tx: *mut c_void, _sig: *const Signature, ctx: *const SigMatchCtx,
) -> c_int {
let tx = cast_pointer!(tx, WebSocketTransaction);
let ctx = cast_pointer!(ctx, DetectUintData<u8>);
return SCDetectU8Match(tx.pdu.flags, ctx);
}
unsafe extern "C" fn websocket_detect_flags_free(_de: *mut DetectEngineCtx, ctx: *mut c_void) {
let ctx = cast_pointer!(ctx, DetectUintData<u8>);
SCDetectU8Free(ctx);
}
pub unsafe extern "C" fn websocket_detect_payload_setup(
de: *mut DetectEngineCtx, s: *mut Signature, _raw: *const std::os::raw::c_char,
) -> c_int {
if SCDetectSignatureSetAppProto(s, ALPROTO_WEBSOCKET) != 0 {
return -1;
}
if SCDetectBufferSetActiveList(de, s, G_WEBSOCKET_PAYLOAD_BUFFER_ID) < 0 {
return -1;
}
return 0;
}
pub unsafe extern "C" fn websocket_detect_payload_get_data(
tx: *const c_void, _flow_flags: u8, buffer: *mut *const u8, buffer_len: *mut u32,
) -> bool {
let tx = cast_pointer!(tx, WebSocketTransaction);
*buffer = tx.pdu.payload.as_ptr();
*buffer_len = tx.pdu.payload.len() as u32;
return true;
}
#[no_mangle]
pub unsafe extern "C" fn SCDetectWebsocketRegister() {
let kw = SCSigTableAppLiteElmt {
name: b"websocket.opcode\0".as_ptr() as *const libc::c_char,
desc: b"match WebSocket opcode\0".as_ptr() as *const libc::c_char,
url: b"/rules/websocket-keywords.html#websocket-opcode\0".as_ptr() as *const libc::c_char,
AppLayerTxMatch: Some(websocket_detect_opcode_match),
Setup: Some(websocket_detect_opcode_setup),
Free: Some(websocket_detect_opcode_free),
flags: 0,
};
G_WEBSOCKET_OPCODE_KW_ID = SCDetectHelperKeywordRegister(&kw);
G_WEBSOCKET_OPCODE_BUFFER_ID = SCDetectHelperBufferRegister(
b"websocket.opcode\0".as_ptr() as *const libc::c_char,
ALPROTO_WEBSOCKET,
STREAM_TOSERVER | STREAM_TOCLIENT,
);
let kw = SCSigTableAppLiteElmt {
name: b"websocket.mask\0".as_ptr() as *const libc::c_char,
desc: b"match WebSocket mask\0".as_ptr() as *const libc::c_char,
url: b"/rules/websocket-keywords.html#websocket-mask\0".as_ptr() as *const libc::c_char,
AppLayerTxMatch: Some(websocket_detect_mask_match),
Setup: Some(websocket_detect_mask_setup),
Free: Some(websocket_detect_mask_free),
flags: 0,
};
G_WEBSOCKET_MASK_KW_ID = SCDetectHelperKeywordRegister(&kw);
G_WEBSOCKET_MASK_BUFFER_ID = SCDetectHelperBufferRegister(
b"websocket.mask\0".as_ptr() as *const libc::c_char,
ALPROTO_WEBSOCKET,
STREAM_TOSERVER | STREAM_TOCLIENT,
);
let kw = SCSigTableAppLiteElmt {
name: b"websocket.flags\0".as_ptr() as *const libc::c_char,
desc: b"match WebSocket flags\0".as_ptr() as *const libc::c_char,
url: b"/rules/websocket-keywords.html#websocket-flags\0".as_ptr() as *const libc::c_char,
AppLayerTxMatch: Some(websocket_detect_flags_match),
Setup: Some(websocket_detect_flags_setup),
Free: Some(websocket_detect_flags_free),
flags: 0,
};
G_WEBSOCKET_FLAGS_KW_ID = SCDetectHelperKeywordRegister(&kw);
G_WEBSOCKET_FLAGS_BUFFER_ID = SCDetectHelperBufferRegister(
b"websocket.flags\0".as_ptr() as *const libc::c_char,
ALPROTO_WEBSOCKET,
STREAM_TOSERVER | STREAM_TOCLIENT,
);
let kw = SigTableElmtStickyBuffer {
name: String::from("websocket.payload"),
desc: String::from("match WebSocket payload"),
url: String::from("/rules/websocket-keywords.html#websocket-payload"),
setup: websocket_detect_payload_setup,
};
let _g_ws_payload_kw_id = helper_keyword_register_sticky_buffer(&kw);
G_WEBSOCKET_PAYLOAD_BUFFER_ID = SCDetectHelperBufferMpmRegister(
b"websocket.payload\0".as_ptr() as *const libc::c_char,
b"WebSocket payload\0".as_ptr() as *const libc::c_char,
ALPROTO_WEBSOCKET,
STREAM_TOSERVER | STREAM_TOCLIENT,
Some(websocket_detect_payload_get_data),
);
}