forestfire_core/
compiled_artifact.rs1use super::*;
2
3pub(crate) const COMPILED_ARTIFACT_MAGIC: [u8; 4] = *b"FFCA";
4pub(crate) const COMPILED_ARTIFACT_VERSION: u16 = 1;
5pub(crate) const COMPILED_ARTIFACT_BACKEND_CPU: u16 = 1;
6pub(crate) const COMPILED_ARTIFACT_HEADER_LEN: usize = 8;
7
8#[derive(Debug)]
9pub enum CompiledArtifactError {
10 ArtifactTooShort { actual: usize, minimum: usize },
11 InvalidMagic([u8; 4]),
12 UnsupportedVersion(u16),
13 UnsupportedBackend(u16),
14 Encode(String),
15 Decode(String),
16 InvalidSemanticModel(IrError),
17 InvalidRuntime(OptimizeError),
18}
19
20impl Display for CompiledArtifactError {
21 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22 match self {
23 CompiledArtifactError::ArtifactTooShort { actual, minimum } => write!(
24 f,
25 "Compiled artifact is too short: expected at least {} bytes, found {}.",
26 minimum, actual
27 ),
28 CompiledArtifactError::InvalidMagic(magic) => {
29 write!(f, "Compiled artifact has invalid magic bytes: {:?}.", magic)
30 }
31 CompiledArtifactError::UnsupportedVersion(version) => {
32 write!(f, "Unsupported compiled artifact version: {}.", version)
33 }
34 CompiledArtifactError::UnsupportedBackend(backend) => {
35 write!(f, "Unsupported compiled artifact backend: {}.", backend)
36 }
37 CompiledArtifactError::Encode(message) => {
38 write!(f, "Failed to encode compiled artifact: {}.", message)
39 }
40 CompiledArtifactError::Decode(message) => {
41 write!(f, "Failed to decode compiled artifact: {}.", message)
42 }
43 CompiledArtifactError::InvalidSemanticModel(err) => err.fmt(f),
44 CompiledArtifactError::InvalidRuntime(err) => err.fmt(f),
45 }
46 }
47}
48
49impl Error for CompiledArtifactError {}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub(crate) struct CompiledArtifactPayload {
53 pub(crate) semantic_ir: ModelPackageIr,
54 pub(crate) runtime: OptimizedRuntime,
55 #[serde(default)]
56 pub(crate) feature_projection: Option<Vec<usize>>,
57}
58
59impl OptimizedModel {
60 pub fn serialize_compiled(&self) -> Result<Vec<u8>, CompiledArtifactError> {
61 let payload = CompiledArtifactPayload {
62 semantic_ir: self.source_model.to_ir(),
63 runtime: self.runtime.clone(),
64 feature_projection: Some(self.feature_projection.clone()),
65 };
66 let mut payload_bytes = Vec::new();
67 ciborium::into_writer(&payload, &mut payload_bytes)
68 .map_err(|err| CompiledArtifactError::Encode(err.to_string()))?;
69 let mut bytes = Vec::with_capacity(COMPILED_ARTIFACT_HEADER_LEN + payload_bytes.len());
70 bytes.extend_from_slice(&COMPILED_ARTIFACT_MAGIC);
71 bytes.extend_from_slice(&COMPILED_ARTIFACT_VERSION.to_le_bytes());
72 bytes.extend_from_slice(&COMPILED_ARTIFACT_BACKEND_CPU.to_le_bytes());
73 bytes.extend_from_slice(&payload_bytes);
74 Ok(bytes)
75 }
76
77 pub fn deserialize_compiled(
78 serialized: &[u8],
79 physical_cores: Option<usize>,
80 ) -> Result<Self, CompiledArtifactError> {
81 if serialized.len() < COMPILED_ARTIFACT_HEADER_LEN {
82 return Err(CompiledArtifactError::ArtifactTooShort {
83 actual: serialized.len(),
84 minimum: COMPILED_ARTIFACT_HEADER_LEN,
85 });
86 }
87
88 let magic = [serialized[0], serialized[1], serialized[2], serialized[3]];
89 if magic != COMPILED_ARTIFACT_MAGIC {
90 return Err(CompiledArtifactError::InvalidMagic(magic));
91 }
92
93 let version = u16::from_le_bytes([serialized[4], serialized[5]]);
94 if version != COMPILED_ARTIFACT_VERSION {
95 return Err(CompiledArtifactError::UnsupportedVersion(version));
96 }
97
98 let backend = u16::from_le_bytes([serialized[6], serialized[7]]);
99 if backend != COMPILED_ARTIFACT_BACKEND_CPU {
100 return Err(CompiledArtifactError::UnsupportedBackend(backend));
101 }
102
103 let payload: CompiledArtifactPayload = ciborium::from_reader(std::io::Cursor::new(
104 &serialized[COMPILED_ARTIFACT_HEADER_LEN..],
105 ))
106 .map_err(|err| CompiledArtifactError::Decode(err.to_string()))?;
107 let source_model = ir::model_from_ir(payload.semantic_ir)
108 .map_err(CompiledArtifactError::InvalidSemanticModel)?;
109 let feature_projection = payload
110 .feature_projection
111 .unwrap_or_else(|| (0..source_model.num_features()).collect());
112 let thread_count = resolve_inference_thread_count(physical_cores)
113 .map_err(CompiledArtifactError::InvalidRuntime)?;
114 let executor =
115 InferenceExecutor::new(thread_count).map_err(CompiledArtifactError::InvalidRuntime)?;
116
117 Ok(Self {
118 source_model,
119 runtime: payload.runtime,
120 executor,
121 feature_projection,
122 })
123 }
124}