mod config;
mod model;
mod tasks;
pub use config::{Sd3Config, Sd3ConfigError};
pub use model::{
ClipTextEncoder, Sd3Error, Sd3TextEmbeddings, Sd3TextEncoderPipeline, T5Attention, T5Encoder,
T5EncoderLayer, T5FeedForward, T5RelativePositionBias,
};
pub use tasks::{Sd3TaskError, Sd3TextEncoder};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_relative_position_bucket_exact_region() {
let num_buckets = 32;
let max_distance = 128;
let bucket_q_after_k = T5RelativePositionBias::relative_position_bucket(
-3, true,
num_buckets,
max_distance,
);
assert_eq!(
bucket_q_after_k, 3,
"Query 3 steps after key: bucket should be 3"
);
let bucket_q_before_k = T5RelativePositionBias::relative_position_bucket(
3, true,
num_buckets,
max_distance,
);
assert_eq!(
bucket_q_before_k, 19,
"Query 3 steps before key: bucket should be 19 (16 + 3)"
);
let bucket_zero =
T5RelativePositionBias::relative_position_bucket(0, true, num_buckets, max_distance);
assert_eq!(
bucket_zero, 0,
"Zero relative position should map to bucket 0"
);
}
#[test]
fn test_relative_position_bucket_log_region() {
let num_buckets = 32;
let max_distance = 128;
let bucket_far =
T5RelativePositionBias::relative_position_bucket(-64, true, num_buckets, max_distance);
assert!(
bucket_far >= 8,
"Far q-after-k position should be in log-spaced region (>= 8)"
);
assert!(
bucket_far < 16,
"Far q-after-k bucket should be < 16 (effective half)"
);
assert!(
bucket_far < num_buckets,
"Bucket must be within num_buckets range"
);
let bucket_max = T5RelativePositionBias::relative_position_bucket(
-10000,
true,
num_buckets,
max_distance,
);
assert!(
bucket_max < num_buckets,
"Very far position must not exceed num_buckets"
);
}
#[test]
fn test_relative_position_bias_shape() {
let num_heads = 8;
let num_buckets = 32;
let max_distance = 128;
let rpb = T5RelativePositionBias::new(num_heads, num_buckets, max_distance);
let seq_len = 5;
let bias = rpb.compute_bias(seq_len, true);
assert_eq!(
bias.len(),
seq_len * seq_len,
"Bias should have seq_len^2 entries"
);
for entry in &bias {
assert_eq!(
entry.len(),
num_heads,
"Each bias entry should have num_heads values"
);
}
for i in 0..seq_len {
let diagonal_entry = &bias[i * seq_len + i];
assert_eq!(
diagonal_entry.len(),
num_heads,
"Diagonal entries must have num_heads values"
);
}
}
#[test]
fn test_t5_attention_num_heads() {
let cfg = Sd3Config::default();
assert_eq!(cfg.t5_num_heads, 64);
assert_eq!(cfg.t5_head_dim(), 64);
}
#[test]
fn test_config_defaults() {
let cfg = Sd3Config::default();
assert_eq!(cfg.t5_vocab_size, 32128);
assert_eq!(cfg.t5_hidden_size, 4096);
assert_eq!(cfg.t5_num_layers, 24);
assert_eq!(cfg.t5_num_heads, 64);
assert_eq!(cfg.t5_intermediate_size, 10240);
assert_eq!(cfg.t5_relative_attn_buckets, 32);
assert_eq!(cfg.t5_max_distance, 128);
assert_eq!(cfg.clip_vocab_size, 49408);
assert_eq!(cfg.clip_hidden_size, 768);
assert_eq!(cfg.clip_num_layers, 12);
assert_eq!(cfg.clip_num_heads, 12);
assert_eq!(cfg.clip_intermediate_size, 3072);
assert_eq!(cfg.clip_g_hidden_size, 1280);
assert_eq!(cfg.clip_g_num_layers, 32);
assert_eq!(cfg.clip_g_num_heads, 20);
assert_eq!(cfg.text_embedding_dim, 4096);
assert_eq!(cfg.pooled_embedding_dim, 2048);
assert_eq!(cfg.max_sequence_length, 77);
assert_eq!(cfg.max_t5_sequence_length, 256);
}
#[test]
fn test_pooled_embedding_dim() {
let cfg = Sd3Config::default();
let expected = cfg.clip_hidden_size + cfg.clip_g_hidden_size;
assert_eq!(expected, 2048);
assert_eq!(cfg.pooled_embedding_dim, expected);
}
#[test]
fn test_text_embeddings_struct() {
let t5_hidden = 4096usize;
let max_t5_seq = 16usize; let pooled_dim = 2048usize;
let emb = Sd3TextEmbeddings {
t5_embeddings: vec![0.0f32; max_t5_seq * t5_hidden],
pooled_embeddings: vec![0.0f32; pooled_dim],
seq_len: 10,
};
assert_eq!(emb.t5_embeddings.len(), max_t5_seq * t5_hidden);
assert_eq!(emb.pooled_embeddings.len(), pooled_dim);
assert_eq!(emb.seq_len, 10);
assert_eq!(emb.t5_embedding_dim(max_t5_seq), t5_hidden);
assert_eq!(emb.pooled_dim(), pooled_dim);
}
#[test]
fn test_clip_vs_t5_params() {
let cfg = Sd3Config::default();
assert!(cfg.t5_hidden_size > cfg.clip_hidden_size);
assert!(cfg.t5_hidden_size > cfg.clip_g_hidden_size);
assert!(cfg.t5_num_layers > cfg.clip_num_layers);
assert!(cfg.t5_num_heads > cfg.clip_num_heads);
assert!(cfg.validate().is_ok());
}
}