Skip to main content

shift_preflight/policy/
provider.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5/// Constraints for a specific model or provider default.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct ModelConstraints {
8    pub max_images: usize,
9    pub max_image_dim: u32,
10    pub max_image_size_bytes: usize,
11    #[serde(default)]
12    pub max_image_megapixels: Option<f64>,
13    pub supported_formats: Vec<String>,
14}
15
16/// Full provider profile with per-model constraints.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ProviderProfile {
19    pub name: String,
20    #[serde(default)]
21    pub models: HashMap<String, ModelConstraints>,
22    pub default: ModelConstraints,
23}
24
25impl ProviderProfile {
26    /// Get constraints for a specific model, falling back to provider defaults.
27    pub fn constraints_for(&self, model: Option<&str>) -> &ModelConstraints {
28        if let Some(model_name) = model {
29            if let Some(constraints) = self.models.get(model_name) {
30                return constraints;
31            }
32        }
33        &self.default
34    }
35
36    /// Load a provider profile from JSON bytes.
37    pub fn from_json(data: &[u8]) -> Result<Self> {
38        serde_json::from_slice(data).context("failed to parse provider profile JSON")
39    }
40}
41
42// Embedded profiles compiled into the binary
43const OPENAI_PROFILE: &str = include_str!("../../profiles/openai.json");
44const ANTHROPIC_PROFILE: &str = include_str!("../../profiles/anthropic.json");
45
46/// Load a built-in provider profile by name.
47pub fn load_builtin(provider: &str) -> Result<ProviderProfile> {
48    let json = match provider.to_lowercase().as_str() {
49        "openai" => OPENAI_PROFILE,
50        "anthropic" | "claude" => ANTHROPIC_PROFILE,
51        _ => anyhow::bail!(
52            "unknown provider '{}': supported providers are 'openai' and 'anthropic'",
53            provider
54        ),
55    };
56    serde_json::from_str(json).context("failed to parse built-in provider profile")
57}
58
59/// Load a provider profile from an external JSON file.
60pub fn load_from_file(path: &str) -> Result<ProviderProfile> {
61    let data =
62        std::fs::read(path).with_context(|| format!("failed to read profile from {}", path))?;
63    ProviderProfile::from_json(&data)
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn test_load_openai_profile() {
72        let profile = load_builtin("openai").unwrap();
73        assert_eq!(profile.name, "openai");
74        assert_eq!(profile.default.max_images, 10);
75        assert_eq!(profile.default.max_image_dim, 2048);
76        assert!(profile.models.contains_key("gpt-4o"));
77    }
78
79    #[test]
80    fn test_load_anthropic_profile() {
81        let profile = load_builtin("anthropic").unwrap();
82        assert_eq!(profile.name, "anthropic");
83        assert_eq!(profile.default.max_images, 20);
84        assert_eq!(profile.default.max_image_megapixels, Some(1.15));
85    }
86
87    #[test]
88    fn test_load_claude_alias() {
89        let profile = load_builtin("claude").unwrap();
90        assert_eq!(profile.name, "anthropic");
91    }
92
93    #[test]
94    fn test_unknown_provider() {
95        assert!(load_builtin("unknown").is_err());
96    }
97
98    #[test]
99    fn test_constraints_for_specific_model() {
100        let profile = load_builtin("openai").unwrap();
101        let constraints = profile.constraints_for(Some("gpt-4o"));
102        assert_eq!(constraints.max_image_dim, 2048);
103    }
104
105    #[test]
106    fn test_constraints_for_unknown_model_falls_back() {
107        let profile = load_builtin("openai").unwrap();
108        let constraints = profile.constraints_for(Some("gpt-99"));
109        assert_eq!(constraints.max_image_dim, profile.default.max_image_dim);
110    }
111
112    #[test]
113    fn test_constraints_for_none_uses_default() {
114        let profile = load_builtin("openai").unwrap();
115        let constraints = profile.constraints_for(None);
116        assert_eq!(constraints.max_image_dim, profile.default.max_image_dim);
117    }
118
119    #[test]
120    fn test_supported_formats() {
121        let profile = load_builtin("openai").unwrap();
122        assert!(profile
123            .default
124            .supported_formats
125            .contains(&"png".to_string()));
126        assert!(profile
127            .default
128            .supported_formats
129            .contains(&"jpeg".to_string()));
130        assert!(profile
131            .default
132            .supported_formats
133            .contains(&"gif".to_string()));
134        assert!(profile
135            .default
136            .supported_formats
137            .contains(&"webp".to_string()));
138    }
139}