burn_dragon_language 0.5.0

Language modeling components for burn_dragon
Documentation
use std::path::PathBuf;

use anyhow::{Context, Result, anyhow};
use serde::{Deserialize, Serialize};

pub const BITNET_ARTIFACT_BINARY_MAGIC: &[u8; 8] = b"BDBITN01";

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct LanguageBitNetArtifactBundle {
    pub schema_version: u32,
    pub source_checkpoint_epoch: usize,
    pub source_training_config_sha256: String,
    #[serde(default)]
    pub source_run_dir: Option<PathBuf>,
    #[serde(default)]
    pub kernel_abi_version: Option<u32>,
    pub quant: burn_dragon_core::LowBitQuantizationConfig,
    pub rho: burn_dragon_core::LowBitRhoConfig,
    #[serde(default)]
    pub deploy_base_burnpack: Option<Vec<u8>>,
    pub static_weights: burn_dragon_core::experimental::bitnet_reference::BdhBitNetStaticArtifacts,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct PackedWeightArtifactMetadata {
    encoding: burn_dragon_core::experimental::bitnet_reference::PackedWeightEncoding,
    logical_shape: Vec<usize>,
    scale: f32,
    len: usize,
    packed_len: usize,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
struct BdhBitNetStaticArtifactsMetadata {
    decoder_x: Option<PackedWeightArtifactMetadata>,
    decoder_y: Option<PackedWeightArtifactMetadata>,
    encoder: Option<PackedWeightArtifactMetadata>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Default)]
struct DeployBaseBurnpackMetadata {
    len: usize,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
struct LanguageBitNetArtifactBundleMetadata {
    schema_version: u32,
    source_checkpoint_epoch: usize,
    source_training_config_sha256: String,
    #[serde(default)]
    source_run_dir: Option<PathBuf>,
    #[serde(default)]
    kernel_abi_version: Option<u32>,
    quant: burn_dragon_core::LowBitQuantizationConfig,
    rho: burn_dragon_core::LowBitRhoConfig,
    #[serde(default)]
    deploy_base_burnpack: Option<DeployBaseBurnpackMetadata>,
    static_weights: BdhBitNetStaticArtifactsMetadata,
}

fn packed_weight_artifact_metadata(
    artifact: &burn_dragon_core::experimental::bitnet_reference::PackedWeightArtifact,
) -> PackedWeightArtifactMetadata {
    PackedWeightArtifactMetadata {
        encoding: artifact.encoding.clone(),
        logical_shape: artifact.logical_shape.clone(),
        scale: artifact.scale,
        len: artifact.len,
        packed_len: artifact.packed.len(),
    }
}

fn bitnet_artifact_bundle_metadata(
    bundle: &LanguageBitNetArtifactBundle,
) -> LanguageBitNetArtifactBundleMetadata {
    LanguageBitNetArtifactBundleMetadata {
        schema_version: bundle.schema_version,
        source_checkpoint_epoch: bundle.source_checkpoint_epoch,
        source_training_config_sha256: bundle.source_training_config_sha256.clone(),
        source_run_dir: bundle.source_run_dir.clone(),
        kernel_abi_version: bundle.kernel_abi_version,
        quant: bundle.quant.clone(),
        rho: bundle.rho.clone(),
        deploy_base_burnpack: bundle
            .deploy_base_burnpack
            .as_ref()
            .map(|bytes| DeployBaseBurnpackMetadata { len: bytes.len() }),
        static_weights: BdhBitNetStaticArtifactsMetadata {
            decoder_x: bundle
                .static_weights
                .decoder_x
                .as_ref()
                .map(packed_weight_artifact_metadata),
            decoder_y: bundle
                .static_weights
                .decoder_y
                .as_ref()
                .map(packed_weight_artifact_metadata),
            encoder: bundle
                .static_weights
                .encoder
                .as_ref()
                .map(packed_weight_artifact_metadata),
        },
    }
}

fn rebuild_packed_weight_artifact(
    metadata: Option<PackedWeightArtifactMetadata>,
    payload: &[u8],
    offset: &mut usize,
    label: &str,
    name: &str,
) -> Result<Option<burn_dragon_core::experimental::bitnet_reference::PackedWeightArtifact>> {
    let Some(metadata) = metadata else {
        return Ok(None);
    };
    let end = offset.saturating_add(metadata.packed_len);
    if end > payload.len() {
        return Err(anyhow!(
            "bitnet artifact {label} truncated while reading `{name}` payload"
        ));
    }
    let artifact = burn_dragon_core::experimental::bitnet_reference::PackedWeightArtifact {
        encoding: metadata.encoding,
        logical_shape: metadata.logical_shape,
        scale: metadata.scale,
        packed: payload[*offset..end].to_vec(),
        len: metadata.len,
    };
    *offset = end;
    Ok(Some(artifact))
}

fn rebuild_deploy_base_burnpack(
    metadata: Option<DeployBaseBurnpackMetadata>,
    payload: &[u8],
    offset: &mut usize,
    label: &str,
) -> Result<Option<Vec<u8>>> {
    let Some(metadata) = metadata else {
        return Ok(None);
    };
    let end = offset.saturating_add(metadata.len);
    if end > payload.len() {
        return Err(anyhow!(
            "bitnet artifact {label} truncated while reading deploy base burnpack payload"
        ));
    }
    let bytes = payload[*offset..end].to_vec();
    *offset = end;
    Ok(Some(bytes))
}

pub fn serialize_bitnet_artifact_binary(bundle: &LanguageBitNetArtifactBundle) -> Result<Vec<u8>> {
    let metadata = bitnet_artifact_bundle_metadata(bundle);
    let metadata_json =
        serde_json::to_vec(&metadata).context("serialize compact bitnet artifact metadata")?;
    let mut bytes = Vec::with_capacity(
        BITNET_ARTIFACT_BINARY_MAGIC.len()
            + core::mem::size_of::<u64>()
            + metadata_json.len()
            + bundle
                .deploy_base_burnpack
                .as_ref()
                .map_or(0, |payload| payload.len())
            + bundle
                .static_weights
                .decoder_x
                .as_ref()
                .map_or(0, |artifact| artifact.packed.len())
            + bundle
                .static_weights
                .decoder_y
                .as_ref()
                .map_or(0, |artifact| artifact.packed.len())
            + bundle
                .static_weights
                .encoder
                .as_ref()
                .map_or(0, |artifact| artifact.packed.len()),
    );
    bytes.extend_from_slice(BITNET_ARTIFACT_BINARY_MAGIC);
    bytes.extend_from_slice(&(metadata_json.len() as u64).to_le_bytes());
    bytes.extend_from_slice(&metadata_json);
    if let Some(deploy_base_burnpack) = bundle.deploy_base_burnpack.as_ref() {
        bytes.extend_from_slice(deploy_base_burnpack);
    }
    if let Some(artifact) = bundle.static_weights.decoder_x.as_ref() {
        bytes.extend_from_slice(&artifact.packed);
    }
    if let Some(artifact) = bundle.static_weights.decoder_y.as_ref() {
        bytes.extend_from_slice(&artifact.packed);
    }
    if let Some(artifact) = bundle.static_weights.encoder.as_ref() {
        bytes.extend_from_slice(&artifact.packed);
    }
    Ok(bytes)
}

pub fn deserialize_bitnet_artifact_binary(
    bytes: &[u8],
    label: &str,
) -> Result<LanguageBitNetArtifactBundle> {
    if !bytes.starts_with(BITNET_ARTIFACT_BINARY_MAGIC) {
        return Err(anyhow!(
            "BitNet artifact {label} is not in the supported binary format; re-export it as `.bitnet_artifact.bin` or `.bitnet_artifact.bin.gz`"
        ));
    }
    if bytes.len() < BITNET_ARTIFACT_BINARY_MAGIC.len() + core::mem::size_of::<u64>() {
        return Err(anyhow!(
            "bitnet artifact {label} is too small to contain binary header"
        ));
    }
    let metadata_len_offset = BITNET_ARTIFACT_BINARY_MAGIC.len();
    let metadata_len_end = metadata_len_offset + core::mem::size_of::<u64>();
    let metadata_len = u64::from_le_bytes(
        bytes[metadata_len_offset..metadata_len_end]
            .try_into()
            .expect("fixed-size metadata header"),
    ) as usize;
    let metadata_start = metadata_len_end;
    let metadata_end = metadata_start.saturating_add(metadata_len);
    if metadata_end > bytes.len() {
        return Err(anyhow!(
            "bitnet artifact {label} has truncated binary metadata"
        ));
    }
    let metadata: LanguageBitNetArtifactBundleMetadata =
        serde_json::from_slice(&bytes[metadata_start..metadata_end]).with_context(|| {
            format!("failed to parse compact bitnet artifact metadata from {label}")
        })?;
    let payload = &bytes[metadata_end..];
    let mut offset = 0usize;
    let deploy_base_burnpack =
        rebuild_deploy_base_burnpack(metadata.deploy_base_burnpack, payload, &mut offset, label)?;
    let static_weights =
        burn_dragon_core::experimental::bitnet_reference::BdhBitNetStaticArtifacts {
            decoder_x: rebuild_packed_weight_artifact(
                metadata.static_weights.decoder_x,
                payload,
                &mut offset,
                label,
                "decoder_x",
            )?,
            decoder_y: rebuild_packed_weight_artifact(
                metadata.static_weights.decoder_y,
                payload,
                &mut offset,
                label,
                "decoder_y",
            )?,
            encoder: rebuild_packed_weight_artifact(
                metadata.static_weights.encoder,
                payload,
                &mut offset,
                label,
                "encoder",
            )?,
        };
    if offset != payload.len() {
        return Err(anyhow!(
            "bitnet artifact {label} has {} trailing payload bytes",
            payload.len().saturating_sub(offset)
        ));
    }
    Ok(LanguageBitNetArtifactBundle {
        schema_version: metadata.schema_version,
        source_checkpoint_epoch: metadata.source_checkpoint_epoch,
        source_training_config_sha256: metadata.source_training_config_sha256,
        source_run_dir: metadata.source_run_dir,
        kernel_abi_version: metadata.kernel_abi_version,
        quant: metadata.quant,
        rho: metadata.rho,
        deploy_base_burnpack,
        static_weights,
    })
}