rig_llama_cpp/
embedding.rs1use std::num::NonZeroU32;
2use std::thread;
3
4use rig::embeddings::{Embedding, EmbeddingError, EmbeddingModel as _};
5use tokio::sync::{mpsc, oneshot};
6
7use crate::error::LoadError;
8
9enum EmbeddingCommand {
10 Request(EmbeddingRequest),
11 Shutdown,
12}
13
14struct EmbeddingRequest {
15 texts: Vec<String>,
16 response_tx: oneshot::Sender<Result<Vec<Vec<f32>>, String>>,
17}
18
19pub struct EmbeddingClient {
42 request_tx: mpsc::UnboundedSender<EmbeddingCommand>,
43 ndims: usize,
44 worker_handle: Option<thread::JoinHandle<()>>,
45}
46
47impl EmbeddingClient {
48 pub fn from_gguf(
61 model_path: impl Into<String>,
62 n_gpu_layers: u32,
63 n_ctx: u32,
64 ) -> Result<Self, LoadError> {
65 let model_path = model_path.into();
66 let (request_tx, mut request_rx) = mpsc::unbounded_channel::<EmbeddingCommand>();
67 let (init_tx, init_rx) = std::sync::mpsc::channel::<Result<usize, LoadError>>();
68
69 let worker_handle = thread::spawn(move || {
70 embedding_worker(&model_path, n_gpu_layers, n_ctx, init_tx, &mut request_rx);
71 });
72
73 let ndims = init_rx
74 .recv()
75 .map_err(|_| LoadError::WorkerInitDisconnected)??;
76
77 Ok(Self {
78 request_tx,
79 ndims,
80 worker_handle: Some(worker_handle),
81 })
82 }
83
84 pub fn embedding_model(&self, model: impl Into<String>) -> EmbeddingModelHandle {
86 EmbeddingModelHandle::make(self, model, None)
87 }
88
89 pub fn embedding_model_with_ndims(
91 &self,
92 model: impl Into<String>,
93 ndims: usize,
94 ) -> EmbeddingModelHandle {
95 EmbeddingModelHandle::make(self, model, Some(ndims))
96 }
97}
98
99impl Drop for EmbeddingClient {
100 fn drop(&mut self) {
101 let _ = self.request_tx.send(EmbeddingCommand::Shutdown);
102
103 if let Some(worker_handle) = self.worker_handle.take() {
104 let _ = worker_handle.join();
105 }
106 }
107}
108
109#[derive(Clone)]
113pub struct EmbeddingModelHandle {
114 request_tx: mpsc::UnboundedSender<EmbeddingCommand>,
115 ndims: usize,
116 #[allow(dead_code)]
117 model_id: String,
118}
119
120impl rig::embeddings::EmbeddingModel for EmbeddingModelHandle {
121 const MAX_DOCUMENTS: usize = 256;
122 type Client = EmbeddingClient;
123
124 fn make(client: &EmbeddingClient, model: impl Into<String>, dims: Option<usize>) -> Self {
125 Self {
126 request_tx: client.request_tx.clone(),
127 ndims: dims.unwrap_or(client.ndims),
128 model_id: model.into(),
129 }
130 }
131
132 fn ndims(&self) -> usize {
133 self.ndims
134 }
135
136 async fn embed_texts(
137 &self,
138 texts: impl IntoIterator<Item = String> + Send,
139 ) -> Result<Vec<Embedding>, EmbeddingError> {
140 let texts: Vec<String> = texts.into_iter().collect();
141 let documents = texts.clone();
142
143 let (tx, rx) = oneshot::channel();
144 self.request_tx
145 .send(EmbeddingCommand::Request(EmbeddingRequest {
146 texts,
147 response_tx: tx,
148 }))
149 .map_err(|_| EmbeddingError::ProviderError("Embedding worker shut down".into()))?;
150
151 let raw_embeddings = rx
152 .await
153 .map_err(|_| EmbeddingError::ProviderError("Response channel closed".into()))?
154 .map_err(EmbeddingError::ProviderError)?;
155
156 Ok(documents
157 .into_iter()
158 .zip(raw_embeddings)
159 .map(|(doc, vec)| Embedding {
160 document: doc,
161 vec: vec.into_iter().map(|v| v as f64).collect(),
162 })
163 .collect())
164 }
165}
166
167fn embedding_worker(
170 model_path: &str,
171 n_gpu_layers: u32,
172 n_ctx: u32,
173 init_tx: std::sync::mpsc::Sender<Result<usize, LoadError>>,
174 rx: &mut mpsc::UnboundedReceiver<EmbeddingCommand>,
175) {
176 use llama_cpp_2::list_llama_ggml_backend_devices;
177 use llama_cpp_2::model::LlamaModel as LlamaCppModel;
178 use llama_cpp_2::model::params::LlamaModelParams;
179
180 let backend = match crate::shared_backend() {
181 Ok(b) => b,
182 Err(e) => {
183 let _ = init_tx.send(Err(LoadError::BackendInit(e)));
184 return;
185 }
186 };
187 let mut model_params = LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
188
189 if backend.supports_gpu_offload() {
190 let vulkan_devices: Vec<usize> = list_llama_ggml_backend_devices()
191 .into_iter()
192 .filter(|device| device.backend.eq_ignore_ascii_case("vulkan"))
193 .map(|device| device.index)
194 .collect();
195
196 if !vulkan_devices.is_empty() {
197 model_params = match model_params.with_devices(&vulkan_devices) {
198 Ok(params) => {
199 log::info!("Using Vulkan backend devices: {vulkan_devices:?}");
200 params
201 }
202 Err(e) => {
203 let _ = init_tx.send(Err(LoadError::ConfigureDevices(e.to_string())));
204 return;
205 }
206 };
207 }
208 }
209
210 log::info!("Loading embedding model from {model_path}...");
211
212 let model = match LlamaCppModel::load_from_file(backend, model_path, &model_params) {
213 Ok(m) => m,
214 Err(e) => {
215 let _ = init_tx.send(Err(LoadError::ModelLoad(e.to_string())));
216 return;
217 }
218 };
219
220 let ndims = model.n_embd() as usize;
221 log::info!("Embedding model loaded (ndims={ndims}).");
222
223 let _ = init_tx.send(Ok(ndims));
224
225 while let Some(command) = rx.blocking_recv() {
226 let req = match command {
227 EmbeddingCommand::Request(req) => req,
228 EmbeddingCommand::Shutdown => break,
229 };
230
231 let result = run_embedding(backend, &model, n_ctx, &req.texts);
232 let _ = req.response_tx.send(result);
233 }
234}
235
236fn run_embedding(
237 backend: &llama_cpp_2::llama_backend::LlamaBackend,
238 model: &llama_cpp_2::model::LlamaModel,
239 n_ctx: u32,
240 texts: &[String],
241) -> Result<Vec<Vec<f32>>, String> {
242 use llama_cpp_2::context::params::LlamaContextParams;
243 use llama_cpp_2::llama_batch::LlamaBatch;
244 use llama_cpp_2::model::AddBos;
245
246 let ctx_params = LlamaContextParams::default()
253 .with_n_ctx(NonZeroU32::new(n_ctx).map(Some).unwrap_or(None))
254 .with_n_batch(n_ctx)
255 .with_n_ubatch(n_ctx)
256 .with_n_seq_max((texts.len() as u32).max(1))
257 .with_embeddings(true);
258
259 let mut ctx = model
260 .new_context(backend, ctx_params)
261 .map_err(|e| format!("Embedding context creation failed: {e}"))?;
262
263 let batch_limit = ctx.n_batch().max(1) as usize;
264
265 let tokenized: Vec<Vec<_>> = texts
267 .iter()
268 .map(|text| model.str_to_token(text, AddBos::Always))
269 .collect::<Result<_, _>>()
270 .map_err(|e| format!("Tokenization failed: {e}"))?;
271
272 let mut results = Vec::with_capacity(texts.len());
273 let mut text_idx = 0;
274
275 while text_idx < texts.len() {
276 let mut batch = LlamaBatch::new(batch_limit, texts.len().min(batch_limit) as i32);
277 let mut total_tokens = 0;
278 let mut batch_seq_ids = Vec::new();
279 let batch_start = text_idx;
280
281 while text_idx < texts.len() {
283 let tokens = &tokenized[text_idx];
284 if total_tokens + tokens.len() > batch_limit && !batch_seq_ids.is_empty() {
285 break;
286 }
287 let seq_id = (text_idx - batch_start) as i32;
288 for (pos, &token) in tokens.iter().enumerate() {
289 batch
290 .add(token, pos as i32, &[seq_id], true)
291 .map_err(|e| format!("Batch add failed: {e}"))?;
292 }
293 batch_seq_ids.push(seq_id);
294 total_tokens += tokens.len();
295 text_idx += 1;
296 }
297
298 ctx.encode(&mut batch)
299 .map_err(|e| format!("Embedding encode failed: {e}"))?;
300
301 for &seq_id in &batch_seq_ids {
302 let emb = ctx
303 .embeddings_seq_ith(seq_id)
304 .map_err(|e| format!("Failed to get embedding for seq {seq_id}: {e}"))?;
305 results.push(emb.to_vec());
306 }
307
308 ctx.clear_kv_cache();
309 }
310
311 Ok(results)
312}