Skip to main content

forestfire_core/
compiled_artifact.rs

1use super::*;
2
3pub(crate) const COMPILED_ARTIFACT_MAGIC: [u8; 4] = *b"FFCA";
4pub(crate) const COMPILED_ARTIFACT_VERSION: u16 = 1;
5pub(crate) const COMPILED_ARTIFACT_BACKEND_CPU: u16 = 1;
6pub(crate) const COMPILED_ARTIFACT_HEADER_LEN: usize = 8;
7
8#[derive(Debug)]
9pub enum CompiledArtifactError {
10    ArtifactTooShort { actual: usize, minimum: usize },
11    InvalidMagic([u8; 4]),
12    UnsupportedVersion(u16),
13    UnsupportedBackend(u16),
14    Encode(String),
15    Decode(String),
16    InvalidSemanticModel(IrError),
17    InvalidRuntime(OptimizeError),
18}
19
20impl Display for CompiledArtifactError {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        match self {
23            CompiledArtifactError::ArtifactTooShort { actual, minimum } => write!(
24                f,
25                "Compiled artifact is too short: expected at least {} bytes, found {}.",
26                minimum, actual
27            ),
28            CompiledArtifactError::InvalidMagic(magic) => {
29                write!(f, "Compiled artifact has invalid magic bytes: {:?}.", magic)
30            }
31            CompiledArtifactError::UnsupportedVersion(version) => {
32                write!(f, "Unsupported compiled artifact version: {}.", version)
33            }
34            CompiledArtifactError::UnsupportedBackend(backend) => {
35                write!(f, "Unsupported compiled artifact backend: {}.", backend)
36            }
37            CompiledArtifactError::Encode(message) => {
38                write!(f, "Failed to encode compiled artifact: {}.", message)
39            }
40            CompiledArtifactError::Decode(message) => {
41                write!(f, "Failed to decode compiled artifact: {}.", message)
42            }
43            CompiledArtifactError::InvalidSemanticModel(err) => err.fmt(f),
44            CompiledArtifactError::InvalidRuntime(err) => err.fmt(f),
45        }
46    }
47}
48
49impl Error for CompiledArtifactError {}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub(crate) struct CompiledArtifactPayload {
53    pub(crate) semantic_ir: ModelPackageIr,
54    pub(crate) runtime: OptimizedRuntime,
55    #[serde(default)]
56    pub(crate) feature_projection: Option<Vec<usize>>,
57}
58
59impl OptimizedModel {
60    pub fn serialize_compiled(&self) -> Result<Vec<u8>, CompiledArtifactError> {
61        let payload = CompiledArtifactPayload {
62            semantic_ir: self.source_model.to_ir(),
63            runtime: self.runtime.clone(),
64            feature_projection: Some(self.feature_projection.clone()),
65        };
66        let mut payload_bytes = Vec::new();
67        ciborium::into_writer(&payload, &mut payload_bytes)
68            .map_err(|err| CompiledArtifactError::Encode(err.to_string()))?;
69        let mut bytes = Vec::with_capacity(COMPILED_ARTIFACT_HEADER_LEN + payload_bytes.len());
70        bytes.extend_from_slice(&COMPILED_ARTIFACT_MAGIC);
71        bytes.extend_from_slice(&COMPILED_ARTIFACT_VERSION.to_le_bytes());
72        bytes.extend_from_slice(&COMPILED_ARTIFACT_BACKEND_CPU.to_le_bytes());
73        bytes.extend_from_slice(&payload_bytes);
74        Ok(bytes)
75    }
76
77    pub fn deserialize_compiled(
78        serialized: &[u8],
79        physical_cores: Option<usize>,
80    ) -> Result<Self, CompiledArtifactError> {
81        if serialized.len() < COMPILED_ARTIFACT_HEADER_LEN {
82            return Err(CompiledArtifactError::ArtifactTooShort {
83                actual: serialized.len(),
84                minimum: COMPILED_ARTIFACT_HEADER_LEN,
85            });
86        }
87
88        let magic = [serialized[0], serialized[1], serialized[2], serialized[3]];
89        if magic != COMPILED_ARTIFACT_MAGIC {
90            return Err(CompiledArtifactError::InvalidMagic(magic));
91        }
92
93        let version = u16::from_le_bytes([serialized[4], serialized[5]]);
94        if version != COMPILED_ARTIFACT_VERSION {
95            return Err(CompiledArtifactError::UnsupportedVersion(version));
96        }
97
98        let backend = u16::from_le_bytes([serialized[6], serialized[7]]);
99        if backend != COMPILED_ARTIFACT_BACKEND_CPU {
100            return Err(CompiledArtifactError::UnsupportedBackend(backend));
101        }
102
103        let payload: CompiledArtifactPayload = ciborium::from_reader(std::io::Cursor::new(
104            &serialized[COMPILED_ARTIFACT_HEADER_LEN..],
105        ))
106        .map_err(|err| CompiledArtifactError::Decode(err.to_string()))?;
107        let source_model = ir::model_from_ir(payload.semantic_ir)
108            .map_err(CompiledArtifactError::InvalidSemanticModel)?;
109        let feature_projection = payload
110            .feature_projection
111            .unwrap_or_else(|| (0..source_model.num_features()).collect());
112        let thread_count = resolve_inference_thread_count(physical_cores)
113            .map_err(CompiledArtifactError::InvalidRuntime)?;
114        let executor =
115            InferenceExecutor::new(thread_count).map_err(CompiledArtifactError::InvalidRuntime)?;
116
117        Ok(Self {
118            source_model,
119            runtime: payload.runtime,
120            executor,
121            feature_projection,
122        })
123    }
124}