use crate::error::{Result, RuvLLMError};
use crate::kernels::rope::{precompute_rope_tables_with_config, RopeConfig, RopeTables};
use crate::kernels::{apply_rope_neon, flash_attention_neon, rms_norm_neon, AttentionConfig};
use crate::sona::{SonaConfig, SonaIntegration, Trajectory};
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AneOptimization {
Disabled,
AneOnly,
GpuOnly,
HybridDispatch,
Adaptive,
}
impl Default for AneOptimization {
fn default() -> Self {
Self::HybridDispatch
}
}
impl AneOptimization {
pub fn uses_ane(&self) -> bool {
matches!(self, Self::AneOnly | Self::HybridDispatch | Self::Adaptive)
}
pub fn uses_gpu(&self) -> bool {
matches!(self, Self::GpuOnly | Self::HybridDispatch | Self::Adaptive)
}
pub fn ane_tile_size(&self) -> usize {
768
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationType {
None,
Fp16,
Bf16,
Int8,
Int4,
Q4KM,
MixedPrecision,
}
impl Default for QuantizationType {
fn default() -> Self {
Self::Int4
}
}
impl QuantizationType {
pub fn bytes_per_weight(&self) -> f32 {
match self {
Self::None => 4.0,
Self::Fp16 | Self::Bf16 => 2.0,
Self::Int8 => 1.0,
Self::Int4 | Self::Q4KM => 0.5,
Self::MixedPrecision => 1.0, }
}
pub fn estimate_memory_mb(&self, num_params: usize) -> f32 {
(num_params as f32 * self.bytes_per_weight()) / (1024.0 * 1024.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum MemoryLayout {
RowMajor,
ColumnMajor,
#[default]
Nhwc,
Blocked,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuvLtraConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_kv_heads: usize,
pub vocab_size: usize,
pub max_position_embeddings: usize,
pub rope_theta: f32,
pub rms_norm_eps: f32,
pub head_dim: usize,
pub use_flash_attention: bool,
pub sliding_window: Option<usize>,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub pad_token_id: u32,
pub ane_optimization: AneOptimization,
pub quantization: QuantizationType,
pub memory_layout: MemoryLayout,
pub ane_matmul_optimized: bool,
pub tile_size: usize,
pub sona_enabled: bool,
pub sona_config: SonaConfig,
}
impl Default for RuvLtraConfig {
fn default() -> Self {
Self::qwen_0_5b()
}
}
impl RuvLtraConfig {
pub fn qwen_0_5b() -> Self {
Self {
hidden_size: 896,
intermediate_size: 4864,
num_hidden_layers: 24,
num_attention_heads: 14,
num_kv_heads: 2, vocab_size: 151936,
max_position_embeddings: 32768,
rope_theta: 1000000.0, rms_norm_eps: 1e-6,
head_dim: 64, use_flash_attention: true,
sliding_window: None, bos_token_id: 151643,
eos_token_id: 151645,
pad_token_id: 151643,
ane_optimization: AneOptimization::HybridDispatch,
quantization: QuantizationType::Int4,
memory_layout: MemoryLayout::Nhwc,
ane_matmul_optimized: true, tile_size: 64,
sona_enabled: true,
sona_config: SonaConfig {
hidden_dim: 896,
embedding_dim: 896,
micro_lora_rank: 2,
base_lora_rank: 4,
instant_learning_rate: 0.01,
background_learning_rate: 0.001,
ewc_lambda: 0.1,
pattern_capacity: 10000,
background_interval_secs: 3600,
deep_interval_secs: 604800,
quality_threshold: 0.5,
},
}
}
pub fn qwen_1_8b() -> Self {
Self {
hidden_size: 2048,
intermediate_size: 5504,
num_hidden_layers: 24,
num_attention_heads: 16,
num_kv_heads: 16,
vocab_size: 151936,
max_position_embeddings: 32768,
rope_theta: 1000000.0,
rms_norm_eps: 1e-6,
head_dim: 128,
use_flash_attention: true,
sliding_window: None,
bos_token_id: 151643,
eos_token_id: 151645,
pad_token_id: 151643,
ane_optimization: AneOptimization::HybridDispatch,
quantization: QuantizationType::Int4,
memory_layout: MemoryLayout::Nhwc,
ane_matmul_optimized: true,
tile_size: 64,
sona_enabled: true,
sona_config: SonaConfig {
hidden_dim: 2048,
embedding_dim: 2048,
..SonaConfig::default()
},
}
}
pub fn tiny() -> Self {
Self {
hidden_size: 768, intermediate_size: 2048,
num_hidden_layers: 4,
num_attention_heads: 12,
num_kv_heads: 2,
vocab_size: 32000,
max_position_embeddings: 2048,
rope_theta: 10000.0,
rms_norm_eps: 1e-5,
head_dim: 64,
use_flash_attention: true,
sliding_window: None,
bos_token_id: 1,
eos_token_id: 2,
pad_token_id: 0,
ane_optimization: AneOptimization::AneOnly,
quantization: QuantizationType::Fp16,
memory_layout: MemoryLayout::Nhwc,
ane_matmul_optimized: true,
tile_size: 64,
sona_enabled: false,
sona_config: SonaConfig::default(),
}
}
pub fn with_ane_optimization(mut self, opt: AneOptimization) -> Self {
self.ane_optimization = opt;
self
}
pub fn with_quantization(mut self, quant: QuantizationType) -> Self {
self.quantization = quant;
self
}
pub fn with_sona(mut self, enabled: bool) -> Self {
self.sona_enabled = enabled;
self
}
pub fn with_memory_layout(mut self, layout: MemoryLayout) -> Self {
self.memory_layout = layout;
self
}
pub fn gqa_ratio(&self) -> usize {
self.num_attention_heads / self.num_kv_heads
}
pub fn attention_config(&self) -> AttentionConfig {
AttentionConfig {
num_heads: self.num_attention_heads,
num_kv_heads: self.num_kv_heads,
head_dim: self.head_dim,
max_seq_len: self.max_position_embeddings,
causal: true,
scale: 1.0 / (self.head_dim as f32).sqrt(),
}
}
pub fn rope_config(&self) -> RopeConfig {
RopeConfig {
base: self.rope_theta,
head_dim: self.head_dim,
max_seq_len: self.max_position_embeddings,
scaling_factor: 1.0,
ntk_aware: false,
original_max_len: self.max_position_embeddings,
}
}
pub fn is_ane_optimized(&self) -> bool {
self.ane_matmul_optimized && self.hidden_size >= 768
}
pub fn estimate_params(&self) -> usize {
let embed_params = self.vocab_size * self.hidden_size;
let attn_params = self.num_hidden_layers
* (
4 * self.hidden_size * self.hidden_size
);
let mlp_params = self.num_hidden_layers
* (
3 * self.hidden_size * self.intermediate_size
);
let norm_params = (self.num_hidden_layers * 2 + 1) * self.hidden_size;
embed_params + attn_params + mlp_params + norm_params
}
pub fn estimate_memory_mb(&self) -> f32 {
self.quantization.estimate_memory_mb(self.estimate_params())
}
}
#[derive(Debug)]
pub struct RuvLtraAttention {
pub q_proj: Vec<f32>,
pub k_proj: Vec<f32>,
pub v_proj: Vec<f32>,
pub o_proj: Vec<f32>,
pub config: RuvLtraConfig,
pub rope_tables: RopeTables,
}
impl RuvLtraAttention {
pub fn new(config: &RuvLtraConfig) -> Self {
let hidden_size = config.hidden_size;
let kv_dim = config.num_kv_heads * config.head_dim;
Self {
q_proj: vec![0.0; hidden_size * hidden_size],
k_proj: vec![0.0; hidden_size * kv_dim],
v_proj: vec![0.0; hidden_size * kv_dim],
o_proj: vec![0.0; hidden_size * hidden_size],
config: config.clone(),
rope_tables: precompute_rope_tables_with_config(&config.rope_config()),
}
}
pub fn load_weights(
&mut self,
q_proj: &[f32],
k_proj: &[f32],
v_proj: &[f32],
o_proj: &[f32],
) -> Result<()> {
let hidden_size = self.config.hidden_size;
let kv_dim = self.config.num_kv_heads * self.config.head_dim;
if q_proj.len() != hidden_size * hidden_size {
return Err(RuvLLMError::Model(format!(
"Invalid q_proj size: expected {}, got {}",
hidden_size * hidden_size,
q_proj.len()
)));
}
if k_proj.len() != hidden_size * kv_dim || v_proj.len() != hidden_size * kv_dim {
return Err(RuvLLMError::Model(format!(
"Invalid KV proj size: expected {}, got k={}, v={}",
hidden_size * kv_dim,
k_proj.len(),
v_proj.len()
)));
}
self.q_proj.copy_from_slice(q_proj);
self.k_proj.copy_from_slice(k_proj);
self.v_proj.copy_from_slice(v_proj);
self.o_proj.copy_from_slice(o_proj);
Ok(())
}
pub fn forward(
&self,
hidden_states: &[f32],
positions: &[usize],
kv_cache: Option<(&mut Vec<f32>, &mut Vec<f32>)>,
) -> Result<Vec<f32>> {
let seq_len = positions.len();
let hidden_size = self.config.hidden_size;
let num_heads = self.config.num_attention_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = self.config.head_dim;
let gqa_ratio = num_heads / num_kv_heads;
if hidden_states.len() != seq_len * hidden_size {
return Err(RuvLLMError::InvalidOperation(format!(
"Invalid hidden_states shape: expected {}, got {}",
seq_len * hidden_size,
hidden_states.len()
)));
}
let mut query =
self.linear_transform(hidden_states, &self.q_proj, hidden_size, hidden_size);
let mut key = self.linear_transform(
hidden_states,
&self.k_proj,
hidden_size,
num_kv_heads * head_dim,
);
let value = self.linear_transform(
hidden_states,
&self.v_proj,
hidden_size,
num_kv_heads * head_dim,
);
self.apply_rope(&mut query, positions, num_heads, head_dim);
self.apply_rope(&mut key, positions, num_kv_heads, head_dim);
let (key_states, value_states) = if let Some((k_cache, v_cache)) = kv_cache {
k_cache.extend_from_slice(&key);
v_cache.extend_from_slice(&value);
(k_cache.as_slice(), v_cache.as_slice())
} else {
(key.as_slice(), value.as_slice())
};
let kv_len = key_states.len() / (num_kv_heads * head_dim);
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0; seq_len * hidden_size];
for h in 0..num_heads {
let kv_head = h / gqa_ratio;
for t in 0..seq_len {
let q_offset = (t * num_heads + h) * head_dim;
let q_slice = &query[q_offset..q_offset + head_dim];
let mut k_slice = Vec::with_capacity(kv_len * head_dim);
let mut v_slice = Vec::with_capacity(kv_len * head_dim);
for kv_t in 0..kv_len {
let kv_offset = (kv_t * num_kv_heads + kv_head) * head_dim;
k_slice.extend_from_slice(&key_states[kv_offset..kv_offset + head_dim]);
v_slice.extend_from_slice(&value_states[kv_offset..kv_offset + head_dim]);
}
let (k_slice, v_slice, _effective_kv_len) =
if let Some(window) = self.config.sliding_window {
let pos = positions[t];
let start = pos.saturating_sub(window);
if start > 0 {
let start_offset = start * head_dim;
(
k_slice[start_offset..].to_vec(),
v_slice[start_offset..].to_vec(),
kv_len - start,
)
} else {
(k_slice, v_slice, kv_len)
}
} else {
(k_slice, v_slice, kv_len)
};
let head_output = flash_attention_neon(q_slice, &k_slice, &v_slice, scale, true);
let out_offset = (t * num_heads + h) * head_dim;
output[out_offset..out_offset + head_dim].copy_from_slice(&head_output);
}
}
let output = self.linear_transform(&output, &self.o_proj, hidden_size, hidden_size);
Ok(output)
}
fn apply_rope(&self, x: &mut [f32], positions: &[usize], num_heads: usize, head_dim: usize) {
let seq_len = positions.len();
for h in 0..num_heads {
for t in 0..seq_len {
let offset = (t * num_heads + h) * head_dim;
let mut head_vec = x[offset..offset + head_dim].to_vec();
apply_rope_neon(
&mut head_vec,
&[positions[t]],
head_dim,
self.config.rope_theta,
);
x[offset..offset + head_dim].copy_from_slice(&head_vec);
}
}
}
fn linear_transform(
&self,
input: &[f32],
weights: &[f32],
in_dim: usize,
out_dim: usize,
) -> Vec<f32> {
let batch_size = input.len() / in_dim;
let mut output = vec![0.0; batch_size * out_dim];
#[cfg(target_arch = "aarch64")]
unsafe {
self.linear_neon(input, weights, &mut output, batch_size, in_dim, out_dim);
}
#[cfg(not(target_arch = "aarch64"))]
{
for b in 0..batch_size {
for o in 0..out_dim {
let mut sum = 0.0;
for i in 0..in_dim {
sum += input[b * in_dim + i] * weights[o * in_dim + i];
}
output[b * out_dim + o] = sum;
}
}
}
output
}
#[cfg(target_arch = "aarch64")]
unsafe fn linear_neon(
&self,
input: &[f32],
weights: &[f32],
output: &mut [f32],
batch_size: usize,
in_dim: usize,
out_dim: usize,
) {
let in_ptr: *const f32 = input.as_ptr();
let w_ptr: *const f32 = weights.as_ptr();
let out_ptr: *mut f32 = output.as_mut_ptr();
for b in 0..batch_size {
for o in 0..out_dim {
let mut acc = vdupq_n_f32(0.0);
let mut i = 0;
while i + 4 <= in_dim {
let x = vld1q_f32(in_ptr.add(b * in_dim + i));
let w = vld1q_f32(w_ptr.add(o * in_dim + i));
acc = vfmaq_f32(acc, x, w);
i += 4;
}
let mut sum = vaddvq_f32(acc);
while i < in_dim {
sum += *in_ptr.add(b * in_dim + i) * *w_ptr.add(o * in_dim + i);
i += 1;
}
*out_ptr.add(b * out_dim + o) = sum;
}
}
}
}
#[derive(Debug)]
pub struct RuvLtraMLP {
pub gate_proj: Vec<f32>,
pub up_proj: Vec<f32>,
pub down_proj: Vec<f32>,
pub hidden_size: usize,
pub intermediate_size: usize,
pub use_ane: bool,
}
impl RuvLtraMLP {
pub fn new(config: &RuvLtraConfig) -> Self {
Self {
gate_proj: vec![0.0; config.intermediate_size * config.hidden_size],
up_proj: vec![0.0; config.intermediate_size * config.hidden_size],
down_proj: vec![0.0; config.hidden_size * config.intermediate_size],
hidden_size: config.hidden_size,
intermediate_size: config.intermediate_size,
use_ane: config.ane_optimization.uses_ane() && config.is_ane_optimized(),
}
}
pub fn load_weights(
&mut self,
gate_proj: &[f32],
up_proj: &[f32],
down_proj: &[f32],
) -> Result<()> {
let gate_up_size = self.intermediate_size * self.hidden_size;
let down_size = self.hidden_size * self.intermediate_size;
if gate_proj.len() != gate_up_size
|| up_proj.len() != gate_up_size
|| down_proj.len() != down_size
{
return Err(RuvLLMError::Model(format!(
"Invalid MLP weight dimensions: expected gate/up={}, down={}; got gate={}, up={}, down={}",
gate_up_size, down_size, gate_proj.len(), up_proj.len(), down_proj.len()
)));
}
self.gate_proj.copy_from_slice(gate_proj);
self.up_proj.copy_from_slice(up_proj);
self.down_proj.copy_from_slice(down_proj);
Ok(())
}
pub fn forward(&self, hidden_states: &[f32]) -> Result<Vec<f32>> {
let gate = self.linear(
hidden_states,
&self.gate_proj,
self.hidden_size,
self.intermediate_size,
);
let gate_activated = self.silu(&gate);
let up = self.linear(
hidden_states,
&self.up_proj,
self.hidden_size,
self.intermediate_size,
);
let hidden: Vec<f32> = gate_activated
.iter()
.zip(up.iter())
.map(|(g, u)| g * u)
.collect();
let output = self.linear(
&hidden,
&self.down_proj,
self.intermediate_size,
self.hidden_size,
);
Ok(output)
}
fn linear(&self, input: &[f32], weights: &[f32], in_dim: usize, out_dim: usize) -> Vec<f32> {
let batch_size = input.len() / in_dim;
let mut output = vec![0.0; batch_size * out_dim];
#[cfg(target_arch = "aarch64")]
unsafe {
let in_ptr: *const f32 = input.as_ptr();
let w_ptr: *const f32 = weights.as_ptr();
let out_ptr: *mut f32 = output.as_mut_ptr();
for b in 0..batch_size {
for o in 0..out_dim {
let mut acc = vdupq_n_f32(0.0);
let mut i = 0;
while i + 4 <= in_dim {
let x = vld1q_f32(in_ptr.add(b * in_dim + i));
let w = vld1q_f32(w_ptr.add(o * in_dim + i));
acc = vfmaq_f32(acc, x, w);
i += 4;
}
let mut sum = vaddvq_f32(acc);
while i < in_dim {
sum += *in_ptr.add(b * in_dim + i) * *w_ptr.add(o * in_dim + i);
i += 1;
}
*out_ptr.add(b * out_dim + o) = sum;
}
}
}
#[cfg(not(target_arch = "aarch64"))]
{
for b in 0..batch_size {
for o in 0..out_dim {
let mut sum = 0.0;
for i in 0..in_dim {
sum += input[b * in_dim + i] * weights[o * in_dim + i];
}
output[b * out_dim + o] = sum;
}
}
}
output
}
fn silu(&self, x: &[f32]) -> Vec<f32> {
crate::kernels::silu_vec(x)
}
}
#[derive(Debug)]
pub struct RuvLtraDecoderLayer {
pub self_attn: RuvLtraAttention,
pub mlp: RuvLtraMLP,
pub input_layernorm: Vec<f32>,
pub post_attention_layernorm: Vec<f32>,
pub rms_norm_eps: f32,
pub hidden_size: usize,
pub layer_idx: usize,
}
impl RuvLtraDecoderLayer {
pub fn new(config: &RuvLtraConfig, layer_idx: usize) -> Self {
Self {
self_attn: RuvLtraAttention::new(config),
mlp: RuvLtraMLP::new(config),
input_layernorm: vec![1.0; config.hidden_size],
post_attention_layernorm: vec![1.0; config.hidden_size],
rms_norm_eps: config.rms_norm_eps,
hidden_size: config.hidden_size,
layer_idx,
}
}
pub fn forward(
&self,
hidden_states: &[f32],
positions: &[usize],
kv_cache: Option<(&mut Vec<f32>, &mut Vec<f32>)>,
) -> Result<Vec<f32>> {
let seq_len = positions.len();
let mut normed = hidden_states.to_vec();
for t in 0..seq_len {
let offset = t * self.hidden_size;
let slice = &mut normed[offset..offset + self.hidden_size];
rms_norm_neon(slice, &self.input_layernorm, self.rms_norm_eps);
}
let attn_output = self.self_attn.forward(&normed, positions, kv_cache)?;
let mut hidden: Vec<f32> = hidden_states
.iter()
.zip(attn_output.iter())
.map(|(h, a)| h + a)
.collect();
let mut normed = hidden.clone();
for t in 0..seq_len {
let offset = t * self.hidden_size;
let slice = &mut normed[offset..offset + self.hidden_size];
rms_norm_neon(slice, &self.post_attention_layernorm, self.rms_norm_eps);
}
let mlp_output = self.mlp.forward(&normed)?;
for (h, m) in hidden.iter_mut().zip(mlp_output.iter()) {
*h += m;
}
Ok(hidden)
}
}
#[derive(Debug)]
pub struct RuvLtraModel {
pub config: RuvLtraConfig,
pub embed_tokens: Vec<f32>,
pub layers: Vec<RuvLtraDecoderLayer>,
pub norm: Vec<f32>,
pub lm_head: Option<Vec<f32>>,
pub tie_word_embeddings: bool,
sona: Option<Arc<RwLock<SonaIntegration>>>,
}
impl RuvLtraModel {
pub fn new(config: &RuvLtraConfig) -> Result<Self> {
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for i in 0..config.num_hidden_layers {
layers.push(RuvLtraDecoderLayer::new(config, i));
}
let sona = if config.sona_enabled {
Some(Arc::new(RwLock::new(SonaIntegration::new(
config.sona_config.clone(),
))))
} else {
None
};
Ok(Self {
config: config.clone(),
embed_tokens: vec![0.0; config.vocab_size * config.hidden_size],
layers,
norm: vec![1.0; config.hidden_size],
lm_head: None,
tie_word_embeddings: true,
sona,
})
}
pub fn enable_sona_pretraining(&mut self) -> Result<()> {
if self.sona.is_none() {
self.sona = Some(Arc::new(RwLock::new(SonaIntegration::new(
self.config.sona_config.clone(),
))));
}
Ok(())
}
pub fn sona(&self) -> Option<&Arc<RwLock<SonaIntegration>>> {
self.sona.as_ref()
}
pub fn forward(
&self,
input_ids: &[u32],
positions: &[usize],
mut kv_caches: Option<&mut Vec<(Vec<f32>, Vec<f32>)>>,
) -> Result<Vec<f32>> {
let seq_len = positions.len();
if input_ids.len() != seq_len {
return Err(RuvLLMError::InvalidOperation(format!(
"input_ids length {} != positions length {}",
input_ids.len(),
seq_len
)));
}
let mut hidden_states = Vec::with_capacity(seq_len * self.config.hidden_size);
for &token_id in input_ids {
let offset = (token_id as usize) * self.config.hidden_size;
if offset + self.config.hidden_size > self.embed_tokens.len() {
return Err(RuvLLMError::InvalidOperation(format!(
"Token ID {} out of vocabulary bounds",
token_id
)));
}
hidden_states
.extend_from_slice(&self.embed_tokens[offset..offset + self.config.hidden_size]);
}
for (layer_idx, layer) in self.layers.iter().enumerate() {
let kv_cache = kv_caches.as_mut().map(|caches| {
while caches.len() <= layer_idx {
caches.push((Vec::new(), Vec::new()));
}
let (k, v) = &mut caches[layer_idx];
(k, v)
});
hidden_states = layer.forward(&hidden_states, positions, kv_cache)?;
}
for t in 0..seq_len {
let offset = t * self.config.hidden_size;
let slice = &mut hidden_states[offset..offset + self.config.hidden_size];
rms_norm_neon(slice, &self.norm, self.config.rms_norm_eps);
}
let lm_weights = if self.tie_word_embeddings {
&self.embed_tokens
} else {
self.lm_head
.as_ref()
.ok_or_else(|| RuvLLMError::InvalidOperation("No LM head weights".to_string()))?
};
let mut logits = vec![0.0; seq_len * self.config.vocab_size];
for t in 0..seq_len {
for v in 0..self.config.vocab_size {
let mut sum = 0.0;
for h in 0..self.config.hidden_size {
sum += hidden_states[t * self.config.hidden_size + h]
* lm_weights[v * self.config.hidden_size + h];
}
logits[t * self.config.vocab_size + v] = sum;
}
}
Ok(logits)
}
pub fn record_trajectory(&self, trajectory: Trajectory) -> Result<()> {
if let Some(sona) = &self.sona {
sona.write().record_trajectory(trajectory)?;
}
Ok(())
}
pub fn get_routing_recommendation(
&self,
query_embedding: &[f32],
) -> Option<crate::sona::RoutingRecommendation> {
self.sona
.as_ref()
.map(|sona| sona.read().get_routing_recommendation(query_embedding))
}
pub fn info(&self) -> RuvLtraModelInfo {
RuvLtraModelInfo {
name: "RuvLTRA".to_string(),
architecture: "Qwen".to_string(),
num_params: self.config.estimate_params(),
hidden_size: self.config.hidden_size,
num_layers: self.config.num_hidden_layers,
vocab_size: self.config.vocab_size,
max_context: self.config.max_position_embeddings,
quantization: self.config.quantization,
ane_optimized: self.config.is_ane_optimized(),
sona_enabled: self.sona.is_some(),
estimated_memory_mb: self.config.estimate_memory_mb(),
}
}
pub fn apply_chat_template(messages: &[(String, String)], system: Option<&str>) -> String {
let mut result = String::new();
if let Some(sys) = system {
result.push_str("<|im_start|>system\n");
result.push_str(sys);
result.push_str("<|im_end|>\n");
}
for (role, content) in messages {
result.push_str(&format!("<|im_start|>{}\n{}<|im_end|>\n", role, content));
}
result.push_str("<|im_start|>assistant\n");
result
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuvLtraModelInfo {
pub name: String,
pub architecture: String,
pub num_params: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub vocab_size: usize,
pub max_context: usize,
pub quantization: QuantizationType,
pub ane_optimized: bool,
pub sona_enabled: bool,
pub estimated_memory_mb: f32,
}
#[derive(Debug)]
pub struct AneDispatcher {
mode: AneOptimization,
adaptive_threshold: usize,
ane_ops: std::sync::atomic::AtomicU64,
gpu_ops: std::sync::atomic::AtomicU64,
}
impl AneDispatcher {
pub fn new(mode: AneOptimization) -> Self {
Self {
mode,
adaptive_threshold: 512, ane_ops: std::sync::atomic::AtomicU64::new(0),
gpu_ops: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn should_use_ane(&self, op_type: &str, batch_size: usize, seq_len: usize) -> bool {
match self.mode {
AneOptimization::Disabled => false,
AneOptimization::AneOnly => true,
AneOptimization::GpuOnly => false,
AneOptimization::HybridDispatch => {
matches!(op_type, "mlp" | "linear" | "matmul" | "activation")
}
AneOptimization::Adaptive => {
let workload = batch_size * seq_len;
if workload < self.adaptive_threshold {
true
} else {
matches!(op_type, "mlp" | "linear")
}
}
}
}
pub fn record_ane_op(&self) {
self.ane_ops
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn record_gpu_op(&self) {
self.gpu_ops
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn stats(&self) -> (u64, u64) {
(
self.ane_ops.load(std::sync::atomic::Ordering::Relaxed),
self.gpu_ops.load(std::sync::atomic::Ordering::Relaxed),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ruvltra_config_qwen() {
let config = RuvLtraConfig::qwen_0_5b();
assert_eq!(config.hidden_size, 896);
assert_eq!(config.num_hidden_layers, 24);
assert_eq!(config.num_attention_heads, 14);
assert_eq!(config.intermediate_size, 4864);
assert_eq!(config.vocab_size, 151936);
assert!(config.is_ane_optimized());
}
#[test]
fn test_ruvltra_config_tiny() {
let config = RuvLtraConfig::tiny();
assert_eq!(config.hidden_size, 768);
assert!(config.is_ane_optimized());
}
#[test]
fn test_ane_optimization() {
let config = RuvLtraConfig::qwen_0_5b();
assert!(config.ane_optimization.uses_ane());
assert!(config.ane_optimization.uses_gpu());
}
#[test]
fn test_quantization_memory() {
let config = RuvLtraConfig::qwen_0_5b();
let params = config.estimate_params();
let memory_int4 = QuantizationType::Int4.estimate_memory_mb(params);
let memory_fp16 = QuantizationType::Fp16.estimate_memory_mb(params);
assert!(memory_fp16 > memory_int4 * 3.5);
assert!(memory_fp16 < memory_int4 * 4.5);
}
#[test]
fn test_ruvltra_model_creation() {
let config = RuvLtraConfig::tiny();
let model = RuvLtraModel::new(&config).unwrap();
assert_eq!(model.layers.len(), 4);
assert_eq!(
model.embed_tokens.len(),
config.vocab_size * config.hidden_size
);
}
#[test]
fn test_gqa_ratio() {
let config = RuvLtraConfig::qwen_0_5b();
assert_eq!(config.gqa_ratio(), 7); }
#[test]
fn test_ane_dispatcher() {
let dispatcher = AneDispatcher::new(AneOptimization::HybridDispatch);
assert!(dispatcher.should_use_ane("mlp", 1, 128));
assert!(dispatcher.should_use_ane("linear", 1, 128));
assert!(!dispatcher.should_use_ane("attention", 1, 128));
}
#[test]
fn test_chat_template() {
let messages = vec![
("user".to_string(), "Hello!".to_string()),
("assistant".to_string(), "Hi there!".to_string()),
("user".to_string(), "How are you?".to_string()),
];
let template =
RuvLtraModel::apply_chat_template(&messages, Some("You are a helpful assistant."));
assert!(template.contains("<|im_start|>system"));
assert!(template.contains("<|im_start|>user"));
assert!(template.contains("<|im_start|>assistant"));
assert!(template.contains("<|im_end|>"));
assert!(template.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn test_model_info() {
let config = RuvLtraConfig::qwen_0_5b();
let model = RuvLtraModel::new(&config).unwrap();
let info = model.info();
assert_eq!(info.name, "RuvLTRA");
assert_eq!(info.architecture, "Qwen");
assert_eq!(info.hidden_size, 896);
assert!(info.ane_optimized);
assert!(info.sona_enabled);
}
}