Skip to main content

agy_bridge/config/
models.rs

1//! Model configuration types.
2
3use serde::{Deserialize, Serialize};
4
5use super::{DEFAULT_IMAGE_GENERATION_MODEL, DEFAULT_MODEL};
6
7/// Controls the depth of extended thinking for models that support it.
8///
9/// Higher levels allow the model more internal reasoning steps at the cost
10/// of increased latency and token usage.
11#[non_exhaustive]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13#[serde(rename_all = "lowercase")]
14#[derive(Default)]
15pub enum ThinkingLevel {
16    /// Least reasoning depth; fastest and cheapest.
17    Minimal,
18    /// Below-average reasoning depth.
19    Low,
20    /// Balanced reasoning depth (the default).
21    #[default]
22    Medium,
23    /// Maximum reasoning depth; highest latency and token usage.
24    High,
25}
26
27impl ThinkingLevel {
28    /// Returns the lowercase string representation used in serialization.
29    #[must_use]
30    pub const fn as_str(&self) -> &'static str {
31        match self {
32            Self::Minimal => "minimal",
33            Self::Low => "low",
34            Self::Medium => "medium",
35            Self::High => "high",
36        }
37    }
38}
39
40impl std::fmt::Display for ThinkingLevel {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.write_str(self.as_str())
43    }
44}
45
46/// Generation parameters for a model, mirroring the SDK's `GenerationConfig`.
47///
48/// Currently only `thinking_level` is forwarded to the Gemini backend via
49/// the Antigravity SDK. Additional generation parameters (temperature,
50/// `top_p`, etc.) will be added when the SDK exposes them.
51#[derive(Debug, Clone, Default, Serialize, Deserialize)]
52pub struct GenerationConfig {
53    /// Thinking level for models that support extended thinking.
54    /// When `None`, the model's default level is used.
55    #[serde(default)]
56    pub thinking_level: Option<ThinkingLevel>,
57}
58
59/// A single model slot with its name, optional API key, and generation config.
60#[derive(Clone, Serialize, Deserialize)]
61pub struct ModelEntry {
62    /// Model identifier (e.g. `"gemini-3.5-flash"`).
63    pub name: String,
64    /// Per-model API key override.
65    pub api_key: Option<String>,
66    /// Generation parameters for this model.
67    #[serde(default)]
68    pub generation: GenerationConfig,
69}
70
71impl Default for ModelEntry {
72    fn default() -> Self {
73        default_model_entry()
74    }
75}
76
77impl std::fmt::Debug for ModelEntry {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("ModelEntry")
80            .field("name", &self.name)
81            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
82            .field("generation", &self.generation)
83            .finish()
84    }
85}
86
87/// Model selection for each capability, mirroring the SDK's `ModelConfig`.
88///
89/// Each slot holds a full [`ModelEntry`] (with optional per-model API key
90/// and generation config). Bare model name strings are accepted via
91/// `#[serde(deserialize_with)]` coercion on the Python side.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ModelConfig {
94    /// The primary reasoning model.
95    #[serde(default = "default_model_entry")]
96    pub default: ModelEntry,
97    /// The model used for image generation.
98    #[serde(default = "default_image_model_entry")]
99    pub image_generation: ModelEntry,
100}
101
102pub(crate) fn default_model_entry() -> ModelEntry {
103    ModelEntry {
104        name: DEFAULT_MODEL.to_owned(),
105        api_key: None,
106        generation: GenerationConfig::default(),
107    }
108}
109
110pub(crate) fn default_image_model_entry() -> ModelEntry {
111    ModelEntry {
112        name: DEFAULT_IMAGE_GENERATION_MODEL.to_owned(),
113        api_key: None,
114        generation: GenerationConfig::default(),
115    }
116}
117
118impl Default for ModelConfig {
119    fn default() -> Self {
120        Self {
121            default: default_model_entry(),
122            image_generation: default_image_model_entry(),
123        }
124    }
125}
126
127/// Configuration for the Gemini model backend.
128#[derive(Clone, Default, Serialize, Deserialize)]
129pub struct GeminiConfig {
130    /// Shared API key for all models. Falls back to `$GEMINI_API_KEY` env var.
131    /// Individual `ModelEntry` instances can override this.
132    pub api_key: Option<String>,
133    /// Base URL for the Gemini API endpoint.
134    /// When set, overrides the default Gemini API endpoint (e.g., for a local
135    /// proxy, staging environment, or alternative API-compatible gateway).
136    #[serde(default, skip_serializing_if = "Option::is_none")]
137    pub base_url: Option<String>,
138    /// Per-modality model selection and configuration.
139    #[serde(default)]
140    pub models: ModelConfig,
141}
142
143impl std::fmt::Debug for GeminiConfig {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        f.debug_struct("GeminiConfig")
146            .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]"))
147            .field("base_url", &self.base_url)
148            .field("models", &self.models)
149            .finish()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn test_thinking_level_serde() {
159        let level = ThinkingLevel::Minimal;
160        let json = serde_json::to_string(&level).unwrap();
161        assert_eq!(json, "\"minimal\"");
162        let parsed: ThinkingLevel = serde_json::from_str(&json).unwrap();
163        assert_eq!(parsed, ThinkingLevel::Minimal);
164
165        let level = ThinkingLevel::High;
166        let json = serde_json::to_string(&level).unwrap();
167        assert_eq!(json, "\"high\"");
168
169        assert_eq!(ThinkingLevel::Medium.as_str(), "medium");
170    }
171
172    #[test]
173    fn model_entry_serde_roundtrip() {
174        let entry = ModelEntry {
175            name: "gemini-3.5-flash".to_string(),
176            api_key: Some("mock_test_api_key_123".to_string()),
177            generation: GenerationConfig {
178                thinking_level: Some(ThinkingLevel::High),
179            },
180        };
181        let json = serde_json::to_string(&entry).unwrap();
182        let parsed: ModelEntry = serde_json::from_str(&json).unwrap();
183        assert_eq!(parsed.name, "gemini-3.5-flash");
184        assert_eq!(parsed.api_key.as_deref(), Some("mock_test_api_key_123"));
185        assert_eq!(parsed.generation.thinking_level, Some(ThinkingLevel::High));
186    }
187
188    #[test]
189    fn model_entry_minimal_serde() {
190        let json = r#"{"name":"flash"}"#;
191        let parsed: ModelEntry = serde_json::from_str(json).unwrap();
192        assert_eq!(parsed.name, "flash");
193        assert!(parsed.api_key.is_none());
194        assert!(parsed.generation.thinking_level.is_none());
195    }
196
197    #[test]
198    fn model_config_serde_roundtrip() {
199        let config = ModelConfig {
200            default: ModelEntry {
201                name: "gemini-3.5-flash".to_string(),
202                api_key: None,
203                generation: GenerationConfig::default(),
204            },
205            image_generation: ModelEntry {
206                name: "imagen-3".to_string(),
207                api_key: None,
208                generation: GenerationConfig::default(),
209            },
210        };
211        let json = serde_json::to_string(&config).unwrap();
212        let parsed: ModelConfig = serde_json::from_str(&json).unwrap();
213        assert_eq!(parsed.default.name, "gemini-3.5-flash");
214        assert_eq!(parsed.image_generation.name, "imagen-3");
215    }
216
217    #[test]
218    fn model_config_defaults() {
219        let config = ModelConfig::default();
220        assert_eq!(config.default.name, DEFAULT_MODEL);
221        assert_eq!(config.image_generation.name, DEFAULT_IMAGE_GENERATION_MODEL);
222    }
223
224    #[test]
225    fn gemini_config_serde_roundtrip() {
226        let config = GeminiConfig {
227            api_key: Some("global-key".to_string()),
228            base_url: None,
229            models: ModelConfig {
230                default: ModelEntry {
231                    name: "gemini-3.5-flash".to_string(),
232                    api_key: None,
233                    generation: GenerationConfig::default(),
234                },
235                image_generation: default_image_model_entry(),
236            },
237        };
238        let json = serde_json::to_string(&config).unwrap();
239        let parsed: GeminiConfig = serde_json::from_str(&json).unwrap();
240        assert_eq!(parsed.api_key.as_deref(), Some("global-key"));
241        assert!(parsed.base_url.is_none());
242        assert_eq!(parsed.models.default.name, "gemini-3.5-flash");
243    }
244
245    #[test]
246    fn gemini_config_default() {
247        let config = GeminiConfig::default();
248        assert!(config.api_key.is_none());
249        assert_eq!(config.models.default.name, DEFAULT_MODEL);
250        assert_eq!(
251            config.models.image_generation.name,
252            DEFAULT_IMAGE_GENERATION_MODEL
253        );
254    }
255
256    #[test]
257    fn thinking_level_all_variants_python_str() {
258        assert_eq!(ThinkingLevel::Minimal.as_str(), "minimal");
259        assert_eq!(ThinkingLevel::Low.as_str(), "low");
260        assert_eq!(ThinkingLevel::Medium.as_str(), "medium");
261        assert_eq!(ThinkingLevel::High.as_str(), "high");
262    }
263
264    #[test]
265    fn thinking_level_all_variants_serde() {
266        for (variant, expected) in [
267            (ThinkingLevel::Minimal, "\"minimal\""),
268            (ThinkingLevel::Low, "\"low\""),
269            (ThinkingLevel::Medium, "\"medium\""),
270            (ThinkingLevel::High, "\"high\""),
271        ] {
272            let json = serde_json::to_string(&variant).unwrap();
273            assert_eq!(json, expected);
274            let parsed: ThinkingLevel = serde_json::from_str(&json).unwrap();
275            assert_eq!(parsed, variant);
276        }
277    }
278}