use oxibonsai_core::gguf::tensor_info::keys;
use oxibonsai_core::gguf::writer::{GgufWriter, MetadataWriteValue};
#[derive(Debug, Clone, Copy)]
pub struct DitArch {
pub num_layers: u32,
pub num_single_layers: u32,
pub num_attention_heads: u32,
pub attention_head_dim: u32,
pub joint_attention_dim: u32,
pub in_channels: u32,
pub mlp_ratio: f32,
pub axes_dims_rope: [u32; 4],
pub rope_theta: f32,
pub guidance_embeds: bool,
}
impl Default for DitArch {
fn default() -> Self {
Self {
num_layers: 5,
num_single_layers: 20,
num_attention_heads: 24,
attention_head_dim: 128,
joint_attention_dim: 7680,
in_channels: 128,
mlp_ratio: 3.0,
axes_dims_rope: [32, 32, 32, 32],
rope_theta: 2000.0,
guidance_embeds: false,
}
}
}
pub mod arch_keys {
pub const ARCHITECTURE: &str = "bonsai-image";
pub const NUM_LAYERS: &str = "bonsai-image.num_layers";
pub const NUM_SINGLE_LAYERS: &str = "bonsai-image.num_single_layers";
pub const ATTENTION_HEAD_COUNT: &str = "bonsai-image.attention.head_count";
pub const ATTENTION_HEAD_DIM: &str = "bonsai-image.attention.head_dim";
pub const JOINT_ATTENTION_DIM: &str = "bonsai-image.joint_attention_dim";
pub const IN_CHANNELS: &str = "bonsai-image.in_channels";
pub const MLP_RATIO: &str = "bonsai-image.mlp_ratio";
pub const AXES_DIMS_ROPE: &str = "bonsai-image.rope.axes_dims";
pub const ROPE_THETA: &str = "bonsai-image.rope.theta";
pub const GUIDANCE_EMBEDS: &str = "bonsai-image.guidance_embeds";
}
pub fn write_dit_metadata(writer: &mut GgufWriter, arch: &DitArch, model_name: &str) {
writer.add_metadata(
keys::GENERAL_ARCHITECTURE,
MetadataWriteValue::Str(arch_keys::ARCHITECTURE.to_string()),
);
writer.add_metadata(
keys::GENERAL_NAME,
MetadataWriteValue::Str(model_name.to_string()),
);
writer.add_metadata(
"general.quantization_version",
MetadataWriteValue::Str("TQ2_0_G128".to_string()),
);
writer.add_metadata(
arch_keys::NUM_LAYERS,
MetadataWriteValue::U32(arch.num_layers),
);
writer.add_metadata(
arch_keys::NUM_SINGLE_LAYERS,
MetadataWriteValue::U32(arch.num_single_layers),
);
writer.add_metadata(
arch_keys::ATTENTION_HEAD_COUNT,
MetadataWriteValue::U32(arch.num_attention_heads),
);
writer.add_metadata(
arch_keys::ATTENTION_HEAD_DIM,
MetadataWriteValue::U32(arch.attention_head_dim),
);
writer.add_metadata(
arch_keys::JOINT_ATTENTION_DIM,
MetadataWriteValue::U32(arch.joint_attention_dim),
);
writer.add_metadata(
arch_keys::IN_CHANNELS,
MetadataWriteValue::U32(arch.in_channels),
);
writer.add_metadata(
arch_keys::MLP_RATIO,
MetadataWriteValue::F32(arch.mlp_ratio),
);
writer.add_metadata(
arch_keys::AXES_DIMS_ROPE,
MetadataWriteValue::ArrayU32(arch.axes_dims_rope.to_vec()),
);
writer.add_metadata(
arch_keys::ROPE_THETA,
MetadataWriteValue::F32(arch.rope_theta),
);
writer.add_metadata(
arch_keys::GUIDANCE_EMBEDS,
MetadataWriteValue::Bool(arch.guidance_embeds),
);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_arch_matches_design_doc() {
let a = DitArch::default();
assert_eq!(a.num_layers, 5);
assert_eq!(a.num_single_layers, 20);
assert_eq!(a.num_attention_heads, 24);
assert_eq!(a.attention_head_dim, 128);
assert_eq!(a.joint_attention_dim, 7680);
assert_eq!(a.in_channels, 128);
assert_eq!(a.mlp_ratio, 3.0);
assert_eq!(a.axes_dims_rope, [32, 32, 32, 32]);
assert_eq!(a.rope_theta, 2000.0);
assert!(!a.guidance_embeds);
}
#[test]
fn metadata_writes_without_panicking() {
let mut w = GgufWriter::new();
write_dit_metadata(&mut w, &DitArch::default(), "bonsai-image-4B");
let bytes = w.to_bytes().expect("serialise metadata-only file");
assert_eq!(
u32::from_le_bytes(bytes[0..4].try_into().expect("slice")),
0x4655_4747
);
}
}