use std::num::NonZeroU32;
use std::thread;
use rig::embeddings::{Embedding, EmbeddingError, EmbeddingModel as _};
use tokio::sync::{mpsc, oneshot};
use crate::error::LoadError;
enum EmbeddingCommand {
Request(EmbeddingRequest),
Shutdown,
}
struct EmbeddingRequest {
texts: Vec<String>,
response_tx: oneshot::Sender<Result<Vec<Vec<f32>>, String>>,
}
pub struct EmbeddingClient {
request_tx: mpsc::UnboundedSender<EmbeddingCommand>,
ndims: usize,
worker_handle: Option<thread::JoinHandle<()>>,
}
impl EmbeddingClient {
pub fn from_gguf(
model_path: impl Into<String>,
n_gpu_layers: u32,
n_ctx: u32,
) -> Result<Self, LoadError> {
let model_path = model_path.into();
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<EmbeddingCommand>();
let (init_tx, init_rx) = std::sync::mpsc::channel::<Result<usize, LoadError>>();
let worker_handle = thread::spawn(move || {
embedding_worker(&model_path, n_gpu_layers, n_ctx, init_tx, &mut request_rx);
});
let ndims = init_rx
.recv()
.map_err(|_| LoadError::WorkerInitDisconnected)??;
Ok(Self {
request_tx,
ndims,
worker_handle: Some(worker_handle),
})
}
pub fn embedding_model(&self, model: impl Into<String>) -> EmbeddingModelHandle {
EmbeddingModelHandle::make(self, model, None)
}
pub fn embedding_model_with_ndims(
&self,
model: impl Into<String>,
ndims: usize,
) -> EmbeddingModelHandle {
EmbeddingModelHandle::make(self, model, Some(ndims))
}
}
impl Drop for EmbeddingClient {
fn drop(&mut self) {
let _ = self.request_tx.send(EmbeddingCommand::Shutdown);
if let Some(worker_handle) = self.worker_handle.take() {
let _ = worker_handle.join();
}
}
}
#[derive(Clone)]
pub struct EmbeddingModelHandle {
request_tx: mpsc::UnboundedSender<EmbeddingCommand>,
ndims: usize,
#[allow(dead_code)]
model_id: String,
}
impl rig::embeddings::EmbeddingModel for EmbeddingModelHandle {
const MAX_DOCUMENTS: usize = 256;
type Client = EmbeddingClient;
fn make(client: &EmbeddingClient, model: impl Into<String>, dims: Option<usize>) -> Self {
Self {
request_tx: client.request_tx.clone(),
ndims: dims.unwrap_or(client.ndims),
model_id: model.into(),
}
}
fn ndims(&self) -> usize {
self.ndims
}
async fn embed_texts(
&self,
texts: impl IntoIterator<Item = String> + Send,
) -> Result<Vec<Embedding>, EmbeddingError> {
let texts: Vec<String> = texts.into_iter().collect();
let documents = texts.clone();
let (tx, rx) = oneshot::channel();
self.request_tx
.send(EmbeddingCommand::Request(EmbeddingRequest {
texts,
response_tx: tx,
}))
.map_err(|_| EmbeddingError::ProviderError("Embedding worker shut down".into()))?;
let raw_embeddings = rx
.await
.map_err(|_| EmbeddingError::ProviderError("Response channel closed".into()))?
.map_err(EmbeddingError::ProviderError)?;
Ok(documents
.into_iter()
.zip(raw_embeddings)
.map(|(doc, vec)| Embedding {
document: doc,
vec: vec.into_iter().map(|v| v as f64).collect(),
})
.collect())
}
}
fn embedding_worker(
model_path: &str,
n_gpu_layers: u32,
n_ctx: u32,
init_tx: std::sync::mpsc::Sender<Result<usize, LoadError>>,
rx: &mut mpsc::UnboundedReceiver<EmbeddingCommand>,
) {
use llama_cpp_2::list_llama_ggml_backend_devices;
use llama_cpp_2::model::LlamaModel as LlamaCppModel;
use llama_cpp_2::model::params::LlamaModelParams;
let backend = match crate::shared_backend() {
Ok(b) => b,
Err(e) => {
let _ = init_tx.send(Err(LoadError::BackendInit(e)));
return;
}
};
let mut model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
if backend.supports_gpu_offload() {
let vulkan_devices: Vec<usize> = list_llama_ggml_backend_devices()
.into_iter()
.filter(|device| device.backend.eq_ignore_ascii_case("vulkan"))
.map(|device| device.index)
.collect();
if !vulkan_devices.is_empty() {
model_params = match model_params.with_devices(&vulkan_devices) {
Ok(params) => {
log::info!("Using Vulkan backend devices: {vulkan_devices:?}");
params
}
Err(e) => {
let _ = init_tx.send(Err(LoadError::ConfigureDevices(e.to_string())));
return;
}
};
}
}
log::info!("Loading embedding model from {model_path}...");
let model = match LlamaCppModel::load_from_file(backend, model_path, &model_params) {
Ok(m) => m,
Err(e) => {
let _ = init_tx.send(Err(LoadError::ModelLoad(e.to_string())));
return;
}
};
let ndims = model.n_embd() as usize;
log::info!("Embedding model loaded (ndims={ndims}).");
let _ = init_tx.send(Ok(ndims));
while let Some(command) = rx.blocking_recv() {
let req = match command {
EmbeddingCommand::Request(req) => req,
EmbeddingCommand::Shutdown => break,
};
let result = run_embedding(backend, &model, n_ctx, &req.texts);
let _ = req.response_tx.send(result);
}
}
fn run_embedding(
backend: &llama_cpp_2::llama_backend::LlamaBackend,
model: &llama_cpp_2::model::LlamaModel,
n_ctx: u32,
texts: &[String],
) -> Result<Vec<Vec<f32>>, String> {
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::AddBos;
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(n_ctx).map(Some).unwrap_or(None))
.with_n_batch(n_ctx)
.with_n_ubatch(n_ctx)
.with_n_seq_max((texts.len() as u32).max(1))
.with_embeddings(true);
let mut ctx = model
.new_context(backend, ctx_params)
.map_err(|e| format!("Embedding context creation failed: {e}"))?;
let batch_limit = ctx.n_batch().max(1) as usize;
let tokenized: Vec<Vec<_>> = texts
.iter()
.map(|text| model.str_to_token(text, AddBos::Always))
.collect::<Result<_, _>>()
.map_err(|e| format!("Tokenization failed: {e}"))?;
let mut results = Vec::with_capacity(texts.len());
let mut text_idx = 0;
while text_idx < texts.len() {
let mut batch = LlamaBatch::new(batch_limit, texts.len().min(batch_limit) as i32);
let mut total_tokens = 0;
let mut batch_seq_ids = Vec::new();
let batch_start = text_idx;
while text_idx < texts.len() {
let tokens = &tokenized[text_idx];
if total_tokens + tokens.len() > batch_limit && !batch_seq_ids.is_empty() {
break;
}
let seq_id = (text_idx - batch_start) as i32;
for (pos, &token) in tokens.iter().enumerate() {
batch
.add(token, pos as i32, &[seq_id], true)
.map_err(|e| format!("Batch add failed: {e}"))?;
}
batch_seq_ids.push(seq_id);
total_tokens += tokens.len();
text_idx += 1;
}
ctx.encode(&mut batch)
.map_err(|e| format!("Embedding encode failed: {e}"))?;
for &seq_id in &batch_seq_ids {
let emb = ctx
.embeddings_seq_ith(seq_id)
.map_err(|e| format!("Failed to get embedding for seq {seq_id}: {e}"))?;
results.push(emb.to_vec());
}
ctx.clear_kv_cache();
}
Ok(results)
}