Skip to main content

gobby_code/vector/code_symbols/
embedding.rs

1use serde_json::{Value, json};
2use std::collections::{HashMap, hash_map::Entry};
3use std::sync::{Mutex, OnceLock};
4
5use crate::config::{Context, EmbeddingConfig};
6use crate::db;
7use crate::models::Symbol;
8use crate::secrets;
9use gobby_core::ai::{daemon, effective_route};
10use gobby_core::ai_context::{
11    AiConfigSource, AiContext, NoPrimaryAiConfigSource, PostgresAiConfigSource,
12};
13use gobby_core::config::AiCapability;
14
15use super::types::VectorLifecycleError;
16
17const DIMENSION_PROBE_TEXT: &str = "dimension_probe";
18static EMBEDDING_CLIENTS: OnceLock<Mutex<HashMap<u64, reqwest::blocking::Client>>> =
19    OnceLock::new();
20
21pub(super) fn dimension_probe_text() -> &'static str {
22    DIMENSION_PROBE_TEXT
23}
24
25#[derive(Debug, Clone)]
26pub enum EmbeddingSource {
27    Daemon(Box<AiContext>),
28    Direct(EmbeddingConfig),
29}
30
31impl From<EmbeddingConfig> for EmbeddingSource {
32    fn from(config: EmbeddingConfig) -> Self {
33        Self::Direct(config)
34    }
35}
36
37impl From<AiContext> for EmbeddingSource {
38    fn from(context: AiContext) -> Self {
39        Self::Daemon(Box::new(context))
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct EmbeddingBackend {
45    source: EmbeddingSource,
46    direct_client: Option<reqwest::blocking::Client>,
47}
48
49impl EmbeddingBackend {
50    pub fn new(source: EmbeddingSource) -> Result<Self, VectorLifecycleError> {
51        let direct_client = match &source {
52            EmbeddingSource::Direct(config) => {
53                if config.api_base.trim().is_empty() {
54                    return Err(VectorLifecycleError::MissingEmbeddingConfig);
55                }
56                Some(embedding_client(config)?)
57            }
58            EmbeddingSource::Daemon(_) => None,
59        };
60        Ok(Self {
61            source,
62            direct_client,
63        })
64    }
65
66    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>, VectorLifecycleError> {
67        let texts = vec![text.to_string()];
68        let mut embeddings = self.embed_text_batch(&texts)?;
69        embeddings.pop().ok_or_else(|| {
70            VectorLifecycleError::EmbeddingResponse("embedding response was empty".to_string())
71        })
72    }
73
74    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>, VectorLifecycleError> {
75        match &self.source {
76            EmbeddingSource::Direct(config) => {
77                let prefix = config.query_prefix.as_deref().unwrap_or("").trim();
78                let input = if prefix.is_empty() {
79                    text.to_string()
80                } else {
81                    format!("{prefix} {text}")
82                };
83                let client = self.direct_client.as_ref().ok_or_else(|| {
84                    VectorLifecycleError::EmbeddingResponse(
85                        "direct embedding client is not initialized".to_string(),
86                    )
87                })?;
88                embed_text(client, config, &input)
89            }
90            EmbeddingSource::Daemon(context) => {
91                let texts = vec![text.to_string()];
92                let result = daemon::embed_via_daemon(context, &texts, true)
93                    .map_err(|error| VectorLifecycleError::EmbeddingResponse(error.to_string()))?;
94                result.embeddings.into_iter().next().ok_or_else(|| {
95                    VectorLifecycleError::EmbeddingResponse(
96                        "daemon embedding response was empty".to_string(),
97                    )
98                })
99            }
100        }
101    }
102
103    pub fn embed_text_batch(
104        &self,
105        texts: &[String],
106    ) -> Result<Vec<Vec<f32>>, VectorLifecycleError> {
107        match &self.source {
108            EmbeddingSource::Direct(config) => {
109                let client = self.direct_client.as_ref().ok_or_else(|| {
110                    VectorLifecycleError::EmbeddingResponse(
111                        "direct embedding client is not initialized".to_string(),
112                    )
113                })?;
114                embed_text_batch(client, config, texts)
115            }
116            EmbeddingSource::Daemon(context) => daemon::embed_via_daemon(context, texts, false)
117                .map(|result| result.embeddings)
118                .map_err(|error| VectorLifecycleError::EmbeddingResponse(error.to_string())),
119        }
120    }
121}
122
123pub fn embedding_source_from_context(ctx: &Context) -> Option<EmbeddingSource> {
124    let resolved = resolve_embedding_ai_context(ctx);
125    embedding_source_from_resolved_ai_context(resolved.context, resolved.direct_config)
126}
127
128fn embedding_source_from_resolved_ai_context(
129    ai_context: AiContext,
130    direct_config: Option<EmbeddingConfig>,
131) -> Option<EmbeddingSource> {
132    match effective_route(&ai_context, AiCapability::Embed) {
133        gobby_core::config::AiRouting::Off => None,
134        gobby_core::config::AiRouting::Daemon => {
135            Some(EmbeddingSource::Daemon(Box::new(ai_context)))
136        }
137        gobby_core::config::AiRouting::Direct => direct_config.map(EmbeddingSource::Direct),
138        gobby_core::config::AiRouting::Auto => None,
139    }
140}
141
142struct ResolvedEmbeddingAiContext {
143    context: AiContext,
144    direct_config: Option<EmbeddingConfig>,
145}
146
147fn resolve_embedding_ai_context(ctx: &Context) -> ResolvedEmbeddingAiContext {
148    let standalone = crate::config::read_standalone_config_optional();
149    if let Ok(mut conn) = db::connect_readonly(&ctx.database_url) {
150        let primary = PostgresAiConfigSource::new(&mut conn, secrets::resolve_config_value);
151        let mut source = AiConfigSource::with_primary(primary, standalone);
152        let context = AiContext::resolve(Some(ctx.project_id.clone()), &mut source);
153        let direct_config = gobby_core::config::resolve_embedding_config_from_binding(
154            &mut source,
155            context.binding(AiCapability::Embed),
156        );
157        return ResolvedEmbeddingAiContext {
158            context,
159            direct_config,
160        };
161    }
162
163    let mut source = AiConfigSource::with_primary(NoPrimaryAiConfigSource, standalone);
164    let mut context = AiContext::resolve(Some(ctx.project_id.clone()), &mut source);
165    if let Some(embedding) = &ctx.embedding {
166        context.bindings.embed.api_base = Some(embedding.api_base.clone());
167        context.bindings.embed.model = Some(embedding.model.clone());
168        context.bindings.embed.api_key = embedding.api_key.clone();
169    }
170    let direct_config = gobby_core::config::resolve_embedding_config_from_binding(
171        &mut source,
172        context.binding(AiCapability::Embed),
173    )
174    .or_else(|| ctx.embedding.clone());
175    ResolvedEmbeddingAiContext {
176        context,
177        direct_config,
178    }
179}
180
181pub fn embedding_client(
182    config: &EmbeddingConfig,
183) -> Result<reqwest::blocking::Client, VectorLifecycleError> {
184    let mut clients = match EMBEDDING_CLIENTS
185        .get_or_init(|| Mutex::new(HashMap::new()))
186        .lock()
187    {
188        Ok(guard) => guard,
189        Err(poisoned) => poisoned.into_inner(),
190    };
191    // The blocking HTTP client is keyed only by timeout because request-specific
192    // embedding endpoint, model, and auth details are applied per request.
193    match clients.entry(config.timeout_seconds) {
194        Entry::Occupied(entry) => Ok(entry.get().clone()),
195        Entry::Vacant(entry) => {
196            let client = reqwest::blocking::Client::builder()
197                .timeout(std::time::Duration::from_secs(config.timeout_seconds))
198                .build()
199                .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))?;
200            Ok(entry.insert(client).clone())
201        }
202    }
203}
204
205pub fn embed_text(
206    client: &reqwest::blocking::Client,
207    config: &EmbeddingConfig,
208    text: &str,
209) -> Result<Vec<f32>, VectorLifecycleError> {
210    let body = json!({
211        "model": config.model,
212        "input": text,
213    });
214
215    let data = send_embedding_request(client, config, body)?;
216    data.get("data")
217        .and_then(Value::as_array)
218        .and_then(|values| values.first())
219        .ok_or_else(|| {
220            VectorLifecycleError::EmbeddingResponse("missing data[0] object".to_string())
221        })
222        .and_then(parse_embedding)
223}
224
225pub fn probe_embedding_dim(config: &EmbeddingConfig) -> Result<usize, VectorLifecycleError> {
226    let client = embedding_client(config)?;
227    Ok(embed_text(&client, config, dimension_probe_text())?.len())
228}
229
230pub fn embed_text_batch(
231    client: &reqwest::blocking::Client,
232    config: &EmbeddingConfig,
233    texts: &[String],
234) -> Result<Vec<Vec<f32>>, VectorLifecycleError> {
235    if texts.is_empty() {
236        return Ok(Vec::new());
237    }
238    let body = json!({
239        "model": config.model,
240        "input": texts,
241    });
242
243    let data = send_embedding_request(client, config, body)?;
244    let data = data
245        .get("data")
246        .and_then(Value::as_array)
247        .ok_or_else(|| VectorLifecycleError::EmbeddingResponse("missing data array".to_string()))?;
248    if data.len() != texts.len() {
249        return Err(VectorLifecycleError::EmbeddingResponse(format!(
250            "embedding response returned {} vector(s) for {} input(s)",
251            data.len(),
252            texts.len()
253        )));
254    }
255
256    let mut ordered = vec![None; texts.len()];
257    for (position, item) in data.iter().enumerate() {
258        let index = item
259            .get("index")
260            .and_then(Value::as_u64)
261            .and_then(|index| usize::try_from(index).ok())
262            .unwrap_or(position);
263        if index >= texts.len() || ordered[index].is_some() {
264            return Err(VectorLifecycleError::EmbeddingResponse(
265                "embedding response contained an invalid index".to_string(),
266            ));
267        }
268        ordered[index] = Some(parse_embedding(item)?);
269    }
270
271    ordered
272        .into_iter()
273        .map(|embedding| {
274            embedding.ok_or_else(|| {
275                VectorLifecycleError::EmbeddingResponse(
276                    "embedding response omitted an input index".to_string(),
277                )
278            })
279        })
280        .collect()
281}
282
283fn send_embedding_request(
284    client: &reqwest::blocking::Client,
285    config: &EmbeddingConfig,
286    body: Value,
287) -> Result<Value, VectorLifecycleError> {
288    let url = format!("{}/embeddings", config.api_base.trim_end_matches('/'));
289    let mut req = client.post(&url).json(&body);
290
291    if let Some(key) = &config.api_key {
292        req = req.header("Authorization", format!("Bearer {key}"));
293    }
294
295    let resp = req
296        .send()
297        .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))?;
298    if !resp.status().is_success() {
299        let status = resp.status().as_u16();
300        let body = resp.text().unwrap_or_default();
301        return Err(VectorLifecycleError::EmbeddingHttp { status, body });
302    }
303
304    resp.json()
305        .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))
306}
307
308fn parse_embedding(value: &Value) -> Result<Vec<f32>, VectorLifecycleError> {
309    let embedding = value
310        .get("embedding")
311        .and_then(Value::as_array)
312        .ok_or_else(|| {
313            VectorLifecycleError::EmbeddingResponse("missing embedding array".to_string())
314        })?
315        .iter()
316        .map(|value| {
317            let f = value.as_f64().ok_or_else(|| {
318                VectorLifecycleError::EmbeddingResponse(
319                    "embedding array contains a non-number".to_string(),
320                )
321            })?;
322            let converted = f as f32;
323            if !f.is_finite() || converted.is_infinite() {
324                return Err(VectorLifecycleError::EmbeddingResponse(
325                    "embedding contains value outside f32 range".to_string(),
326                ));
327            }
328            Ok(converted)
329        })
330        .collect::<Result<Vec<_>, _>>()?;
331
332    if embedding.is_empty() {
333        return Err(VectorLifecycleError::EmbeddingResponse(
334            "embedding vector was empty".to_string(),
335        ));
336    }
337    Ok(embedding)
338}
339
340pub fn embed_query(config: &EmbeddingConfig, text: &str) -> Option<Vec<f32>> {
341    let prefix = config.query_prefix.as_deref().unwrap_or("").trim();
342    let input = if prefix.is_empty() {
343        text.to_string()
344    } else {
345        format!("{prefix} {text}")
346    };
347    let client = match embedding_client(config) {
348        Ok(client) => client,
349        Err(error) => {
350            eprintln!("gcode: query embedding failed: {error}");
351            return None;
352        }
353    };
354    match embed_text(&client, config, &input) {
355        Ok(embedding) => Some(embedding),
356        Err(error) => {
357            eprintln!("gcode: query embedding failed: {error}");
358            None
359        }
360    }
361}
362
363pub fn embed_query_with_source(source: &EmbeddingSource, text: &str) -> Option<Vec<f32>> {
364    let backend = match EmbeddingBackend::new(source.clone()) {
365        Ok(backend) => backend,
366        Err(error) => {
367            eprintln!("gcode: query embedding failed: {error}");
368            return None;
369        }
370    };
371    match backend.embed_query(text) {
372        Ok(embedding) => Some(embedding),
373        Err(error) => {
374            eprintln!("gcode: query embedding failed: {error}");
375            None
376        }
377    }
378}
379
380pub fn vector_text_for_symbol(symbol: &Symbol) -> String {
381    let mut lines = vec![
382        format!("name: {}", symbol.name),
383        format!("qualified_name: {}", symbol.qualified_name),
384        format!("kind: {}", symbol.kind),
385        format!("language: {}", symbol.language),
386        format!("file_path: {}", symbol.file_path),
387        format!("range: {}-{}", symbol.line_start, symbol.line_end),
388    ];
389    if let Some(signature) = symbol
390        .signature
391        .as_deref()
392        .filter(|value| !value.trim().is_empty())
393    {
394        lines.push(format!("signature: {signature}"));
395    }
396    if let Some(docstring) = symbol
397        .docstring
398        .as_deref()
399        .filter(|value| !value.trim().is_empty())
400    {
401        lines.push(format!("docstring: {docstring}"));
402    }
403    if let Some(summary) = symbol
404        .summary
405        .as_deref()
406        .filter(|value| !value.trim().is_empty())
407    {
408        lines.push(format!("summary: {summary}"));
409    }
410    lines.join("\n")
411}
412
413#[cfg(test)]
414mod tests {
415    use super::{EmbeddingSource, embedding_source_from_resolved_ai_context};
416    use crate::config::EmbeddingConfig;
417    use gobby_core::ai_context::AiContext;
418    use gobby_core::config::{ConfigSource, ai_keys, embedding_keys};
419    use std::collections::HashMap;
420
421    #[derive(Default)]
422    struct TestSource {
423        values: HashMap<&'static str, &'static str>,
424    }
425
426    impl TestSource {
427        fn with_values(values: impl IntoIterator<Item = (&'static str, &'static str)>) -> Self {
428            Self {
429                values: values.into_iter().collect(),
430            }
431        }
432    }
433
434    impl ConfigSource for TestSource {
435        fn config_value(&mut self, key: &str) -> Option<String> {
436            self.values.get(key).map(|value| (*value).to_string())
437        }
438
439        fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
440            match value {
441                "$secret:EMBEDDING_KEY" => Ok("resolved-embedding-key".to_string()),
442                value => Ok(value.to_string()),
443            }
444        }
445    }
446
447    #[test]
448    fn resolves_via_shared_routing() {
449        let mut auto_source = TestSource::with_values([
450            (ai_keys::EMBEDDINGS_ROUTING, "auto"),
451            (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
452            (
453                ai_keys::EMBEDDINGS_API_BASE,
454                "http://embeddings.local:11434/v1",
455            ),
456        ]);
457        let config = crate::config::resolve_embedding_config_from_source(None, &mut auto_source)
458            .expect("auto route with endpoint should use direct embeddings");
459        assert_eq!(config.api_base, "http://embeddings.local:11434/v1");
460
461        let mut daemon_source = TestSource::with_values([
462            (ai_keys::EMBEDDINGS_ROUTING, "daemon"),
463            (
464                ai_keys::EMBEDDINGS_API_BASE,
465                "http://daemon-should-not-be-used:11434/v1",
466            ),
467        ]);
468        assert!(
469            crate::config::resolve_embedding_config_from_source(None, &mut daemon_source).is_none()
470        );
471
472        let mut off_source = TestSource::with_values([
473            (ai_keys::EMBEDDINGS_ROUTING, "off"),
474            (
475                ai_keys::EMBEDDINGS_API_BASE,
476                "http://off-should-not-be-used:11434/v1",
477            ),
478        ]);
479        assert!(
480            crate::config::resolve_embedding_config_from_source(None, &mut off_source).is_none()
481        );
482    }
483
484    #[test]
485    fn reads_endpoint_from_shared_binding() {
486        let mut source = TestSource::with_values([
487            (ai_keys::EMBEDDINGS_ROUTING, "direct"),
488            (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
489            (
490                ai_keys::EMBEDDINGS_API_BASE,
491                "http://shared-binding.local:11434/v1",
492            ),
493            (ai_keys::EMBEDDINGS_MODEL, "shared-embed-model"),
494            (ai_keys::EMBEDDINGS_API_KEY, "$secret:EMBEDDING_KEY"),
495            (embedding_keys::AI_QUERY_PREFIX, "query:"),
496            (embedding_keys::AI_TIMEOUT_SECONDS, "12"),
497        ]);
498
499        let config = crate::config::resolve_embedding_config_from_source(None, &mut source)
500            .expect("embedding config from shared binding");
501
502        assert_eq!(config.api_base, "http://shared-binding.local:11434/v1");
503        assert_eq!(config.model, "shared-embed-model");
504        assert_eq!(config.api_key.as_deref(), Some("resolved-embedding-key"));
505        assert_eq!(config.query_prefix.as_deref(), Some("query:"));
506        assert_eq!(config.timeout_seconds, 12);
507    }
508
509    #[test]
510    fn direct_source_uses_resolved_embedding_config() {
511        let mut source = TestSource::with_values([
512            (ai_keys::EMBEDDINGS_ROUTING, "direct"),
513            (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
514            (ai_keys::EMBEDDINGS_API_BASE, "http://resolved.local/v1"),
515            (ai_keys::EMBEDDINGS_MODEL, "resolved-embed-model"),
516        ]);
517        let context = AiContext::resolve(None, &mut source);
518        let direct_config = EmbeddingConfig {
519            api_base: "http://resolved.local/v1".to_string(),
520            model: "resolved-embed-model".to_string(),
521            api_key: None,
522            query_prefix: None,
523            timeout_seconds: 10,
524        };
525
526        let source =
527            embedding_source_from_resolved_ai_context(context, Some(direct_config.clone()));
528
529        match source {
530            Some(EmbeddingSource::Direct(config)) => assert_eq!(config, direct_config),
531            other => panic!("expected direct embedding source, got {other:?}"),
532        }
533    }
534}