use crate::llama3::config::LLaMA3Config;
use crate::llama3::model::LLaMA3ForCausalLM;
use trustformers_core::errors::Result;
use trustformers_core::tensor::Tensor;
pub struct LLaMA3CausalLMOutput {
pub logits: Tensor,
}
pub fn format_llama3_chat(system: &str, messages: &[(String, String)]) -> String {
const BEGIN: &str = "<|begin_of_text|>";
const START_HDR: &str = "<|start_header_id|>";
const END_HDR: &str = "<|end_header_id|>";
const EOT: &str = "<|eot_id|>";
let mut buf = String::new();
buf.push_str(BEGIN);
if !system.is_empty() {
buf.push_str(START_HDR);
buf.push_str("system");
buf.push_str(END_HDR);
buf.push('\n');
buf.push_str(system);
buf.push_str(EOT);
buf.push('\n');
}
for (role, content) in messages {
buf.push_str(START_HDR);
buf.push_str(role.as_str());
buf.push_str(END_HDR);
buf.push('\n');
buf.push_str(content.as_str());
buf.push_str(EOT);
buf.push('\n');
}
buf.push_str(START_HDR);
buf.push_str("assistant");
buf.push_str(END_HDR);
buf.push('\n');
buf
}
pub struct LLaMA3ChatModel {
inner: LLaMA3ForCausalLM,
}
impl LLaMA3ChatModel {
pub fn new(config: LLaMA3Config) -> Result<Self> {
let inner = LLaMA3ForCausalLM::new(config)?;
Ok(Self { inner })
}
pub fn config(&self) -> &LLaMA3Config {
self.inner.config()
}
pub fn parameter_count(&self) -> usize {
self.inner.parameter_count()
}
pub fn forward(&self, input_ids: Vec<u32>) -> Result<LLaMA3CausalLMOutput> {
let logits = self.inner.forward(input_ids)?;
Ok(LLaMA3CausalLMOutput { logits })
}
pub fn build_prompt(&self, system_prompt: &str, messages: &[(String, String)]) -> String {
format_llama3_chat(system_prompt, messages)
}
pub fn greedy_next_token(&self, logits: &Tensor) -> Result<u32> {
match logits {
Tensor::F32(arr) => {
let shape = arr.shape();
let vocab_size = if shape.len() >= 2 {
*shape.last().unwrap_or(&arr.len())
} else {
arr.len()
};
let flat: Vec<f32> = arr.iter().copied().collect();
let last_row = if flat.len() > vocab_size {
&flat[flat.len() - vocab_size..]
} else {
&flat[..]
};
let best = last_row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
.unwrap_or(0);
Ok(best)
},
_ => Ok(0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_begins_with_begin_token() {
let out = format_llama3_chat("", &[]);
assert!(
out.starts_with("<|begin_of_text|>"),
"must start with begin token"
);
}
#[test]
fn test_format_no_system_block_when_empty() {
let out = format_llama3_chat("", &[]);
assert!(!out.contains("system"), "no system block when empty");
}
#[test]
fn test_format_includes_system_block() {
let out = format_llama3_chat("You are helpful.", &[]);
assert!(out.contains("system"), "system role must appear");
assert!(
out.contains("You are helpful."),
"system content must appear"
);
assert!(out.contains("<|eot_id|>"), "system block must end with eot");
}
#[test]
fn test_format_user_message_present() {
let msgs = vec![("user".to_string(), "Hello world".to_string())];
let out = format_llama3_chat("", &msgs);
assert!(out.contains("user"), "user role present");
assert!(out.contains("Hello world"), "user content present");
}
#[test]
fn test_format_ends_with_open_assistant_turn() {
let msgs = vec![("user".to_string(), "Hi".to_string())];
let out = format_llama3_chat("sys", &msgs);
assert!(
out.ends_with("<|end_header_id|>\n"),
"must end with open assistant header"
);
}
#[test]
fn test_format_assistant_message() {
let msgs = vec![
("user".to_string(), "question".to_string()),
("assistant".to_string(), "answer".to_string()),
];
let out = format_llama3_chat("", &msgs);
assert!(out.contains("question"), "user message in output");
assert!(out.contains("answer"), "assistant message in output");
}
#[test]
fn test_format_multiple_rounds() {
let msgs = vec![
("user".to_string(), "turn 1".to_string()),
("assistant".to_string(), "reply 1".to_string()),
("user".to_string(), "turn 2".to_string()),
];
let out = format_llama3_chat("sys", &msgs);
let count = out.matches("<|start_header_id|>").count();
assert_eq!(count, 5, "expected 5 header openings, got {count}");
}
#[test]
fn test_format_eot_tokens_count() {
let msgs = vec![
("user".to_string(), "hello".to_string()),
("assistant".to_string(), "hi".to_string()),
];
let out = format_llama3_chat("sys", &msgs);
let eot_count = out.matches("<|eot_id|>").count();
assert_eq!(eot_count, 3, "expected 3 eot tokens, got {eot_count}");
}
#[test]
fn test_format_empty_messages_only_system_and_assistant() {
let out = format_llama3_chat("Sys", &[]);
assert!(out.contains("Sys"), "system content present");
assert!(
out.ends_with("<|end_header_id|>\n"),
"trailing assistant header"
);
}
#[test]
fn test_format_deterministic() {
let msgs = vec![("user".to_string(), "deterministic".to_string())];
let a = format_llama3_chat("sys", &msgs);
let b = format_llama3_chat("sys", &msgs);
assert_eq!(a, b, "format must be deterministic");
}
#[test]
fn test_small_test_config_valid() {
use trustformers_core::traits::Config;
let cfg = LLaMA3Config::small_test();
assert!(cfg.validate().is_ok(), "small_test config must be valid");
}
#[test]
fn test_small_test_config_fields() {
let cfg = LLaMA3Config::small_test();
assert_eq!(cfg.hidden_size, 64);
assert_eq!(cfg.num_hidden_layers, 2);
assert!(cfg.vocab_size > 0);
}
#[test]
fn test_chat_model_creation_small_config() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg);
assert!(model.is_ok(), "model construction must succeed");
}
#[test]
fn test_chat_model_parameter_count_nonzero() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg).unwrap_or_else(|_| panic!("model init failed"));
assert!(model.parameter_count() > 0, "parameter count must be > 0");
}
#[test]
fn test_chat_model_config_accessor() {
let cfg = LLaMA3Config::small_test();
let hidden = cfg.hidden_size;
let model = LLaMA3ChatModel::new(cfg).unwrap_or_else(|_| panic!("model init failed"));
assert_eq!(model.config().hidden_size, hidden);
}
#[test]
fn test_build_prompt_matches_format_fn() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg).unwrap_or_else(|_| panic!("model init failed"));
let msgs = vec![("user".to_string(), "test".to_string())];
let via_model = model.build_prompt("system", &msgs);
let direct = format_llama3_chat("system", &msgs);
assert_eq!(
via_model, direct,
"build_prompt must match format_llama3_chat"
);
}
#[test]
fn test_llama3_8b_head_dim() {
let cfg = LLaMA3Config::llama3_8b();
let expected = cfg.hidden_size / cfg.num_attention_heads;
assert_eq!(cfg.head_dim(), expected);
}
#[test]
fn test_llama3_8b_uses_gqa() {
let cfg = LLaMA3Config::llama3_8b();
assert!(
cfg.uses_gqa(),
"8B model uses GQA (8 KV heads < 32 Q heads)"
);
}
#[test]
fn test_llama3_70b_query_groups() {
let cfg = LLaMA3Config::llama3_70b();
let expected = cfg.num_attention_heads / cfg.num_key_value_heads;
assert_eq!(cfg.num_query_groups(), expected);
}
#[test]
fn test_chat_model_forward_output_shape() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg.clone()).unwrap_or_else(|_| panic!("init failed"));
let input_ids = vec![1u32, 2, 3];
let output = model.forward(input_ids);
assert!(output.is_ok(), "forward must succeed");
let out = output.unwrap_or_else(|_| panic!("forward failed"));
if let Tensor::F32(arr) = &out.logits {
assert!(!arr.is_empty(), "logits must be non-empty");
}
}
#[test]
fn test_greedy_next_token_within_vocab() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg.clone()).unwrap_or_else(|_| panic!("init failed"));
let input_ids = vec![1u32, 2];
if let Ok(out) = model.forward(input_ids) {
if let Ok(tok) = model.greedy_next_token(&out.logits) {
assert!(
(tok as usize) < cfg.vocab_size,
"token must be within vocab"
);
}
}
}
#[test]
fn test_greedy_next_token_picks_max() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg).unwrap_or_else(|_| panic!("init failed"));
let logits_vec = vec![0.1f32, 0.2, 0.1, 5.0, 0.1];
let tensor =
Tensor::from_vec(logits_vec, &[5]).unwrap_or_else(|_| panic!("tensor creation failed"));
let tok = model.greedy_next_token(&tensor).unwrap_or(0);
assert_eq!(tok, 3u32, "greedy must pick index 3 (highest logit)");
}
#[test]
fn test_format_system_special_chars() {
let sys = "Role: AI <&> entity\nLine2";
let out = format_llama3_chat(sys, &[]);
assert!(out.contains(sys), "system content must be verbatim");
}
#[test]
fn test_forward_output_finite() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg).unwrap_or_else(|_| panic!("init failed"));
let input_ids = vec![5u32];
if let Ok(out) = model.forward(input_ids) {
if let Tensor::F32(arr) = &out.logits {
for &v in arr.iter() {
assert!(v.is_finite(), "logit {v} is not finite");
}
}
}
}
#[test]
fn test_causal_lm_output_logits_accessible() {
let cfg = LLaMA3Config::small_test();
let model = LLaMA3ChatModel::new(cfg).unwrap_or_else(|_| panic!("init failed"));
if let Ok(out) = model.forward(vec![1u32]) {
let _ = &out.logits;
}
}
}