Skip to main content

gobby_code/vector/code_symbols/
embedding.rs

1use std::collections::{HashMap, hash_map::Entry};
2use std::sync::{Mutex, OnceLock};
3
4use crate::config::{Context, EmbeddingConfig};
5use crate::db;
6use crate::models::Symbol;
7use crate::secrets;
8use gobby_core::ai::{daemon, effective_route};
9use gobby_core::ai_context::{
10    AiConfigSource, AiContext, NoPrimaryAiConfigSource, PostgresAiConfigSource,
11};
12use gobby_core::ai_types::AiError;
13use gobby_core::config::AiCapability;
14
15use super::types::VectorLifecycleError;
16
17const DIMENSION_PROBE_TEXT: &str = "dimension_probe";
18static EMBEDDING_CLIENTS: OnceLock<Mutex<HashMap<u64, reqwest::blocking::Client>>> =
19    OnceLock::new();
20
21pub(super) fn dimension_probe_text() -> &'static str {
22    DIMENSION_PROBE_TEXT
23}
24
25#[derive(Debug, Clone)]
26pub enum EmbeddingSource {
27    Daemon(Box<AiContext>),
28    Direct(EmbeddingConfig),
29}
30
31impl From<EmbeddingConfig> for EmbeddingSource {
32    fn from(config: EmbeddingConfig) -> Self {
33        Self::Direct(config)
34    }
35}
36
37impl From<AiContext> for EmbeddingSource {
38    fn from(context: AiContext) -> Self {
39        Self::Daemon(Box::new(context))
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct EmbeddingBackend {
45    source: EmbeddingSource,
46    direct_client: Option<reqwest::blocking::Client>,
47}
48
49impl EmbeddingBackend {
50    pub fn new(source: EmbeddingSource) -> Result<Self, VectorLifecycleError> {
51        let direct_client = match &source {
52            EmbeddingSource::Direct(config) => {
53                if config.api_base.trim().is_empty() {
54                    return Err(VectorLifecycleError::MissingEmbeddingConfig);
55                }
56                Some(embedding_client(config)?)
57            }
58            EmbeddingSource::Daemon(_) => None,
59        };
60        Ok(Self {
61            source,
62            direct_client,
63        })
64    }
65
66    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>, VectorLifecycleError> {
67        let texts = vec![text.to_string()];
68        let mut embeddings = self.embed_text_batch(&texts)?;
69        embeddings.pop().ok_or_else(|| {
70            VectorLifecycleError::EmbeddingResponse("embedding response was empty".to_string())
71        })
72    }
73
74    pub fn embed_query(&self, text: &str) -> Result<Vec<f32>, VectorLifecycleError> {
75        match &self.source {
76            EmbeddingSource::Direct(config) => {
77                let prefix = config.query_prefix.as_deref().unwrap_or("").trim();
78                let input = if prefix.is_empty() {
79                    text.to_string()
80                } else {
81                    format!("{prefix} {text}")
82                };
83                let client = self.direct_client.as_ref().ok_or_else(|| {
84                    VectorLifecycleError::EmbeddingResponse(
85                        "direct embedding client is not initialized".to_string(),
86                    )
87                })?;
88                embed_text(client, config, &input)
89            }
90            EmbeddingSource::Daemon(context) => {
91                let texts = vec![text.to_string()];
92                let result = daemon::embed_via_daemon(context, &texts, true)
93                    .map_err(|error| VectorLifecycleError::EmbeddingResponse(error.to_string()))?;
94                result.embeddings.into_iter().next().ok_or_else(|| {
95                    VectorLifecycleError::EmbeddingResponse(
96                        "daemon embedding response was empty".to_string(),
97                    )
98                })
99            }
100        }
101    }
102
103    pub fn embed_text_batch(
104        &self,
105        texts: &[String],
106    ) -> Result<Vec<Vec<f32>>, VectorLifecycleError> {
107        match &self.source {
108            EmbeddingSource::Direct(config) => {
109                let client = self.direct_client.as_ref().ok_or_else(|| {
110                    VectorLifecycleError::EmbeddingResponse(
111                        "direct embedding client is not initialized".to_string(),
112                    )
113                })?;
114                embed_text_batch(client, config, texts)
115            }
116            EmbeddingSource::Daemon(context) => daemon::embed_via_daemon(context, texts, false)
117                .map(|result| result.embeddings)
118                .map_err(|error| VectorLifecycleError::EmbeddingResponse(error.to_string())),
119        }
120    }
121}
122
123pub fn embedding_source_from_context(ctx: &Context) -> Option<EmbeddingSource> {
124    let resolved = resolve_embedding_ai_context(ctx);
125    embedding_source_from_resolved_ai_context(resolved.context, resolved.direct_config)
126}
127
128fn embedding_source_from_resolved_ai_context(
129    ai_context: AiContext,
130    direct_config: Option<EmbeddingConfig>,
131) -> Option<EmbeddingSource> {
132    match effective_route(&ai_context, AiCapability::Embed) {
133        gobby_core::config::AiRouting::Off => None,
134        gobby_core::config::AiRouting::Daemon => {
135            Some(EmbeddingSource::Daemon(Box::new(ai_context)))
136        }
137        gobby_core::config::AiRouting::Direct => direct_config.map(EmbeddingSource::Direct),
138        gobby_core::config::AiRouting::Auto => None,
139    }
140}
141
142struct ResolvedEmbeddingAiContext {
143    context: AiContext,
144    direct_config: Option<EmbeddingConfig>,
145}
146
147fn resolve_embedding_ai_context(ctx: &Context) -> ResolvedEmbeddingAiContext {
148    let standalone = crate::config::read_standalone_config_optional();
149    if let Ok(mut conn) = db::connect_readonly(&ctx.database_url) {
150        let primary = PostgresAiConfigSource::new(&mut conn, secrets::resolve_config_value);
151        let mut source = AiConfigSource::with_primary(primary, standalone);
152        let context = AiContext::resolve(Some(ctx.project_id.clone()), &mut source);
153        let direct_config = gobby_core::config::resolve_embedding_config_from_binding(
154            &mut source,
155            context.binding(AiCapability::Embed),
156        );
157        return ResolvedEmbeddingAiContext {
158            context,
159            direct_config,
160        };
161    }
162
163    let mut source = AiConfigSource::with_primary(NoPrimaryAiConfigSource, standalone);
164    let mut context = AiContext::resolve(Some(ctx.project_id.clone()), &mut source);
165    if let Some(embedding) = &ctx.embedding {
166        context.bindings.embed.api_base = Some(embedding.api_base.clone());
167        context.bindings.embed.model = Some(embedding.model.clone());
168        context.bindings.embed.api_key = embedding.api_key.clone();
169    }
170    let direct_config = gobby_core::config::resolve_embedding_config_from_binding(
171        &mut source,
172        context.binding(AiCapability::Embed),
173    )
174    .or_else(|| ctx.embedding.clone());
175    ResolvedEmbeddingAiContext {
176        context,
177        direct_config,
178    }
179}
180
181pub fn embedding_client(
182    config: &EmbeddingConfig,
183) -> Result<reqwest::blocking::Client, VectorLifecycleError> {
184    let mut clients = match EMBEDDING_CLIENTS
185        .get_or_init(|| Mutex::new(HashMap::new()))
186        .lock()
187    {
188        Ok(guard) => guard,
189        Err(poisoned) => poisoned.into_inner(),
190    };
191    // The blocking HTTP client is keyed only by timeout because request-specific
192    // embedding endpoint, model, and auth details are applied per request.
193    match clients.entry(config.timeout_seconds) {
194        Entry::Occupied(entry) => Ok(entry.get().clone()),
195        Entry::Vacant(entry) => {
196            let client = reqwest::blocking::Client::builder()
197                .timeout(std::time::Duration::from_secs(config.timeout_seconds))
198                .build()
199                .map_err(|err| VectorLifecycleError::EmbeddingResponse(err.to_string()))?;
200            Ok(entry.insert(client).clone())
201        }
202    }
203}
204
205pub fn embed_text(
206    client: &reqwest::blocking::Client,
207    config: &EmbeddingConfig,
208    text: &str,
209) -> Result<Vec<f32>, VectorLifecycleError> {
210    gobby_core::ai::embeddings::embed_one(client, config, text).map_err(embedding_error)
211}
212
213pub fn probe_embedding_dim(config: &EmbeddingConfig) -> Result<usize, VectorLifecycleError> {
214    let client = embedding_client(config)?;
215    Ok(embed_text(&client, config, dimension_probe_text())?.len())
216}
217
218pub fn embed_text_batch(
219    client: &reqwest::blocking::Client,
220    config: &EmbeddingConfig,
221    texts: &[String],
222) -> Result<Vec<Vec<f32>>, VectorLifecycleError> {
223    gobby_core::ai::embeddings::embed_batch(client, config, texts).map_err(embedding_error)
224}
225
226fn embedding_error(error: AiError) -> VectorLifecycleError {
227    match error {
228        AiError::HttpStatus { status, body } => VectorLifecycleError::EmbeddingHttp {
229            status,
230            body: body.unwrap_or_default(),
231        },
232        AiError::RateLimited {
233            status: Some(status),
234            body,
235            ..
236        } => VectorLifecycleError::EmbeddingHttp {
237            status,
238            body: body.unwrap_or_default(),
239        },
240        AiError::TransportFailure {
241            status: Some(status),
242            body: Some(body),
243            ..
244        } => VectorLifecycleError::EmbeddingHttp { status, body },
245        other => VectorLifecycleError::EmbeddingResponse(other.to_string()),
246    }
247}
248
249pub fn embed_query(config: &EmbeddingConfig, text: &str) -> Option<Vec<f32>> {
250    let prefix = config.query_prefix.as_deref().unwrap_or("").trim();
251    let input = if prefix.is_empty() {
252        text.to_string()
253    } else {
254        format!("{prefix} {text}")
255    };
256    let client = match embedding_client(config) {
257        Ok(client) => client,
258        Err(error) => {
259            eprintln!("gcode: query embedding failed: {error}");
260            return None;
261        }
262    };
263    match embed_text(&client, config, &input) {
264        Ok(embedding) => Some(embedding),
265        Err(error) => {
266            eprintln!("gcode: query embedding failed: {error}");
267            None
268        }
269    }
270}
271
272pub fn embed_query_with_source(source: &EmbeddingSource, text: &str) -> Option<Vec<f32>> {
273    let backend = match EmbeddingBackend::new(source.clone()) {
274        Ok(backend) => backend,
275        Err(error) => {
276            eprintln!("gcode: query embedding failed: {error}");
277            return None;
278        }
279    };
280    match backend.embed_query(text) {
281        Ok(embedding) => Some(embedding),
282        Err(error) => {
283            eprintln!("gcode: query embedding failed: {error}");
284            None
285        }
286    }
287}
288
289pub fn vector_text_for_symbol(symbol: &Symbol) -> String {
290    let mut lines = vec![
291        format!("name: {}", symbol.name),
292        format!("qualified_name: {}", symbol.qualified_name),
293        format!("kind: {}", symbol.kind),
294        format!("language: {}", symbol.language),
295        format!("file_path: {}", symbol.file_path),
296        format!("range: {}-{}", symbol.line_start, symbol.line_end),
297    ];
298    if let Some(signature) = symbol
299        .signature
300        .as_deref()
301        .filter(|value| !value.trim().is_empty())
302    {
303        lines.push(format!("signature: {signature}"));
304    }
305    if let Some(docstring) = symbol
306        .docstring
307        .as_deref()
308        .filter(|value| !value.trim().is_empty())
309    {
310        lines.push(format!("docstring: {docstring}"));
311    }
312    if let Some(summary) = symbol
313        .summary
314        .as_deref()
315        .filter(|value| !value.trim().is_empty())
316    {
317        lines.push(format!("summary: {summary}"));
318    }
319    lines.join("\n")
320}
321
322#[cfg(test)]
323mod tests {
324    use super::{EmbeddingSource, embedding_source_from_resolved_ai_context};
325    use crate::config::EmbeddingConfig;
326    use gobby_core::ai_context::AiContext;
327    use gobby_core::config::{ConfigSource, ai_keys, embedding_keys};
328    use std::collections::HashMap;
329
330    #[derive(Default)]
331    struct TestSource {
332        values: HashMap<&'static str, &'static str>,
333    }
334
335    impl TestSource {
336        fn with_values(values: impl IntoIterator<Item = (&'static str, &'static str)>) -> Self {
337            Self {
338                values: values.into_iter().collect(),
339            }
340        }
341    }
342
343    impl ConfigSource for TestSource {
344        fn config_value(&mut self, key: &str) -> Option<String> {
345            self.values.get(key).map(|value| (*value).to_string())
346        }
347
348        fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
349            match value {
350                "$secret:EMBEDDING_KEY" => Ok("resolved-embedding-key".to_string()),
351                value => Ok(value.to_string()),
352            }
353        }
354    }
355
356    #[test]
357    fn resolves_via_shared_routing() {
358        let mut auto_source = TestSource::with_values([
359            (ai_keys::EMBEDDINGS_ROUTING, "auto"),
360            (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
361            (
362                ai_keys::EMBEDDINGS_API_BASE,
363                "http://embeddings.local:11434/v1",
364            ),
365        ]);
366        let config = crate::config::resolve_embedding_config_from_source(None, &mut auto_source)
367            .expect("auto route with endpoint should use direct embeddings");
368        assert_eq!(config.api_base, "http://embeddings.local:11434/v1");
369
370        let mut daemon_source = TestSource::with_values([
371            (ai_keys::EMBEDDINGS_ROUTING, "daemon"),
372            (
373                ai_keys::EMBEDDINGS_API_BASE,
374                "http://daemon-should-not-be-used:11434/v1",
375            ),
376        ]);
377        assert!(
378            crate::config::resolve_embedding_config_from_source(None, &mut daemon_source).is_none()
379        );
380
381        let mut off_source = TestSource::with_values([
382            (ai_keys::EMBEDDINGS_ROUTING, "off"),
383            (
384                ai_keys::EMBEDDINGS_API_BASE,
385                "http://off-should-not-be-used:11434/v1",
386            ),
387        ]);
388        assert!(
389            crate::config::resolve_embedding_config_from_source(None, &mut off_source).is_none()
390        );
391    }
392
393    #[test]
394    fn reads_endpoint_from_shared_binding() {
395        let mut source = TestSource::with_values([
396            (ai_keys::EMBEDDINGS_ROUTING, "direct"),
397            (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
398            (
399                ai_keys::EMBEDDINGS_API_BASE,
400                "http://shared-binding.local:11434/v1",
401            ),
402            (ai_keys::EMBEDDINGS_MODEL, "shared-embed-model"),
403            (ai_keys::EMBEDDINGS_API_KEY, "$secret:EMBEDDING_KEY"),
404            (embedding_keys::AI_QUERY_PREFIX, "query:"),
405            (embedding_keys::AI_TIMEOUT_SECONDS, "12"),
406        ]);
407
408        let config = crate::config::resolve_embedding_config_from_source(None, &mut source)
409            .expect("embedding config from shared binding");
410
411        assert_eq!(config.api_base, "http://shared-binding.local:11434/v1");
412        assert_eq!(config.model, "shared-embed-model");
413        assert_eq!(config.api_key.as_deref(), Some("resolved-embedding-key"));
414        assert_eq!(config.query_prefix.as_deref(), Some("query:"));
415        assert_eq!(config.timeout_seconds, 12);
416    }
417
418    #[test]
419    fn direct_source_uses_resolved_embedding_config() {
420        let mut source = TestSource::with_values([
421            (ai_keys::EMBEDDINGS_ROUTING, "direct"),
422            (ai_keys::EMBEDDINGS_TRANSPORT, "openai_compatible_http"),
423            (ai_keys::EMBEDDINGS_API_BASE, "http://resolved.local/v1"),
424            (ai_keys::EMBEDDINGS_MODEL, "resolved-embed-model"),
425        ]);
426        let context = AiContext::resolve(None, &mut source);
427        let direct_config = EmbeddingConfig {
428            api_base: "http://resolved.local/v1".to_string(),
429            model: "resolved-embed-model".to_string(),
430            api_key: None,
431            query_prefix: None,
432            timeout_seconds: 10,
433        };
434
435        let source =
436            embedding_source_from_resolved_ai_context(context, Some(direct_config.clone()));
437
438        match source {
439            Some(EmbeddingSource::Direct(config)) => assert_eq!(config, direct_config),
440            other => panic!("expected direct embedding source, got {other:?}"),
441        }
442    }
443}