ocsf_codegen/
events.rs

1use std::error::Error;
2
3use codegen::{Field, Function, Impl, Struct};
4use serde::{Deserialize, Deserializer, Serialize};
5use serde_json::Map;
6
7use crate::module::Module;
8use crate::*;
9
10#[derive(Clone, Debug, Default, Serialize, Deserialize)]
11pub struct EventDef {
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub uid: Option<u32>,
14    #[serde(skip_serializing_if = "Option::is_none")]
15    pub class_name: Option<String>,
16    pub name: String,
17    pub description: String,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub category: Option<String>,
20    pub attributes: HashMap<String, EventAttribute>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub associations: Option<Map<String, Value>>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub profiles: Option<Vec<String>>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    extends: Option<String>,
27}
28
29#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
30pub enum Group {
31    Classification,
32    Context,
33    Occurrence,
34    Primary,
35}
36
37impl<'de> Deserialize<'de> for Group {
38    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39    where
40        D: Deserializer<'de>,
41    {
42        let stringval = String::deserialize(deserializer)?.to_lowercase();
43        let result: Group = stringval.as_str().into();
44        Ok(result)
45    }
46}
47
48impl From<&str> for Group {
49    fn from(value: &str) -> Self {
50        match value {
51            "classification" => Self::Classification,
52            "context" => Self::Context,
53            "occurrence" => Self::Occurrence,
54            "primary" => Self::Primary,
55            _ => panic!("Invalid enum value {value} - select from classification, context, occurrence, primary")
56        }
57    }
58}
59impl From<Group> for &'static str {
60    fn from(input: Group) -> &'static str {
61        match input {
62            Group::Classification => "classification",
63            Group::Context => "context",
64            Group::Occurrence => "occurrence",
65            Group::Primary => "primary",
66        }
67    }
68}
69
70#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
71pub enum Requirement {
72    Optional,
73    Recommended,
74    Required,
75}
76
77impl<'de> Deserialize<'de> for Requirement {
78    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79    where
80        D: Deserializer<'de>,
81    {
82        let stringval = String::deserialize(deserializer)?.to_lowercase();
83        let result: Requirement = stringval.as_str().into();
84        Ok(result)
85    }
86}
87
88impl From<&str> for Requirement {
89    fn from(value: &str) -> Self {
90        match value {
91            "optional" => Self::Optional,
92            "recommended" => Self::Recommended,
93            "required" => Self::Required,
94            _ => panic!("Invalid enum value '{value}' - select from optional,recommended,required"),
95        }
96    }
97}
98
99impl From<Requirement> for &'static str {
100    fn from(input: Requirement) -> &'static str {
101        match input {
102            Requirement::Optional => "optional",
103            Requirement::Recommended => "recommended",
104            Requirement::Required => "required",
105        }
106    }
107}
108
109#[test]
110fn test_from_str_requirement() {
111    use crate::events::*;
112    assert_eq!(Requirement::from("required"), Requirement::Required);
113}
114
115#[test]
116#[should_panic]
117fn test_from_str_invalid_requirement() {
118    use crate::events::*;
119    let _ = Requirement::from("requiasdfasdfred");
120}
121
122// #[allow(dead_code)]
123#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
124pub struct EventAttribute {
125    #[serde(skip_serializing_if = "Option::is_none")]
126    name: Option<String>,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    profile: Option<String>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    description: Option<String>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    caption: Option<String>,
133    #[serde(skip_serializing_if = "Option::is_none")]
134    requirement: Option<Requirement>,
135    #[serde(skip_serializing_if = "Option::is_none")]
136    group: Option<Group>,
137    #[serde(alias = "$include", skip)]
138    include: Option<String>,
139    /// This is the string name of the type, not the enum value
140    enum_name: String,
141    just_includes: Vec<String>,
142}
143
144impl<'de> Deserialize<'de> for EventAttribute {
145    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
146    where
147        D: Deserializer<'de>,
148    {
149        let value = Value::deserialize(deserializer)?;
150
151        trace!("Deserializing to EventAttribute: {:?}", value);
152
153        if value.is_array() {
154            return Ok(EventAttribute {
155                just_includes: value
156                    .as_array()
157                    .unwrap()
158                    .iter()
159                    .map(|v| v.as_str().unwrap().to_string())
160                    .collect(),
161                ..Default::default()
162            });
163        }
164
165        let name = match value.get("name") {
166            Some(n) => n
167                .as_str()
168                .map(|f| f.to_string().replace("type", "type_name")),
169            None => None,
170        };
171
172        let profile = match value.get("profile") {
173            Some(n) => n.as_str().map(|f| f.to_string()),
174            None => None,
175        };
176
177        let description = match value.get("description") {
178            Some(n) => n.as_str().map(|f| f.to_string()),
179            None => None,
180        };
181
182        let caption = match value.get("caption") {
183            Some(n) => n.as_str().map(|f| f.to_string()),
184            None => None,
185        };
186
187        let requirement: Option<Requirement> = match value.get("requirement") {
188            Some(n) => {
189                let strval = n.as_str().unwrap();
190                Some(Requirement::from(strval))
191            }
192            None => None,
193        };
194
195        let group: Option<Group> = match value.get("group") {
196            Some(n) => {
197                let strval = n.as_str().unwrap();
198                Some(Group::from(strval))
199            }
200            None => None,
201        };
202
203        let includes: Vec<String> = match value.get("$include") {
204            Some(include) => {
205                let mut result: Vec<String> = vec![];
206
207                if include.is_array() {
208                    panic!(
209                        "Found include array while deserializing an event attribute: {:?}",
210                        include.as_array().unwrap()
211                    );
212                    // result.extend(include.as_array().unwrap().iter().map(|f| f.as_str().unwrap().to_string()));
213                } else if include.is_string() {
214                    trace!(
215                        "Found include string or eventattribute: {:?}",
216                        include.as_str().unwrap()
217                    );
218                    result.push(include.as_str().unwrap().to_string());
219                } else {
220                    panic!(
221                        "found an include we can't handle in deserializing an attribute! {:?}",
222                        include
223                    );
224                }
225                result
226            }
227            None => vec![],
228        };
229        let mut enum_name = "String".to_string();
230        if !includes.is_empty() {
231            info!("Found includes: {:?}", includes);
232            enum_name = format!(
233                "crate::{}",
234                collapsed_title_case(includes.first().unwrap().split('/').last().unwrap())
235            );
236        }
237
238        Ok(EventAttribute {
239            name,
240            profile,
241            description,
242            caption,
243            requirement,
244            group,
245            enum_name,
246            ..Default::default()
247        })
248    }
249}
250
251impl EventAttribute {
252    pub fn new(name: String) -> Self {
253        EventAttribute {
254            name: Some(name),
255            ..Self::default()
256        }
257    }
258}
259
260impl Default for EventAttribute {
261    fn default() -> Self {
262        Self {
263            name: Some("".to_string()),
264            caption: Default::default(),
265            description: None,
266            group: Default::default(),
267            profile: Default::default(),
268            requirement: Default::default(),
269            include: Default::default(),
270            // because everything's a string at some point.:D how's
271            enum_name: "String".to_string(),
272            just_includes: vec![],
273        }
274    }
275}
276
277fn load_all_event_files(paths: &DirPaths) -> HashMap<String, EventDef> {
278    let target_path = format!("{}events/", paths.schema_path);
279    info!("loading all event files from {}", target_path);
280
281    let mut result: HashMap<String, EventDef> = HashMap::new();
282
283    for file in WalkDir::new(&target_path) {
284        let file = match file {
285            Ok(val) => val,
286            Err(err) => {
287                error!("Failed to walk dir somewhere: {err:?}");
288                continue;
289            }
290        };
291        if !file.clone().into_path().is_file() {
292            debug!("Skipping {:?} as it's not a file...", file);
293            continue;
294        }
295        debug!("Reading {file:?} into EventDef");
296        let file_value =
297            read_file_to_value(file.clone().into_path().as_os_str().to_str().unwrap()).unwrap();
298        let mut file_event: EventDef = serde_json::from_value(file_value).unwrap();
299        // stripping out the include value, because by this point we should have handled it!
300        file_event.attributes.remove("$include");
301
302        result.insert(
303            file.into_path()
304                .as_os_str()
305                .to_str()
306                .unwrap()
307                .to_owned()
308                .replace(&paths.schema_path, ""),
309            file_event,
310        );
311    }
312    result
313}
314
315pub fn generate_events(paths: &DirPaths, root_module: &mut Module) -> Result<(), Box<dyn Error>> {
316    // let event_schema_path = format!("{}events/", paths.schema_path);
317    // let filenames = find_files(&event_schema_path);
318
319    let categories_file = read_file_to_value(&format!("{}categories.json", paths.schema_path))?;
320    let categories_file = categories_file
321        .get("attributes")
322        .expect("Coudln't get categories file attributes");
323    let categories: HashMap<String, Category> = serde_json::from_value(categories_file.to_owned())?;
324
325    let mut all_events = load_all_event_files(paths);
326
327    for (filename, event) in all_events.iter_mut() {
328        if filename.len() <= 1 {
329            warn!("Can't handle file {}", filename);
330            continue;
331        }
332
333        let struct_name = event.name.to_owned();
334        info!("Struct name: {} from {}", struct_name, filename);
335
336        let target_module_path = PathBuf::from(filename.replace("events/", ""))
337            .parent()
338            .unwrap()
339            .to_str()
340            .unwrap()
341            .to_owned();
342        debug!("Putting it into module: {}", target_module_path);
343
344        let mut target_module = root_module
345            .children
346            .get_mut("events")
347            .expect("Couldn't get events module from root?");
348
349        for tm in target_module_path.split('/') {
350            if tm.is_empty() {
351                continue;
352            }
353            if !target_module.has_child(tm) {
354                target_module.add_child(tm.to_owned());
355            }
356            target_module = target_module
357                .children
358                .get_mut(tm)
359                .unwrap_or_else(|| panic!("Couldn't get {tm}"));
360        }
361
362        trace!("Target module: {:#?}", target_module);
363
364        let struct_doc = format!("{}\n\nSourced from: `{}`", &event.description, filename);
365        let mut module_struct = Struct::new(&collapsed_title_case(&event.name));
366        module_struct
367            .doc(&struct_doc)
368            .vis("pub")
369            .derive("serde::Deserialize")
370            .derive("serde::Serialize")
371            .derive("Default")
372            .derive("Debug");
373
374        let mut func_new = Function::new("new");
375        func_new.vis("pub").ret("Self");
376
377        func_new.line("Self {");
378
379        let mut module_impl = Impl::new(&collapsed_title_case(&event.name));
380
381        // yes, we're sorting struct fields.
382        event
383            .attributes
384            .iter()
385            .sorted_by(|a, b| Ord::cmp(a.0, b.0))
386            .for_each(|(attr_name, attr)| {
387                if attr_name == "$include" {
388                    return error!(
389                        "need to handle attribute $include {:#?}",
390                        attr.just_includes
391                    );
392                    //TODO: need to handle  attribute includes
393                } else {
394                    trace!("attr name: {attr_name}");
395                }
396                let attr_name = match attr_name == "type" {
397                    true => "type_name",
398                    false => attr_name,
399                };
400
401                let field_requirement_template: &'static str = match &attr.requirement {
402                    Some(val) => match val {
403                        Requirement::Optional => "Option<{}>",
404                        Requirement::Recommended => "Option<{}>",
405                        Requirement::Required => "{}",
406                    },
407                    None => "Option<{}>",
408                };
409
410                let mut attr_field = Field::new(
411                    attr_name,
412                    field_requirement_template.replace("{}", &attr.enum_name),
413                );
414                attr_field.vis("pub");
415                // documentation is always nice
416                if let Some(description) = &attr.description {
417                    attr_field.doc(fix_docstring(description.to_owned(), None));
418                }
419
420                let mut serde_annotations: Vec<&str> = vec![];
421                if attr_name == "type_name" {
422                    // because when we serialize it out, it needs the right name
423                    serde_annotations.push("alias=\"type\"");
424                }
425                // add the attributes to the new() function
426                if attr.requirement.is_some() && attr.requirement == Some(Requirement::Required) {
427                    func_new.arg(attr_name, &attr.enum_name);
428                    func_new.line(format!("{attr_name},"));
429                } else {
430                    func_new.line(format!("{attr_name}: None,"));
431                    let mut with_func = Function::new(format!("with_{attr_name}"));
432
433                    with_func
434                        .vis("pub")
435                        .doc(format!("Set the value of {}", attr_name))
436                        .arg_self()
437                        .arg(attr_name, &attr.enum_name)
438                        .ret("Self");
439
440                    with_func.line(format!("Self {{ {attr_name}: Some({attr_name}),"));
441                    if event.attributes.len() > 1 {
442                        with_func.line("..self  ");
443                    }
444                    with_func.line("}");
445
446                    module_impl.push_fn(with_func);
447
448                    // if it's optional, then we need to tell serde to ignore it on serialization
449                    serde_annotations.push("skip_serializing_if = \"Option::is_none\"");
450                }
451
452                if !serde_annotations.is_empty() {
453                    attr_field.annotation(&format!("#[serde({})]", serde_annotations.join(",")));
454                }
455
456                // add builders for the not-required fields
457
458                module_struct.push_field(attr_field);
459            });
460
461        // this is the end of the Self::new() function
462        func_new.line("}");
463
464        let mut uid = event.uid.unwrap_or(0);
465        if let Some(category) = event.category.clone() {
466            if categories.contains_key(&category) {
467                uid += 1000 * categories[&category].uid;
468                trace!("Set UID to {uid}");
469            }
470        }
471        if uid != 0 {
472            module_impl.associate_const("UID", "u16", format!("{}", uid), "pub");
473        }
474
475        module_impl.push_fn(func_new);
476
477        target_module.scope.push_struct(module_struct);
478        target_module.scope.push_impl(module_impl);
479    }
480
481    Ok(())
482}