Skip to main content

agy_bridge/config/
models.rs

1//! Model configuration types.
2
3use serde::{Deserialize, Serialize};
4
5use super::{DEFAULT_IMAGE_GENERATION_MODEL, DEFAULT_MODEL};
6
7/// Controls the depth of extended thinking for models that support it.
8///
9/// Higher levels allow the model more internal reasoning steps at the cost
10/// of increased latency and token usage.
11#[non_exhaustive]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14#[derive(Default)]
15pub enum ThinkingLevel {
16    Minimal,
17    Low,
18    #[default]
19    Medium,
20    High,
21}
22
23impl ThinkingLevel {
24    /// Returns the lowercase string representation used in serialization.
25    #[must_use]
26    pub const fn as_str(&self) -> &'static str {
27        match self {
28            Self::Minimal => "minimal",
29            Self::Low => "low",
30            Self::Medium => "medium",
31            Self::High => "high",
32        }
33    }
34}
35
36impl std::fmt::Display for ThinkingLevel {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.write_str(self.as_str())
39    }
40}
41
42/// Generation parameters for a model, mirroring the SDK's `GenerationConfig`.
43///
44/// Currently only `thinking_level` is forwarded to the Gemini backend via
45/// the Antigravity SDK. Additional generation parameters (temperature,
46/// `top_p`, etc.) will be added when the SDK exposes them.
47#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48pub struct GenerationConfig {
49    /// Thinking level for models that support extended thinking.
50    /// When `None`, the model's default level is used.
51    #[serde(default)]
52    pub thinking_level: Option<ThinkingLevel>,
53}
54
55/// A single model slot with its name, optional API key, and generation config.
56#[derive(Clone, Serialize, Deserialize)]
57pub struct ModelEntry {
58    /// Model identifier (e.g. `"gemini-3.5-flash"`).
59    pub name: String,
60    /// Per-model API key override.
61    pub api_key: Option<String>,
62    /// Generation parameters for this model.
63    #[serde(default)]
64    pub generation: GenerationConfig,
65}
66
67impl Default for ModelEntry {
68    fn default() -> Self {
69        default_model_entry()
70    }
71}
72
73impl std::fmt::Debug for ModelEntry {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("ModelEntry")
76            .field("name", &self.name)
77            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
78            .field("generation", &self.generation)
79            .finish()
80    }
81}
82
83/// Model selection for each capability, mirroring the SDK's `ModelConfig`.
84///
85/// Each slot holds a full [`ModelEntry`] (with optional per-model API key
86/// and generation config). Bare model name strings are accepted via
87/// `#[serde(deserialize_with)]` coercion on the Python side.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ModelConfig {
90    /// The primary reasoning model.
91    #[serde(default = "default_model_entry")]
92    pub default: ModelEntry,
93    /// The model used for image generation.
94    #[serde(default = "default_image_model_entry")]
95    pub image_generation: ModelEntry,
96}
97
98pub(crate) fn default_model_entry() -> ModelEntry {
99    ModelEntry {
100        name: DEFAULT_MODEL.to_owned(),
101        api_key: None,
102        generation: GenerationConfig::default(),
103    }
104}
105
106pub(crate) fn default_image_model_entry() -> ModelEntry {
107    ModelEntry {
108        name: DEFAULT_IMAGE_GENERATION_MODEL.to_owned(),
109        api_key: None,
110        generation: GenerationConfig::default(),
111    }
112}
113
114impl Default for ModelConfig {
115    fn default() -> Self {
116        Self {
117            default: default_model_entry(),
118            image_generation: default_image_model_entry(),
119        }
120    }
121}
122
123/// Configuration for the Gemini model backend.
124#[derive(Clone, Default, Serialize, Deserialize)]
125pub struct GeminiConfig {
126    /// Shared API key for all models. Falls back to `$GEMINI_API_KEY` env var.
127    /// Individual `ModelEntry` instances can override this.
128    pub api_key: Option<String>,
129    /// Base URL for the Gemini API endpoint.
130    /// When set, overrides the default Gemini API endpoint (e.g., for a local
131    /// proxy, staging environment, or alternative API-compatible gateway).
132    #[serde(default, skip_serializing_if = "Option::is_none")]
133    pub base_url: Option<String>,
134    /// Per-modality model selection and configuration.
135    #[serde(default)]
136    pub models: ModelConfig,
137}
138
139impl std::fmt::Debug for GeminiConfig {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("GeminiConfig")
142            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
143            .field("base_url", &self.base_url)
144            .field("models", &self.models)
145            .finish()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_thinking_level_serde() {
155        let level = ThinkingLevel::Minimal;
156        let json = serde_json::to_string(&level).unwrap();
157        assert_eq!(json, "\"minimal\"");
158        let parsed: ThinkingLevel = serde_json::from_str(&json).unwrap();
159        assert_eq!(parsed, ThinkingLevel::Minimal);
160
161        let level = ThinkingLevel::High;
162        let json = serde_json::to_string(&level).unwrap();
163        assert_eq!(json, "\"high\"");
164
165        assert_eq!(ThinkingLevel::Medium.as_str(), "medium");
166    }
167
168    #[test]
169    fn model_entry_serde_roundtrip() {
170        let entry = ModelEntry {
171            name: "gemini-3.5-flash".to_string(),
172            api_key: Some("mock_test_api_key_123".to_string()),
173            generation: GenerationConfig {
174                thinking_level: Some(ThinkingLevel::High),
175            },
176        };
177        let json = serde_json::to_string(&entry).unwrap();
178        let parsed: ModelEntry = serde_json::from_str(&json).unwrap();
179        assert_eq!(parsed.name, "gemini-3.5-flash");
180        assert_eq!(parsed.api_key.as_deref(), Some("mock_test_api_key_123"));
181        assert_eq!(parsed.generation.thinking_level, Some(ThinkingLevel::High));
182    }
183
184    #[test]
185    fn model_entry_minimal_serde() {
186        let json = r#"{"name":"flash"}"#;
187        let parsed: ModelEntry = serde_json::from_str(json).unwrap();
188        assert_eq!(parsed.name, "flash");
189        assert!(parsed.api_key.is_none());
190        assert!(parsed.generation.thinking_level.is_none());
191    }
192
193    #[test]
194    fn model_config_serde_roundtrip() {
195        let config = ModelConfig {
196            default: ModelEntry {
197                name: "gemini-3.5-flash".to_string(),
198                api_key: None,
199                generation: GenerationConfig::default(),
200            },
201            image_generation: ModelEntry {
202                name: "imagen-3".to_string(),
203                api_key: None,
204                generation: GenerationConfig::default(),
205            },
206        };
207        let json = serde_json::to_string(&config).unwrap();
208        let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
209        assert_eq!(parsed.default.name, "gemini-3.5-flash");
210        assert_eq!(parsed.image_generation.name, "imagen-3");
211    }
212
213    #[test]
214    fn model_config_defaults() {
215        let config = ModelConfig::default();
216        assert_eq!(config.default.name, DEFAULT_MODEL);
217        assert_eq!(config.image_generation.name, DEFAULT_IMAGE_GENERATION_MODEL);
218    }
219
220    #[test]
221    fn gemini_config_serde_roundtrip() {
222        let config = GeminiConfig {
223            api_key: Some("global-key".to_string()),
224            base_url: None,
225            models: ModelConfig {
226                default: ModelEntry {
227                    name: "gemini-3.5-flash".to_string(),
228                    api_key: None,
229                    generation: GenerationConfig::default(),
230                },
231                image_generation: default_image_model_entry(),
232            },
233        };
234        let json = serde_json::to_string(&config).unwrap();
235        let parsed: GeminiConfig = serde_json::from_str(&json).unwrap();
236        assert_eq!(parsed.api_key.as_deref(), Some("global-key"));
237        assert!(parsed.base_url.is_none());
238        assert_eq!(parsed.models.default.name, "gemini-3.5-flash");
239    }
240
241    #[test]
242    fn gemini_config_default() {
243        let config = GeminiConfig::default();
244        assert!(config.api_key.is_none());
245        assert_eq!(config.models.default.name, DEFAULT_MODEL);
246        assert_eq!(
247            config.models.image_generation.name,
248            DEFAULT_IMAGE_GENERATION_MODEL
249        );
250    }
251
252    #[test]
253    fn thinking_level_all_variants_python_str() {
254        assert_eq!(ThinkingLevel::Minimal.as_str(), "minimal");
255        assert_eq!(ThinkingLevel::Low.as_str(), "low");
256        assert_eq!(ThinkingLevel::Medium.as_str(), "medium");
257        assert_eq!(ThinkingLevel::High.as_str(), "high");
258    }
259
260    #[test]
261    fn thinking_level_all_variants_serde() {
262        for (variant, expected) in [
263            (ThinkingLevel::Minimal, "\"minimal\""),
264            (ThinkingLevel::Low, "\"low\""),
265            (ThinkingLevel::Medium, "\"medium\""),
266            (ThinkingLevel::High, "\"high\""),
267        ] {
268            let json = serde_json::to_string(&variant).unwrap();
269            assert_eq!(json, expected);
270            let parsed: ThinkingLevel = serde_json::from_str(&json).unwrap();
271            assert_eq!(parsed, variant);
272        }
273    }
274}