Skip to main content

adk_managed/types/
model_ref.rs

1//! Model reference types for provider-neutral model declaration.
2//!
3//! These types enable a `ManagedAgentDef` to specify which model to use
4//! without being tied to a specific provider's API.
5
6use serde::{Deserialize, Serialize};
7
8/// Provider-neutral model reference.
9///
10/// Supports two forms:
11/// - **Shorthand**: a plain string like `"gemini-2.5-flash"` or `"gpt-4.1"`
12/// - **Structured**: an explicit provider + model config + optional speed hint
13///
14/// # Examples
15///
16/// ```rust
17/// use adk_managed::types::ModelRef;
18///
19/// // Shorthand form
20/// let json = serde_json::json!("gemini-2.5-flash");
21/// let model_ref: ModelRef = serde_json::from_value(json).unwrap();
22///
23/// // Structured form
24/// let json = serde_json::json!({
25///     "provider": "openai",
26///     "model": "gpt-4.1"
27/// });
28/// let model_ref: ModelRef = serde_json::from_value(json).unwrap();
29/// ```
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(untagged)]
32pub enum ModelRef {
33    /// A plain model name string (e.g. `"gemini-2.5-flash"`).
34    Shorthand(String),
35    /// A structured model reference with explicit provider.
36    Structured {
37        /// The LLM provider.
38        provider: Provider,
39        /// The model identifier or compatible configuration.
40        model: ModelConfig,
41        /// Optional speed hint (e.g. `"fast"`, `"balanced"`).
42        #[serde(skip_serializing_if = "Option::is_none")]
43        speed: Option<String>,
44    },
45}
46
47/// Supported LLM providers.
48///
49/// Serializes to/from lowercase snake_case strings.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51#[serde(rename_all = "snake_case")]
52pub enum Provider {
53    /// Google Gemini.
54    Gemini,
55    /// OpenAI (GPT family).
56    Openai,
57    /// Anthropic (Claude family).
58    Anthropic,
59    /// Ollama (local models).
60    Ollama,
61    /// OpenAI-compatible endpoint with custom base URL.
62    OpenaiCompatible,
63}
64
65/// Model configuration — either a simple name or a full compatible endpoint config.
66///
67/// # Examples
68///
69/// ```rust
70/// use adk_managed::types::ModelConfig;
71///
72/// // Simple name
73/// let json = serde_json::json!("gpt-4.1");
74/// let config: ModelConfig = serde_json::from_value(json).unwrap();
75///
76/// // Compatible endpoint
77/// let json = serde_json::json!({
78///     "model": "deepseek-chat",
79///     "base_url": "https://api.deepseek.com/v1",
80///     "api_key": "sk-xxx"
81/// });
82/// let config: ModelConfig = serde_json::from_value(json).unwrap();
83/// ```
84#[derive(Debug, Clone, Serialize, Deserialize)]
85#[serde(untagged)]
86pub enum ModelConfig {
87    /// A simple model name string.
88    Name(String),
89    /// A full compatible endpoint configuration with model, base URL, and API key.
90    Compatible {
91        /// The model identifier.
92        model: String,
93        /// The base URL for the compatible API endpoint.
94        base_url: String,
95        /// The resolved API key (plaintext — platform resolves refs before passing to runtime).
96        api_key: String,
97    },
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_shorthand_parses() {
106        let json = serde_json::json!("gemini-2.5-flash");
107        let model_ref: ModelRef = serde_json::from_value(json).unwrap();
108
109        match model_ref {
110            ModelRef::Shorthand(name) => assert_eq!(name, "gemini-2.5-flash"),
111            _ => panic!("expected Shorthand variant"),
112        }
113    }
114
115    #[test]
116    fn test_shorthand_round_trip() {
117        let original = ModelRef::Shorthand("gpt-4.1".to_string());
118        let json = serde_json::to_value(&original).unwrap();
119        assert_eq!(json, serde_json::json!("gpt-4.1"));
120
121        let deserialized: ModelRef = serde_json::from_value(json).unwrap();
122        match deserialized {
123            ModelRef::Shorthand(name) => assert_eq!(name, "gpt-4.1"),
124            _ => panic!("expected Shorthand variant"),
125        }
126    }
127
128    #[test]
129    fn test_structured_parses() {
130        let json = serde_json::json!({
131            "provider": "openai",
132            "model": "gpt-4.1"
133        });
134        let model_ref: ModelRef = serde_json::from_value(json).unwrap();
135
136        match model_ref {
137            ModelRef::Structured { provider, model, speed } => {
138                assert_eq!(provider, Provider::Openai);
139                match model {
140                    ModelConfig::Name(name) => assert_eq!(name, "gpt-4.1"),
141                    _ => panic!("expected Name variant"),
142                }
143                assert_eq!(speed, None);
144            }
145            _ => panic!("expected Structured variant"),
146        }
147    }
148
149    #[test]
150    fn test_structured_with_speed() {
151        let json = serde_json::json!({
152            "provider": "gemini",
153            "model": "gemini-2.5-flash",
154            "speed": "fast"
155        });
156        let model_ref: ModelRef = serde_json::from_value(json).unwrap();
157
158        match model_ref {
159            ModelRef::Structured { provider, model, speed } => {
160                assert_eq!(provider, Provider::Gemini);
161                match model {
162                    ModelConfig::Name(name) => assert_eq!(name, "gemini-2.5-flash"),
163                    _ => panic!("expected Name variant"),
164                }
165                assert_eq!(speed, Some("fast".to_string()));
166            }
167            _ => panic!("expected Structured variant"),
168        }
169    }
170
171    #[test]
172    fn test_openai_compatible_with_base_url() {
173        let json = serde_json::json!({
174            "provider": "openai_compatible",
175            "model": {
176                "model": "deepseek-chat",
177                "base_url": "https://api.deepseek.com/v1",
178                "api_key": "sk-test-key-123"
179            }
180        });
181        let model_ref: ModelRef = serde_json::from_value(json).unwrap();
182
183        match model_ref {
184            ModelRef::Structured { provider, model, speed } => {
185                assert_eq!(provider, Provider::OpenaiCompatible);
186                match model {
187                    ModelConfig::Compatible { model, base_url, api_key } => {
188                        assert_eq!(model, "deepseek-chat");
189                        assert_eq!(base_url, "https://api.deepseek.com/v1");
190                        assert_eq!(api_key, "sk-test-key-123");
191                    }
192                    _ => panic!("expected Compatible variant"),
193                }
194                assert_eq!(speed, None);
195            }
196            _ => panic!("expected Structured variant"),
197        }
198    }
199
200    #[test]
201    fn test_provider_serialization() {
202        assert_eq!(serde_json::to_value(Provider::Gemini).unwrap(), serde_json::json!("gemini"));
203        assert_eq!(serde_json::to_value(Provider::Openai).unwrap(), serde_json::json!("openai"));
204        assert_eq!(
205            serde_json::to_value(Provider::Anthropic).unwrap(),
206            serde_json::json!("anthropic")
207        );
208        assert_eq!(serde_json::to_value(Provider::Ollama).unwrap(), serde_json::json!("ollama"));
209        assert_eq!(
210            serde_json::to_value(Provider::OpenaiCompatible).unwrap(),
211            serde_json::json!("openai_compatible")
212        );
213    }
214
215    #[test]
216    fn test_model_config_name() {
217        let json = serde_json::json!("claude-3.5-sonnet");
218        let config: ModelConfig = serde_json::from_value(json).unwrap();
219
220        match config {
221            ModelConfig::Name(name) => assert_eq!(name, "claude-3.5-sonnet"),
222            _ => panic!("expected Name variant"),
223        }
224    }
225
226    #[test]
227    fn test_model_config_compatible() {
228        let json = serde_json::json!({
229            "model": "local-llama",
230            "base_url": "http://localhost:11434/v1",
231            "api_key": "ollama"
232        });
233        let config: ModelConfig = serde_json::from_value(json).unwrap();
234
235        match config {
236            ModelConfig::Compatible { model, base_url, api_key } => {
237                assert_eq!(model, "local-llama");
238                assert_eq!(base_url, "http://localhost:11434/v1");
239                assert_eq!(api_key, "ollama");
240            }
241            _ => panic!("expected Compatible variant"),
242        }
243    }
244
245    #[test]
246    fn test_structured_speed_omitted_in_serialization() {
247        let model_ref = ModelRef::Structured {
248            provider: Provider::Anthropic,
249            model: ModelConfig::Name("claude-3.5-sonnet".to_string()),
250            speed: None,
251        };
252        let json = serde_json::to_value(&model_ref).unwrap();
253        assert!(!json.as_object().unwrap().contains_key("speed"));
254    }
255}