use crate::falcon::config::FalconConfig;
use scirs2_core::ndarray::{s, ArrayD, IxDyn}; use std::io::Read;
use trustformers_core::{
device::Device,
errors::{tensor_op_error, Result, TrustformersError},
layers::{Embedding, LayerNorm, Linear},
ops::activations::{gelu, silu},
tensor::Tensor,
traits::{Config, Layer, Model},
};
pub struct ALiBi {
slopes: Tensor,
num_heads: usize,
device: Device,
}
impl ALiBi {
pub fn new(num_heads: usize) -> Result<Self> {
Self::new_with_device(num_heads, Device::CPU)
}
pub fn new_with_device(num_heads: usize, device: Device) -> Result<Self> {
let mut slopes = Vec::new();
let ratio = 2.0_f32.powf(-8.0 / num_heads as f32);
if num_heads.is_multiple_of(2) {
for i in 0..num_heads / 2 {
slopes.push(ratio.powf((2 * i + 1) as f32));
}
for i in 0..num_heads / 2 {
slopes.push(ratio.powf((2 * i + 2) as f32));
}
} else {
for i in 0..num_heads {
slopes.push(ratio.powf((i + 1) as f32));
}
}
let slopes_tensor = Tensor::new(slopes)?;
Ok(Self {
slopes: slopes_tensor,
num_heads,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn apply_bias(&self, attention_scores: &Tensor, seq_len: usize) -> Result<Tensor> {
let mut bias_data = Vec::new();
for head_idx in 0..self.num_heads {
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
bias_data.push(-10000.0);
} else {
let distance = (i - j) as f32;
let slope = if let Ok(slopes_data) = self.slopes.data() {
if head_idx < slopes_data.len() {
slopes_data[head_idx]
} else {
1.0
}
} else {
1.0
};
bias_data.push(-distance * slope);
}
}
}
}
let bias_tensor = Tensor::from_vec(bias_data, &[seq_len, seq_len])?;
let biased_scores = attention_scores.add(&bias_tensor)?;
Ok(biased_scores)
}
}
pub struct FalconAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
dense: Linear,
alibi: Option<ALiBi>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
#[allow(dead_code)]
attention_dropout: f32,
#[allow(dead_code)]
use_flash_attention: bool,
device: Device,
}
impl FalconAttention {
pub fn new(config: &FalconConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
let head_dim = config.head_dim();
let num_kv_heads = config.num_kv_heads();
let q_proj = Linear::new(
config.hidden_size,
config.num_attention_heads * head_dim,
config.bias,
);
let k_proj = Linear::new(config.hidden_size, num_kv_heads * head_dim, config.bias);
let v_proj = Linear::new(config.hidden_size, num_kv_heads * head_dim, config.bias);
let dense = Linear::new(
config.num_attention_heads * head_dim,
config.hidden_size,
config.bias,
);
let alibi = if config.alibi {
Some(ALiBi::new_with_device(config.num_attention_heads, device)?)
} else {
None
};
Ok(Self {
q_proj,
k_proj,
v_proj,
dense,
alibi,
num_heads: config.num_attention_heads,
num_kv_heads,
head_dim,
attention_dropout: config.attention_dropout,
use_flash_attention: config.use_flash_attention.unwrap_or(false),
device,
})
}
pub fn device(&self) -> Device {
self.device
}
fn create_causal_mask(&self, seq_len: usize) -> Result<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
Tensor::from_vec(mask_data, &[seq_len, seq_len])
}
pub fn parameter_count(&self) -> usize {
self.q_proj.parameter_count()
+ self.k_proj.parameter_count()
+ self.v_proj.parameter_count()
+ self.dense.parameter_count()
}
}
impl Layer for FalconAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
let q = self.q_proj.forward(input.clone())?;
let k = self.k_proj.forward(input.clone())?;
let v = self.v_proj.forward(input)?;
let q = q.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim])?;
let k = k.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])?;
let v = v.reshape(&[batch_size, seq_len, self.num_kv_heads, self.head_dim])?;
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
let (k, v) = if self.num_kv_heads < self.num_heads {
let repeats = self.num_heads / self.num_kv_heads;
let mut k_heads = Vec::new();
let mut v_heads = Vec::new();
for head_idx in 0..self.num_kv_heads {
let k_head = k.slice_multi(&[
(0, batch_size),
(head_idx, head_idx + 1),
(0, seq_len),
(0, self.head_dim),
])?;
let v_head = v.slice_multi(&[
(0, batch_size),
(head_idx, head_idx + 1),
(0, seq_len),
(0, self.head_dim),
])?;
for _ in 0..repeats {
k_heads.push(k_head.clone());
v_heads.push(v_head.clone());
}
}
let k_repeated = Tensor::concat(&k_heads, 1)?;
let v_repeated = Tensor::concat(&v_heads, 1)?;
(k_repeated, v_repeated)
} else {
(k, v)
};
let k_transposed = k.transpose(2, 3)?;
let scores = q.matmul(&k_transposed)?;
let scale = (self.head_dim as f32).sqrt();
let scaled_scores = scores.div_scalar(scale)?;
let causal_mask = self.create_causal_mask(seq_len)?;
let masked_scores = scaled_scores.add(&causal_mask)?;
let attention_weights = masked_scores.softmax(-1)?;
let attention_output = attention_weights.matmul(&v)?;
let attention_output = attention_output.transpose(1, 2)?;
let attention_output =
attention_output.reshape(&[batch_size, seq_len, self.num_heads * self.head_dim])?;
let biased_output = if let Some(alibi) = &self.alibi {
alibi.apply_bias(&attention_output, seq_len)?
} else {
attention_output
};
let output = self.dense.forward(biased_output)?;
Ok(output)
}
}
pub struct FalconMLP {
dense_h_to_4h: Linear,
dense_4h_to_h: Linear,
activation: String,
device: Device,
}
impl FalconMLP {
pub fn new(config: &FalconConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
let intermediate_size = 4 * config.hidden_size;
let dense_h_to_4h = Linear::new(config.hidden_size, intermediate_size, config.bias);
let dense_4h_to_h = Linear::new(intermediate_size, config.hidden_size, config.bias);
Ok(Self {
dense_h_to_4h,
dense_4h_to_h,
activation: config.hidden_act.clone(),
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.dense_h_to_4h.parameter_count() + self.dense_4h_to_h.parameter_count()
}
}
impl Layer for FalconMLP {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden = self.dense_h_to_4h.forward(input)?;
let activated = match self.activation.as_str() {
"gelu" => gelu(&hidden)?,
"relu" => hidden.relu()?,
"silu" | "swish" => silu(&hidden)?,
_ => hidden,
};
let output = self.dense_4h_to_h.forward(activated)?;
Ok(output)
}
}
pub struct FalconDecoderLayer {
input_layernorm: LayerNorm,
self_attention: FalconAttention,
mlp: FalconMLP,
parallel_attn: bool,
apply_residual_connection_post_layernorm: bool,
device: Device,
}
impl FalconDecoderLayer {
pub fn new(config: &FalconConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FalconConfig, device: Device) -> Result<Self> {
let input_layernorm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_epsilon)?;
let self_attention = FalconAttention::new_with_device(config, device)?;
let mlp = FalconMLP::new_with_device(config, device)?;
Ok(Self {
input_layernorm,
self_attention,
mlp,
parallel_attn: config.parallel_attn,
apply_residual_connection_post_layernorm: config
.apply_residual_connection_post_layernorm,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.input_layernorm.parameter_count()
+ self.self_attention.parameter_count()
+ self.mlp.parameter_count()
}
}
impl Layer for FalconDecoderLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
if self.parallel_attn {
let layernorm_output = self.input_layernorm.forward(input.clone())?;
let attention_output = self.self_attention.forward(layernorm_output.clone())?;
let mlp_output = self.mlp.forward(layernorm_output.clone())?;
let residual_input = if self.apply_residual_connection_post_layernorm {
layernorm_output
} else {
input
};
let output = residual_input.add(&attention_output)?.add(&mlp_output)?;
Ok(output)
} else {
let layernorm_output = self.input_layernorm.forward(input.clone())?;
let attention_output = self.self_attention.forward(layernorm_output)?;
let residual_output = input.add(&attention_output)?;
let layernorm_output2 = self.input_layernorm.forward(residual_output.clone())?;
let mlp_output = self.mlp.forward(layernorm_output2)?;
let output = residual_output.add(&mlp_output)?;
Ok(output)
}
}
}
pub struct FalconModel {
word_embeddings: Embedding,
layers: Vec<FalconDecoderLayer>,
ln_f: LayerNorm,
config: FalconConfig,
device: Device,
}
impl FalconModel {
pub fn new(config: FalconConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: FalconConfig, device: Device) -> Result<Self> {
config.validate()?;
let word_embeddings = Embedding::new(
config.vocab_size,
config.hidden_size,
config.pad_token_id.map(|id| id as usize),
)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(FalconDecoderLayer::new_with_device(&config, device)?);
}
let ln_f = LayerNorm::new(vec![config.hidden_size], config.layer_norm_epsilon)?;
Ok(Self {
word_embeddings,
layers,
ln_f,
config,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn config(&self) -> &FalconConfig {
&self.config
}
}
impl Model for FalconModel {
type Config = FalconConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
Layer::forward(self, input)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Err(TrustformersError::not_implemented(
"Use load_from_path or load_from_huggingface for enhanced weight loading".to_string(),
))
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let embeddings_params = self.word_embeddings.parameter_count();
let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
let norm_params = self.ln_f.parameter_count();
embeddings_params + layers_params + norm_params
}
}
impl Layer for FalconModel {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let token_ids = match &input {
Tensor::F32(arr) => {
arr.iter().map(|&x| x as u32).collect::<Vec<u32>>()
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Input must be F32 tensor",
))
},
};
if token_ids.is_empty() {
return Err(TrustformersError::model_error(
"Empty token_ids provided".to_string(),
));
}
let mut hidden_states = self.word_embeddings.forward(token_ids)?;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
let output = self.ln_f.forward(hidden_states)?;
Ok(output)
}
}
pub struct FalconForCausalLM {
transformer: FalconModel,
lm_head: Linear,
device: Device,
}
impl FalconForCausalLM {
pub fn new(config: FalconConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: FalconConfig, device: Device) -> Result<Self> {
let transformer = FalconModel::new_with_device(config.clone(), device)?;
let lm_head = Linear::new(
config.hidden_size,
config.vocab_size,
false, );
Ok(Self {
transformer,
lm_head,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn load_from_path(&mut self, model_path: impl AsRef<std::path::Path>) -> Result<()> {
use crate::weight_loading::{auto_create_loader, WeightLoadingConfig};
let config = WeightLoadingConfig {
lazy_loading: true,
memory_mapped: false,
..Default::default()
};
let mut loader = auto_create_loader(model_path, Some(config))?;
if let Ok(embed_weights) = loader.load_tensor("transformer.word_embeddings.weight") {
self.transformer.word_embeddings.set_weight(embed_weights)?;
}
for (i, layer) in self.transformer.layers.iter_mut().enumerate() {
let attn_prefix = format!("transformer.h.{}.self_attention", i);
if let Ok(qkv_weight) =
loader.load_tensor(&format!("{}.query_key_value.weight", attn_prefix))
{
match &qkv_weight {
Tensor::F32(arr) => {
let shape = arr.shape();
let combined_size = shape[0];
let _hidden_size = shape[1];
let head_dim = combined_size / 3;
let q_slice = arr.slice(s![0..head_dim, ..]).to_owned();
let k_slice = arr.slice(s![head_dim..2 * head_dim, ..]).to_owned();
let v_slice = arr.slice(s![2 * head_dim..3 * head_dim, ..]).to_owned();
let q_dyn = q_slice.into_dyn();
let k_dyn = k_slice.into_dyn();
let v_dyn = v_slice.into_dyn();
layer.self_attention.q_proj.set_weight(Tensor::F32(q_dyn))?;
layer.self_attention.k_proj.set_weight(Tensor::F32(k_dyn))?;
layer.self_attention.v_proj.set_weight(Tensor::F32(v_dyn))?;
},
_ => {
layer.self_attention.q_proj.set_weight(qkv_weight.clone())?;
},
}
}
if let Ok(o_weight) = loader.load_tensor(&format!("{}.dense.weight", attn_prefix)) {
layer.self_attention.dense.set_weight(o_weight)?;
}
let mlp_prefix = format!("transformer.h.{}.mlp", i);
if let Ok(up_weight) =
loader.load_tensor(&format!("{}.dense_h_to_4h.weight", mlp_prefix))
{
layer.mlp.dense_h_to_4h.set_weight(up_weight)?;
}
if let Ok(down_weight) =
loader.load_tensor(&format!("{}.dense_4h_to_h.weight", mlp_prefix))
{
layer.mlp.dense_4h_to_h.set_weight(down_weight)?;
}
if let Ok(ln_weight) =
loader.load_tensor(&format!("transformer.h.{}.input_layernorm.weight", i))
{
layer.input_layernorm.set_weight(ln_weight)?;
}
if let Ok(ln_bias) =
loader.load_tensor(&format!("transformer.h.{}.input_layernorm.bias", i))
{
layer.input_layernorm.set_bias(ln_bias)?;
}
}
if let Ok(norm_weight) = loader.load_tensor("transformer.ln_f.weight") {
self.transformer.ln_f.set_weight(norm_weight)?;
}
if let Ok(norm_bias) = loader.load_tensor("transformer.ln_f.bias") {
self.transformer.ln_f.set_bias(norm_bias)?;
}
if let Ok(lm_head_weight) = loader.load_tensor("lm_head.weight") {
self.lm_head.set_weight(lm_head_weight)?;
}
Ok(())
}
pub fn load_from_huggingface(&mut self, model_name: &str) -> Result<()> {
let cache_dir = std::env::var("HF_HOME")
.or_else(|_| std::env::var("HUGGINGFACE_HUB_CACHE"))
.unwrap_or_else(|_| {
std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
+ "/.cache/huggingface/hub"
});
let model_path = std::path::Path::new(&cache_dir)
.join(format!("models--{}", model_name.replace("/", "--")));
if model_path.exists() {
self.load_from_path(&model_path)
} else {
self.download_from_huggingface_hub(model_name, &model_path)?;
self.load_from_path(&model_path)
}
}
fn download_from_huggingface_hub(
&self,
model_name: &str,
model_path: &std::path::Path,
) -> Result<()> {
use std::process::Command;
println!(
"Downloading model {} from HuggingFace Hub to {:?}",
model_name, model_path
);
std::fs::create_dir_all(model_path).map_err(|e| {
TrustformersError::io_error(format!("Failed to create model directory: {}", e))
})?;
let essential_files = vec![
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"pytorch_model.bin", "model.safetensors", ];
let base_url = format!("https://huggingface.co/{}/resolve/main", model_name);
for file_name in &essential_files {
let file_url = format!("{}/{}", base_url, file_name);
let file_path = model_path.join(file_name);
println!("Attempting to download {}", file_url);
let curl_result = Command::new("curl")
.args([
"-L", "-f", "-o",
file_path.to_str().expect("operation failed"),
&file_url,
])
.output();
match curl_result {
Ok(output) if output.status.success() => {
println!("Successfully downloaded {}", file_name);
continue;
},
Ok(output) => {
eprintln!(
"Failed to download {} with curl: {}",
file_name,
String::from_utf8_lossy(&output.stderr)
);
},
Err(e) => {
println!("curl not available: {}", e);
},
}
let wget_result = Command::new("wget")
.args([
"-O",
file_path.to_str().expect("operation failed"),
&file_url,
])
.output();
match wget_result {
Ok(output) if output.status.success() => {
println!("Successfully downloaded {} with wget", file_name);
continue;
},
Ok(output) => {
eprintln!(
"Failed to download {} with wget: {}",
file_name,
String::from_utf8_lossy(&output.stderr)
);
},
Err(e) => {
println!("wget not available: {}", e);
},
}
if matches!(file_name, &"config.json" | &"pytorch_model.bin") {
return Err(TrustformersError::io_error(format!(
"Failed to download essential file {} for model {}. Please ensure curl or wget is installed and you have internet access.",
file_name, model_name
)));
}
}
println!(
"Successfully downloaded model {} from HuggingFace Hub",
model_name
);
Ok(())
}
pub fn load_from_hub(&mut self, model_name: &str) -> Result<()> {
self.load_from_huggingface(model_name)
}
pub fn generate(&self, input_ids: Tensor, max_length: usize) -> Result<Tensor> {
let mut current_ids = input_ids;
let current_length = current_ids.shape()[current_ids.shape().len() - 1];
for _ in current_length..max_length {
let logits = <Self as Model>::forward(self, current_ids.clone())?;
let last_logits = match &logits {
Tensor::F32(arr) => {
let shape = arr.shape();
let seq_len = shape[shape.len() - 2];
let _vocab_size = shape[shape.len() - 1];
let last_token_slice = if shape.len() == 3 {
arr.slice(s![0, seq_len - 1, ..])
} else {
arr.slice(s![seq_len - 1, ..])
};
last_token_slice.to_owned()
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Logits must be F32 tensor",
))
},
};
let next_token_id = last_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
.ok_or_else(|| {
TrustformersError::model_error("Failed to find next token".to_string())
})?;
if next_token_id == 2 {
break;
}
current_ids = match ¤t_ids {
Tensor::F32(arr) => {
let mut new_shape = arr.shape().to_vec();
let last_idx = new_shape.len() - 1;
new_shape[last_idx] += 1;
let mut new_arr = ArrayD::<f32>::zeros(IxDyn(&new_shape));
if arr.ndim() == 2 {
for i in 0..arr.shape()[0] {
for j in 0..arr.shape()[1] {
new_arr[[i, j]] = arr[[i, j]];
}
new_arr[[i, arr.shape()[1]]] = next_token_id as f32;
}
} else if arr.ndim() == 1 {
for i in 0..arr.shape()[0] {
new_arr[[i]] = arr[[i]];
}
new_arr[[arr.shape()[0]]] = next_token_id as f32;
}
Tensor::F32(new_arr)
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Input must be F32 tensor",
))
},
};
}
Ok(current_ids)
}
pub fn model(&self) -> &FalconModel {
&self.transformer
}
}
impl Model for FalconForCausalLM {
type Config = FalconConfig;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
Layer::forward(self, input)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
self.transformer.load_pretrained(reader)
}
fn get_config(&self) -> &Self::Config {
self.transformer.get_config()
}
fn num_parameters(&self) -> usize {
self.transformer.num_parameters() + self.lm_head.parameter_count()
}
}
impl Layer for FalconForCausalLM {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden_states = Layer::forward(&self.transformer, input)?;
let logits = self.lm_head.forward(hidden_states)?;
Ok(logits)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_falcon_config() -> FalconConfig {
FalconConfig {
vocab_size: 64,
hidden_size: 64,
num_hidden_layers: 1,
num_attention_heads: 4,
num_kv_heads: Some(1), max_position_embeddings: 32,
alibi: false, parallel_attn: true,
..FalconConfig::default()
}
}
#[test]
#[ignore] fn test_falcon_model_creation() {
let config = FalconConfig::falcon_7b();
let model = FalconModel::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_falcon_causal_lm_creation() {
let config = FalconConfig::falcon_7b();
let model = FalconForCausalLM::new(config);
assert!(model.is_ok());
}
#[test]
fn test_falcon_config_variants() {
let config_7b = FalconConfig::falcon_7b();
assert_eq!(config_7b.hidden_size, 4544);
assert_eq!(config_7b.num_hidden_layers, 32);
assert!(config_7b.uses_alibi());
let config_40b = FalconConfig::falcon_40b();
assert_eq!(config_40b.hidden_size, 8192);
assert_eq!(config_40b.num_hidden_layers, 60);
assert!(config_40b.uses_alibi());
let config_180b = FalconConfig::falcon_180b();
assert_eq!(config_180b.hidden_size, 14848);
assert_eq!(config_180b.num_hidden_layers, 80);
assert!(!config_180b.uses_alibi());
assert!(config_180b.uses_new_architecture());
}
#[test]
fn test_alibi_creation() {
let alibi = ALiBi::new(8);
assert!(alibi.is_ok());
let alibi = alibi.expect("operation failed");
assert_eq!(alibi.num_heads, 8);
}
#[test]
fn test_falcon_attention_creation() {
let config = FalconConfig::falcon_7b();
let attention = FalconAttention::new(&config);
assert!(attention.is_ok());
}
#[test]
fn test_falcon_mlp_creation() {
let config = FalconConfig::falcon_7b();
let mlp = FalconMLP::new(&config);
assert!(mlp.is_ok());
}
#[test]
fn test_alibi_even_heads() {
let alibi = ALiBi::new(8).expect("ALiBi with 8 heads");
assert_eq!(alibi.num_heads, 8);
}
#[test]
fn test_alibi_odd_heads() {
let alibi = ALiBi::new(7).expect("ALiBi with 7 heads");
assert_eq!(alibi.num_heads, 7);
}
#[test]
fn test_alibi_device_cpu() {
let alibi = ALiBi::new(4).expect("ALiBi with 4 heads");
assert_eq!(alibi.device(), Device::CPU);
}
#[test]
fn test_falcon_7b_num_kv_heads_is_one() {
let config = FalconConfig::falcon_7b();
assert_eq!(
config.num_kv_heads(),
1,
"Falcon-7B must use 1 KV head (multi-query)"
);
}
#[test]
fn test_falcon_attention_tiny_creation() {
let config = tiny_falcon_config();
let attn = FalconAttention::new(&config);
assert!(
attn.is_ok(),
"FalconAttention construction with tiny config failed"
);
}
#[test]
fn test_falcon_attention_parameter_count_positive() {
let config = tiny_falcon_config();
let attn = FalconAttention::new(&config).expect("FalconAttention construction");
assert!(attn.parameter_count() > 0);
}
#[test]
fn test_falcon_decoder_layer_parallel_attn_flag() {
let config = tiny_falcon_config();
let layer = FalconDecoderLayer::new(&config).expect("FalconDecoderLayer construction");
assert!(layer.parallel_attn, "Tiny config sets parallel_attn=true");
}
#[test]
fn test_falcon_decoder_layer_sequential_attn() {
let mut config = tiny_falcon_config();
config.parallel_attn = false;
let layer = FalconDecoderLayer::new(&config).expect("FalconDecoderLayer construction");
assert!(!layer.parallel_attn);
}
#[test]
fn test_falcon_180b_new_decoder_architecture() {
let config = FalconConfig::falcon_180b();
assert!(config.new_decoder_architecture);
}
#[test]
fn test_falcon_7b_old_decoder_architecture() {
let config = FalconConfig::falcon_7b();
assert!(!config.new_decoder_architecture);
}
#[test]
fn test_falcon_mlp_tiny_creation() {
let config = tiny_falcon_config();
let mlp = FalconMLP::new(&config).expect("FalconMLP tiny creation");
assert!(mlp.parameter_count() > 0);
}
#[test]
fn test_falcon_mlp_device_cpu() {
let config = tiny_falcon_config();
let mlp = FalconMLP::new(&config).expect("FalconMLP tiny creation");
assert_eq!(mlp.device(), Device::CPU);
}
#[test]
fn test_falcon_model_tiny_creation() {
let config = tiny_falcon_config();
let model = FalconModel::new(config);
assert!(model.is_ok(), "FalconModel with tiny config must succeed");
}
#[test]
fn test_falcon_model_num_parameters_positive() {
let config = tiny_falcon_config();
let model = FalconModel::new(config).expect("FalconModel tiny");
assert!(model.num_parameters() > 0);
}
#[test]
fn test_falcon_causal_lm_tiny_creation() {
let config = tiny_falcon_config();
let model = FalconForCausalLM::new(config);
assert!(
model.is_ok(),
"FalconForCausalLM with tiny config must succeed"
);
}
#[test]
fn test_falcon_causal_lm_parameter_count_exceeds_base() {
let config = tiny_falcon_config();
let base = FalconModel::new(config.clone()).expect("FalconModel");
let lm_head_model = FalconForCausalLM::new(config).expect("FalconForCausalLM");
assert!(lm_head_model.num_parameters() > base.num_parameters());
}
#[test]
fn test_alibi_slopes_positive() {
let alibi = ALiBi::new(4).expect("ALiBi with 4 heads");
let data = alibi.slopes.data().expect("slope data");
for (i, &s) in data.iter().enumerate() {
assert!(s > 0.0, "Slope[{}] = {} must be positive", i, s);
}
}
#[test]
fn test_causal_mask_upper_triangle_is_neg_inf() {
let config = tiny_falcon_config();
let attn = FalconAttention::new(&config).expect("FalconAttention");
let mask = attn.create_causal_mask(4).expect("causal mask");
match &mask {
Tensor::F32(arr) => {
assert!(arr[[0, 1]].is_infinite() && arr[[0, 1]] < 0.0);
assert!(arr[[0, 2]].is_infinite() && arr[[0, 2]] < 0.0);
assert!(arr[[1, 2]].is_infinite() && arr[[1, 2]] < 0.0);
assert_eq!(arr[[0, 0]], 0.0);
assert_eq!(arr[[1, 1]], 0.0);
assert_eq!(arr[[2, 1]], 0.0);
},
_ => panic!("Expected F32 mask"),
}
}
}