1use serde_json::{Value, json};
2use std::collections::{HashMap, hash_map::Entry};
3use std::sync::{Mutex, OnceLock};
4
5use crate::config::{Context, EmbeddingConfig};
6use crate::db;
7use crate::models::Symbol;
8use crate::secrets;
9use gobby_core::ai::{daemon, effective_route};
10use gobby_core::ai_context::{
11 AiConfigSource, AiContext, NoPrimaryAiConfigSource, PostgresAiConfigSource,
12};
13use gobby_core::config::AiCapability;
14
15use super::types::VectorLifecycleError;
16
17const DIMENSION_PROBE_TEXT: &str = "dimension_probe";
18static EMBEDDING_CLIENTS: OnceLock<Mutex<HashMap<u64, reqwest::blocking::Client>>> =
19 OnceLock::new();
20
21pub(super) fn dimension_probe_text() -> &'static str {
22 DIMENSION_PROBE_TEXT
23}
24
25#[derive(Debug, Clone)]
26pub enum EmbeddingSource {
27 Daemon(Box<AiContext>),
28 Direct(EmbeddingConfig),
29}
30
31impl From<EmbeddingConfig> for EmbeddingSource {
32 fn from(config: EmbeddingConfig) -> Self {
33 Self::Direct(config)
34 }
35}
36
37impl From<AiContext> for EmbeddingSource {
38 fn from(context: AiContext) -> Self {
39 Self::Daemon(Box::new(context))
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct EmbeddingBackend {
45 source: EmbeddingSource,
46 direct_client: Option<reqwest::blocking::Client>,
47}
48
49impl EmbeddingBackend {
50 pub fn new(source: EmbeddingSource) -> Result<Self, VectorLifecycleError> {
51 let direct_client = match &source {
52 EmbeddingSource::Direct(config) => {
53 if config.api_base.trim().is_empty() {
54 return Err(VectorLifecycleError::MissingEmbeddingConfig);
55 }
56 Some(embedding_client(config)?)
57 }
58 EmbeddingSource::Daemon(_) => None,
59 };
60 Ok(Self {
61 source,
62 direct_client,
63 })
64 }
65
66 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>, VectorLifecycleError> {
67 let texts = vec![text.to_string()];
68 let mut embeddings = self.embed_text_batch(&texts)?;
69 embeddings.pop().ok_or_else(|| {
70 VectorLifecycleError::EmbeddingResponse("embedding response was empty".to_string())
71 })
72 }
73
74 pub fn embed_query(&self, text: &str) -> Result<Vec<f32>, VectorLifecycleError> {
75 match &self.source {
76 EmbeddingSource::Direct(config) => {
77 let prefix = config.query_prefix.as_deref().unwrap_or("").trim();
78 let input = if prefix.is_empty() {
79 text.to_string()
80 } else {
81 format!("{prefix} {text}")
82 };
83 let client = self.direct_client.as_ref().ok_or_else(|| {
84 VectorLifecycleError::EmbeddingResponse(
85 "direct embedding client is not initialized".to_string(),
86 )
87 })?;
88 embed_text(client, config, &input)
89 }
90 EmbeddingSource::Daemon(context) => {
91 let texts = vec![text.to_string()];
92 let result = daemon::embed_via_daemon(context, &texts, true)
93 .map_err(|error| VectorLifecycleError::EmbeddingResponse(error.to_string()))?;
94 result.embeddings.into_iter().next().ok_or_else(|| {
95 VectorLifecycleError::EmbeddingResponse(
96 "daemon embedding response was empty".to_string(),
97 )
98 })
99 }
100 }
101 }
102
103 pub fn embed_text_batch(
104 &self,
105 texts: &[String],
106 ) -> Result<Vec<Vec<f32>>, VectorLifecycleError> {
107 match &self.source {
108 EmbeddingSource::Direct(config) => {
109 let client = self.direct_client.as_ref().ok_or_else(|| {
110 VectorLifecycleError::EmbeddingResponse(
111 "direct embedding client is not initialized".to_string(),
112 )
113 })?;
114 embed_text_batch(client, config, texts)
115 }
116 EmbeddingSource::Daemon(context) => daemon::embed_via_daemon(context, texts, false)
117 .map(|result| result.embeddings)
118 .map_err(|error| VectorLifecycleError::EmbeddingResponse(error.to_string())),
119 }
120 }
121}
122
123pub fn embedding_source_from_context(ctx: &Context) -> Option<EmbeddingSource> {
124 let resolved = resolve_embedding_ai_context(ctx);
125 embedding_source_from_resolved_ai_context(resolved.context, resolved.direct_config)
126}
127
128fn embedding_source_from_resolved_ai_context(
129 ai_context: AiContext,
130 direct_config: Option<EmbeddingConfig>,
131) -> Option<EmbeddingSource> {
132 match effective_route(&ai_context, AiCapability::Embed) {
133 gobby_core::config::AiRouting::Off => None,
134 gobby_core::config::AiRouting::Daemon => {
135 Some(EmbeddingSource::Daemon(Box::new(ai_context)))
136 }
137 gobby_core::config::AiRouting::Direct => direct_config.map(EmbeddingSource::Direct),
138 gobby_core::config::AiRouting::Auto => None,
139 }
140}
141
142struct ResolvedEmbeddingAiContext {
143 context: AiContext,
144 direct_config: Option<EmbeddingConfig>,
145}
146
147fn resolve_embedding_ai_context(ctx: &Context) -> ResolvedEmbeddingAiContext {
148 let standalone = crate::config::read_standalone_config_optional();
149 if let Ok(mut conn) = db::connect_readonly(&ctx.database_url) {
150 let primary = PostgresAiConfigSource::new(&mut conn, secrets::resolve_config_value);
151 let mut source = AiConfigSource::with_primary(primary, standalone);
152 let context = AiContext::resolve(Some(ctx.project_id.clone()), &mut source);
153 let direct_config = gobby_core::config::resolve_embedding_config_from_binding(
154 &mut source,
155 context.binding(AiCapability::Embed),
156 );
157 return ResolvedEmbeddingAiContext {
158 context,
159 direct_config,
160 };
161 }
162
163 let mut source = AiConfigSource::with_primary(NoPrimaryAiConfigSource, standalone);
164 let mut context = AiContext::resolve(Some(ctx.project_id.clone()), &mut source);
165 if let Some(embedding) = &ctx.embedding {
166 context.bindings.embed.api_base = Some(embedding.api_base.clone());
167 context.bindings.embed.model = Some(embedding.model.clone());
168 context.bindings.embed.api_key = embedding.api_key.clone();
169 }
170 let direct_config = gobby_core::config::resolve_embedding_config_from_binding(
171 &mut source,
172 context.binding(AiCapability::Embed),
173 )
174 .or_else(|| ctx.embedding.clone());
175 ResolvedEmbeddingAiContext {
176 context,
177 direct_config,
178 }
179}
180
181pub fn embedding_client(
182 config: &EmbeddingConfig,
183) -> Result<reqwest::blocking::Client, VectorLifecycleError> {
184 let mut clients = match EMBEDDING_CLIENTS
185 .get_or_init(|| Mutex::new(HashMap::new()))
186 .lock()
187 {
188 Ok(guard) => guard,
189 Err(poisoned) => poisoned.into_inner(),
190 };
191 match clients.entry(config.timeout_seconds) {
194 Entry::Occupied(entry) => Ok(entry.get().clone()),
195 Entry::Vacant(entry) => {
196 let client = reqwest::blocking::Client::builder()
197 .timeout(std::time::Duration::from_secs(config.timeout_seconds))
198 .build()
199 .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))?;
200 Ok(entry.insert(client).clone())
201 }
202 }
203}
204
205pub fn embed_text(
206 client: &reqwest::blocking::Client,
207 config: &EmbeddingConfig,
208 text: &str,
209) -> Result<Vec<f32>, VectorLifecycleError> {
210 let body = json!({
211 "model": config.model,
212 "input": text,
213 });
214
215 let data = send_embedding_request(client, config, body)?;
216 data.get("data")
217 .and_then(Value::as_array)
218 .and_then(|values| values.first())
219 .ok_or_else(|| {
220 VectorLifecycleError::EmbeddingResponse("missing data[0] object".to_string())
221 })
222 .and_then(parse_embedding)
223}
224
225pub fn probe_embedding_dim(config: &EmbeddingConfig) -> Result<usize, VectorLifecycleError> {
226 let client = embedding_client(config)?;
227 Ok(embed_text(&client, config, dimension_probe_text())?.len())
228}
229
230pub fn embed_text_batch(
231 client: &reqwest::blocking::Client,
232 config: &EmbeddingConfig,
233 texts: &[String],
234) -> Result<Vec<Vec<f32>>, VectorLifecycleError> {
235 if texts.is_empty() {
236 return Ok(Vec::new());
237 }
238 let body = json!({
239 "model": config.model,
240 "input": texts,
241 });
242
243 let data = send_embedding_request(client, config, body)?;
244 let data = data
245 .get("data")
246 .and_then(Value::as_array)
247 .ok_or_else(|| VectorLifecycleError::EmbeddingResponse("missing data array".to_string()))?;
248 if data.len() != texts.len() {
249 return Err(VectorLifecycleError::EmbeddingResponse(format!(
250 "embedding response returned {} vector(s) for {} input(s)",
251 data.len(),
252 texts.len()
253 )));
254 }
255
256 let mut ordered = vec![None; texts.len()];
257 for (position, item) in data.iter().enumerate() {
258 let index = item
259 .get("index")
260 .and_then(Value::as_u64)
261 .and_then(|index| usize::try_from(index).ok())
262 .unwrap_or(position);
263 if index >= texts.len() || ordered[index].is_some() {
264 return Err(VectorLifecycleError::EmbeddingResponse(
265 "embedding response contained an invalid index".to_string(),
266 ));
267 }
268 ordered[index] = Some(parse_embedding(item)?);
269 }
270
271 ordered
272 .into_iter()
273 .map(|embedding| {
274 embedding.ok_or_else(|| {
275 VectorLifecycleError::EmbeddingResponse(
276 "embedding response omitted an input index".to_string(),
277 )
278 })
279 })
280 .collect()
281}
282
283fn send_embedding_request(
284 client: &reqwest::blocking::Client,
285 config: &EmbeddingConfig,
286 body: Value,
287) -> Result<Value, VectorLifecycleError> {
288 let url = format!("{}/embeddings", config.api_base.trim_end_matches('/'));
289 let mut req = client.post(&url).json(&body);
290
291 if let Some(key) = &config.api_key {
292 req = req.header("Authorization", format!("Bearer {key}"));
293 }
294
295 let resp = req
296 .send()
297 .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))?;
298 if !resp.status().is_success() {
299 let status = resp.status().as_u16();
300 let body = resp.text().unwrap_or_default();
301 return Err(VectorLifecycleError::EmbeddingHttp { status, body });
302 }
303
304 resp.json()
305 .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))
306}
307
308fn parse_embedding(value: &Value) -> Result<Vec<f32>, VectorLifecycleError> {
309 let embedding = value
310 .get("embedding")
311 .and_then(Value::as_array)
312 .ok_or_else(|| {
313 VectorLifecycleError::EmbeddingResponse("missing embedding array".to_string())
314 })?
315 .iter()
316 .map(|value| {
317 let f = value.as_f64().ok_or_else(|| {
318 VectorLifecycleError::EmbeddingResponse(
319 "embedding array contains a non-number".to_string(),
320 )
321 })?;
322 let converted = f as f32;
323 if !f.is_finite() || converted.is_infinite() {
324 return Err(VectorLifecycleError::EmbeddingResponse(
325 "embedding contains value outside f32 range".to_string(),
326 ));
327 }
328 Ok(converted)
329 })
330 .collect::<Result<Vec<_>, _>>()?;
331
332 if embedding.is_empty() {
333 return Err(VectorLifecycleError::EmbeddingResponse(
334 "embedding vector was empty".to_string(),
335 ));
336 }
337 Ok(embedding)
338}
339
340pub fn embed_query(config: &EmbeddingConfig, text: &str) -> Option<Vec<f32>> {
341 let prefix = config.query_prefix.as_deref().unwrap_or("").trim();
342 let input = if prefix.is_empty() {
343 text.to_string()
344 } else {
345 format!("{prefix} {text}")
346 };
347 let client = match embedding_client(config) {
348 Ok(client) => client,
349 Err(error) => {
350 eprintln!("gcode: query embedding failed: {error}");
351 return None;
352 }
353 };
354 match embed_text(&client, config, &input) {
355 Ok(embedding) => Some(embedding),
356 Err(error) => {
357 eprintln!("gcode: query embedding failed: {error}");
358 None
359 }
360 }
361}
362
363pub fn embed_query_with_source(source: &EmbeddingSource, text: &str) -> Option<Vec<f32>> {
364 let backend = match EmbeddingBackend::new(source.clone()) {
365 Ok(backend) => backend,
366 Err(error) => {
367 eprintln!("gcode: query embedding failed: {error}");
368 return None;
369 }
370 };
371 match backend.embed_query(text) {
372 Ok(embedding) => Some(embedding),
373 Err(error) => {
374 eprintln!("gcode: query embedding failed: {error}");
375 None
376 }
377 }
378}
379
380pub fn vector_text_for_symbol(symbol: &Symbol) -> String {
381 let mut lines = vec![
382 format!("name: {}", symbol.name),
383 format!("qualified_name: {}", symbol.qualified_name),
384 format!("kind: {}", symbol.kind),
385 format!("language: {}", symbol.language),
386 format!("file_path: {}", symbol.file_path),
387 format!("range: {}-{}", symbol.line_start, symbol.line_end),
388 ];
389 if let Some(signature) = symbol
390 .signature
391 .as_deref()
392 .filter(|value| !value.trim().is_empty())
393 {
394 lines.push(format!("signature: {signature}"));
395 }
396 if let Some(docstring) = symbol
397 .docstring
398 .as_deref()
399 .filter(|value| !value.trim().is_empty())
400 {
401 lines.push(format!("docstring: {docstring}"));
402 }
403 if let Some(summary) = symbol
404 .summary
405 .as_deref()
406 .filter(|value| !value.trim().is_empty())
407 {
408 lines.push(format!("summary: {summary}"));
409 }
410 lines.join("\n")
411}
412
413#[cfg(test)]
414mod tests {
415 use super::{EmbeddingSource, embedding_source_from_resolved_ai_context};
416 use crate::config::EmbeddingConfig;
417 use gobby_core::ai_context::AiContext;
418 use gobby_core::config::{ConfigSource, ai_keys, embedding_keys};
419 use std::collections::HashMap;
420
421 #[derive(Default)]
422 struct TestSource {
423 values: HashMap<&'static str, &'static str>,
424 }
425
426 impl TestSource {
427 fn with_values(values: impl IntoIterator<Item = (&'static str, &'static str)>) -> Self {
428 Self {
429 values: values.into_iter().collect(),
430 }
431 }
432 }
433
434 impl ConfigSource for TestSource {
435 fn config_value(&mut self, key: &str) -> Option<String> {
436 self.values.get(key).map(|value| (*value).to_string())
437 }
438
439 fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
440 match value {
441 "$secret:EMBEDDING_KEY" => Ok("resolved-embedding-key".to_string()),
442 value => Ok(value.to_string()),
443 }
444 }
445 }
446
447 #[test]
448 fn resolves_via_shared_routing() {
449 let mut auto_source = TestSource::with_values([
450 (ai_keys::EMBEDDINGS_ROUTING, "auto"),
451 (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
452 (
453 ai_keys::EMBEDDINGS_API_BASE,
454 "http://embeddings.local:11434/v1",
455 ),
456 ]);
457 let config = crate::config::resolve_embedding_config_from_source(None, &mut auto_source)
458 .expect("auto route with endpoint should use direct embeddings");
459 assert_eq!(config.api_base, "http://embeddings.local:11434/v1");
460
461 let mut daemon_source = TestSource::with_values([
462 (ai_keys::EMBEDDINGS_ROUTING, "daemon"),
463 (
464 ai_keys::EMBEDDINGS_API_BASE,
465 "http://daemon-should-not-be-used:11434/v1",
466 ),
467 ]);
468 assert!(
469 crate::config::resolve_embedding_config_from_source(None, &mut daemon_source).is_none()
470 );
471
472 let mut off_source = TestSource::with_values([
473 (ai_keys::EMBEDDINGS_ROUTING, "off"),
474 (
475 ai_keys::EMBEDDINGS_API_BASE,
476 "http://off-should-not-be-used:11434/v1",
477 ),
478 ]);
479 assert!(
480 crate::config::resolve_embedding_config_from_source(None, &mut off_source).is_none()
481 );
482 }
483
484 #[test]
485 fn reads_endpoint_from_shared_binding() {
486 let mut source = TestSource::with_values([
487 (ai_keys::EMBEDDINGS_ROUTING, "direct"),
488 (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
489 (
490 ai_keys::EMBEDDINGS_API_BASE,
491 "http://shared-binding.local:11434/v1",
492 ),
493 (ai_keys::EMBEDDINGS_MODEL, "shared-embed-model"),
494 (ai_keys::EMBEDDINGS_API_KEY, "$secret:EMBEDDING_KEY"),
495 (embedding_keys::AI_QUERY_PREFIX, "query:"),
496 (embedding_keys::AI_TIMEOUT_SECONDS, "12"),
497 ]);
498
499 let config = crate::config::resolve_embedding_config_from_source(None, &mut source)
500 .expect("embedding config from shared binding");
501
502 assert_eq!(config.api_base, "http://shared-binding.local:11434/v1");
503 assert_eq!(config.model, "shared-embed-model");
504 assert_eq!(config.api_key.as_deref(), Some("resolved-embedding-key"));
505 assert_eq!(config.query_prefix.as_deref(), Some("query:"));
506 assert_eq!(config.timeout_seconds, 12);
507 }
508
509 #[test]
510 fn direct_source_uses_resolved_embedding_config() {
511 let mut source = TestSource::with_values([
512 (ai_keys::EMBEDDINGS_ROUTING, "direct"),
513 (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
514 (ai_keys::EMBEDDINGS_API_BASE, "http://resolved.local/v1"),
515 (ai_keys::EMBEDDINGS_MODEL, "resolved-embed-model"),
516 ]);
517 let context = AiContext::resolve(None, &mut source);
518 let direct_config = EmbeddingConfig {
519 api_base: "http://resolved.local/v1".to_string(),
520 model: "resolved-embed-model".to_string(),
521 api_key: None,
522 query_prefix: None,
523 timeout_seconds: 10,
524 };
525
526 let source =
527 embedding_source_from_resolved_ai_context(context, Some(direct_config.clone()));
528
529 match source {
530 Some(EmbeddingSource::Direct(config)) => assert_eq!(config, direct_config),
531 other => panic!("expected direct embedding source, got {other:?}"),
532 }
533 }
534}