Skip to main content

inference_runtime_gemini/
config.rs

1use serde::{Deserialize, Serialize};
2use url::Url;
3
4use inference_core::deployment::{RateLimits, RetryPolicy, Timeouts};
5use inference_core::runtime::CircuitBreakerConfig;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct GeminiConfig {
9    #[serde(flatten)]
10    pub variant: GeminiVariant,
11    /// Auth credential. AI Studio uses an API key (`StaticApiKey`);
12    /// Vertex uses an OAuth2 access token via the operator-supplied
13    /// credential provider.
14    pub credential: SecretRef,
15    #[serde(default)]
16    pub safety: Vec<SafetySetting>,
17    #[serde(default)]
18    pub rate_limits: RateLimits,
19    #[serde(default)]
20    pub retry: RetryPolicy,
21    #[serde(default)]
22    pub circuit_breaker: CircuitBreakerConfig,
23    #[serde(default)]
24    pub timeouts: Timeouts,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "variant", rename_all = "snake_case")]
29pub enum GeminiVariant {
30    AiStudio {
31        #[serde(default = "default_aistudio_endpoint")]
32        endpoint: Url,
33    },
34    Vertex {
35        project: String,
36        region: String,
37    },
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct SafetySetting {
42    pub category: String,
43    pub threshold: String,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(tag = "from", rename_all = "snake_case")]
48pub enum SecretRef {
49    Env {
50        name: String,
51    },
52    File {
53        path: std::path::PathBuf,
54    },
55    Inline {
56        value: String,
57    },
58    /// Vertex uses application default credentials; resolved by the
59    /// operator-supplied `CredentialProvider`.
60    Adc,
61}
62
63fn default_aistudio_endpoint() -> Url {
64    Url::parse("https://generativelanguage.googleapis.com/v1beta/").expect("static url")
65}
66
67impl GeminiConfig {
68    pub fn generate_content_url(&self, model: &str, stream: bool) -> Result<Url, url::ParseError> {
69        let suffix = if stream {
70            ":streamGenerateContent?alt=sse"
71        } else {
72            ":generateContent"
73        };
74        match &self.variant {
75            GeminiVariant::AiStudio { endpoint } => {
76                endpoint.join(&format!("models/{model}{suffix}"))
77            }
78            GeminiVariant::Vertex { project, region } => Url::parse(&format!(
79                "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}{suffix}"
80            )),
81        }
82    }
83}