use std::fmt;
#[derive(Debug)]
pub enum DeepSeekTaskError {
InvalidConfig(String),
ModelBuildError(String),
ForwardError(String),
EmptyInput,
InvalidNumLabels(usize),
}
impl fmt::Display for DeepSeekTaskError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeepSeekTaskError::InvalidConfig(msg) => {
write!(f, "DeepSeek invalid config: {}", msg)
},
DeepSeekTaskError::ModelBuildError(msg) => {
write!(f, "DeepSeek model build error: {}", msg)
},
DeepSeekTaskError::ForwardError(msg) => {
write!(f, "DeepSeek forward error: {}", msg)
},
DeepSeekTaskError::EmptyInput => write!(f, "DeepSeek error: empty input"),
DeepSeekTaskError::InvalidNumLabels(n) => {
write!(f, "DeepSeek error: num_labels must be >= 2, got {}", n)
},
}
}
}
impl std::error::Error for DeepSeekTaskError {}
pub struct DeepSeekForCausalLM {
config: crate::deepseek::DeepSeekConfig,
inner: crate::deepseek::DeepSeekForCausalLM,
}
impl DeepSeekForCausalLM {
pub fn new(config: crate::deepseek::DeepSeekConfig) -> Result<Self, DeepSeekTaskError> {
let inner = crate::deepseek::DeepSeekForCausalLM::new(config.clone())
.map_err(|e| DeepSeekTaskError::ModelBuildError(e.to_string()))?;
Ok(Self { config, inner })
}
pub fn config(&self) -> &crate::deepseek::DeepSeekConfig {
&self.config
}
pub fn forward(
&self,
input_ids: Vec<u32>,
) -> Result<trustformers_core::tensor::Tensor, DeepSeekTaskError> {
if input_ids.is_empty() {
return Err(DeepSeekTaskError::EmptyInput);
}
self.inner
.forward(input_ids)
.map_err(|e| DeepSeekTaskError::ForwardError(e.to_string()))
}
pub fn greedy_next_token(logits: &[f32]) -> Option<u32> {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx as u32)
}
}
pub struct DeepSeekForSequenceClassification {
config: crate::deepseek::DeepSeekConfig,
num_labels: usize,
classifier_weight: Vec<Vec<f32>>,
}
impl DeepSeekForSequenceClassification {
pub fn new(
config: crate::deepseek::DeepSeekConfig,
num_labels: usize,
) -> Result<Self, DeepSeekTaskError> {
if num_labels < 2 {
return Err(DeepSeekTaskError::InvalidNumLabels(num_labels));
}
let hidden = config.hidden_size;
let mut state: u64 = 0xdeadbeef_cafebabe;
let classifier_weight = (0..num_labels)
.map(|_| {
(0..hidden)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state as f32 / u64::MAX as f32) * 0.02 - 0.01
})
.collect()
})
.collect();
Ok(Self {
config,
num_labels,
classifier_weight,
})
}
pub fn config(&self) -> &crate::deepseek::DeepSeekConfig {
&self.config
}
pub fn num_labels(&self) -> usize {
self.num_labels
}
pub fn forward(&self, hidden_state: &[f32]) -> Result<Vec<f32>, DeepSeekTaskError> {
if hidden_state.is_empty() {
return Err(DeepSeekTaskError::EmptyInput);
}
let expected = self.config.hidden_size;
let input: Vec<f32> = if hidden_state.len() >= expected {
hidden_state[..expected].to_vec()
} else {
let mut padded = hidden_state.to_vec();
padded.resize(expected, 0.0);
padded
};
let logits = self
.classifier_weight
.iter()
.map(|row| row.iter().zip(input.iter()).map(|(&w, &x)| w * x).sum::<f32>())
.collect();
Ok(logits)
}
}
pub struct DeepSeekForTokenClassification {
config: crate::deepseek::DeepSeekConfig,
num_labels: usize,
classifier_weight: Vec<Vec<f32>>,
}
impl DeepSeekForTokenClassification {
pub fn new(
config: crate::deepseek::DeepSeekConfig,
num_labels: usize,
) -> Result<Self, DeepSeekTaskError> {
if num_labels < 2 {
return Err(DeepSeekTaskError::InvalidNumLabels(num_labels));
}
let hidden = config.hidden_size;
let mut state: u64 = 0x1234567890abcdef;
let classifier_weight = (0..num_labels)
.map(|_| {
(0..hidden)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
(state as f32 / u64::MAX as f32) * 0.02 - 0.01
})
.collect()
})
.collect();
Ok(Self {
config,
num_labels,
classifier_weight,
})
}
pub fn config(&self) -> &crate::deepseek::DeepSeekConfig {
&self.config
}
pub fn num_labels(&self) -> usize {
self.num_labels
}
pub fn forward(
&self,
hidden_states: &[f32],
seq_len: usize,
) -> Result<Vec<f32>, DeepSeekTaskError> {
if hidden_states.is_empty() || seq_len == 0 {
return Err(DeepSeekTaskError::EmptyInput);
}
let hidden = self.config.hidden_size;
let mut output = Vec::with_capacity(seq_len * self.num_labels);
for tok in 0..seq_len {
let start = tok * hidden;
let slice: Vec<f32> = if start + hidden <= hidden_states.len() {
hidden_states[start..start + hidden].to_vec()
} else if start < hidden_states.len() {
let mut v = hidden_states[start..].to_vec();
v.resize(hidden, 0.0);
v
} else {
vec![0.0f32; hidden]
};
for row in &self.classifier_weight {
let logit: f32 = row.iter().zip(slice.iter()).map(|(&w, &x)| w * x).sum();
output.push(logit);
}
}
Ok(output)
}
}
pub fn moe_topk_indices(logits: &[f32], k: usize) -> Vec<usize> {
let effective_k = k.min(logits.len());
let mut indexed: Vec<(usize, f32)> = logits.iter().cloned().enumerate().collect();
indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
indexed[..effective_k].iter().map(|(i, _)| *i).collect()
}
pub fn moe_load_balance(routing_indices: &[usize], n_experts: usize) -> Vec<f32> {
if routing_indices.is_empty() || n_experts == 0 {
return vec![0.0f32; n_experts];
}
let mut counts = vec![0usize; n_experts];
for &idx in routing_indices {
if idx < n_experts {
counts[idx] += 1;
}
}
let total = routing_indices.len() as f32;
counts.iter().map(|&c| c as f32 / total).collect()
}
pub fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_v = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = logits.iter().map(|&x| (x - max_v).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum == 0.0 {
return vec![1.0 / logits.len() as f32; logits.len()];
}
exps.iter().map(|&e| e / sum).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::deepseek::DeepSeekConfig;
fn small_cfg() -> DeepSeekConfig {
DeepSeekConfig::small_test()
}
#[test]
fn test_causal_lm_construction() {
let result = DeepSeekForCausalLM::new(small_cfg());
assert!(
result.is_ok(),
"DeepSeekForCausalLM must construct: {:?}",
result.err()
);
}
#[test]
fn test_causal_lm_config_accessor() {
let model = DeepSeekForCausalLM::new(small_cfg()).expect("construction");
assert_eq!(model.config().hidden_size, 64);
assert_eq!(model.config().vocab_size, 1024);
}
#[test]
fn test_causal_lm_forward_safe() {
let model = DeepSeekForCausalLM::new(small_cfg()).expect("construction");
match model.forward(vec![1u32, 2, 3]) {
Ok(out) => {
use trustformers_core::tensor::Tensor;
if let Tensor::F32(arr) = &out {
assert!(!arr.is_empty(), "logits must be non-empty");
}
},
Err(_) => {
},
}
}
#[test]
fn test_causal_lm_empty_input_error() {
let model = DeepSeekForCausalLM::new(small_cfg()).expect("construction");
let result = model.forward(vec![]);
assert!(
matches!(result, Err(DeepSeekTaskError::EmptyInput)),
"empty input must return EmptyInput error"
);
}
#[test]
fn test_greedy_next_token_picks_max() {
let logits = vec![0.1f32, 0.9, 0.2, 0.5];
let tok = DeepSeekForCausalLM::greedy_next_token(&logits);
assert_eq!(tok, Some(1u32), "argmax of [0.1,0.9,0.2,0.5] must be 1");
}
#[test]
fn test_greedy_next_token_empty_returns_none() {
assert_eq!(DeepSeekForCausalLM::greedy_next_token(&[]), None);
}
#[test]
fn test_seq_cls_construction() {
let result = DeepSeekForSequenceClassification::new(small_cfg(), 4);
assert!(result.is_ok(), "SequenceClassification must construct");
}
#[test]
fn test_seq_cls_invalid_labels() {
let result = DeepSeekForSequenceClassification::new(small_cfg(), 1);
assert!(
matches!(result, Err(DeepSeekTaskError::InvalidNumLabels(1))),
"num_labels=1 must be rejected"
);
}
#[test]
fn test_seq_cls_forward_output_length() {
let model = DeepSeekForSequenceClassification::new(small_cfg(), 3).expect("construction");
let hidden = vec![0.1f32; small_cfg().hidden_size];
let logits = model.forward(&hidden).expect("forward");
assert_eq!(logits.len(), 3, "must produce 3 logits");
}
#[test]
fn test_seq_cls_empty_input_error() {
let model = DeepSeekForSequenceClassification::new(small_cfg(), 2).expect("construction");
let result = model.forward(&[]);
assert!(matches!(result, Err(DeepSeekTaskError::EmptyInput)));
}
#[test]
fn test_tok_cls_construction() {
let result = DeepSeekForTokenClassification::new(small_cfg(), 5);
assert!(result.is_ok(), "TokenClassification must construct");
}
#[test]
fn test_tok_cls_forward_output_shape() {
let cfg = small_cfg();
let hidden = cfg.hidden_size;
let model = DeepSeekForTokenClassification::new(cfg, 4).expect("construction");
let seq_len = 3usize;
let states = vec![0.1f32; seq_len * hidden];
let logits = model.forward(&states, seq_len).expect("forward");
assert_eq!(
logits.len(),
seq_len * 4,
"output shape must be seq_len * num_labels"
);
}
#[test]
fn test_tok_cls_empty_input_error() {
let model = DeepSeekForTokenClassification::new(small_cfg(), 2).expect("construction");
let result = model.forward(&[], 0);
assert!(matches!(result, Err(DeepSeekTaskError::EmptyInput)));
}
#[test]
fn test_moe_topk_indices_basic() {
let logits = vec![0.1f32, 0.8, 0.3, 0.9, 0.05];
let topk = moe_topk_indices(&logits, 2);
assert_eq!(topk.len(), 2);
assert!(topk.contains(&3), "index 3 must be top-1");
assert!(topk.contains(&1), "index 1 must be top-2");
}
#[test]
fn test_moe_topk_k_exceeds_len() {
let logits = vec![0.5f32, 0.3];
let topk = moe_topk_indices(&logits, 10);
assert_eq!(topk.len(), 2, "k capped at logits.len()");
}
#[test]
fn test_moe_load_balance_uniform() {
let routing = vec![0usize, 1, 0, 1, 0, 1];
let balance = moe_load_balance(&routing, 2);
assert_eq!(balance.len(), 2);
assert!((balance[0] - 0.5).abs() < 1e-5);
assert!((balance[1] - 0.5).abs() < 1e-5);
}
#[test]
fn test_moe_load_balance_empty() {
let balance = moe_load_balance(&[], 4);
assert_eq!(balance.len(), 4);
for &v in &balance {
assert_eq!(v, 0.0);
}
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0f32, 2.0, 3.0, 4.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax must sum to 1, got {}",
sum
);
}
#[test]
fn test_softmax_ordering_preserved() {
let logits = vec![0.0f32, 1.0, 2.0];
let probs = softmax(&logits);
assert!(probs[0] < probs[1] && probs[1] < probs[2]);
}
#[test]
fn test_error_display() {
let e1 = DeepSeekTaskError::InvalidConfig("bad param".to_string());
assert!(e1.to_string().contains("bad param"));
let e2 = DeepSeekTaskError::EmptyInput;
assert!(e2.to_string().contains("empty"));
let e3 = DeepSeekTaskError::InvalidNumLabels(1);
assert!(e3.to_string().contains("1"));
}
#[test]
fn test_lcg_value_range() {
let mut state: u64 = 42;
for _ in 0..100 {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let v = (state as f32 / u64::MAX as f32) * 0.02 - 0.01;
assert!((-0.01..=0.01).contains(&v), "value out of range: {v}");
}
}
#[test]
fn test_num_labels_accessor() {
let model = DeepSeekForSequenceClassification::new(small_cfg(), 7).expect("construction");
assert_eq!(model.num_labels(), 7);
let tok_model = DeepSeekForTokenClassification::new(small_cfg(), 9).expect("construction");
assert_eq!(tok_model.num_labels(), 9);
}
}