inference_runtime_gemini/
config.rs1use serde::{Deserialize, Serialize};
2use url::Url;
3
4use inference_core::deployment::{RateLimits, RetryPolicy, Timeouts};
5use inference_core::runtime::CircuitBreakerConfig;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct GeminiConfig {
9 #[serde(flatten)]
10 pub variant: GeminiVariant,
11 pub credential: SecretRef,
15 #[serde(default)]
16 pub safety: Vec<SafetySetting>,
17 #[serde(default)]
18 pub rate_limits: RateLimits,
19 #[serde(default)]
20 pub retry: RetryPolicy,
21 #[serde(default)]
22 pub circuit_breaker: CircuitBreakerConfig,
23 #[serde(default)]
24 pub timeouts: Timeouts,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "variant", rename_all = "snake_case")]
29pub enum GeminiVariant {
30 AiStudio {
31 #[serde(default = "default_aistudio_endpoint")]
32 endpoint: Url,
33 },
34 Vertex {
35 project: String,
36 region: String,
37 },
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct SafetySetting {
42 pub category: String,
43 pub threshold: String,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47#[serde(tag = "from", rename_all = "snake_case")]
48pub enum SecretRef {
49 Env {
50 name: String,
51 },
52 File {
53 path: std::path::PathBuf,
54 },
55 Inline {
56 value: String,
57 },
58 Adc,
61}
62
63fn default_aistudio_endpoint() -> Url {
64 Url::parse("https://generativelanguage.googleapis.com/v1beta/").expect("static url")
65}
66
67impl GeminiConfig {
68 pub fn generate_content_url(&self, model: &str, stream: bool) -> Result<Url, url::ParseError> {
69 let suffix = if stream {
70 ":streamGenerateContent?alt=sse"
71 } else {
72 ":generateContent"
73 };
74 match &self.variant {
75 GeminiVariant::AiStudio { endpoint } => {
76 endpoint.join(&format!("models/{model}{suffix}"))
77 }
78 GeminiVariant::Vertex { project, region } => Url::parse(&format!(
79 "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}{suffix}"
80 )),
81 }
82 }
83}