use crate::{CanonicalType, FieldDescriptor, LayoutManifest, SegmentFieldDescriptor};
use std::fmt::Write;
fn ts_type(ct: &CanonicalType) -> &'static str {
match ct {
CanonicalType::U8 | CanonicalType::U16 | CanonicalType::U32 => "number",
CanonicalType::I8 | CanonicalType::I16 | CanonicalType::I32 => "number",
CanonicalType::U64 | CanonicalType::U128 => "bigint",
CanonicalType::I64 | CanonicalType::I128 => "bigint",
CanonicalType::Bool => "boolean",
CanonicalType::Pubkey => "PublicKey",
CanonicalType::Header => "JiminyHeader",
CanonicalType::Bytes(_) => "Uint8Array",
}
}
fn ts_read(field: &FieldDescriptor, offset: usize) -> String {
match field.canonical_type {
CanonicalType::U8 => format!("data[{}]", offset),
CanonicalType::I8 => format!("new Int8Array(data.buffer, data.byteOffset + {}, 1)[0]", offset),
CanonicalType::U16 => format!("view.getUint16({}, true)", offset),
CanonicalType::I16 => format!("view.getInt16({}, true)", offset),
CanonicalType::U32 => format!("view.getUint32({}, true)", offset),
CanonicalType::I32 => format!("view.getInt32({}, true)", offset),
CanonicalType::U64 => format!("view.getBigUint64({}, true)", offset),
CanonicalType::I64 => format!("view.getBigInt64({}, true)", offset),
CanonicalType::U128 => format!(
"view.getBigUint64({}, true) | (view.getBigUint64({}, true) << 64n)",
offset, offset + 8
),
CanonicalType::I128 => format!(
"view.getBigUint64({}, true) | (view.getBigInt64({}, true) << 64n)",
offset, offset + 8
),
CanonicalType::Bool => format!("data[{}] !== 0", offset),
CanonicalType::Pubkey => format!("new PublicKey(data.slice({}, {}))", offset, offset + 32),
CanonicalType::Header => format!("decodeJiminyHeader(data, {})", offset),
CanonicalType::Bytes(n) => format!("data.slice({}, {})", offset, offset + n),
}
}
pub fn ts_decoder(manifest: &LayoutManifest) -> String {
let mut out = String::with_capacity(2048);
writeln!(out, "// Auto-generated by jiminy-schema codegen - do not edit").unwrap();
writeln!(out, "// Layout: {} v{} (disc={})", manifest.name, manifest.version, manifest.discriminator).unwrap();
writeln!(out, "import {{ PublicKey }} from '@solana/web3.js';").unwrap();
writeln!(out).unwrap();
let has_header = manifest.fields.iter().any(|f| matches!(f.canonical_type, CanonicalType::Header));
if has_header {
writeln!(out, "export interface JiminyHeader {{").unwrap();
writeln!(out, " discriminator: number;").unwrap();
writeln!(out, " version: number;").unwrap();
writeln!(out, " flags: number;").unwrap();
writeln!(out, " layoutId: Uint8Array;").unwrap();
writeln!(out, " reserved: Uint8Array;").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "function decodeJiminyHeader(data: Uint8Array, offset: number): JiminyHeader {{").unwrap();
writeln!(out, " const view = new DataView(data.buffer, data.byteOffset + offset, 16);").unwrap();
writeln!(out, " return {{").unwrap();
writeln!(out, " discriminator: data[offset],").unwrap();
writeln!(out, " version: data[offset + 1],").unwrap();
writeln!(out, " flags: view.getUint16(2, true),").unwrap();
writeln!(out, " layoutId: data.slice(offset + 4, offset + 12),").unwrap();
writeln!(out, " reserved: data.slice(offset + 12, offset + 16),").unwrap();
writeln!(out, " }};").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
let lid = manifest.layout_id;
writeln!(
out,
"export const {}_LAYOUT_ID = new Uint8Array([{}]);",
manifest.name.to_uppercase(),
lid.iter().map(|b| b.to_string()).collect::<Vec<_>>().join(", ")
).unwrap();
writeln!(out, "export const {}_SIZE = {};", manifest.name.to_uppercase(), manifest.total_size()).unwrap();
writeln!(out).unwrap();
writeln!(out, "export interface {} {{", manifest.name).unwrap();
for field in manifest.fields {
writeln!(out, " {}: {};", field.name, ts_type(&field.canonical_type)).unwrap();
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "export function decode{}(data: Uint8Array): {} {{", manifest.name, manifest.name).unwrap();
writeln!(out, " if (data.length < {}) {{", manifest.total_size()).unwrap();
writeln!(out, " throw new Error('Account data too short: expected {} bytes, got ' + data.length);", manifest.total_size()).unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " if (data[0] !== {}) {{", manifest.discriminator).unwrap();
writeln!(out, " throw new Error('Invalid discriminator: expected {}, got ' + data[0]);", manifest.discriminator).unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " const layoutId = data.slice(4, 12);").unwrap();
write!(out, " if (").unwrap();
for (i, b) in lid.iter().enumerate() {
if i > 0 { write!(out, " || ").unwrap(); }
write!(out, "layoutId[{}] !== {}", i, b).unwrap();
}
writeln!(out, ") {{").unwrap();
writeln!(out, " throw new Error('Layout ID mismatch');").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " const view = new DataView(data.buffer, data.byteOffset, data.length);").unwrap();
let mut offset = 0;
writeln!(out, " return {{").unwrap();
for field in manifest.fields {
writeln!(out, " {}: {},", field.name, ts_read(field, offset)).unwrap();
offset += field.size;
}
writeln!(out, " }};").unwrap();
writeln!(out, "}}").unwrap();
if !manifest.segments.is_empty() {
emit_segment_types(&mut out, manifest);
}
out
}
fn emit_segment_types(out: &mut String, manifest: &LayoutManifest) {
let fixed_size = manifest.total_size();
let seg_count = manifest.segments.len();
writeln!(out).unwrap();
writeln!(out, "export interface SegmentDescriptor {{").unwrap();
writeln!(out, " offset: number;").unwrap();
writeln!(out, " count: number;").unwrap();
writeln!(out, " elementSize: number;").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "function readSegmentDescriptor(view: DataView, pos: number): SegmentDescriptor {{").unwrap();
writeln!(out, " return {{").unwrap();
writeln!(out, " offset: view.getUint32(pos, true),").unwrap();
writeln!(out, " count: view.getUint16(pos + 4, true),").unwrap();
writeln!(out, " elementSize: view.getUint16(pos + 6, true),").unwrap();
writeln!(out, " }};").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "export interface {}Segments {{", manifest.name).unwrap();
for seg in manifest.segments {
writeln!(out, " {}: Uint8Array[];", seg.name).unwrap();
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "export const {}_SEGMENT_COUNT = {};", manifest.name.to_uppercase(), seg_count).unwrap();
writeln!(out, "export const {}_TABLE_OFFSET = {};", manifest.name.to_uppercase(), fixed_size).unwrap();
writeln!(out).unwrap();
writeln!(out, "export function decode{}Segments(data: Uint8Array): {}Segments {{", manifest.name, manifest.name).unwrap();
let table_end = fixed_size + seg_count * 8;
writeln!(out, " if (data.length < {}) {{", table_end).unwrap();
writeln!(out, " throw new Error('Account data too short for segment table: expected at least {} bytes, got ' + data.length);", table_end).unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, " const view = new DataView(data.buffer, data.byteOffset, data.length);").unwrap();
for (i, seg) in manifest.segments.iter().enumerate() {
let desc_off = fixed_size + i * 8;
writeln!(out, " const {}_desc = readSegmentDescriptor(view, {});", seg.name, desc_off).unwrap();
}
writeln!(out).unwrap();
for seg in manifest.segments {
writeln!(out, " if ({}_desc.elementSize !== {}) {{", seg.name, seg.element_size).unwrap();
writeln!(out, " throw new Error('Segment \"{}\" element size mismatch: expected {}, got ' + {}_desc.elementSize);", seg.name, seg.element_size, seg.name).unwrap();
writeln!(out, " }}").unwrap();
}
writeln!(out).unwrap();
writeln!(out, " return {{").unwrap();
for seg in manifest.segments {
writeln!(out, " {name}: Array.from({{ length: {name}_desc.count }}, (_, i) => {{", name = seg.name).unwrap();
writeln!(out, " const start = {name}_desc.offset + i * {size};", name = seg.name, size = seg.element_size).unwrap();
writeln!(out, " return data.slice(start, start + {});", seg.element_size).unwrap();
writeln!(out, " }}),").unwrap();
}
writeln!(out, " }};").unwrap();
writeln!(out, "}}").unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
fn vault_manifest() -> LayoutManifest {
LayoutManifest {
name: "Vault",
version: 1,
discriminator: 1,
layout_id: [0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89],
fields: &[
FieldDescriptor { name: "header", canonical_type: CanonicalType::Header, size: 16 },
FieldDescriptor { name: "balance", canonical_type: CanonicalType::U64, size: 8 },
FieldDescriptor { name: "authority", canonical_type: CanonicalType::Pubkey, size: 32 },
],
segments: &[],
}
}
#[test]
fn generates_interface() {
let ts = ts_decoder(&vault_manifest());
assert!(ts.contains("export interface Vault {"));
assert!(ts.contains("balance: bigint;"));
assert!(ts.contains("authority: PublicKey;"));
}
#[test]
fn generates_decoder_function() {
let ts = ts_decoder(&vault_manifest());
assert!(ts.contains("export function decodeVault(data: Uint8Array): Vault {"));
}
#[test]
fn includes_disc_check() {
let ts = ts_decoder(&vault_manifest());
assert!(ts.contains("data[0] !== 1"));
}
#[test]
fn includes_layout_id_constant() {
let ts = ts_decoder(&vault_manifest());
assert!(ts.contains("VAULT_LAYOUT_ID"));
assert!(ts.contains("171, 205, 239, 1, 35, 69, 103, 137"));
}
#[test]
fn includes_size_constant() {
let ts = ts_decoder(&vault_manifest());
assert!(ts.contains("VAULT_SIZE = 56"));
}
#[test]
fn includes_header_decoder() {
let ts = ts_decoder(&vault_manifest());
assert!(ts.contains("export interface JiminyHeader"));
assert!(ts.contains("function decodeJiminyHeader"));
}
fn segmented_manifest() -> LayoutManifest {
LayoutManifest {
name: "OrderBook",
version: 1,
discriminator: 5,
layout_id: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88],
fields: &[
FieldDescriptor { name: "header", canonical_type: CanonicalType::Header, size: 16 },
FieldDescriptor { name: "base_mint", canonical_type: CanonicalType::Pubkey, size: 32 },
],
segments: &[
SegmentFieldDescriptor { name: "bids", element_type: "Order", element_size: 48 },
SegmentFieldDescriptor { name: "asks", element_type: "Order", element_size: 48 },
],
}
}
#[test]
fn no_segments_for_fixed_layout() {
let ts = ts_decoder(&vault_manifest());
assert!(!ts.contains("SegmentDescriptor"));
assert!(!ts.contains("decodeVaultSegments"));
}
#[test]
fn emits_segment_descriptor_type() {
let ts = ts_decoder(&segmented_manifest());
assert!(ts.contains("export interface SegmentDescriptor {"));
assert!(ts.contains("readSegmentDescriptor"));
}
#[test]
fn emits_segments_interface() {
let ts = ts_decoder(&segmented_manifest());
assert!(ts.contains("export interface OrderBookSegments {"));
assert!(ts.contains("bids: Uint8Array[];"));
assert!(ts.contains("asks: Uint8Array[];"));
}
#[test]
fn emits_segment_constants() {
let ts = ts_decoder(&segmented_manifest());
assert!(ts.contains("ORDERBOOK_SEGMENT_COUNT = 2"));
assert!(ts.contains("ORDERBOOK_TABLE_OFFSET = 48"));
}
#[test]
fn emits_segment_decoder_function() {
let ts = ts_decoder(&segmented_manifest());
assert!(ts.contains("export function decodeOrderBookSegments(data: Uint8Array): OrderBookSegments {"));
}
#[test]
fn segment_decoder_validates_element_size() {
let ts = ts_decoder(&segmented_manifest());
assert!(ts.contains("bids_desc.elementSize !== 48"));
assert!(ts.contains("asks_desc.elementSize !== 48"));
}
}