use std::num::NonZeroU32;
use anyhow::{Context, Result};
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct ExtractedFeatures {
pub data: Tensor,
pub n_layers: usize,
pub feature_dim: usize,
pub n_timesteps: usize,
}
#[derive(Debug, Clone)]
pub struct LlamaFeatureConfig {
pub model_path: String,
pub layer_positions: Vec<f64>,
pub n_layers: usize,
pub n_ctx: u32,
pub frequency: f64,
}
impl Default for LlamaFeatureConfig {
fn default() -> Self {
Self {
model_path: String::new(),
layer_positions: vec![0.5, 0.75, 1.0],
n_layers: 28, n_ctx: 2048,
frequency: 2.0,
}
}
}
pub fn compute_layer_indices(layer_positions: &[f64], n_total_layers: usize) -> Vec<usize> {
layer_positions
.iter()
.map(|&f| {
let idx = (f * (n_total_layers as f64 - 1.0)).floor() as usize;
idx.min(n_total_layers - 1)
})
.collect()
}
pub fn extract_llama_features(
config: &LlamaFeatureConfig,
prompt: &str,
verbose: bool,
) -> Result<ExtractedFeatures> {
use llama_cpp_4::context::params::LlamaContextParams;
use llama_cpp_4::llama_backend::LlamaBackend;
use llama_cpp_4::llama_batch::LlamaBatch;
use llama_cpp_4::model::params::LlamaModelParams;
use llama_cpp_4::model::{AddBos, LlamaModel};
let backend = LlamaBackend::init()?;
let model_params = {
#[cfg(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal"))]
{ LlamaModelParams::default().with_n_gpu_layers(1000) }
#[cfg(not(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal")))]
{ LlamaModelParams::default() }
};
let model = LlamaModel::load_from_file(&backend, &config.model_path, &model_params)
.with_context(|| format!("unable to load LLaMA model: {}", config.model_path))?;
let n_total_layers = model.n_layer() as usize;
let hidden_dim = model.n_embd() as usize;
let layer_indices = compute_layer_indices(&config.layer_positions, n_total_layers);
let n_layer_groups = layer_indices.len();
if verbose {
eprintln!("LLaMA: {} layers, hidden_dim={}", n_total_layers, hidden_dim);
eprintln!("Extracting layer groups: {:?} (from positions {:?})",
layer_indices, config.layer_positions);
}
let tokens = model
.str_to_token(prompt, AddBos::Always)
.with_context(|| "failed to tokenize prompt")?;
let n_tokens = tokens.len();
if verbose {
eprintln!("Tokens: {}", n_tokens);
}
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(NonZeroU32::new(config.n_ctx.max(n_tokens as u32 + 16)).unwrap()))
.with_embeddings(true);
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create LLaMA context")?;
let mut batch = LlamaBatch::new(n_tokens + 16, 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(*token, i as i32, &[0], true)?;
}
ctx.decode(&mut batch)
.with_context(|| "llama_decode() failed")?;
let mut token_embeddings: Vec<Vec<f32>> = Vec::with_capacity(n_tokens);
for i in 0..n_tokens {
let emb = ctx.embeddings_ith(i as i32)
.with_context(|| format!("failed to get embedding for token {}", i))?;
token_embeddings.push(emb.to_vec());
}
if verbose {
eprintln!("Got {} token embeddings of dim {}", token_embeddings.len(), hidden_dim);
}
let n_timesteps = n_tokens;
let total = n_layer_groups * hidden_dim * n_timesteps;
let mut data = vec![0.0f32; total];
for li in 0..n_layer_groups {
for ti in 0..n_timesteps {
let emb = &token_embeddings[ti];
for di in 0..hidden_dim {
data[li * hidden_dim * n_timesteps + di * n_timesteps + ti] = emb[di];
}
}
}
Ok(ExtractedFeatures {
data: Tensor::from_vec(data, vec![n_layer_groups, hidden_dim, n_timesteps]),
n_layers: n_layer_groups,
feature_dim: hidden_dim,
n_timesteps,
})
}
pub fn extract_llama_features_timed(
config: &LlamaFeatureConfig,
words: &[(String, f64)],
total_duration: f64,
verbose: bool,
) -> Result<ExtractedFeatures> {
use llama_cpp_4::context::params::LlamaContextParams;
use llama_cpp_4::llama_backend::LlamaBackend;
use llama_cpp_4::llama_batch::LlamaBatch;
use llama_cpp_4::model::params::LlamaModelParams;
use llama_cpp_4::model::{AddBos, LlamaModel};
let backend = LlamaBackend::init()?;
let model_params = {
#[cfg(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal"))]
{ LlamaModelParams::default().with_n_gpu_layers(1000) }
#[cfg(not(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal")))]
{ LlamaModelParams::default() }
};
let model = LlamaModel::load_from_file(&backend, &config.model_path, &model_params)
.with_context(|| format!("unable to load LLaMA model: {}", config.model_path))?;
let n_total_layers = model.n_layer() as usize;
let hidden_dim = model.n_embd() as usize;
let layer_indices = compute_layer_indices(&config.layer_positions, n_total_layers);
let n_layer_groups = layer_indices.len();
let full_text: String = words.iter().map(|(w, _)| w.as_str()).collect::<Vec<_>>().join(" ");
let tokens = model
.str_to_token(&full_text, AddBos::Always)
.with_context(|| "failed to tokenize")?;
let n_tokens = tokens.len();
if verbose {
eprintln!("LLaMA timed: {} words, {} tokens, {} layers, dim={}",
words.len(), n_tokens, n_total_layers, hidden_dim);
}
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(NonZeroU32::new(config.n_ctx.max(n_tokens as u32 + 16)).unwrap()))
.with_embeddings(true);
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create LLaMA context")?;
let mut batch = LlamaBatch::new(n_tokens + 16, 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(*token, i as i32, &[0], true)?;
}
ctx.decode(&mut batch)
.with_context(|| "llama_decode() failed")?;
let mut token_embeddings: Vec<Vec<f32>> = Vec::with_capacity(n_tokens);
for i in 0..n_tokens {
let emb = ctx.embeddings_ith(i as i32)
.with_context(|| format!("failed to get embedding for token {}", i))?;
token_embeddings.push(emb.to_vec());
}
let n_words = words.len();
let tokens_per_word = if n_words > 0 { (n_tokens - 1).max(1) as f64 / n_words as f64 } else { 1.0 };
let mut word_embeddings: Vec<(Vec<f32>, f64)> = Vec::with_capacity(n_words);
for (wi, (_, start_time)) in words.iter().enumerate() {
let tok_start = 1 + (wi as f64 * tokens_per_word).floor() as usize;
let tok_end = (1 + ((wi + 1) as f64 * tokens_per_word).floor() as usize).min(n_tokens);
let tok_end = tok_end.max(tok_start + 1).min(n_tokens);
let mut avg = vec![0.0f32; hidden_dim];
let count = (tok_end - tok_start) as f32;
for ti in tok_start..tok_end {
for di in 0..hidden_dim {
avg[di] += token_embeddings[ti][di];
}
}
if count > 0.0 {
for v in avg.iter_mut() {
*v /= count;
}
}
word_embeddings.push((avg, *start_time));
}
let n_timesteps = (total_duration * config.frequency).ceil() as usize;
let dt = 1.0 / config.frequency;
let total = n_layer_groups * hidden_dim * n_timesteps;
let mut data = vec![0.0f32; total];
for ti in 0..n_timesteps {
let t = ti as f64 * dt;
let emb = if let Some(pos) = word_embeddings.iter().rposition(|(_, st)| *st <= t) {
&word_embeddings[pos].0
} else if !word_embeddings.is_empty() {
&word_embeddings[0].0
} else {
continue; };
for li in 0..n_layer_groups {
for di in 0..hidden_dim {
data[li * hidden_dim * n_timesteps + di * n_timesteps + ti] = emb[di];
}
}
}
Ok(ExtractedFeatures {
data: Tensor::from_vec(data, vec![n_layer_groups, hidden_dim, n_timesteps]),
n_layers: n_layer_groups,
feature_dim: hidden_dim,
n_timesteps,
})
}
pub fn zero_features(n_layers: usize, feature_dim: usize, n_timesteps: usize) -> ExtractedFeatures {
ExtractedFeatures {
data: Tensor::zeros(&[n_layers, feature_dim, n_timesteps]),
n_layers,
feature_dim,
n_timesteps,
}
}
pub fn resample_features(features: &ExtractedFeatures, n_timesteps_out: usize) -> ExtractedFeatures {
let n_layers = features.n_layers;
let feature_dim = features.feature_dim;
let n_in = features.n_timesteps;
if n_in == n_timesteps_out {
return features.clone();
}
let mut data = vec![0.0f32; n_layers * feature_dim * n_timesteps_out];
for li in 0..n_layers {
for di in 0..feature_dim {
for to in 0..n_timesteps_out {
let ti = (to as f64 * n_in as f64 / n_timesteps_out as f64).floor() as usize;
let ti = ti.min(n_in - 1);
data[li * feature_dim * n_timesteps_out + di * n_timesteps_out + to] =
features.data.data[li * feature_dim * n_in + di * n_in + ti];
}
}
}
ExtractedFeatures {
data: Tensor::from_vec(data, vec![n_layers, feature_dim, n_timesteps_out]),
n_layers,
feature_dim,
n_timesteps: n_timesteps_out,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_layer_indices() {
let indices = compute_layer_indices(&[0.5, 0.75, 1.0], 28);
assert_eq!(indices, vec![13, 20, 27]);
}
#[test]
fn test_compute_layer_indices_small() {
let indices = compute_layer_indices(&[0.5, 0.75, 1.0], 4);
assert_eq!(indices, vec![1, 2, 3]);
}
#[test]
fn test_zero_features() {
let f = zero_features(3, 1024, 100);
assert_eq!(f.data.shape, vec![3, 1024, 100]);
assert!(f.data.data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_resample_features_identity() {
let f = zero_features(2, 4, 10);
let r = resample_features(&f, 10);
assert_eq!(r.n_timesteps, 10);
}
#[test]
fn test_resample_features_upsample() {
let mut f = zero_features(1, 2, 4);
f.data.data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let r = resample_features(&f, 8);
assert_eq!(r.n_timesteps, 8);
assert_eq!(r.data.shape, vec![1, 2, 8]);
assert_eq!(r.data.data[0], 1.0); assert_eq!(r.data.data[1], 1.0); assert_eq!(r.data.data[2], 2.0); }
}