1use serde::de::{MapAccess, Visitor};
2use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize};
3use std::collections::BTreeMap;
4use std::fmt::Debug;
5use std::marker::PhantomData;
6
7use super::task::{SiemTask, SiemTaskResult};
8use super::{
9 command_types::{
10 FilterDomain, FilterEmail, FilterIp, IsolateEndpoint, IsolateIp, LoggedOnUser, LoginUser,
11 ParserDefinition, RuleDefinition, TaskDefinition, UseCaseDefinition,
12 },
13 common::{DatasetDefinition, UserRole},
14};
15use crate::events::field::SiemField;
16use crate::prelude::types::LogString;
17
18#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
20#[allow(non_camel_case_types)]
21#[non_exhaustive]
22pub enum SiemFunctionType {
23 STOP_COMPONENT,
24 START_COMPONENT,
25 LOG_QUERY,
26 ISOLATE_IP,
27 ISOLATE_ENDPOINT,
28 FILTER_IP,
29 FILTER_DOMAIN,
30 FILTER_EMAIL_SENDER,
31 LIST_USE_CASES,
32 GET_USE_CASE,
33 LIST_RULES,
34 GET_RULE,
35 LIST_TASKS,
36 LIST_DATASETS,
37 DOWNLOAD_QUERY,
38 LOGIN_USER,
39 LIST_PARSERS,
40 START_TASK,
41 GET_TASK_RESULT,
42 OTHER(LogString),
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone)]
47pub struct CommandDefinition {
48 class: SiemFunctionType,
49 name: LogString,
50 description: LogString,
51 min_permission: UserRole,
52}
53impl CommandDefinition {
54 pub fn new(
55 class: SiemFunctionType,
56 name: LogString,
57 description: LogString,
58 min_permission: UserRole,
59 ) -> CommandDefinition {
60 CommandDefinition {
61 class,
62 name,
63 description,
64 min_permission,
65 }
66 }
67
68 pub fn class(&self) -> &SiemFunctionType {
69 &self.class
70 }
71 pub fn name(&self) -> &LogString {
72 &self.name
73 }
74 pub fn description(&self) -> &LogString {
75 &self.description
76 }
77 pub fn min_permission(&self) -> &UserRole {
78 &self.min_permission
79 }
80}
81
82#[derive(Serialize, Deserialize, Debug, Default)]
83pub struct SiemCommandHeader {
84 pub user: String,
86 pub comp_id: u64,
88 pub comm_id: u64,
95}
96
97impl SiemCommandHeader {
98 pub fn new<S: Into<String>>(user: S, component: u64, command: u64) -> Self {
99 Self {
100 user: user.into(),
101 comp_id: component,
102 comm_id: command,
103 }
104 }
105 pub fn for_user<S: Into<String>>(user: S) -> Self {
106 Self {
107 user: user.into(),
108 ..Default::default()
109 }
110 }
111}
112
113#[derive(Serialize, Deserialize, Debug, Clone)]
115#[allow(non_camel_case_types)]
116#[non_exhaustive]
117pub enum SiemCommandCall {
118 START_COMPONENT(String),
120 STOP_COMPONENT(String),
122 LOG_QUERY(QueryInfo),
124 ISOLATE_IP(IsolateIp),
126 ISOLATE_ENDPOINT(IsolateEndpoint),
128 FILTER_IP(FilterIp),
130 FILTER_DOMAIN(FilterDomain),
132 FILTER_EMAIL_SENDER(FilterEmail),
134 LIST_USE_CASES(Pagination),
136 GET_USE_CASE(String),
137 LIST_RULES(Pagination),
139 GET_RULE(String),
141 LIST_DATASETS(Pagination),
143 LIST_TASKS(Pagination),
145 DOWNLOAD_QUERY(),
146 LIST_PARSERS(Pagination),
147 LOGIN_USER(LoginUser),
148 START_TASK(SiemTask),
149 GET_TASK_RESULT(u64),
150 OTHER(LogString, BTreeMap<LogString, LogString>),
152}
153
154impl SiemCommandCall {
155 pub fn get_type(&self) -> SiemFunctionType {
156 match self {
157 SiemCommandCall::START_COMPONENT(_) => SiemFunctionType::START_COMPONENT,
158 SiemCommandCall::STOP_COMPONENT(_) => SiemFunctionType::STOP_COMPONENT,
159 SiemCommandCall::LOG_QUERY(_) => SiemFunctionType::LOG_QUERY,
160 SiemCommandCall::ISOLATE_IP(_) => SiemFunctionType::ISOLATE_IP,
161 SiemCommandCall::ISOLATE_ENDPOINT(_) => SiemFunctionType::ISOLATE_ENDPOINT,
162 SiemCommandCall::FILTER_IP(_) => SiemFunctionType::FILTER_IP,
163 SiemCommandCall::FILTER_DOMAIN(_) => SiemFunctionType::FILTER_DOMAIN,
164 SiemCommandCall::FILTER_EMAIL_SENDER(_) => SiemFunctionType::FILTER_EMAIL_SENDER,
165 SiemCommandCall::LIST_USE_CASES(_) => SiemFunctionType::LIST_USE_CASES,
166 SiemCommandCall::GET_USE_CASE(_) => SiemFunctionType::GET_USE_CASE,
167 SiemCommandCall::LIST_RULES(_) => SiemFunctionType::LIST_RULES,
168 SiemCommandCall::GET_RULE(_) => SiemFunctionType::GET_RULE,
169 SiemCommandCall::LIST_DATASETS(_) => SiemFunctionType::LIST_DATASETS,
170 SiemCommandCall::LIST_TASKS(_) => SiemFunctionType::LIST_TASKS,
171 SiemCommandCall::DOWNLOAD_QUERY() => SiemFunctionType::DOWNLOAD_QUERY,
172 SiemCommandCall::LIST_PARSERS(_) => SiemFunctionType::LIST_PARSERS,
173 SiemCommandCall::LOGIN_USER(_) => SiemFunctionType::LOGIN_USER,
174 SiemCommandCall::START_TASK(_) => SiemFunctionType::START_TASK,
175 SiemCommandCall::GET_TASK_RESULT(_) => SiemFunctionType::GET_TASK_RESULT,
176 SiemCommandCall::OTHER(v, _) => SiemFunctionType::OTHER(v.clone()),
177 }
178 }
179}
180
181#[derive(Serialize, Deserialize, Debug, Clone)]
182pub struct Pagination {
183 pub offset: u32,
184 pub limit: u32,
185}
186
187#[derive(Serialize, Deserialize, Debug, Clone)]
188#[non_exhaustive]
189pub enum CommandError {
190 BadParameters(LogString),
191 SyntaxError(LogString),
192 NotFound(LogString),
193}
194
195#[derive(Serialize, Deserialize, Debug, Clone)]
197#[allow(non_camel_case_types)]
198#[non_exhaustive]
199pub enum SiemCommandResponse {
200 START_COMPONENT(CommandResult<String>),
201 STOP_COMPONENT(CommandResult<String>),
202 LOG_QUERY(QueryInfo, CommandResult<Vec<BTreeMap<String, SiemField>>>),
204 ISOLATE_IP(CommandResult<String>),
205 ISOLATE_ENDPOINT(CommandResult<String>),
206 FILTER_IP(CommandResult<String>),
208 FILTER_DOMAIN(CommandResult<String>),
210 FILTER_EMAIL_SENDER(CommandResult<String>),
212 LIST_USE_CASES(CommandResult<Vec<UseCaseDefinition>>),
214 GET_USE_CASE(CommandResult<UseCaseDefinition>),
215 LIST_RULES(CommandResult<Vec<RuleDefinition>>),
216 GET_RULE(CommandResult<RuleDefinition>),
217 LIST_DATASETS(CommandResult<Vec<DatasetDefinition>>),
218 LIST_TASKS(CommandResult<Vec<TaskDefinition>>),
219 LIST_PARSERS(CommandResult<Vec<ParserDefinition>>),
220 LOGIN_USER(CommandResult<LoggedOnUser>),
221 START_TASK(CommandResult<u64>),
222 GET_TASK_RESULT(CommandResult<SiemTaskResult>),
223 OTHER(LogString, CommandResult<BTreeMap<LogString, LogString>>),
224 }
226
227#[derive(Serialize, Debug, Clone)]
228pub enum CommandResult<T>
229where
230 T: Serialize + DeserializeOwned + std::fmt::Debug + Clone,
231{
232 #[serde(rename = "ok")]
233 Ok(T),
234 #[serde(rename = "err")]
235 Err(CommandError),
236}
237
238impl<'de, T: Serialize + Clone + Debug + ?Sized + DeserializeOwned> Deserialize<'de>
239 for CommandResult<T>
240{
241 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
242 where
243 D: Deserializer<'de>,
244 {
245 deserializer.deserialize_map(CommandResultVisitor::new())
246 }
247}
248
249struct CommandResultVisitor<T> {
250 marker: PhantomData<fn() -> T>,
251}
252
253impl<T> CommandResultVisitor<T> {
254 fn new() -> Self {
255 CommandResultVisitor {
256 marker: PhantomData,
257 }
258 }
259}
260
261impl<'de, T> Visitor<'de> for CommandResultVisitor<T>
262where
263 T: DeserializeOwned + Debug + Serialize + Clone,
264{
265 type Value = CommandResult<T>;
267
268 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
272 where
273 M: MapAccess<'de>,
274 {
275 if let Some(key) = access.next_key::<&str>()? {
278 if key == "ok" {
279 let val: T = access.next_value()?;
280 Ok(CommandResult::Ok(val))
281 } else if key == "err" {
282 let val: CommandError = access.next_value()?;
283 Ok(CommandResult::Err(val))
284 } else {
285 Err(serde::de::Error::missing_field("No ok/err field available"))
286 }
287 } else {
288 Err(serde::de::Error::missing_field("No ok/err field available"))
289 }
290 }
291
292 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
293 write!(formatter, "A valid command result")
294 }
295}
296#[derive(Serialize, Deserialize, Debug, Clone, Default)]
298pub struct QueryInfo {
299 pub user: String,
301 pub is_native: bool,
303 pub query_id: Option<String>,
305 pub from: i64,
307 pub to: i64,
309 pub limit: usize,
311 pub offset: usize,
313 pub ttl: i64,
315 pub query: String,
317 pub fields: Vec<String>,
319}
320
321impl QueryInfo {
322 pub fn new<S: Into<String>>(query: S) -> Self {
323 Self {
324 query: query.into(),
325 is_native: true,
326 from: 0,
327 to: i64::MAX,
328 limit: 128,
329 offset: 0,
330 ttl: 3_600_000,
331 ..Default::default()
332 }
333 }
334}
335
336#[cfg(test)]
337mod de_ser {
338
339 use crate::prelude::{types::LogString, DatasetDefinition};
340
341 use super::SiemCommandResponse;
342
343 #[test]
344 fn should_serialize_and_deserialize_command_response() {
345 let res =
346 SiemCommandResponse::FILTER_IP(super::CommandResult::Ok(format!("Ip was filtered")));
347 let str = serde_json::to_string(&res).unwrap();
348 let res2: SiemCommandResponse = serde_json::from_str(&str).unwrap();
349
350 match (res, res2) {
351 (SiemCommandResponse::FILTER_IP(ip1), SiemCommandResponse::FILTER_IP(ip2)) => {
352 match (ip1, ip2) {
353 (super::CommandResult::Ok(v1), super::CommandResult::Ok(v2)) => {
354 assert_eq!(v1, v2)
355 }
356 (super::CommandResult::Err(v1), super::CommandResult::Err(v2)) => {
357 match (v1, v2) {
358 (
359 super::CommandError::BadParameters(v1),
360 super::CommandError::BadParameters(v2),
361 ) => assert_eq!(v1, v2),
362 (
363 super::CommandError::SyntaxError(v1),
364 super::CommandError::SyntaxError(v2),
365 ) => assert_eq!(v1, v2),
366 (
367 super::CommandError::NotFound(v1),
368 super::CommandError::NotFound(v2),
369 ) => assert_eq!(v1, v2),
370 _ => panic!("Error must be the same"),
371 }
372 }
373 _ => panic!("Both responses must be the same"),
374 }
375 }
376 _ => panic!("Must not happen"),
377 }
378
379 let res = SiemCommandResponse::LIST_DATASETS(super::CommandResult::Ok(vec![
380 DatasetDefinition::new(
381 crate::prelude::SiemDatasetType::CustomIpMap(LogString::Borrowed("")),
382 LogString::Borrowed("Description"),
383 crate::prelude::UserRole::Administrator,
384 ),
385 ]));
386 let str = serde_json::to_string(&res).unwrap();
387 let res2: SiemCommandResponse = serde_json::from_str(&str).unwrap();
388
389 match (res, res2) {
390 (SiemCommandResponse::LIST_DATASETS(ip1), SiemCommandResponse::LIST_DATASETS(ip2)) => {
391 match (ip1, ip2) {
392 (super::CommandResult::Ok(v1), super::CommandResult::Ok(v2)) => {
393 assert_eq!(v1, v2)
394 }
395 (super::CommandResult::Err(v1), super::CommandResult::Err(v2)) => {
396 match (v1, v2) {
397 (
398 super::CommandError::BadParameters(v1),
399 super::CommandError::BadParameters(v2),
400 ) => assert_eq!(v1, v2),
401 (
402 super::CommandError::SyntaxError(v1),
403 super::CommandError::SyntaxError(v2),
404 ) => assert_eq!(v1, v2),
405 (
406 super::CommandError::NotFound(v1),
407 super::CommandError::NotFound(v2),
408 ) => assert_eq!(v1, v2),
409 _ => panic!("Error must be the same"),
410 }
411 }
412 _ => panic!("Both responses must be the same"),
413 }
414 }
415 _ => panic!("Must not happen"),
416 }
417 }
418}