use std::num::NonZeroU32;
use anyhow::{Context, Result};
use crate::tensor::Tensor;
use llama_cpp_4::context::tensor_capture::TensorCapture;
#[derive(Debug, Clone)]
pub struct ExtractedFeatures {
pub data: Tensor,
pub n_layers: usize,
pub feature_dim: usize,
pub n_timesteps: usize,
}
#[derive(Debug, Clone)]
pub struct LlamaFeatureConfig {
pub model_path: String,
pub layer_positions: Vec<f64>,
pub n_layers: usize,
pub n_ctx: u32,
pub frequency: f64,
}
impl Default for LlamaFeatureConfig {
fn default() -> Self {
Self {
model_path: String::new(),
layer_positions: vec![0.5, 0.75, 1.0],
n_layers: 28, n_ctx: 2048,
frequency: 2.0,
}
}
}
pub fn compute_layer_indices(layer_positions: &[f64], n_total_layers: usize) -> Vec<usize> {
layer_positions
.iter()
.map(|&f| {
let idx = (f * (n_total_layers as f64 - 1.0)).floor() as usize;
idx.min(n_total_layers - 1)
})
.collect()
}
fn fill_layer_data(
data: &mut [f32],
layer_indices: &[usize],
capture: &TensorCapture,
fallback_embs: Option<&[Vec<f32>]>,
hidden_dim: usize,
n_timesteps: usize,
) {
for (li, &layer_idx) in layer_indices.iter().enumerate() {
if let Some(ct) = capture.get_layer(layer_idx) {
let tokens_to_copy = ct.n_tokens().min(n_timesteps);
let dims_to_copy = ct.n_embd().min(hidden_dim);
for ti in 0..tokens_to_copy {
let emb = ct.token_embedding(ti).unwrap();
for di in 0..dims_to_copy {
data[li * hidden_dim * n_timesteps + di * n_timesteps + ti] = emb[di];
}
}
} else if let Some(embs) = fallback_embs {
for ti in 0..n_timesteps {
for di in 0..hidden_dim {
data[li * hidden_dim * n_timesteps + di * n_timesteps + ti] = embs[ti][di];
}
}
}
}
}
pub fn extract_llama_features(
config: &LlamaFeatureConfig,
prompt: &str,
verbose: bool,
) -> Result<ExtractedFeatures> {
use llama_cpp_4::context::params::LlamaContextParams;
use llama_cpp_4::llama_backend::LlamaBackend;
use llama_cpp_4::llama_batch::LlamaBatch;
use llama_cpp_4::model::params::LlamaModelParams;
use llama_cpp_4::model::{AddBos, LlamaModel};
let backend = LlamaBackend::init()?;
let model_params = {
#[cfg(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal"))]
{ LlamaModelParams::default().with_n_gpu_layers(1000) }
#[cfg(not(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal")))]
{ LlamaModelParams::default() }
};
let model = LlamaModel::load_from_file(&backend, &config.model_path, &model_params)
.with_context(|| format!("unable to load LLaMA model: {}", config.model_path))?;
let n_total_layers = model.n_layer() as usize;
let hidden_dim = model.n_embd() as usize;
let layer_indices = compute_layer_indices(&config.layer_positions, n_total_layers);
let n_layer_groups = layer_indices.len();
if verbose {
eprintln!("LLaMA: {} layers, hidden_dim={}", n_total_layers, hidden_dim);
eprintln!("Extracting layers: {:?} (from positions {:?})",
layer_indices, config.layer_positions);
}
let tokens = model
.str_to_token(prompt, AddBos::Always)
.with_context(|| "failed to tokenize prompt")?;
let n_tokens = tokens.len();
if verbose {
eprintln!("Tokens: {}", n_tokens);
}
let mut capture = TensorCapture::for_layers(&layer_indices);
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(NonZeroU32::new(config.n_ctx.max(n_tokens as u32 + 16)).unwrap()))
.with_embeddings(true)
.with_tensor_capture(&mut capture);
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create LLaMA context")?;
let mut batch = LlamaBatch::new(n_tokens + 16, 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(*token, i as i32, &[0], true)?;
}
ctx.decode(&mut batch)
.with_context(|| "llama_decode() failed")?;
let captured = capture.captured_layers();
if verbose {
eprintln!("Captured {}/{} layers: {:?}", captured.len(), n_layer_groups, captured);
}
let n_timesteps = n_tokens;
let total = n_layer_groups * hidden_dim * n_timesteps;
let mut data = vec![0.0f32; total];
let fallback_embs: Option<Vec<Vec<f32>>> = if captured.len() < n_layer_groups {
if verbose {
eprintln!("WARNING: using final-layer fallback for {}/{} missing layers",
n_layer_groups - captured.len(), n_layer_groups);
}
let mut embs = Vec::with_capacity(n_tokens);
for i in 0..n_tokens {
let emb = ctx.embeddings_ith(i as i32)
.with_context(|| format!("failed to get embedding for token {}", i))?;
embs.push(emb.to_vec());
}
Some(embs)
} else {
None
};
fill_layer_data(
&mut data,
&layer_indices,
&capture,
fallback_embs.as_deref(),
hidden_dim,
n_timesteps,
);
Ok(ExtractedFeatures {
data: Tensor::from_vec(data, vec![n_layer_groups, hidden_dim, n_timesteps]),
n_layers: n_layer_groups,
feature_dim: hidden_dim,
n_timesteps,
})
}
pub fn extract_llama_features_timed(
config: &LlamaFeatureConfig,
words: &[(String, f64)],
total_duration: f64,
verbose: bool,
) -> Result<ExtractedFeatures> {
use llama_cpp_4::context::params::LlamaContextParams;
use llama_cpp_4::llama_backend::LlamaBackend;
use llama_cpp_4::llama_batch::LlamaBatch;
use llama_cpp_4::model::params::LlamaModelParams;
use llama_cpp_4::model::{AddBos, LlamaModel};
let backend = LlamaBackend::init()?;
let model_params = {
#[cfg(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal"))]
{ LlamaModelParams::default().with_n_gpu_layers(1000) }
#[cfg(not(any(feature = "llama-cuda", feature = "llama-vulkan", feature = "llama-metal")))]
{ LlamaModelParams::default() }
};
let model = LlamaModel::load_from_file(&backend, &config.model_path, &model_params)
.with_context(|| format!("unable to load LLaMA model: {}", config.model_path))?;
let n_total_layers = model.n_layer() as usize;
let hidden_dim = model.n_embd() as usize;
let layer_indices = compute_layer_indices(&config.layer_positions, n_total_layers);
let n_layer_groups = layer_indices.len();
let full_text: String = words.iter().map(|(w, _)| w.as_str()).collect::<Vec<_>>().join(" ");
let tokens = model
.str_to_token(&full_text, AddBos::Always)
.with_context(|| "failed to tokenize")?;
let n_tokens = tokens.len();
if verbose {
eprintln!("LLaMA timed: {} words, {} tokens, {} layers, dim={}",
words.len(), n_tokens, n_total_layers, hidden_dim);
eprintln!("Extracting layers: {:?}", layer_indices);
}
let mut capture = TensorCapture::for_layers(&layer_indices);
let ctx_params = LlamaContextParams::default()
.with_n_ctx(Some(NonZeroU32::new(config.n_ctx.max(n_tokens as u32 + 16)).unwrap()))
.with_embeddings(true)
.with_tensor_capture(&mut capture);
let mut ctx = model
.new_context(&backend, ctx_params)
.with_context(|| "unable to create LLaMA context")?;
let mut batch = LlamaBatch::new(n_tokens + 16, 1);
for (i, token) in tokens.iter().enumerate() {
batch.add(*token, i as i32, &[0], true)?;
}
ctx.decode(&mut batch)
.with_context(|| "llama_decode() failed")?;
let captured = capture.captured_layers();
let all_captured = captured.len() == n_layer_groups;
if verbose {
eprintln!("Captured {}/{} layers: {:?}", captured.len(), n_layer_groups, captured);
}
let final_embeddings: Option<Vec<Vec<f32>>> = if !all_captured {
let mut embs = Vec::with_capacity(n_tokens);
for i in 0..n_tokens {
let emb = ctx.embeddings_ith(i as i32)
.with_context(|| format!("failed to get embedding for token {}", i))?;
embs.push(emb.to_vec());
}
Some(embs)
} else {
None
};
let n_words = words.len();
let tokens_per_word = if n_words > 0 {
(n_tokens - 1).max(1) as f64 / n_words as f64
} else {
1.0
};
let mut layer_word_embeddings: Vec<Vec<(Vec<f32>, f64)>> = Vec::with_capacity(n_layer_groups);
for &layer_idx in &layer_indices {
let ct = capture.get_layer(layer_idx);
let mut word_embs: Vec<(Vec<f32>, f64)> = Vec::with_capacity(n_words);
for (wi, (_, start_time)) in words.iter().enumerate() {
let tok_start = 1 + (wi as f64 * tokens_per_word).floor() as usize;
let tok_end = (1 + ((wi + 1) as f64 * tokens_per_word).floor() as usize).min(n_tokens);
let tok_end = tok_end.max(tok_start + 1).min(n_tokens);
let mut avg = vec![0.0f32; hidden_dim];
let count = (tok_end - tok_start) as f32;
for ti in tok_start..tok_end {
if let Some(ct) = ct {
let ti_clamped = ti.min(ct.n_tokens() - 1);
if let Some(emb) = ct.token_embedding(ti_clamped) {
for di in 0..hidden_dim.min(ct.n_embd()) {
avg[di] += emb[di];
}
}
} else if let Some(ref embs) = final_embeddings {
for di in 0..hidden_dim {
avg[di] += embs[ti][di];
}
}
}
if count > 0.0 {
for v in avg.iter_mut() {
*v /= count;
}
}
word_embs.push((avg, *start_time));
}
layer_word_embeddings.push(word_embs);
}
let n_timesteps = (total_duration * config.frequency).ceil() as usize;
let dt = 1.0 / config.frequency;
let total = n_layer_groups * hidden_dim * n_timesteps;
let mut data = vec![0.0f32; total];
for ti in 0..n_timesteps {
let t = ti as f64 * dt;
for li in 0..n_layer_groups {
let word_embs = &layer_word_embeddings[li];
let emb = if let Some(pos) = word_embs.iter().rposition(|(_, st)| *st <= t) {
&word_embs[pos].0
} else if !word_embs.is_empty() {
&word_embs[0].0
} else {
continue;
};
for di in 0..hidden_dim {
data[li * hidden_dim * n_timesteps + di * n_timesteps + ti] = emb[di];
}
}
}
Ok(ExtractedFeatures {
data: Tensor::from_vec(data, vec![n_layer_groups, hidden_dim, n_timesteps]),
n_layers: n_layer_groups,
feature_dim: hidden_dim,
n_timesteps,
})
}
pub fn zero_features(n_layers: usize, feature_dim: usize, n_timesteps: usize) -> ExtractedFeatures {
ExtractedFeatures {
data: Tensor::zeros(&[n_layers, feature_dim, n_timesteps]),
n_layers,
feature_dim,
n_timesteps,
}
}
pub fn resample_features(features: &ExtractedFeatures, n_timesteps_out: usize) -> ExtractedFeatures {
let n_layers = features.n_layers;
let feature_dim = features.feature_dim;
let n_in = features.n_timesteps;
if n_in == n_timesteps_out {
return features.clone();
}
let mut data = vec![0.0f32; n_layers * feature_dim * n_timesteps_out];
for li in 0..n_layers {
for di in 0..feature_dim {
for to in 0..n_timesteps_out {
let ti = (to as f64 * n_in as f64 / n_timesteps_out as f64).floor() as usize;
let ti = ti.min(n_in - 1);
data[li * feature_dim * n_timesteps_out + di * n_timesteps_out + to] =
features.data.data[li * feature_dim * n_in + di * n_in + ti];
}
}
}
ExtractedFeatures {
data: Tensor::from_vec(data, vec![n_layers, feature_dim, n_timesteps_out]),
n_layers,
feature_dim,
n_timesteps: n_timesteps_out,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_layer_indices() {
let indices = compute_layer_indices(&[0.5, 0.75, 1.0], 28);
assert_eq!(indices, vec![13, 20, 27]);
}
#[test]
fn test_compute_layer_indices_small() {
let indices = compute_layer_indices(&[0.5, 0.75, 1.0], 4);
assert_eq!(indices, vec![1, 2, 3]);
}
#[test]
fn test_zero_features() {
let f = zero_features(3, 1024, 100);
assert_eq!(f.data.shape, vec![3, 1024, 100]);
assert!(f.data.data.iter().all(|&v| v == 0.0));
}
#[test]
fn test_resample_features_identity() {
let f = zero_features(2, 4, 10);
let r = resample_features(&f, 10);
assert_eq!(r.n_timesteps, 10);
}
#[test]
fn test_resample_features_upsample() {
let mut f = zero_features(1, 2, 4);
f.data.data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
let r = resample_features(&f, 8);
assert_eq!(r.n_timesteps, 8);
assert_eq!(r.data.shape, vec![1, 2, 8]);
assert_eq!(r.data.data[0], 1.0); assert_eq!(r.data.data[1], 1.0); assert_eq!(r.data.data[2], 2.0); }
}