Skip to main content

rig_llama_cpp/
embedding.rs

1use std::num::NonZeroU32;
2use std::thread;
3
4use rig::embeddings::{Embedding, EmbeddingError, EmbeddingModel as _};
5use tokio::sync::{mpsc, oneshot};
6
7use crate::error::LoadError;
8
9enum EmbeddingCommand {
10    Request(EmbeddingRequest),
11    Shutdown,
12}
13
14struct EmbeddingRequest {
15    texts: Vec<String>,
16    response_tx: oneshot::Sender<Result<Vec<Vec<f32>>, String>>,
17}
18
19/// The llama.cpp embedding client.
20///
21/// `EmbeddingClient` loads a GGUF embedding model on a dedicated worker thread
22/// and exposes it through Rig's [`rig::embeddings::EmbeddingModel`] trait.
23/// Create one with [`EmbeddingClient::from_gguf`].
24///
25/// ```rust,no_run
26/// use rig::embeddings::EmbeddingModel;
27///
28/// # #[tokio::main]
29/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
30/// let client = rig_llama_cpp::EmbeddingClient::from_gguf(
31///     "path/to/embedding-model.gguf",
32///     99,   // n_gpu_layers
33///     8192, // n_ctx
34/// )?;
35/// let model = client.embedding_model("local");
36/// let embedding = model.embed_text("Hello, world!").await?;
37/// println!("dims: {}", embedding.vec.len());
38/// # Ok(())
39/// # }
40/// ```
41pub struct EmbeddingClient {
42    request_tx: mpsc::UnboundedSender<EmbeddingCommand>,
43    ndims: usize,
44    worker_handle: Option<thread::JoinHandle<()>>,
45}
46
47impl EmbeddingClient {
48    /// Load a GGUF embedding model and start the embedding worker thread.
49    ///
50    /// # Arguments
51    ///
52    /// * `model_path` — Path to a `.gguf` embedding model file.
53    /// * `n_gpu_layers` — Number of layers to offload to the GPU (`u32::MAX` for all).
54    /// * `n_ctx` — Context window size in tokens.
55    ///
56    /// # Errors
57    ///
58    /// Returns a [`LoadError`] if the backend fails to initialise or the
59    /// model cannot be loaded.
60    pub fn from_gguf(
61        model_path: impl Into<String>,
62        n_gpu_layers: u32,
63        n_ctx: u32,
64    ) -> Result<Self, LoadError> {
65        let model_path = model_path.into();
66        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<EmbeddingCommand>();
67        let (init_tx, init_rx) = std::sync::mpsc::channel::<Result<usize, LoadError>>();
68
69        let worker_handle = thread::spawn(move || {
70            embedding_worker(&model_path, n_gpu_layers, n_ctx, init_tx, &mut request_rx);
71        });
72
73        let ndims = init_rx
74            .recv()
75            .map_err(|_| LoadError::WorkerInitDisconnected)??;
76
77        Ok(Self {
78            request_tx,
79            ndims,
80            worker_handle: Some(worker_handle),
81        })
82    }
83
84    /// Create an embedding model handle from this client.
85    pub fn embedding_model(&self, model: impl Into<String>) -> EmbeddingModelHandle {
86        EmbeddingModelHandle::make(self, model, None)
87    }
88
89    /// Create an embedding model handle with explicit dimensions.
90    pub fn embedding_model_with_ndims(
91        &self,
92        model: impl Into<String>,
93        ndims: usize,
94    ) -> EmbeddingModelHandle {
95        EmbeddingModelHandle::make(self, model, Some(ndims))
96    }
97}
98
99impl Drop for EmbeddingClient {
100    fn drop(&mut self) {
101        let _ = self.request_tx.send(EmbeddingCommand::Shutdown);
102
103        if let Some(worker_handle) = self.worker_handle.take() {
104            let _ = worker_handle.join();
105        }
106    }
107}
108
109/// A handle to a loaded embedding model that implements Rig's [`rig::embeddings::EmbeddingModel`] trait.
110///
111/// Obtained via [`EmbeddingClient::embedding_model`].
112#[derive(Clone)]
113pub struct EmbeddingModelHandle {
114    request_tx: mpsc::UnboundedSender<EmbeddingCommand>,
115    ndims: usize,
116    #[allow(dead_code)]
117    model_id: String,
118}
119
120impl rig::embeddings::EmbeddingModel for EmbeddingModelHandle {
121    const MAX_DOCUMENTS: usize = 256;
122    type Client = EmbeddingClient;
123
124    fn make(client: &EmbeddingClient, model: impl Into<String>, dims: Option<usize>) -> Self {
125        Self {
126            request_tx: client.request_tx.clone(),
127            ndims: dims.unwrap_or(client.ndims),
128            model_id: model.into(),
129        }
130    }
131
132    fn ndims(&self) -> usize {
133        self.ndims
134    }
135
136    async fn embed_texts(
137        &self,
138        texts: impl IntoIterator<Item = String> + Send,
139    ) -> Result<Vec<Embedding>, EmbeddingError> {
140        let texts: Vec<String> = texts.into_iter().collect();
141        let documents = texts.clone();
142
143        let (tx, rx) = oneshot::channel();
144        self.request_tx
145            .send(EmbeddingCommand::Request(EmbeddingRequest {
146                texts,
147                response_tx: tx,
148            }))
149            .map_err(|_| EmbeddingError::ProviderError("Embedding worker shut down".into()))?;
150
151        let raw_embeddings = rx
152            .await
153            .map_err(|_| EmbeddingError::ProviderError("Response channel closed".into()))?
154            .map_err(EmbeddingError::ProviderError)?;
155
156        Ok(documents
157            .into_iter()
158            .zip(raw_embeddings)
159            .map(|(doc, vec)| Embedding {
160                document: doc,
161                vec: vec.into_iter().map(|v| v as f64).collect(),
162            })
163            .collect())
164    }
165}
166
167// === Embedding worker (runs on dedicated thread) ===
168
169fn embedding_worker(
170    model_path: &str,
171    n_gpu_layers: u32,
172    n_ctx: u32,
173    init_tx: std::sync::mpsc::Sender<Result<usize, LoadError>>,
174    rx: &mut mpsc::UnboundedReceiver<EmbeddingCommand>,
175) {
176    use llama_cpp_2::list_llama_ggml_backend_devices;
177    use llama_cpp_2::model::LlamaModel as LlamaCppModel;
178    use llama_cpp_2::model::params::LlamaModelParams;
179
180    let backend = match crate::shared_backend() {
181        Ok(b) => b,
182        Err(e) => {
183            let _ = init_tx.send(Err(LoadError::BackendInit(e)));
184            return;
185        }
186    };
187    let mut model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
188
189    if backend.supports_gpu_offload() {
190        let vulkan_devices: Vec<usize> = list_llama_ggml_backend_devices()
191            .into_iter()
192            .filter(|device| device.backend.eq_ignore_ascii_case("vulkan"))
193            .map(|device| device.index)
194            .collect();
195
196        if !vulkan_devices.is_empty() {
197            model_params = match model_params.with_devices(&vulkan_devices) {
198                Ok(params) => {
199                    log::info!("Using Vulkan backend devices: {vulkan_devices:?}");
200                    params
201                }
202                Err(e) => {
203                    let _ = init_tx.send(Err(LoadError::ConfigureDevices(e.to_string())));
204                    return;
205                }
206            };
207        }
208    }
209
210    log::info!("Loading embedding model from {model_path}...");
211
212    let model = match LlamaCppModel::load_from_file(backend, model_path, &model_params) {
213        Ok(m) => m,
214        Err(e) => {
215            let _ = init_tx.send(Err(LoadError::ModelLoad(e.to_string())));
216            return;
217        }
218    };
219
220    let ndims = model.n_embd() as usize;
221    log::info!("Embedding model loaded (ndims={ndims}).");
222
223    let _ = init_tx.send(Ok(ndims));
224
225    while let Some(command) = rx.blocking_recv() {
226        let req = match command {
227            EmbeddingCommand::Request(req) => req,
228            EmbeddingCommand::Shutdown => break,
229        };
230
231        let result = run_embedding(backend, &model, n_ctx, &req.texts);
232        let _ = req.response_tx.send(result);
233    }
234}
235
236fn run_embedding(
237    backend: &llama_cpp_2::llama_backend::LlamaBackend,
238    model: &llama_cpp_2::model::LlamaModel,
239    n_ctx: u32,
240    texts: &[String],
241) -> Result<Vec<Vec<f32>>, String> {
242    use llama_cpp_2::context::params::LlamaContextParams;
243    use llama_cpp_2::llama_batch::LlamaBatch;
244    use llama_cpp_2::model::AddBos;
245
246    // Encoder-only embedding models require `n_ubatch >= n_tokens` for
247    // every single sequence in the batch (llama.cpp asserts on this).
248    // The default `n_ubatch` is 512, which is smaller than a typical
249    // chunk's token count, so we widen both the batch and micro-batch to
250    // match `n_ctx`. This lets any chunk that fits the context window also
251    // fit in one micro-batch step.
252    let ctx_params = LlamaContextParams::default()
253        .with_n_ctx(NonZeroU32::new(n_ctx).map(Some).unwrap_or(None))
254        .with_n_batch(n_ctx)
255        .with_n_ubatch(n_ctx)
256        .with_n_seq_max((texts.len() as u32).max(1))
257        .with_embeddings(true);
258
259    let mut ctx = model
260        .new_context(backend, ctx_params)
261        .map_err(|e| format!("Embedding context creation failed: {e}"))?;
262
263    let batch_limit = ctx.n_batch().max(1) as usize;
264
265    // Tokenize all texts
266    let tokenized: Vec<Vec<_>> = texts
267        .iter()
268        .map(|text| model.str_to_token(text, AddBos::Always))
269        .collect::<Result<_, _>>()
270        .map_err(|e| format!("Tokenization failed: {e}"))?;
271
272    let mut results = Vec::with_capacity(texts.len());
273    let mut text_idx = 0;
274
275    while text_idx < texts.len() {
276        let mut batch = LlamaBatch::new(batch_limit, texts.len().min(batch_limit) as i32);
277        let mut total_tokens = 0;
278        let mut batch_seq_ids = Vec::new();
279        let batch_start = text_idx;
280
281        // Pack as many texts as fit in one batch
282        while text_idx < texts.len() {
283            let tokens = &tokenized[text_idx];
284            if total_tokens + tokens.len() > batch_limit && !batch_seq_ids.is_empty() {
285                break;
286            }
287            let seq_id = (text_idx - batch_start) as i32;
288            for (pos, &token) in tokens.iter().enumerate() {
289                batch
290                    .add(token, pos as i32, &[seq_id], true)
291                    .map_err(|e| format!("Batch add failed: {e}"))?;
292            }
293            batch_seq_ids.push(seq_id);
294            total_tokens += tokens.len();
295            text_idx += 1;
296        }
297
298        ctx.encode(&mut batch)
299            .map_err(|e| format!("Embedding encode failed: {e}"))?;
300
301        for &seq_id in &batch_seq_ids {
302            let emb = ctx
303                .embeddings_seq_ith(seq_id)
304                .map_err(|e| format!("Failed to get embedding for seq {seq_id}: {e}"))?;
305            results.push(emb.to_vec());
306        }
307
308        ctx.clear_kv_cache();
309    }
310
311    Ok(results)
312}