agent_chain_core/language_models/
model_profile.rs1use std::collections::HashMap;
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct ModelProfile {
22 #[serde(skip_serializing_if = "Option::is_none")]
25 pub max_input_tokens: Option<u32>,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub image_inputs: Option<bool>,
30
31 #[serde(skip_serializing_if = "Option::is_none")]
33 pub image_url_inputs: Option<bool>,
34
35 #[serde(skip_serializing_if = "Option::is_none")]
37 pub pdf_inputs: Option<bool>,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub audio_inputs: Option<bool>,
42
43 #[serde(skip_serializing_if = "Option::is_none")]
45 pub video_inputs: Option<bool>,
46
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub image_tool_message: Option<bool>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub pdf_tool_message: Option<bool>,
54
55 #[serde(skip_serializing_if = "Option::is_none")]
58 pub max_output_tokens: Option<u32>,
59
60 #[serde(skip_serializing_if = "Option::is_none")]
62 pub reasoning_output: Option<bool>,
63
64 #[serde(skip_serializing_if = "Option::is_none")]
66 pub image_outputs: Option<bool>,
67
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub audio_outputs: Option<bool>,
71
72 #[serde(skip_serializing_if = "Option::is_none")]
74 pub video_outputs: Option<bool>,
75
76 #[serde(skip_serializing_if = "Option::is_none")]
79 pub tool_calling: Option<bool>,
80
81 #[serde(skip_serializing_if = "Option::is_none")]
83 pub tool_choice: Option<bool>,
84
85 #[serde(skip_serializing_if = "Option::is_none")]
88 pub structured_output: Option<bool>,
89}
90
91impl ModelProfile {
92 pub fn new() -> Self {
94 Self::default()
95 }
96
97 pub fn with_max_input_tokens(mut self, tokens: u32) -> Self {
99 self.max_input_tokens = Some(tokens);
100 self
101 }
102
103 pub fn with_max_output_tokens(mut self, tokens: u32) -> Self {
105 self.max_output_tokens = Some(tokens);
106 self
107 }
108
109 pub fn with_image_inputs(mut self, supported: bool) -> Self {
111 self.image_inputs = Some(supported);
112 self
113 }
114
115 pub fn with_image_url_inputs(mut self, supported: bool) -> Self {
117 self.image_url_inputs = Some(supported);
118 self
119 }
120
121 pub fn with_pdf_inputs(mut self, supported: bool) -> Self {
123 self.pdf_inputs = Some(supported);
124 self
125 }
126
127 pub fn with_audio_inputs(mut self, supported: bool) -> Self {
129 self.audio_inputs = Some(supported);
130 self
131 }
132
133 pub fn with_video_inputs(mut self, supported: bool) -> Self {
135 self.video_inputs = Some(supported);
136 self
137 }
138
139 pub fn with_image_tool_message(mut self, supported: bool) -> Self {
141 self.image_tool_message = Some(supported);
142 self
143 }
144
145 pub fn with_pdf_tool_message(mut self, supported: bool) -> Self {
147 self.pdf_tool_message = Some(supported);
148 self
149 }
150
151 pub fn with_reasoning_output(mut self, supported: bool) -> Self {
153 self.reasoning_output = Some(supported);
154 self
155 }
156
157 pub fn with_image_outputs(mut self, supported: bool) -> Self {
159 self.image_outputs = Some(supported);
160 self
161 }
162
163 pub fn with_audio_outputs(mut self, supported: bool) -> Self {
165 self.audio_outputs = Some(supported);
166 self
167 }
168
169 pub fn with_video_outputs(mut self, supported: bool) -> Self {
171 self.video_outputs = Some(supported);
172 self
173 }
174
175 pub fn with_tool_calling(mut self, supported: bool) -> Self {
177 self.tool_calling = Some(supported);
178 self
179 }
180
181 pub fn with_tool_choice(mut self, supported: bool) -> Self {
183 self.tool_choice = Some(supported);
184 self
185 }
186
187 pub fn with_structured_output(mut self, supported: bool) -> Self {
189 self.structured_output = Some(supported);
190 self
191 }
192
193 pub fn supports_multimodal_inputs(&self) -> bool {
195 self.image_inputs.unwrap_or(false)
196 || self.audio_inputs.unwrap_or(false)
197 || self.video_inputs.unwrap_or(false)
198 || self.pdf_inputs.unwrap_or(false)
199 }
200
201 pub fn supports_multimodal_outputs(&self) -> bool {
203 self.image_outputs.unwrap_or(false)
204 || self.audio_outputs.unwrap_or(false)
205 || self.video_outputs.unwrap_or(false)
206 }
207
208 pub fn supports_tool_calling(&self) -> bool {
210 self.tool_calling.unwrap_or(false)
211 }
212
213 pub fn supports_structured_output(&self) -> bool {
215 self.structured_output.unwrap_or(false)
216 }
217}
218
219pub type ModelProfileRegistry = HashMap<String, ModelProfile>;
221
222#[allow(dead_code)]
224pub fn new_registry() -> ModelProfileRegistry {
225 HashMap::new()
226}
227
228#[allow(dead_code)]
230pub fn default_registry() -> ModelProfileRegistry {
231 let mut registry = HashMap::new();
232
233 registry.insert(
235 "gpt-4-turbo".to_string(),
236 ModelProfile::new()
237 .with_max_input_tokens(128000)
238 .with_max_output_tokens(4096)
239 .with_image_inputs(true)
240 .with_image_url_inputs(true)
241 .with_tool_calling(true)
242 .with_tool_choice(true)
243 .with_structured_output(true),
244 );
245
246 registry.insert(
248 "gpt-4o".to_string(),
249 ModelProfile::new()
250 .with_max_input_tokens(128000)
251 .with_max_output_tokens(16384)
252 .with_image_inputs(true)
253 .with_image_url_inputs(true)
254 .with_audio_inputs(true)
255 .with_audio_outputs(true)
256 .with_tool_calling(true)
257 .with_tool_choice(true)
258 .with_structured_output(true),
259 );
260
261 registry.insert(
263 "claude-3-5-sonnet-20241022".to_string(),
264 ModelProfile::new()
265 .with_max_input_tokens(200000)
266 .with_max_output_tokens(8192)
267 .with_image_inputs(true)
268 .with_pdf_inputs(true)
269 .with_tool_calling(true)
270 .with_tool_choice(true)
271 .with_structured_output(true),
272 );
273
274 registry.insert(
276 "claude-3-opus-20240229".to_string(),
277 ModelProfile::new()
278 .with_max_input_tokens(200000)
279 .with_max_output_tokens(4096)
280 .with_image_inputs(true)
281 .with_tool_calling(true)
282 .with_tool_choice(true),
283 );
284
285 registry
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 #[test]
293 fn test_model_profile_builder() {
294 let profile = ModelProfile::new()
295 .with_max_input_tokens(100000)
296 .with_max_output_tokens(4096)
297 .with_image_inputs(true)
298 .with_tool_calling(true)
299 .with_structured_output(true);
300
301 assert_eq!(profile.max_input_tokens, Some(100000));
302 assert_eq!(profile.max_output_tokens, Some(4096));
303 assert_eq!(profile.image_inputs, Some(true));
304 assert_eq!(profile.tool_calling, Some(true));
305 assert_eq!(profile.structured_output, Some(true));
306 }
307
308 #[test]
309 fn test_supports_multimodal_inputs() {
310 let profile = ModelProfile::new().with_image_inputs(true);
311 assert!(profile.supports_multimodal_inputs());
312
313 let profile = ModelProfile::new().with_audio_inputs(true);
314 assert!(profile.supports_multimodal_inputs());
315
316 let profile = ModelProfile::new();
317 assert!(!profile.supports_multimodal_inputs());
318 }
319
320 #[test]
321 fn test_supports_multimodal_outputs() {
322 let profile = ModelProfile::new().with_image_outputs(true);
323 assert!(profile.supports_multimodal_outputs());
324
325 let profile = ModelProfile::new();
326 assert!(!profile.supports_multimodal_outputs());
327 }
328
329 #[test]
330 fn test_default_registry() {
331 let registry = default_registry();
332
333 assert!(registry.contains_key("gpt-4-turbo"));
334 assert!(registry.contains_key("gpt-4o"));
335 assert!(registry.contains_key("claude-3-5-sonnet-20241022"));
336
337 let gpt4o = registry.get("gpt-4o").unwrap();
338 assert!(gpt4o.supports_multimodal_inputs());
339 assert!(gpt4o.supports_tool_calling());
340 }
341
342 #[test]
343 fn test_model_profile_serialization() {
344 let profile = ModelProfile::new()
345 .with_max_input_tokens(100000)
346 .with_tool_calling(true);
347
348 let json = serde_json::to_string(&profile).unwrap();
349 let deserialized: ModelProfile = serde_json::from_str(&json).unwrap();
350
351 assert_eq!(deserialized.max_input_tokens, Some(100000));
352 assert_eq!(deserialized.tool_calling, Some(true));
353 }
354}