1use super::{
4 CanonicalId, ComponentExtensions, ComponentIdentifiers, ComponentType, CryptoProperties,
5 DependencyScope, DependencyType, DocumentMetadata, Ecosystem, ExternalReference,
6 FormatExtensions, Hash, LicenseInfo, Organization, VexStatus, VulnerabilityRef,
7};
8use indexmap::IndexMap;
9use serde::{Deserialize, Serialize};
10use xxhash_rust::xxh3::xxh3_64;
11
12const CANONICAL_NAN_BITS: u64 = 0x7ff8_0000_0000_0000;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct NormalizedSbom {
20 pub document: DocumentMetadata,
22 pub components: IndexMap<CanonicalId, Component>,
24 pub edges: Vec<DependencyEdge>,
26 pub extensions: FormatExtensions,
28 pub content_hash: u64,
30 pub primary_component_id: Option<CanonicalId>,
33 #[serde(skip)]
35 pub collision_count: usize,
36}
37
38impl NormalizedSbom {
39 #[must_use]
41 pub fn new(document: DocumentMetadata) -> Self {
42 Self {
43 document,
44 components: IndexMap::new(),
45 edges: Vec::new(),
46 extensions: FormatExtensions::default(),
47 content_hash: 0,
48 primary_component_id: None,
49 collision_count: 0,
50 }
51 }
52
53 #[must_use]
63 pub fn direct_dependency_ids(&self) -> std::collections::HashSet<CanonicalId> {
64 use std::collections::HashSet;
65 if let Some(root) = &self.primary_component_id {
66 return self
67 .edges
68 .iter()
69 .filter(|e| &e.from == root)
70 .map(|e| e.to.clone())
71 .collect();
72 }
73 let incoming: HashSet<&CanonicalId> = self.edges.iter().map(|e| &e.to).collect();
75 let roots: HashSet<&CanonicalId> = self
76 .components
77 .keys()
78 .filter(|id| !incoming.contains(id))
79 .collect();
80 self.edges
81 .iter()
82 .filter(|e| roots.contains(&e.from))
83 .map(|e| e.to.clone())
84 .collect()
85 }
86
87 pub fn add_component(&mut self, component: Component) -> bool {
92 let id = component.canonical_id.clone();
93 if let Some(existing) = self.components.get(&id) {
94 if existing.identifiers.format_id != component.identifiers.format_id
96 || existing.name != component.name
97 {
98 self.collision_count += 1;
99 }
100 self.components.insert(id, component);
101 true
102 } else {
103 self.components.insert(id, component);
104 false
105 }
106 }
107
108 pub fn log_collision_summary(&self) {
110 if self.collision_count > 0 {
111 tracing::info!(
112 collision_count = self.collision_count,
113 "Canonical ID collisions: {} distinct components resolved to the same ID \
114 and were overwritten. Consider adding PURL identifiers to disambiguate.",
115 self.collision_count
116 );
117 }
118 }
119
120 pub fn add_edge(&mut self, edge: DependencyEdge) {
122 self.edges.push(edge);
123 }
124
125 #[must_use]
127 pub fn get_component(&self, id: &CanonicalId) -> Option<&Component> {
128 self.components.get(id)
129 }
130
131 #[must_use]
133 pub fn get_dependencies(&self, id: &CanonicalId) -> Vec<&DependencyEdge> {
134 self.edges.iter().filter(|e| &e.from == id).collect()
135 }
136
137 #[must_use]
139 pub fn get_dependents(&self, id: &CanonicalId) -> Vec<&DependencyEdge> {
140 self.edges.iter().filter(|e| &e.to == id).collect()
141 }
142
143 pub fn calculate_content_hash(&mut self) {
145 let mut hasher_input = Vec::new();
146
147 if let Ok(meta_json) = serde_json::to_vec(&self.document) {
149 hasher_input.extend(meta_json);
150 }
151
152 let mut component_ids: Vec<_> = self.components.keys().collect();
154 component_ids.sort_by(|a, b| a.value().cmp(b.value()));
155
156 for id in component_ids {
157 if let Some(comp) = self.components.get(id) {
158 hasher_input.extend(comp.content_hash.to_le_bytes());
159 }
160 }
161
162 let mut edge_keys: Vec<_> = self
164 .edges
165 .iter()
166 .map(|edge| {
167 (
168 edge.from.value(),
169 edge.to.value(),
170 edge.relationship.to_string(),
171 edge.scope
172 .as_ref()
173 .map_or(String::new(), std::string::ToString::to_string),
174 )
175 })
176 .collect();
177 edge_keys.sort();
178 for (from, to, relationship, scope) in &edge_keys {
179 hasher_input.extend(from.as_bytes());
180 hasher_input.extend(to.as_bytes());
181 hasher_input.extend(relationship.as_bytes());
182 hasher_input.extend(scope.as_bytes());
183 }
184
185 self.content_hash = xxh3_64(&hasher_input);
186 }
187
188 #[must_use]
190 pub fn component_count(&self) -> usize {
191 self.components.len()
192 }
193
194 #[must_use]
196 pub fn primary_component(&self) -> Option<&Component> {
197 self.primary_component_id
198 .as_ref()
199 .and_then(|id| self.components.get(id))
200 }
201
202 pub fn set_primary_component(&mut self, id: CanonicalId) {
204 self.primary_component_id = Some(id);
205 }
206
207 pub fn ecosystems(&self) -> Vec<&Ecosystem> {
209 let mut ecosystems: Vec<_> = self
210 .components
211 .values()
212 .filter_map(|c| c.ecosystem.as_ref())
213 .collect();
214 ecosystems.sort_by_key(std::string::ToString::to_string);
215 ecosystems.dedup();
216 ecosystems
217 }
218
219 #[must_use]
221 pub fn all_vulnerabilities(&self) -> Vec<(&Component, &VulnerabilityRef)> {
222 self.components
223 .values()
224 .flat_map(|c| c.vulnerabilities.iter().map(move |v| (c, v)))
225 .collect()
226 }
227
228 #[must_use]
230 pub fn vulnerability_counts(&self) -> VulnerabilityCounts {
231 let mut counts = VulnerabilityCounts::default();
232 for (_, vuln) in self.all_vulnerabilities() {
233 match vuln.severity {
234 Some(super::Severity::Critical) => counts.critical += 1,
235 Some(super::Severity::High) => counts.high += 1,
236 Some(super::Severity::Medium) => counts.medium += 1,
237 Some(super::Severity::Low) => counts.low += 1,
238 _ => counts.unknown += 1,
239 }
240 }
241 counts
242 }
243
244 pub fn build_index(&self) -> super::NormalizedSbomIndex {
259 super::NormalizedSbomIndex::build(self)
260 }
261
262 #[must_use]
266 pub fn get_dependencies_indexed<'a>(
267 &'a self,
268 id: &CanonicalId,
269 index: &super::NormalizedSbomIndex,
270 ) -> Vec<&'a DependencyEdge> {
271 index.dependencies_of(id, &self.edges)
272 }
273
274 #[must_use]
278 pub fn get_dependents_indexed<'a>(
279 &'a self,
280 id: &CanonicalId,
281 index: &super::NormalizedSbomIndex,
282 ) -> Vec<&'a DependencyEdge> {
283 index.dependents_of(id, &self.edges)
284 }
285
286 #[must_use]
290 pub fn find_by_name_indexed(
291 &self,
292 name: &str,
293 index: &super::NormalizedSbomIndex,
294 ) -> Vec<&Component> {
295 let name_lower = name.to_lowercase();
296 index
297 .find_by_name_lower(&name_lower)
298 .iter()
299 .filter_map(|id| self.components.get(id))
300 .collect()
301 }
302
303 #[must_use]
307 pub fn search_by_name_indexed(
308 &self,
309 query: &str,
310 index: &super::NormalizedSbomIndex,
311 ) -> Vec<&Component> {
312 let query_lower = query.to_lowercase();
313 index
314 .search_by_name(&query_lower)
315 .iter()
316 .filter_map(|id| self.components.get(id))
317 .collect()
318 }
319
320 pub fn apply_cra_sidecar(&mut self, sidecar: &super::CraSidecarMetadata) {
325 if self.document.security_contact.is_none() {
327 self.document
328 .security_contact
329 .clone_from(&sidecar.security_contact);
330 }
331
332 if self.document.vulnerability_disclosure_url.is_none() {
333 self.document
334 .vulnerability_disclosure_url
335 .clone_from(&sidecar.vulnerability_disclosure_url);
336 }
337
338 if self.document.support_end_date.is_none() {
339 self.document.support_end_date = sidecar.support_end_date;
340 }
341
342 if self.document.name.is_none() {
343 self.document.name.clone_from(&sidecar.product_name);
344 }
345
346 if let Some(manufacturer) = &sidecar.manufacturer_name {
348 let has_org = self
349 .document
350 .creators
351 .iter()
352 .any(|c| c.creator_type == super::CreatorType::Organization);
353
354 if !has_org {
355 self.document.creators.push(super::Creator {
356 creator_type: super::CreatorType::Organization,
357 name: manufacturer.clone(),
358 email: sidecar.manufacturer_email.clone(),
359 });
360 }
361 }
362 }
363}
364
365impl Default for NormalizedSbom {
366 fn default() -> Self {
367 Self::new(DocumentMetadata::default())
368 }
369}
370
371#[derive(Debug, Clone, Default, Serialize, Deserialize)]
373pub struct VulnerabilityCounts {
374 pub critical: usize,
375 pub high: usize,
376 pub medium: usize,
377 pub low: usize,
378 pub unknown: usize,
379}
380
381impl VulnerabilityCounts {
382 #[must_use]
383 pub const fn total(&self) -> usize {
384 self.critical + self.high + self.medium + self.low + self.unknown
385 }
386}
387
388#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
390#[non_exhaustive]
391pub enum StalenessLevel {
392 Fresh,
394 Aging,
396 Stale,
398 Abandoned,
400 Deprecated,
402 Archived,
404}
405
406impl StalenessLevel {
407 #[must_use]
409 pub const fn from_days(days: u32) -> Self {
410 match days {
411 0..=182 => Self::Fresh, 183..=365 => Self::Aging, 366..=730 => Self::Stale, _ => Self::Abandoned, }
416 }
417
418 #[must_use]
420 pub const fn label(&self) -> &'static str {
421 match self {
422 Self::Fresh => "Fresh",
423 Self::Aging => "Aging",
424 Self::Stale => "Stale",
425 Self::Abandoned => "Abandoned",
426 Self::Deprecated => "Deprecated",
427 Self::Archived => "Archived",
428 }
429 }
430
431 #[must_use]
433 pub const fn icon(&self) -> &'static str {
434 match self {
435 Self::Fresh => "✓",
436 Self::Aging => "⏳",
437 Self::Stale => "⚠",
438 Self::Abandoned => "⛔",
439 Self::Deprecated => "⊘",
440 Self::Archived => "📦",
441 }
442 }
443
444 #[must_use]
446 pub const fn severity(&self) -> u8 {
447 match self {
448 Self::Fresh => 0,
449 Self::Aging => 1,
450 Self::Stale => 2,
451 Self::Abandoned => 3,
452 Self::Deprecated | Self::Archived => 4,
453 }
454 }
455}
456
457impl std::fmt::Display for StalenessLevel {
458 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459 write!(f, "{}", self.label())
460 }
461}
462
463#[derive(Debug, Clone, Serialize, Deserialize)]
465pub struct StalenessInfo {
466 pub level: StalenessLevel,
468 pub last_published: Option<chrono::DateTime<chrono::Utc>>,
470 pub is_deprecated: bool,
472 pub is_archived: bool,
474 pub deprecation_message: Option<String>,
476 pub days_since_update: Option<u32>,
478 pub latest_version: Option<String>,
480}
481
482impl StalenessInfo {
483 #[must_use]
485 pub const fn new(level: StalenessLevel) -> Self {
486 Self {
487 level,
488 last_published: None,
489 is_deprecated: false,
490 is_archived: false,
491 deprecation_message: None,
492 days_since_update: None,
493 latest_version: None,
494 }
495 }
496
497 #[must_use]
499 pub fn from_date(last_published: chrono::DateTime<chrono::Utc>) -> Self {
500 let days = (chrono::Utc::now() - last_published).num_days().max(0) as u32;
501 let level = StalenessLevel::from_days(days);
502 Self {
503 level,
504 last_published: Some(last_published),
505 is_deprecated: false,
506 is_archived: false,
507 deprecation_message: None,
508 days_since_update: Some(days),
509 latest_version: None,
510 }
511 }
512
513 #[must_use]
515 pub const fn needs_attention(&self) -> bool {
516 self.level.severity() >= 2
517 }
518}
519
520#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
522#[non_exhaustive]
523pub enum EolStatus {
524 Supported,
526 SecurityOnly,
528 ApproachingEol,
530 EndOfLife,
532 Unknown,
534}
535
536impl EolStatus {
537 #[must_use]
539 pub const fn label(&self) -> &'static str {
540 match self {
541 Self::Supported => "Supported",
542 Self::SecurityOnly => "Security Only",
543 Self::ApproachingEol => "Approaching EOL",
544 Self::EndOfLife => "End of Life",
545 Self::Unknown => "Unknown",
546 }
547 }
548
549 #[must_use]
551 pub const fn icon(&self) -> &'static str {
552 match self {
553 Self::Supported => "✓",
554 Self::SecurityOnly => "🔒",
555 Self::ApproachingEol => "⚠",
556 Self::EndOfLife => "⛔",
557 Self::Unknown => "?",
558 }
559 }
560
561 #[must_use]
563 pub const fn severity(&self) -> u8 {
564 match self {
565 Self::Supported => 0,
566 Self::SecurityOnly => 1,
567 Self::ApproachingEol => 2,
568 Self::EndOfLife => 3,
569 Self::Unknown => 0,
570 }
571 }
572}
573
574impl std::fmt::Display for EolStatus {
575 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
576 write!(f, "{}", self.label())
577 }
578}
579
580#[derive(Debug, Clone, Serialize, Deserialize)]
582pub struct EolInfo {
583 pub status: EolStatus,
585 pub product: String,
587 pub cycle: String,
589 pub eol_date: Option<chrono::NaiveDate>,
591 pub support_end_date: Option<chrono::NaiveDate>,
593 pub is_lts: bool,
595 pub latest_in_cycle: Option<String>,
597 pub latest_release_date: Option<chrono::NaiveDate>,
599 pub days_until_eol: Option<i64>,
601}
602
603impl EolInfo {
604 #[must_use]
606 pub const fn needs_attention(&self) -> bool {
607 self.status.severity() >= 2
608 }
609}
610
611#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct Component {
614 pub canonical_id: CanonicalId,
616 pub identifiers: ComponentIdentifiers,
618 pub name: String,
620 pub version: Option<String>,
622 pub semver: Option<semver::Version>,
624 pub component_type: ComponentType,
626 pub ecosystem: Option<Ecosystem>,
628 pub licenses: LicenseInfo,
630 pub supplier: Option<Organization>,
632 pub hashes: Vec<Hash>,
634 pub external_refs: Vec<ExternalReference>,
636 pub vulnerabilities: Vec<VulnerabilityRef>,
638 pub vex_status: Option<VexStatus>,
640 pub content_hash: u64,
642 pub extensions: ComponentExtensions,
644 pub description: Option<String>,
646 pub copyright: Option<String>,
648 pub author: Option<String>,
650 pub group: Option<String>,
652 pub is_external: bool,
654 pub version_range: Option<String>,
656 pub staleness: Option<StalenessInfo>,
658 pub eol: Option<EolInfo>,
660 pub ml_model: Option<crate::model::MlModelInfo>,
662 pub dataset: Option<crate::model::DatasetInfo>,
664 #[serde(default, skip_serializing_if = "Option::is_none")]
666 pub crypto_properties: Option<CryptoProperties>,
667}
668
669impl Component {
670 #[must_use]
672 pub fn new(name: String, format_id: String) -> Self {
673 let identifiers = ComponentIdentifiers::new(format_id);
674 let canonical_id = identifiers.canonical_id();
675
676 Self {
677 canonical_id,
678 identifiers,
679 name,
680 version: None,
681 semver: None,
682 component_type: ComponentType::Library,
683 ecosystem: None,
684 licenses: LicenseInfo::default(),
685 supplier: None,
686 hashes: Vec::new(),
687 external_refs: Vec::new(),
688 vulnerabilities: Vec::new(),
689 vex_status: None,
690 content_hash: 0,
691 extensions: ComponentExtensions::default(),
692 description: None,
693 copyright: None,
694 author: None,
695 group: None,
696 is_external: false,
697 version_range: None,
698 staleness: None,
699 eol: None,
700 ml_model: None,
701 dataset: None,
702 crypto_properties: None,
703 }
704 }
705
706 #[must_use]
708 pub fn with_purl(mut self, purl: String) -> Self {
709 self.set_purl(purl);
710 self
711 }
712
713 pub fn set_purl(&mut self, purl: String) {
717 self.identifiers.purl = Some(purl);
718 self.canonical_id = self.identifiers.canonical_id();
719
720 if let Some(purl_str) = &self.identifiers.purl
722 && let Some(purl_type) = purl_str
723 .strip_prefix("pkg:")
724 .and_then(|s| s.split('/').next())
725 {
726 self.ecosystem = Some(Ecosystem::from_purl_type(purl_type));
727 }
728 }
729
730 #[must_use]
732 pub fn with_version(mut self, version: String) -> Self {
733 self.semver = semver::Version::parse(&version).ok();
734 self.version = Some(version);
735 self
736 }
737
738 #[must_use]
749 pub fn with_swhid(mut self, swhid: String) -> Self {
750 if let Ok(obj) = crate::model::SwhidObject::parse(&swhid) {
751 self.identifiers.swhid.push(obj);
752 self.canonical_id = self.identifiers.canonical_id();
753 }
754 self
755 }
756
757 #[must_use]
759 pub fn with_swhid_object(mut self, swhid: crate::model::SwhidObject) -> Self {
760 self.identifiers.swhid.push(swhid);
761 self.canonical_id = self.identifiers.canonical_id();
762 self
763 }
764
765 #[must_use]
767 pub fn with_ml_model(mut self, ml_model: crate::model::MlModelInfo) -> Self {
768 self.ml_model = Some(ml_model);
769 self
770 }
771
772 #[must_use]
774 pub fn with_dataset(mut self, dataset: crate::model::DatasetInfo) -> Self {
775 self.dataset = Some(dataset);
776 self
777 }
778
779 fn extend_with_optional_str(hasher_input: &mut Vec<u8>, value: &Option<String>) {
780 if let Some(value) = value {
781 hasher_input.extend(value.as_bytes());
782 }
783 }
784
785 fn extend_with_string_list(hasher_input: &mut Vec<u8>, values: &[String]) {
786 for value in values {
787 hasher_input.extend(value.as_bytes());
788 }
789 }
790
791 fn extend_with_optional_f64(hasher_input: &mut Vec<u8>, value: Option<f64>) {
792 if let Some(value) = value {
793 let normalized = if value == 0.0 {
794 0.0
795 } else if value.is_nan() {
796 f64::from_bits(CANONICAL_NAN_BITS)
797 } else {
798 value
799 };
800 hasher_input.extend(normalized.to_bits().to_le_bytes());
801 }
802 }
803
804 fn extend_with_ml_model(
805 hasher_input: &mut Vec<u8>,
806 ml_model: &Option<crate::model::MlModelInfo>,
807 ) {
808 if let Some(ml_model) = ml_model {
809 Self::extend_with_optional_str(hasher_input, &ml_model.approach);
810 Self::extend_with_optional_str(hasher_input, &ml_model.architecture_family);
811 Self::extend_with_optional_str(hasher_input, &ml_model.architecture_name);
812 Self::extend_with_optional_str(hasher_input, &ml_model.task);
813 Self::extend_with_optional_str(hasher_input, &ml_model.quantization);
814 Self::extend_with_optional_str(hasher_input, &ml_model.limitations);
815 Self::extend_with_optional_str(hasher_input, &ml_model.model_card_url);
816 Self::extend_with_optional_f64(hasher_input, ml_model.energy_kwh_training);
817
818 for dataset in &ml_model.training_datasets {
819 Self::extend_with_optional_str(hasher_input, &dataset.name);
820 Self::extend_with_optional_str(hasher_input, &dataset.purl);
821 }
822 }
823 }
824
825 fn extend_with_dataset(
826 hasher_input: &mut Vec<u8>,
827 dataset: &Option<crate::model::DatasetInfo>,
828 ) {
829 if let Some(dataset) = dataset {
830 Self::extend_with_optional_str(hasher_input, &dataset.dataset_type);
831 Self::extend_with_string_list(hasher_input, &dataset.sensitivity_classifications);
832 Self::extend_with_string_list(hasher_input, &dataset.governance_owners);
833 }
834 }
835 pub fn calculate_content_hash(&mut self) {
837 let mut hasher_input = Vec::new();
838
839 hasher_input.extend(self.name.as_bytes());
840 Self::extend_with_optional_str(&mut hasher_input, &self.version);
841 Self::extend_with_optional_str(&mut hasher_input, &self.identifiers.purl);
842 for license in &self.licenses.declared {
843 hasher_input.extend(license.expression.as_bytes());
844 }
845 if let Some(supplier) = &self.supplier {
846 hasher_input.extend(supplier.name.as_bytes());
847 }
848 for hash in &self.hashes {
849 hasher_input.extend(hash.value.as_bytes());
850 }
851 for vuln in &self.vulnerabilities {
852 hasher_input.extend(vuln.id.as_bytes());
853 }
854 if self.is_external {
855 hasher_input.push(b'E');
856 }
857 if let Some(vr) = &self.version_range {
858 hasher_input.extend(vr.as_bytes());
859 }
860 Self::extend_with_ml_model(&mut hasher_input, &self.ml_model);
861 Self::extend_with_dataset(&mut hasher_input, &self.dataset);
862
863 if let Some(cp) = &self.crypto_properties {
865 hasher_input.extend(cp.asset_type.to_string().as_bytes());
866 if let Some(oid) = &cp.oid {
867 hasher_input.extend(oid.as_bytes());
868 }
869 if let Some(algo) = &cp.algorithm_properties {
870 if let Some(family) = &algo.algorithm_family {
871 hasher_input.extend(family.as_bytes());
872 }
873 if let Some(level) = algo.nist_quantum_security_level {
874 hasher_input.push(level);
875 }
876 }
877 if let Some(mat) = &cp.related_crypto_material_properties
878 && let Some(state) = &mat.state
879 {
880 hasher_input.extend(state.to_string().as_bytes());
881 }
882 if let Some(cert) = &cp.certificate_properties
883 && let Some(expiry) = &cert.not_valid_after
884 {
885 hasher_input.extend(expiry.to_rfc3339().as_bytes());
886 }
887 }
888
889 self.content_hash = xxh3_64(&hasher_input);
890 }
891
892 #[must_use]
894 pub fn is_oss(&self) -> bool {
895 self.licenses.declared.iter().any(|l| l.is_valid_spdx) || self.identifiers.purl.is_some()
897 }
898
899 #[must_use]
901 pub fn display_name(&self) -> String {
902 self.version
903 .as_ref()
904 .map_or_else(|| self.name.clone(), |v| format!("{}@{}", self.name, v))
905 }
906}
907
908#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
910pub struct DependencyEdge {
911 pub from: CanonicalId,
913 pub to: CanonicalId,
915 pub relationship: DependencyType,
917 pub scope: Option<DependencyScope>,
919}
920
921impl DependencyEdge {
922 #[must_use]
924 pub const fn new(from: CanonicalId, to: CanonicalId, relationship: DependencyType) -> Self {
925 Self {
926 from,
927 to,
928 relationship,
929 scope: None,
930 }
931 }
932
933 #[must_use]
935 pub const fn with_scope(mut self, scope: DependencyScope) -> Self {
936 self.scope = Some(scope);
937 self
938 }
939
940 #[must_use]
942 pub const fn is_direct(&self) -> bool {
943 matches!(
944 self.relationship,
945 DependencyType::DependsOn
946 | DependencyType::DevDependsOn
947 | DependencyType::BuildDependsOn
948 | DependencyType::TestDependsOn
949 | DependencyType::RuntimeDependsOn
950 )
951 }
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957 use crate::model::MlModelInfo;
958
959 #[test]
960 fn test_content_hash_normalizes_ml_energy_zero_and_nan() {
961 let mut positive_zero = Component::new("model".to_string(), "model@1".to_string());
962 positive_zero.ml_model = Some(MlModelInfo {
963 energy_kwh_training: Some(0.0),
964 ..MlModelInfo::default()
965 });
966 positive_zero.calculate_content_hash();
967
968 let mut negative_zero = Component::new("model".to_string(), "model@1".to_string());
969 negative_zero.ml_model = Some(MlModelInfo {
970 energy_kwh_training: Some(-0.0),
971 ..MlModelInfo::default()
972 });
973 negative_zero.calculate_content_hash();
974
975 let mut nan_a = Component::new("model".to_string(), "model@1".to_string());
976 nan_a.ml_model = Some(MlModelInfo {
977 energy_kwh_training: Some(f64::NAN),
978 ..MlModelInfo::default()
979 });
980 nan_a.calculate_content_hash();
981
982 let mut nan_b = Component::new("model".to_string(), "model@1".to_string());
983 nan_b.ml_model = Some(MlModelInfo {
984 energy_kwh_training: Some(f64::from_bits(CANONICAL_NAN_BITS + 1)),
985 ..MlModelInfo::default()
986 });
987 nan_b.calculate_content_hash();
988
989 assert_eq!(positive_zero.content_hash, negative_zero.content_hash);
990 assert_eq!(nan_a.content_hash, nan_b.content_hash);
991 }
992}