use anyhow::Result;
use ordered_float::OrderedFloat;
use sapient_hub::model_info::ModelInfo;
use sapient_ir::{graph::Graph, op::OpType};
pub fn build(info: &ModelInfo) -> Result<Graph> {
let mut g = Graph::new(format!("llama_{}", info.model_type));
let input_ids = g.add_input("input_ids", None, None); let _attn_mask = g.add_input("attention_mask", None, None); let _pos_ids = g.add_input("position_ids", None, None);
let mut x = g.add_op(
OpType::Embedding {
vocab_size: info.vocab_size,
dim: info.hidden_size,
},
vec![input_ids],
1,
Some("embed_tokens".into()),
);
for layer_idx in 0..info.num_hidden_layers {
x = build_decoder_layer(&mut g, x, info, layer_idx);
}
let normed = g.add_op(
OpType::RmsNorm {
epsilon: OrderedFloat(info.rms_norm_eps),
},
vec![x],
1,
Some("norm".into()),
);
let logits = g.add_op(OpType::MatMul, vec![normed], 1, Some("lm_head".into()));
g.mark_output(logits, "logits");
Ok(g)
}
fn build_decoder_layer(
g: &mut Graph,
x: sapient_ir::node::NodeId,
info: &ModelInfo,
idx: usize,
) -> sapient_ir::node::NodeId {
let pfx = format!("layers.{idx}");
let eps = OrderedFloat(info.rms_norm_eps);
let attn_norm = g.add_op(
OpType::RmsNorm { epsilon: eps },
vec![x],
1,
Some(format!("{pfx}.input_layernorm")),
);
let q = g.add_op(
OpType::MatMul,
vec![attn_norm],
1,
Some(format!("{pfx}.self_attn.q_proj")),
);
let k = g.add_op(
OpType::MatMul,
vec![attn_norm],
1,
Some(format!("{pfx}.self_attn.k_proj")),
);
let v = g.add_op(
OpType::MatMul,
vec![attn_norm],
1,
Some(format!("{pfx}.self_attn.v_proj")),
);
let q_rope = g.add_op(
OpType::RotaryEmbedding {
base: OrderedFloat(info.rope_theta),
dim: info.head_dim,
},
vec![q],
1,
Some(format!("{pfx}.self_attn.q_rope")),
);
let k_rope = g.add_op(
OpType::RotaryEmbedding {
base: OrderedFloat(info.rope_theta),
dim: info.head_dim,
},
vec![k],
1,
Some(format!("{pfx}.self_attn.k_rope")),
);
let attn_out = g.add_op(
OpType::GroupedQueryAttention {
n_heads: info.num_attention_heads,
n_kv_heads: info.num_key_value_heads,
head_dim: info.head_dim,
causal: true,
},
vec![q_rope, k_rope, v],
1,
Some(format!("{pfx}.self_attn.gqa")),
);
let o_proj = g.add_op(
OpType::MatMul,
vec![attn_out],
1,
Some(format!("{pfx}.self_attn.o_proj")),
);
let x = g.add_op(
OpType::Add,
vec![x, o_proj],
1,
Some(format!("{pfx}.attn_residual")),
);
let ffn_norm = g.add_op(
OpType::RmsNorm { epsilon: eps },
vec![x],
1,
Some(format!("{pfx}.post_attention_layernorm")),
);
let gate = g.add_op(
OpType::MatMul,
vec![ffn_norm],
1,
Some(format!("{pfx}.mlp.gate_proj")),
);
let up = g.add_op(
OpType::MatMul,
vec![ffn_norm],
1,
Some(format!("{pfx}.mlp.up_proj")),
);
let gate_act = g.add_op(OpType::Silu, vec![gate], 1, Some(format!("{pfx}.mlp.silu")));
let ffn_mid = g.add_op(
OpType::Mul,
vec![gate_act, up],
1,
Some(format!("{pfx}.mlp.gate_mul")),
);
let down = g.add_op(
OpType::MatMul,
vec![ffn_mid],
1,
Some(format!("{pfx}.mlp.down_proj")),
);
g.add_op(
OpType::Add,
vec![x, down],
1,
Some(format!("{pfx}.ffn_residual")),
)
}
#[cfg(test)]
mod tests {
use super::*;
use sapient_hub::model_info::ModelInfo;
const TINY_LLAMA_CFG: &str = r#"{
"architectures": ["LlamaForCausalLM"],
"model_type": "llama",
"vocab_size": 1000,
"hidden_size": 64,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"intermediate_size": 128,
"max_position_embeddings": 512,
"rms_norm_eps": 1e-5,
"hidden_act": "silu",
"rope_theta": 10000.0
}"#;
#[test]
fn tiny_llama_builds() {
let info = ModelInfo::from_json_str(TINY_LLAMA_CFG).unwrap();
let g = build(&info).unwrap();
assert!(g.node_count() > 10, "graph should have many nodes");
assert_eq!(g.outputs.len(), 1, "should have one output (logits)");
}
}