use crate::claude::config::ClaudeConfig;
use std::collections::HashMap;
use trustformers_core::{
errors::Result,
layers::{Embedding, LayerNorm, Linear},
ops::activations::silu,
tensor::Tensor,
traits::{Layer, Model},
};
pub struct ClaudeAttention {
config: ClaudeConfig,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
rotary_emb: RotaryEmbedding,
attention_dropout: f32,
scale: f32,
}
impl ClaudeAttention {
pub fn new(config: ClaudeConfig) -> Result<Self> {
let hidden_size = config.hidden_size;
let num_heads = config.num_attention_heads;
let num_kv_heads = config.num_kv_heads();
let head_dim = config.head_dim();
let q_proj = Linear::new(hidden_size, num_heads * head_dim, false);
let k_proj = Linear::new(hidden_size, num_kv_heads * head_dim, false);
let v_proj = Linear::new(hidden_size, num_kv_heads * head_dim, false);
let o_proj = Linear::new(num_heads * head_dim, hidden_size, false);
let rotary_emb =
RotaryEmbedding::new(head_dim, config.max_position_embeddings, config.rope_theta);
let scale = 1.0 / (head_dim as f32).sqrt();
Ok(Self {
attention_dropout: config.attention_dropout,
config,
q_proj,
k_proj,
v_proj,
o_proj,
rotary_emb,
scale,
})
}
}
impl Layer for ClaudeAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, hidden_states: Self::Input) -> Result<Self::Output> {
let seq_len = hidden_states.shape()[1];
let batch_size = hidden_states.shape()[0];
let query_states = self.q_proj.forward(hidden_states.clone())?;
let key_states = self.k_proj.forward(hidden_states.clone())?;
let value_states = self.v_proj.forward(hidden_states)?;
let query_states = query_states.reshape(&[
batch_size,
seq_len,
self.config.num_attention_heads,
self.config.head_dim(),
])?;
let key_states = key_states.reshape(&[
batch_size,
seq_len,
self.config.num_kv_heads(),
self.config.head_dim(),
])?;
let value_states = value_states.reshape(&[
batch_size,
seq_len,
self.config.num_kv_heads(),
self.config.head_dim(),
])?;
let position_ids: Vec<usize> = (0..seq_len).collect();
let (query_states, key_states) =
self.rotary_emb.apply_rotary_emb(&query_states, &key_states, &position_ids)?;
let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
let attn_weights = attn_weights.mul_scalar(self.scale)?;
let causal_mask = create_causal_mask(seq_len)?;
let mask_value = Tensor::from_vec(vec![f32::NEG_INFINITY], &[1])?;
let attn_weights = attn_weights.add(&causal_mask.mul(&mask_value)?)?;
let attn_weights = attn_weights.softmax(3)?;
let attn_weights = if self.attention_dropout > 0.0 {
attn_weights.dropout(self.attention_dropout)?
} else {
attn_weights
};
let attn_output = attn_weights.matmul(&value_states)?;
let attn_output = attn_output.reshape(&[batch_size, seq_len, self.config.hidden_size])?;
let attn_output = self.o_proj.forward(attn_output)?;
Ok(attn_output)
}
}
impl ClaudeAttention {
pub fn parameter_count(&self) -> usize {
self.q_proj.parameter_count()
+ self.k_proj.parameter_count()
+ self.v_proj.parameter_count()
+ self.o_proj.parameter_count()
}
}
pub struct ClaudeMLP {
#[allow(dead_code)]
config: ClaudeConfig,
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
dropout: f32,
}
impl ClaudeMLP {
pub fn new(config: ClaudeConfig) -> Result<Self> {
let hidden_size = config.hidden_size;
let intermediate_size = config.intermediate_size;
let gate_proj = Linear::new(hidden_size, intermediate_size, false);
let up_proj = Linear::new(hidden_size, intermediate_size, false);
let down_proj = Linear::new(intermediate_size, hidden_size, false);
Ok(Self {
dropout: config.ffn_dropout,
config,
gate_proj,
up_proj,
down_proj,
})
}
}
impl Layer for ClaudeMLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, hidden_states: Self::Input) -> Result<Self::Output> {
let gate_output = self.gate_proj.forward(hidden_states.clone())?;
let up_output = self.up_proj.forward(hidden_states)?;
let gate_output = silu(&gate_output)?;
let intermediate = gate_output.mul(&up_output)?;
let intermediate = if self.dropout > 0.0 {
intermediate.dropout(self.dropout)?
} else {
intermediate
};
let output = self.down_proj.forward(intermediate)?;
Ok(output)
}
}
impl ClaudeMLP {
pub fn parameter_count(&self) -> usize {
self.gate_proj.parameter_count()
+ self.up_proj.parameter_count()
+ self.down_proj.parameter_count()
}
}
pub struct ClaudeDecoderLayer {
#[allow(dead_code)]
config: ClaudeConfig,
self_attn: ClaudeAttention,
mlp: ClaudeMLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
#[allow(dead_code)]
constitutional_ai: bool,
}
impl ClaudeDecoderLayer {
pub fn new(config: ClaudeConfig) -> Result<Self> {
let self_attn = ClaudeAttention::new(config.clone())?;
let mlp = ClaudeMLP::new(config.clone())?;
let input_layernorm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?;
let post_attention_layernorm =
LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?;
Ok(Self {
constitutional_ai: config.constitutional_ai,
config,
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
}
impl Layer for ClaudeDecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, hidden_states: Self::Input) -> Result<Self::Output> {
let residual = hidden_states.clone();
let hidden_states = self.input_layernorm.forward(hidden_states)?;
let attn_output = self.self_attn.forward(hidden_states)?;
let hidden_states = residual.add(&attn_output)?;
let residual = hidden_states.clone();
let hidden_states = self.post_attention_layernorm.forward(hidden_states)?;
let mlp_output = self.mlp.forward(hidden_states)?;
let hidden_states = residual.add(&mlp_output)?;
Ok(hidden_states)
}
}
impl ClaudeDecoderLayer {
pub fn parameter_count(&self) -> usize {
self.self_attn.parameter_count()
+ self.mlp.parameter_count()
+ self.input_layernorm.parameter_count()
+ self.post_attention_layernorm.parameter_count()
}
}
pub struct ClaudeModel {
config: ClaudeConfig,
embed_tokens: Embedding,
layers: Vec<ClaudeDecoderLayer>,
norm: LayerNorm,
constitutional_weights: Option<HashMap<String, f32>>,
}
impl ClaudeModel {
pub fn new(config: ClaudeConfig) -> Result<Self> {
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(ClaudeDecoderLayer::new(config.clone())?);
}
let norm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?;
let constitutional_weights = if config.constitutional_ai {
let mut weights = HashMap::new();
weights.insert("harmlessness".to_string(), config.harmlessness_weight);
weights.insert("helpfulness".to_string(), config.helpfulness_weight);
weights.insert("honesty".to_string(), config.honesty_weight);
Some(weights)
} else {
None
};
Ok(Self {
config,
embed_tokens,
layers,
norm,
constitutional_weights,
})
}
pub fn apply_constitutional_ai(&self, hidden_states: &Tensor) -> Result<Tensor> {
if let Some(weights) = &self.constitutional_weights {
let mut result = hidden_states.clone();
if let Some(&harmlessness_weight) = weights.get("harmlessness") {
result = result.mul_scalar(harmlessness_weight)?;
}
if let Some(&helpfulness_weight) = weights.get("helpfulness") {
result = result.mul_scalar(helpfulness_weight)?;
}
if let Some(&honesty_weight) = weights.get("honesty") {
result = result.mul_scalar(honesty_weight)?;
}
Ok(result)
} else {
Ok(hidden_states.clone())
}
}
}
impl Model for ClaudeModel {
type Config = ClaudeConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input_ids: Self::Input) -> Result<Self::Output> {
let input_ids_vec: Vec<u32> =
input_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let mut hidden_states = self.embed_tokens.forward(input_ids_vec)?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
hidden_states = self.norm.forward(hidden_states)?;
hidden_states = self.apply_constitutional_ai(&hidden_states)?;
Ok(hidden_states)
}
fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
use trustformers_core::errors::invalid_input;
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| invalid_input(format!("Failed to read Claude weights: {}", e)))?;
if buffer.is_empty() {
return Err(invalid_input("Claude weight file is empty"));
}
if buffer.len() < 1000 {
return Err(invalid_input(
"Weight file appears to be too small or corrupted",
));
}
let expected_layers = self.config.num_hidden_layers;
let expected_hidden_size = self.config.hidden_size;
let vocab_size = self.config.vocab_size;
let embed_weight_size = vocab_size * expected_hidden_size * 4; if buffer.len() < embed_weight_size {
return Err(invalid_input(format!(
"Insufficient weights for embeddings. Expected: {}, Available: {}",
embed_weight_size,
buffer.len()
)));
}
let layer_weight_size_estimate = expected_hidden_size * expected_hidden_size * 4 * 4; let total_layer_weights = expected_layers * layer_weight_size_estimate;
let norm_weight_size = expected_hidden_size * 4;
let total_required = embed_weight_size + total_layer_weights + norm_weight_size;
if buffer.len() < total_required / 10 {
return Err(invalid_input(format!(
"Weight file appears incomplete. Expected roughly: {}, Got: {}",
total_required / 10,
buffer.len()
)));
}
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let embed_params = self.embed_tokens.parameter_count();
let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
let norm_params = self.norm.parameter_count();
embed_params + layers_params + norm_params
}
}
pub struct ClaudeForCausalLM {
model: ClaudeModel,
lm_head: Linear,
config: ClaudeConfig,
}
impl ClaudeForCausalLM {
pub fn new(config: ClaudeConfig) -> Result<Self> {
let model = ClaudeModel::new(config.clone())?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
Ok(Self {
model,
lm_head,
config,
})
}
pub fn generate_with_constitutional_ai(
&self,
input_ids: Tensor,
max_new_tokens: usize,
temperature: f32,
top_p: f32,
) -> Result<Tensor> {
let mut current_ids = input_ids;
for _ in 0..max_new_tokens {
let hidden_states = self.model.forward(current_ids.clone())?;
let logits = self.lm_head.forward(hidden_states)?;
let logits = logits.div_scalar(temperature)?;
let probs = logits.softmax(-1)?;
let next_token = sample_from_distribution(&probs, top_p)?;
current_ids = Tensor::concat(&[current_ids, next_token], 0)?;
}
Ok(current_ids)
}
}
impl Model for ClaudeForCausalLM {
type Config = ClaudeConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input_ids: Self::Input) -> Result<Self::Output> {
let hidden_states = self.model.forward(input_ids)?;
let logits = self.lm_head.forward(hidden_states)?;
Ok(logits)
}
fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
use trustformers_core::errors::invalid_input;
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| invalid_input(format!("Failed to read Claude weights: {}", e)))?;
if buffer.is_empty() {
return Err(invalid_input("Claude weight file is empty"));
}
if buffer.len() < 1000 {
return Err(invalid_input(
"Weight file appears to be too small or corrupted",
));
}
let expected_layers = self.config.num_hidden_layers;
let expected_hidden_size = self.config.hidden_size;
let vocab_size = self.config.vocab_size;
let embed_weight_size = vocab_size * expected_hidden_size * 4; if buffer.len() < embed_weight_size {
return Err(invalid_input(format!(
"Insufficient weights for embeddings. Expected: {}, Available: {}",
embed_weight_size,
buffer.len()
)));
}
let layer_weight_size_estimate = expected_hidden_size * expected_hidden_size * 4 * 4; let total_layer_weights = expected_layers * layer_weight_size_estimate;
let norm_weight_size = expected_hidden_size * 4;
let total_required = embed_weight_size + total_layer_weights + norm_weight_size;
if buffer.len() < total_required / 10 {
return Err(invalid_input(format!(
"Weight file appears incomplete. Expected roughly: {}, Got: {}",
total_required / 10,
buffer.len()
)));
}
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
self.model.num_parameters() + self.lm_head.parameter_count()
}
}
pub struct RotaryEmbedding {
#[allow(dead_code)]
dim: usize,
#[allow(dead_code)]
max_seq_len: usize,
#[allow(dead_code)]
base: f32,
}
impl RotaryEmbedding {
pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Self {
Self {
dim,
max_seq_len,
base,
}
}
pub fn apply_rotary_emb(
&self,
q: &Tensor,
k: &Tensor,
_position_ids: &[usize],
) -> Result<(Tensor, Tensor)> {
Ok((q.clone(), k.clone()))
}
}
fn create_causal_mask(seq_len: usize) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = 1.0; }
}
Tensor::from_vec(mask_data, &[seq_len, seq_len])
}
fn sample_from_distribution(_probs: &Tensor, _top_p: f32) -> Result<Tensor> {
Tensor::zeros(&[1, 1])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::claude::config::ClaudeConfig;
use trustformers_core::{
tensor::Tensor,
traits::{Config, Model},
};
#[test]
fn test_default_config_vocab_size() {
let cfg = ClaudeConfig::default();
assert_eq!(
cfg.vocab_size, 100352,
"Claude default vocab_size should be 100352"
);
}
#[test]
fn test_default_config_constitutional_ai_enabled() {
let cfg = ClaudeConfig::default();
assert!(
cfg.constitutional_ai,
"Constitutional AI should be enabled by default"
);
}
#[test]
fn test_default_harmlessness_weight() {
let cfg = ClaudeConfig::default();
assert!(
(cfg.harmlessness_weight - 1.0).abs() < 1e-4,
"default harmlessness_weight = 1.0"
);
}
#[test]
fn test_default_helpfulness_weight() {
let cfg = ClaudeConfig::default();
assert!(
(cfg.helpfulness_weight - 1.0).abs() < 1e-4,
"default helpfulness_weight = 1.0"
);
}
#[test]
fn test_default_honesty_weight() {
let cfg = ClaudeConfig::default();
assert!(
(cfg.honesty_weight - 1.0).abs() < 1e-4,
"default honesty_weight = 1.0"
);
}
#[test]
fn test_small_test_config_valid() {
ClaudeConfig::small_test_config()
.validate()
.expect("small_test_config should be valid");
}
#[test]
fn test_config_validate_negative_weight_fails() {
let mut cfg = ClaudeConfig::small_test_config();
cfg.harmlessness_weight = -0.1;
assert!(
cfg.validate().is_err(),
"negative constitutional weight should fail validation"
);
}
#[test]
fn test_head_dim_computation() {
let cfg = ClaudeConfig::small_test_config();
assert_eq!(
cfg.head_dim(),
cfg.hidden_size / cfg.num_attention_heads,
"head_dim = hidden_size / num_attention_heads"
);
}
#[test]
fn test_num_kv_heads_defaults_to_num_attention_heads() {
let mut cfg = ClaudeConfig::small_test_config();
cfg.num_key_value_heads = None;
assert_eq!(
cfg.num_kv_heads(),
cfg.num_attention_heads,
"when num_kv_heads is None, num_kv_heads() == num_attention_heads"
);
}
#[test]
fn test_num_query_groups_with_gqa() {
let cfg = ClaudeConfig::claude_3_haiku(); let expected = cfg.num_attention_heads / cfg.num_kv_heads();
assert_eq!(
cfg.num_query_groups(),
expected,
"num_query_groups = heads / kv_heads"
);
}
#[test]
fn test_model_creation() {
let cfg = ClaudeConfig::small_test_config();
ClaudeModel::new(cfg).expect("ClaudeModel creation should succeed");
}
#[test]
fn test_model_parameter_count_nonzero() {
let cfg = ClaudeConfig::small_test_config();
let model = ClaudeModel::new(cfg).expect("model creation should succeed");
assert!(
model.num_parameters() > 0,
"model must have non-zero parameters"
);
}
#[test]
fn test_constitutional_ai_apply_preserves_shape() {
let cfg = ClaudeConfig::small_test_config();
let model = ClaudeModel::new(cfg.clone()).expect("model creation should succeed");
let hidden = Tensor::from_vec(vec![0.5_f32; cfg.hidden_size], &[cfg.hidden_size])
.expect("tensor creation should succeed");
let result = model
.apply_constitutional_ai(&hidden)
.expect("constitutional AI should succeed");
assert_eq!(result.shape(), hidden.shape(), "shape must be preserved");
}
#[test]
fn test_constitutional_ai_disabled_returns_unchanged() {
let mut cfg = ClaudeConfig::small_test_config();
cfg.constitutional_ai = false;
let model = ClaudeModel::new(cfg.clone()).expect("model creation should succeed");
let hidden = Tensor::from_vec(vec![1.0_f32; cfg.hidden_size], &[cfg.hidden_size])
.expect("tensor creation should succeed");
let result = model.apply_constitutional_ai(&hidden).expect("should succeed");
let orig_vals = hidden.to_vec_f32().expect("to_vec_f32 should succeed");
let result_vals = result.to_vec_f32().expect("to_vec_f32 should succeed");
for (o, r) in orig_vals.iter().zip(result_vals.iter()) {
assert!(
(o - r).abs() < 1e-5,
"disabled CAI must return input unchanged"
);
}
}
#[test]
fn test_claude_3_haiku_smaller_than_opus() {
let haiku = ClaudeConfig::claude_3_haiku();
let opus = ClaudeConfig::claude_3_opus();
assert!(
haiku.hidden_size < opus.hidden_size,
"Haiku should have smaller hidden_size than Opus"
);
}
#[test]
fn test_from_pretrained_name_valid() {
let cfg = ClaudeConfig::from_pretrained_name("claude-2");
assert!(cfg.is_some(), "claude-2 should be a known pretrained name");
}
#[test]
fn test_from_pretrained_name_unknown_returns_none() {
let cfg = ClaudeConfig::from_pretrained_name("unknown-model-xyz");
assert!(cfg.is_none(), "unknown model name should return None");
}
}