1use std::sync::{Arc, Condvar, Mutex};
8
9use crate::config::{
10 AiCapability, AiRouting, AiTuning, CapabilityBinding, ConfigSource, resolve_ai_tuning,
11 resolve_capability_binding,
12};
13use crate::provisioning::{StandaloneConfig, gcore_config_path};
14
15const ALL_CAPABILITIES: [AiCapability; 5] = [
16 AiCapability::Embed,
17 AiCapability::AudioTranscribe,
18 AiCapability::AudioTranslate,
19 AiCapability::VisionExtract,
20 AiCapability::TextGenerate,
21];
22
23#[derive(Debug, Clone)]
25pub struct AiContext {
26 pub bindings: AiBindings,
27 pub tuning: AiTuning,
28 pub limiter: AiLimiter,
29 pub project_id: Option<String>,
30}
31
32impl AiContext {
33 pub fn resolve(project_id: Option<String>, source: &mut impl ConfigSource) -> Self {
35 Self::resolve_with_options(project_id, source, AiContextOptions::default())
36 }
37
38 pub fn resolve_with_options(
40 project_id: Option<String>,
41 source: &mut impl ConfigSource,
42 options: AiContextOptions,
43 ) -> Self {
44 let mut bindings = AiBindings::resolve(source);
45 let mut tuning = resolve_ai_tuning(source);
46
47 if options.no_ai {
48 bindings.force_routing(AiRouting::Off);
49 } else if let Some(routing) = options.forced_routing {
50 bindings.force_routing(routing);
51 }
52
53 if tuning.max_concurrency == 0 {
54 tuning.max_concurrency = 1;
55 }
56 let limiter = AiLimiter::new(tuning.max_concurrency);
57
58 Self {
59 bindings,
60 tuning,
61 limiter,
62 project_id,
63 }
64 }
65
66 pub fn binding(&self, capability: AiCapability) -> &CapabilityBinding {
67 self.bindings.get(capability)
68 }
69}
70
71#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
73pub struct AiContextOptions {
74 pub no_ai: bool,
75 pub forced_routing: Option<AiRouting>,
76}
77
78#[derive(Debug, Clone, PartialEq, Eq)]
80pub struct AiBindings {
81 pub embed: CapabilityBinding,
82 pub audio_transcribe: CapabilityBinding,
83 pub audio_translate: CapabilityBinding,
84 pub vision_extract: CapabilityBinding,
85 pub text_generate: CapabilityBinding,
86}
87
88impl AiBindings {
89 pub fn resolve(source: &mut impl ConfigSource) -> Self {
90 Self {
91 embed: resolve_capability_binding(source, AiCapability::Embed),
92 audio_transcribe: resolve_capability_binding(source, AiCapability::AudioTranscribe),
93 audio_translate: resolve_capability_binding(source, AiCapability::AudioTranslate),
94 vision_extract: resolve_capability_binding(source, AiCapability::VisionExtract),
95 text_generate: resolve_capability_binding(source, AiCapability::TextGenerate),
96 }
97 }
98
99 pub fn get(&self, capability: AiCapability) -> &CapabilityBinding {
100 match capability {
101 AiCapability::Embed => &self.embed,
102 AiCapability::AudioTranscribe => &self.audio_transcribe,
103 AiCapability::AudioTranslate => &self.audio_translate,
104 AiCapability::VisionExtract => &self.vision_extract,
105 AiCapability::TextGenerate => &self.text_generate,
106 }
107 }
108
109 fn get_mut(&mut self, capability: AiCapability) -> &mut CapabilityBinding {
110 match capability {
111 AiCapability::Embed => &mut self.embed,
112 AiCapability::AudioTranscribe => &mut self.audio_transcribe,
113 AiCapability::AudioTranslate => &mut self.audio_translate,
114 AiCapability::VisionExtract => &mut self.vision_extract,
115 AiCapability::TextGenerate => &mut self.text_generate,
116 }
117 }
118
119 fn force_routing(&mut self, routing: AiRouting) {
120 for capability in ALL_CAPABILITIES {
121 self.get_mut(capability).routing = routing;
122 }
123 }
124}
125
126pub fn route(context: &AiContext, capability: AiCapability) -> AiRouting {
128 context.binding(capability).routing
129}
130
131#[derive(Clone)]
133pub struct AiLimiter {
134 inner: Arc<LimiterInner>,
135}
136
137struct LimiterInner {
138 max: u8,
139 active: Mutex<u8>,
140 available: Condvar,
141}
142
143impl AiLimiter {
144 pub fn new(max_concurrency: u8) -> Self {
145 Self {
146 inner: Arc::new(LimiterInner {
147 max: max_concurrency.max(1),
148 active: Mutex::new(0),
149 available: Condvar::new(),
150 }),
151 }
152 }
153
154 pub fn max_concurrency(&self) -> u8 {
155 self.inner.max
156 }
157
158 pub fn acquire(&self) -> AiPermit {
159 let mut active = self
160 .inner
161 .active
162 .lock()
163 .unwrap_or_else(|poisoned| poisoned.into_inner());
164 while *active >= self.inner.max {
165 active = self
166 .inner
167 .available
168 .wait(active)
169 .unwrap_or_else(|poisoned| poisoned.into_inner());
170 }
171 *active += 1;
172 AiPermit {
173 inner: Arc::clone(&self.inner),
174 }
175 }
176
177 pub fn try_acquire(&self) -> Option<AiPermit> {
178 let mut active = self
179 .inner
180 .active
181 .lock()
182 .unwrap_or_else(|poisoned| poisoned.into_inner());
183 if *active >= self.inner.max {
184 return None;
185 }
186 *active += 1;
187 Some(AiPermit {
188 inner: Arc::clone(&self.inner),
189 })
190 }
191}
192
193impl std::fmt::Debug for AiLimiter {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("AiLimiter")
196 .field("max_concurrency", &self.max_concurrency())
197 .finish_non_exhaustive()
198 }
199}
200
201#[derive(Debug)]
203pub struct AiPermit {
204 inner: Arc<LimiterInner>,
205}
206
207impl Drop for AiPermit {
208 fn drop(&mut self) {
209 let mut active = self
210 .inner
211 .active
212 .lock()
213 .unwrap_or_else(|poisoned| poisoned.into_inner());
214 *active = active.saturating_sub(1);
215 self.inner.available.notify_one();
216 }
217}
218
219impl std::fmt::Debug for LimiterInner {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 f.debug_struct("LimiterInner")
222 .field("max", &self.max)
223 .finish_non_exhaustive()
224 }
225}
226
227#[derive(Debug, Clone)]
232pub struct AiConfigSource<P = NoPrimaryAiConfigSource> {
233 primary: Option<P>,
234 standalone: Option<StandaloneConfig>,
235}
236
237pub type LocalAiConfigSource = AiConfigSource<NoPrimaryAiConfigSource>;
238
239impl LocalAiConfigSource {
240 pub fn from_gobby_home(gobby_home: &std::path::Path) -> anyhow::Result<Self> {
241 Ok(Self::with_primary(
242 NoPrimaryAiConfigSource,
243 StandaloneConfig::read_at(&gcore_config_path(gobby_home))?,
244 ))
245 }
246}
247
248impl<P> AiConfigSource<P>
249where
250 P: ConfigSource,
251{
252 pub fn with_primary(primary: P, standalone: Option<StandaloneConfig>) -> Self {
253 Self {
254 primary: Some(primary),
255 standalone,
256 }
257 }
258
259 pub fn with_primary_from_gobby_home(
260 primary: P,
261 gobby_home: &std::path::Path,
262 ) -> anyhow::Result<Self> {
263 Ok(Self::with_primary(
264 primary,
265 StandaloneConfig::read_at(&gcore_config_path(gobby_home))?,
266 ))
267 }
268}
269
270impl<P> ConfigSource for AiConfigSource<P>
271where
272 P: ConfigSource,
273{
274 fn config_value(&mut self, key: &str) -> Option<String> {
275 self.primary
276 .as_mut()
277 .and_then(|source| source.config_value(key))
278 .or_else(|| {
279 self.standalone
280 .as_mut()
281 .and_then(|standalone| standalone.config_value(key))
282 })
283 }
284
285 fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
286 if value.trim().starts_with("$secret:") {
287 let Some(primary) = self.primary.as_mut() else {
288 anyhow::bail!("secret resolution requires a daemon-backed AI config source");
289 };
290 return primary.resolve_value(value);
291 }
292 match self.standalone.as_mut() {
293 Some(standalone) => standalone.resolve_value(value),
294 None => resolve_non_secret_config_value(value),
295 }
296 }
297}
298
299fn resolve_non_secret_config_value(value: &str) -> anyhow::Result<String> {
300 crate::config::resolve_env_pattern(value)?
301 .ok_or_else(|| anyhow::anyhow!("unresolved pattern: {value}"))
302}
303
304#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
306pub struct NoPrimaryAiConfigSource;
307
308impl ConfigSource for NoPrimaryAiConfigSource {
309 fn config_value(&mut self, _key: &str) -> Option<String> {
310 None
311 }
312
313 fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
314 if value.trim().starts_with("$secret:") {
315 anyhow::bail!("secret resolution requires a daemon-backed AI config source");
316 }
317 resolve_non_secret_config_value(value)
318 }
319}
320
321#[cfg(feature = "postgres")]
323pub struct PostgresAiConfigSource<'a, R> {
324 conn: &'a mut postgres::Client,
325 resolver: R,
326 config_store_available: bool,
327}
328
329#[cfg(feature = "postgres")]
330impl<'a, R> PostgresAiConfigSource<'a, R>
331where
332 R: FnMut(&str, &mut postgres::Client) -> anyhow::Result<String>,
333{
334 pub fn new(conn: &'a mut postgres::Client, resolver: R) -> Self {
335 Self {
336 conn,
337 resolver,
338 config_store_available: true,
339 }
340 }
341
342 pub fn config_store_available(&self) -> bool {
343 self.config_store_available
344 }
345}
346
347#[cfg(feature = "postgres")]
348impl<R> ConfigSource for PostgresAiConfigSource<'_, R>
349where
350 R: FnMut(&str, &mut postgres::Client) -> anyhow::Result<String>,
351{
352 fn config_value(&mut self, key: &str) -> Option<String> {
353 if !self.config_store_available {
354 return None;
355 }
356 match crate::postgres::read_config_value(self.conn, key) {
357 Ok(raw) => raw.and_then(|raw| crate::config::decode_config_value(&raw)),
358 Err(error) if config_store_missing(&error) => {
359 self.config_store_available = false;
360 None
361 }
362 Err(error) => {
363 log::warn!("failed to read AI config key {key:?}: {error}");
364 None
365 }
366 }
367 }
368
369 fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
370 if value.trim().starts_with("$secret:") {
371 return (self.resolver)(value, self.conn);
372 }
373 Ok(value.to_string())
374 }
375}
376
377#[cfg(feature = "postgres")]
378fn config_store_missing(error: &anyhow::Error) -> bool {
379 error.chain().any(|source| {
380 source
381 .downcast_ref::<postgres::Error>()
382 .and_then(postgres::Error::as_db_error)
383 .is_some_and(|db_error| *db_error.code() == postgres::error::SqlState::UNDEFINED_TABLE)
384 })
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::config::{AiCapability, AiRouting, ConfigSource, ai_keys};
391 use crate::provisioning::gcore_config_path;
392 use std::collections::HashMap;
393 use std::fs;
394 use std::path::PathBuf;
395 use std::sync::{Mutex, MutexGuard};
396
397 static CWD_LOCK: Mutex<()> = Mutex::new(());
398
399 struct TestSource {
400 values: HashMap<&'static str, String>,
401 resolved: HashMap<&'static str, String>,
402 }
403
404 impl TestSource {
405 fn with_values(values: impl IntoIterator<Item = (&'static str, &'static str)>) -> Self {
406 Self {
407 values: values
408 .into_iter()
409 .map(|(key, value)| (key, value.to_string()))
410 .collect(),
411 resolved: HashMap::new(),
412 }
413 }
414
415 fn with_resolved(
416 mut self,
417 values: impl IntoIterator<Item = (&'static str, &'static str)>,
418 ) -> Self {
419 self.resolved = values
420 .into_iter()
421 .map(|(key, value)| (key, value.to_string()))
422 .collect();
423 self
424 }
425 }
426
427 impl ConfigSource for TestSource {
428 fn config_value(&mut self, key: &str) -> Option<String> {
429 self.values.get(key).cloned()
430 }
431
432 fn resolve_value(&mut self, value: &str) -> anyhow::Result<String> {
433 self.resolved
434 .get(value)
435 .cloned()
436 .ok_or_else(|| anyhow::anyhow!("unresolved test value: {value}"))
437 }
438 }
439
440 struct CurrentDirGuard {
441 _lock: MutexGuard<'static, ()>,
442 original: PathBuf,
443 }
444
445 impl CurrentDirGuard {
446 fn set(path: &std::path::Path) -> Self {
447 let guard = CWD_LOCK
448 .lock()
449 .unwrap_or_else(|poisoned| poisoned.into_inner());
450 let original = std::env::current_dir().expect("current dir");
451 std::env::set_current_dir(path).expect("set current dir");
452 Self {
453 _lock: guard,
454 original,
455 }
456 }
457 }
458
459 impl Drop for CurrentDirGuard {
460 fn drop(&mut self) {
461 std::env::set_current_dir(&self.original).expect("restore current dir");
462 }
463 }
464
465 fn write_gcore_yaml(home: &std::path::Path, contents: &str) {
466 let path = gcore_config_path(home);
467 fs::create_dir_all(path.parent().expect("gcore config parent")).unwrap();
468 fs::write(path, contents).unwrap();
469 }
470
471 #[test]
472 fn resolves_in_db_and_no_db_modes() {
473 let home = tempfile::tempdir().unwrap();
474 write_gcore_yaml(
475 home.path(),
476 r#"
477ai:
478 embeddings:
479 api_base: http://yaml-embedding
480 model: yaml-embedding-model
481 api_key: yaml-key
482 audio_transcribe:
483 routing: direct
484 max_concurrency: 3
485"#,
486 );
487
488 let mut no_db = LocalAiConfigSource::from_gobby_home(home.path()).unwrap();
489 let no_db_context = AiContext::resolve(Some("yaml-project".to_string()), &mut no_db);
490
491 let no_db_embed = no_db_context.binding(AiCapability::Embed);
492 assert_eq!(
493 no_db_embed.api_base.as_deref(),
494 Some("http://yaml-embedding")
495 );
496 assert_eq!(no_db_embed.model.as_deref(), Some("yaml-embedding-model"));
497 assert_eq!(no_db_embed.api_key.as_deref(), Some("yaml-key"));
498 assert_eq!(
499 route(&no_db_context, AiCapability::AudioTranscribe),
500 AiRouting::Direct
501 );
502 assert_eq!(no_db_context.tuning.max_concurrency, 3);
503 assert_eq!(no_db_context.limiter.max_concurrency(), 3);
504 assert_eq!(no_db_context.project_id.as_deref(), Some("yaml-project"));
505
506 let primary = TestSource::with_values([
507 (ai_keys::EMBEDDINGS_API_BASE, "http://db-embedding"),
508 (ai_keys::EMBEDDINGS_API_KEY, "$secret:db-embedding-key"),
509 (ai_keys::AUDIO_TRANSCRIBE_ROUTING, "daemon"),
510 (ai_keys::MAX_CONCURRENCY, "2"),
511 ])
512 .with_resolved([("$secret:db-embedding-key", "resolved-db-key")]);
513 let mut db = AiConfigSource::with_primary_from_gobby_home(primary, home.path()).unwrap();
514 let db_context = AiContext::resolve(Some("db-project".to_string()), &mut db);
515
516 let db_embed = db_context.binding(AiCapability::Embed);
517 assert_eq!(db_embed.api_base.as_deref(), Some("http://db-embedding"));
518 assert_eq!(db_embed.model.as_deref(), Some("yaml-embedding-model"));
519 assert_eq!(db_embed.api_key.as_deref(), Some("resolved-db-key"));
520 assert_eq!(
521 route(&db_context, AiCapability::AudioTranscribe),
522 AiRouting::Daemon
523 );
524 assert_eq!(db_context.tuning.max_concurrency, 2);
525 }
526
527 #[test]
528 fn project_id_is_caller_supplied() {
529 let home = tempfile::tempdir().unwrap();
530 write_gcore_yaml(home.path(), "ai:\n routing: direct\n");
531 let cwd = tempfile::tempdir().unwrap();
532 fs::create_dir_all(cwd.path().join(".gobby")).unwrap();
533 fs::write(
534 cwd.path().join(".gobby/project.json"),
535 r#"{"id":"stray-cwd-project"}"#,
536 )
537 .unwrap();
538 let _cwd = CurrentDirGuard::set(cwd.path());
539
540 let mut topic_source = LocalAiConfigSource::from_gobby_home(home.path()).unwrap();
541 let topic_context = AiContext::resolve(None, &mut topic_source);
542 assert_eq!(topic_context.project_id, None);
543
544 let mut project_source = LocalAiConfigSource::from_gobby_home(home.path()).unwrap();
545 let project_context =
546 AiContext::resolve(Some("scope-project".to_string()), &mut project_source);
547 assert_eq!(project_context.project_id.as_deref(), Some("scope-project"));
548 }
549
550 #[test]
551 fn db_without_config_store_falls_through() {
552 let home = tempfile::tempdir().unwrap();
553 write_gcore_yaml(
554 home.path(),
555 r#"
556ai:
557 text_generate:
558 routing: direct
559 api_base: http://yaml-text
560"#,
561 );
562 let primary = TestSource::with_values([]);
563 let mut source =
564 AiConfigSource::with_primary_from_gobby_home(primary, home.path()).unwrap();
565
566 let context = AiContext::resolve(None, &mut source);
567
568 assert_eq!(
569 route(&context, AiCapability::TextGenerate),
570 AiRouting::Direct
571 );
572 assert_eq!(
573 context
574 .binding(AiCapability::TextGenerate)
575 .api_base
576 .as_deref(),
577 Some("http://yaml-text")
578 );
579 }
580
581 #[test]
582 fn standalone_values_expand_env_patterns_for_db_fallback() {
583 let home = tempfile::tempdir().unwrap();
584 write_gcore_yaml(
585 home.path(),
586 r#"
587ai:
588 text_generate:
589 routing: direct
590 api_base: ${GOBBY_CONTEXT_TEST_MISSING:-http://expanded-text}
591"#,
592 );
593 let primary = TestSource::with_values([]);
594 let mut source =
595 AiConfigSource::with_primary_from_gobby_home(primary, home.path()).unwrap();
596
597 let context = AiContext::resolve(None, &mut source);
598
599 assert_eq!(
600 context
601 .binding(AiCapability::TextGenerate)
602 .api_base
603 .as_deref(),
604 Some("http://expanded-text")
605 );
606 }
607
608 #[test]
609 fn primary_only_values_expand_env_patterns_without_standalone() {
610 let primary = TestSource::with_values([(
611 ai_keys::TEXT_GENERATE_API_BASE,
612 "${GOBBY_AI_CONTEXT_PRIMARY_FALLBACK_TEST_MISSING:-http://fallback}",
613 )]);
614 let mut source = AiConfigSource::with_primary(primary, None);
615
616 let context = AiContext::resolve(None, &mut source);
617
618 assert_eq!(
619 context
620 .binding(AiCapability::TextGenerate)
621 .api_base
622 .as_deref(),
623 Some("http://fallback")
624 );
625 }
626
627 #[test]
628 fn no_primary_source_expands_env_patterns() {
629 let mut source = NoPrimaryAiConfigSource;
630
631 assert_eq!(
632 source
633 .resolve_value("${GOBBY_AI_CONTEXT_NO_PRIMARY_TEST_MISSING:-http://fallback}")
634 .unwrap(),
635 "http://fallback"
636 );
637 }
638
639 #[test]
640 fn concurrency_cap_enforced() {
641 let limiter = AiLimiter::new(1);
642 let permit = limiter
643 .try_acquire()
644 .expect("first permit should be available");
645
646 assert!(limiter.try_acquire().is_none());
647
648 drop(permit);
649
650 assert!(limiter.try_acquire().is_some());
651 }
652
653 #[test]
654 fn forced_routing_and_no_ai_override() {
655 let source = TestSource::with_values([
656 (ai_keys::AUDIO_TRANSCRIBE_ROUTING, "daemon"),
657 (ai_keys::VISION_EXTRACT_ROUTING, "direct"),
658 ]);
659 let mut source = AiConfigSource::with_primary(source, None);
660 let context = AiContext::resolve(None, &mut source);
661 assert_eq!(
662 route(&context, AiCapability::AudioTranscribe),
663 AiRouting::Daemon
664 );
665 assert_eq!(
666 route(&context, AiCapability::VisionExtract),
667 AiRouting::Direct
668 );
669 assert_eq!(route(&context, AiCapability::Embed), AiRouting::Auto);
670
671 let source = TestSource::with_values([
672 (ai_keys::AUDIO_TRANSCRIBE_ROUTING, "daemon"),
673 (ai_keys::VISION_EXTRACT_ROUTING, "off"),
674 ]);
675 let mut source = AiConfigSource::with_primary(source, None);
676 let forced = AiContext::resolve_with_options(
677 None,
678 &mut source,
679 AiContextOptions {
680 forced_routing: Some(AiRouting::Direct),
681 ..AiContextOptions::default()
682 },
683 );
684 for capability in [
685 AiCapability::Embed,
686 AiCapability::AudioTranscribe,
687 AiCapability::AudioTranslate,
688 AiCapability::VisionExtract,
689 AiCapability::TextGenerate,
690 ] {
691 assert_eq!(route(&forced, capability), AiRouting::Direct);
692 }
693
694 let source = TestSource::with_values([(ai_keys::AUDIO_TRANSCRIBE_ROUTING, "daemon")]);
695 let mut source = AiConfigSource::with_primary(source, None);
696 let disabled = AiContext::resolve_with_options(
697 None,
698 &mut source,
699 AiContextOptions {
700 no_ai: true,
701 forced_routing: Some(AiRouting::Direct),
702 },
703 );
704 for capability in [
705 AiCapability::Embed,
706 AiCapability::AudioTranscribe,
707 AiCapability::AudioTranslate,
708 AiCapability::VisionExtract,
709 AiCapability::TextGenerate,
710 ] {
711 assert_eq!(route(&disabled, capability), AiRouting::Off);
712 }
713 }
714
715 #[test]
716 fn resolve_does_not_discover_local_backend_endpoints() {
717 let source = TestSource::with_values([
718 (ai_keys::EMBEDDINGS_ROUTING, "auto"),
719 (ai_keys::VISION_EXTRACT_ROUTING, "direct"),
720 (ai_keys::TEXT_GENERATE_ROUTING, "direct"),
721 ]);
722 let mut source = AiConfigSource::with_primary(source, None);
723
724 let context = AiContext::resolve(None, &mut source);
725
726 assert_eq!(route(&context, AiCapability::Embed), AiRouting::Auto);
727 assert_eq!(
728 route(&context, AiCapability::VisionExtract),
729 AiRouting::Direct
730 );
731 assert_eq!(
732 route(&context, AiCapability::TextGenerate),
733 AiRouting::Direct
734 );
735 assert_eq!(context.binding(AiCapability::Embed).api_base, None);
736 assert_eq!(context.binding(AiCapability::VisionExtract).api_base, None);
737 assert_eq!(context.binding(AiCapability::TextGenerate).api_base, None);
738 }
739}