forestfire-core 0.4.0

Core tree-learning algorithms for the ForestFire project.
Documentation
use super::*;

pub(crate) const COMPILED_ARTIFACT_MAGIC: [u8; 4] = *b"FFCA";
pub(crate) const COMPILED_ARTIFACT_VERSION: u16 = 1;
pub(crate) const COMPILED_ARTIFACT_BACKEND_CPU: u16 = 1;
pub(crate) const COMPILED_ARTIFACT_HEADER_LEN: usize = 8;

#[derive(Debug)]
pub enum CompiledArtifactError {
    ArtifactTooShort { actual: usize, minimum: usize },
    InvalidMagic([u8; 4]),
    UnsupportedVersion(u16),
    UnsupportedBackend(u16),
    Encode(String),
    Decode(String),
    InvalidSemanticModel(IrError),
    InvalidRuntime(OptimizeError),
}

impl Display for CompiledArtifactError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            CompiledArtifactError::ArtifactTooShort { actual, minimum } => write!(
                f,
                "Compiled artifact is too short: expected at least {} bytes, found {}.",
                minimum, actual
            ),
            CompiledArtifactError::InvalidMagic(magic) => {
                write!(f, "Compiled artifact has invalid magic bytes: {:?}.", magic)
            }
            CompiledArtifactError::UnsupportedVersion(version) => {
                write!(f, "Unsupported compiled artifact version: {}.", version)
            }
            CompiledArtifactError::UnsupportedBackend(backend) => {
                write!(f, "Unsupported compiled artifact backend: {}.", backend)
            }
            CompiledArtifactError::Encode(message) => {
                write!(f, "Failed to encode compiled artifact: {}.", message)
            }
            CompiledArtifactError::Decode(message) => {
                write!(f, "Failed to decode compiled artifact: {}.", message)
            }
            CompiledArtifactError::InvalidSemanticModel(err) => err.fmt(f),
            CompiledArtifactError::InvalidRuntime(err) => err.fmt(f),
        }
    }
}

impl Error for CompiledArtifactError {}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct CompiledArtifactPayload {
    pub(crate) semantic_ir: ModelPackageIr,
    pub(crate) runtime: OptimizedRuntime,
    #[serde(default)]
    pub(crate) feature_projection: Option<Vec<usize>>,
}

impl OptimizedModel {
    pub fn serialize_compiled(&self) -> Result<Vec<u8>, CompiledArtifactError> {
        let payload = CompiledArtifactPayload {
            semantic_ir: self.source_model.to_ir(),
            runtime: self.runtime.clone(),
            feature_projection: Some(self.feature_projection.clone()),
        };
        let mut payload_bytes = Vec::new();
        ciborium::into_writer(&payload, &mut payload_bytes)
            .map_err(|err| CompiledArtifactError::Encode(err.to_string()))?;
        let mut bytes = Vec::with_capacity(COMPILED_ARTIFACT_HEADER_LEN + payload_bytes.len());
        bytes.extend_from_slice(&COMPILED_ARTIFACT_MAGIC);
        bytes.extend_from_slice(&COMPILED_ARTIFACT_VERSION.to_le_bytes());
        bytes.extend_from_slice(&COMPILED_ARTIFACT_BACKEND_CPU.to_le_bytes());
        bytes.extend_from_slice(&payload_bytes);
        Ok(bytes)
    }

    pub fn deserialize_compiled(
        serialized: &[u8],
        physical_cores: Option<usize>,
    ) -> Result<Self, CompiledArtifactError> {
        if serialized.len() < COMPILED_ARTIFACT_HEADER_LEN {
            return Err(CompiledArtifactError::ArtifactTooShort {
                actual: serialized.len(),
                minimum: COMPILED_ARTIFACT_HEADER_LEN,
            });
        }

        let magic = [serialized[0], serialized[1], serialized[2], serialized[3]];
        if magic != COMPILED_ARTIFACT_MAGIC {
            return Err(CompiledArtifactError::InvalidMagic(magic));
        }

        let version = u16::from_le_bytes([serialized[4], serialized[5]]);
        if version != COMPILED_ARTIFACT_VERSION {
            return Err(CompiledArtifactError::UnsupportedVersion(version));
        }

        let backend = u16::from_le_bytes([serialized[6], serialized[7]]);
        if backend != COMPILED_ARTIFACT_BACKEND_CPU {
            return Err(CompiledArtifactError::UnsupportedBackend(backend));
        }

        let payload: CompiledArtifactPayload = ciborium::from_reader(std::io::Cursor::new(
            &serialized[COMPILED_ARTIFACT_HEADER_LEN..],
        ))
        .map_err(|err| CompiledArtifactError::Decode(err.to_string()))?;
        let source_model = ir::model_from_ir(payload.semantic_ir)
            .map_err(CompiledArtifactError::InvalidSemanticModel)?;
        let feature_projection = payload
            .feature_projection
            .unwrap_or_else(|| (0..source_model.num_features()).collect());
        let thread_count = resolve_inference_thread_count(physical_cores)
            .map_err(CompiledArtifactError::InvalidRuntime)?;
        let executor =
            InferenceExecutor::new(thread_count).map_err(CompiledArtifactError::InvalidRuntime)?;

        Ok(Self {
            source_model,
            runtime: payload.runtime,
            executor,
            feature_projection,
        })
    }
}