Skip to main content

gunmetal_core/
lib.rs

1use std::{borrow::Cow, path::Path};
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6use uuid::Uuid;
7
8pub trait ProviderContext: Send + Sync {
9    fn helpers_dir(&self) -> &Path;
10    fn empty_workspace_dir(&self) -> &Path;
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
15pub enum ProviderKind {
16    Codex,
17    Copilot,
18    OpenRouter,
19    Zen,
20    OpenAi,
21    Azure,
22    Nvidia,
23    Custom(String),
24}
25
26impl ProviderKind {
27    pub fn slug(&self) -> Cow<'_, str> {
28        match self {
29            Self::Codex => Cow::Borrowed("codex"),
30            Self::Copilot => Cow::Borrowed("copilot"),
31            Self::OpenRouter => Cow::Borrowed("openrouter"),
32            Self::Zen => Cow::Borrowed("zen"),
33            Self::OpenAi => Cow::Borrowed("openai"),
34            Self::Azure => Cow::Borrowed("azure"),
35            Self::Nvidia => Cow::Borrowed("nvidia"),
36            Self::Custom(value) => Cow::Borrowed(value.as_str()),
37        }
38    }
39}
40
41impl std::fmt::Display for ProviderKind {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        write!(f, "{}", self.slug())
44    }
45}
46
47impl std::str::FromStr for ProviderKind {
48    type Err = String;
49
50    fn from_str(value: &str) -> Result<Self, Self::Err> {
51        match value {
52            "codex" => Ok(Self::Codex),
53            "copilot" => Ok(Self::Copilot),
54            "openrouter" => Ok(Self::OpenRouter),
55            "zen" => Ok(Self::Zen),
56            "openai" => Ok(Self::OpenAi),
57            "azure" => Ok(Self::Azure),
58            "nvidia" => Ok(Self::Nvidia),
59            value if !value.trim().is_empty() => Ok(Self::Custom(value.to_owned())),
60            _ => Err("provider kind cannot be empty".to_owned()),
61        }
62    }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum KeyScope {
68    Inference,
69    ModelsRead,
70    LogsRead,
71}
72
73impl std::fmt::Display for KeyScope {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        let value = match self {
76            Self::Inference => "inference",
77            Self::ModelsRead => "models_read",
78            Self::LogsRead => "logs_read",
79        };
80
81        write!(f, "{value}")
82    }
83}
84
85impl std::str::FromStr for KeyScope {
86    type Err = String;
87
88    fn from_str(value: &str) -> Result<Self, Self::Err> {
89        match value {
90            "inference" => Ok(Self::Inference),
91            "models_read" => Ok(Self::ModelsRead),
92            "logs_read" => Ok(Self::LogsRead),
93            _ => Err(format!("unknown key scope: {value}")),
94        }
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum KeyState {
101    Active,
102    Disabled,
103    Revoked,
104}
105
106impl std::fmt::Display for KeyState {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        let value = match self {
109            Self::Active => "active",
110            Self::Disabled => "disabled",
111            Self::Revoked => "revoked",
112        };
113
114        write!(f, "{value}")
115    }
116}
117
118impl std::str::FromStr for KeyState {
119    type Err = String;
120
121    fn from_str(value: &str) -> Result<Self, Self::Err> {
122        match value {
123            "active" => Ok(Self::Active),
124            "disabled" => Ok(Self::Disabled),
125            "revoked" => Ok(Self::Revoked),
126            _ => Err(format!("unknown key state: {value}")),
127        }
128    }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
132pub struct GunmetalKey {
133    pub id: Uuid,
134    pub name: String,
135    pub prefix: String,
136    pub state: KeyState,
137    pub scopes: Vec<KeyScope>,
138    pub allowed_providers: Vec<ProviderKind>,
139    pub expires_at: Option<DateTime<Utc>>,
140    pub created_at: DateTime<Utc>,
141    pub updated_at: DateTime<Utc>,
142    pub last_used_at: Option<DateTime<Utc>>,
143}
144
145impl GunmetalKey {
146    pub fn can_access_provider(&self, provider: &ProviderKind) -> bool {
147        self.allowed_providers.is_empty()
148            || self.allowed_providers.iter().any(|item| item == provider)
149    }
150
151    pub fn is_usable_at(&self, now: DateTime<Utc>) -> bool {
152        self.state == KeyState::Active && self.expires_at.is_none_or(|value| value > now)
153    }
154}
155
156#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct NewGunmetalKey {
158    pub name: String,
159    pub scopes: Vec<KeyScope>,
160    pub allowed_providers: Vec<ProviderKind>,
161    pub expires_at: Option<DateTime<Utc>>,
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
165pub struct CreatedGunmetalKey {
166    pub record: GunmetalKey,
167    pub secret: String,
168}
169
170#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
171pub struct ProviderProfile {
172    pub id: Uuid,
173    pub provider: ProviderKind,
174    pub name: String,
175    pub base_url: Option<String>,
176    pub enabled: bool,
177    pub credentials: Option<Value>,
178    pub created_at: DateTime<Utc>,
179    pub updated_at: DateTime<Utc>,
180}
181
182#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct NewProviderProfile {
184    pub provider: ProviderKind,
185    pub name: String,
186    pub base_url: Option<String>,
187    pub enabled: bool,
188    pub credentials: Option<Value>,
189}
190
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct ModelDescriptor {
193    pub id: String,
194    pub provider: ProviderKind,
195    pub profile_id: Option<Uuid>,
196    pub upstream_name: String,
197    pub display_name: String,
198    pub metadata: Option<ModelMetadata>,
199}
200
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
202pub struct ModelMetadata {
203    pub family: Option<String>,
204    pub release_date: Option<String>,
205    pub last_updated: Option<String>,
206    #[serde(default)]
207    pub input_modalities: Vec<String>,
208    #[serde(default)]
209    pub output_modalities: Vec<String>,
210    pub context_window: Option<u32>,
211    pub max_output_tokens: Option<u32>,
212    pub supports_attachments: Option<bool>,
213    pub supports_reasoning: Option<bool>,
214    pub supports_tools: Option<bool>,
215    pub open_weights: Option<bool>,
216}
217
218#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
219#[serde(rename_all = "snake_case")]
220pub enum ChatRole {
221    System,
222    User,
223    Assistant,
224}
225
226impl std::fmt::Display for ChatRole {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        let value = match self {
229            Self::System => "system",
230            Self::User => "user",
231            Self::Assistant => "assistant",
232        };
233        write!(f, "{value}")
234    }
235}
236
237impl std::str::FromStr for ChatRole {
238    type Err = String;
239
240    fn from_str(value: &str) -> Result<Self, Self::Err> {
241        match value {
242            "system" => Ok(Self::System),
243            "user" => Ok(Self::User),
244            "assistant" => Ok(Self::Assistant),
245            _ => Err(format!("unknown chat role: {value}")),
246        }
247    }
248}
249
250#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251pub struct ChatMessage {
252    pub role: ChatRole,
253    pub content: String,
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
257pub struct TokenUsage {
258    pub input_tokens: Option<u32>,
259    pub output_tokens: Option<u32>,
260    pub total_tokens: Option<u32>,
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
264#[serde(rename_all = "snake_case")]
265pub enum RequestMode {
266    #[default]
267    Normalized,
268    Passthrough,
269}
270
271#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
272pub struct RequestOptions {
273    pub temperature: Option<f32>,
274    pub top_p: Option<f32>,
275    pub max_output_tokens: Option<u32>,
276    #[serde(default)]
277    pub stop: Vec<String>,
278    #[serde(default)]
279    pub metadata: Map<String, Value>,
280    #[serde(default)]
281    pub provider_options: Map<String, Value>,
282    #[serde(default)]
283    pub mode: RequestMode,
284}
285
286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287pub struct ChatCompletionRequest {
288    pub model: String,
289    pub messages: Vec<ChatMessage>,
290    pub stream: bool,
291    #[serde(default)]
292    pub options: RequestOptions,
293}
294
295#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
296pub struct ChatCompletionResult {
297    pub model: String,
298    pub message: ChatMessage,
299    pub finish_reason: String,
300    pub usage: TokenUsage,
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
304#[serde(rename_all = "snake_case")]
305pub enum ProviderAuthState {
306    SignedOut,
307    SigningIn,
308    Connected,
309    Expired,
310    Error,
311}
312
313#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
314pub struct ProviderAuthStatus {
315    pub state: ProviderAuthState,
316    pub label: String,
317}
318
319#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
320pub struct ProviderLoginSession {
321    pub login_id: String,
322    pub auth_url: String,
323    pub user_code: Option<String>,
324    pub interval_seconds: Option<u64>,
325}
326
327#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
328pub struct RequestLogEntry {
329    pub id: Uuid,
330    pub started_at: DateTime<Utc>,
331    pub key_id: Option<Uuid>,
332    pub profile_id: Option<Uuid>,
333    pub provider: ProviderKind,
334    pub model: String,
335    pub endpoint: String,
336    pub status_code: Option<u16>,
337    pub duration_ms: u64,
338    pub usage: TokenUsage,
339    pub error_message: Option<String>,
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
343pub struct NewRequestLogEntry {
344    pub key_id: Option<Uuid>,
345    pub profile_id: Option<Uuid>,
346    pub provider: ProviderKind,
347    pub model: String,
348    pub endpoint: String,
349    pub status_code: Option<u16>,
350    pub duration_ms: u64,
351    pub usage: TokenUsage,
352    pub error_message: Option<String>,
353}
354
355#[cfg(test)]
356mod tests {
357    use chrono::Duration;
358
359    use super::*;
360
361    #[test]
362    fn provider_parses_known_and_custom_variants() {
363        assert_eq!(
364            "codex".parse::<ProviderKind>().unwrap(),
365            ProviderKind::Codex
366        );
367        assert_eq!(
368            "edgebox".parse::<ProviderKind>().unwrap(),
369            ProviderKind::Custom("edgebox".to_owned())
370        );
371    }
372
373    #[test]
374    fn active_key_checks_state_expiry_and_provider() {
375        let now = Utc::now();
376        let key = GunmetalKey {
377            id: Uuid::new_v4(),
378            name: "default".to_owned(),
379            prefix: "gm_test".to_owned(),
380            state: KeyState::Active,
381            scopes: vec![KeyScope::Inference],
382            allowed_providers: vec![ProviderKind::Codex],
383            expires_at: Some(now + Duration::hours(1)),
384            created_at: now,
385            updated_at: now,
386            last_used_at: None,
387        };
388
389        assert!(key.can_access_provider(&ProviderKind::Codex));
390        assert!(!key.can_access_provider(&ProviderKind::Copilot));
391        assert!(key.is_usable_at(now));
392        assert!(!key.is_usable_at(now + Duration::hours(2)));
393    }
394
395    #[test]
396    fn chat_role_parses_known_values() {
397        assert_eq!("user".parse::<ChatRole>().unwrap(), ChatRole::User);
398        assert!("tool".parse::<ChatRole>().is_err());
399    }
400
401    #[test]
402    fn request_options_default_to_normalized_mode() {
403        let options = RequestOptions::default();
404        assert_eq!(options.mode, RequestMode::Normalized);
405        assert!(options.provider_options.is_empty());
406        assert!(options.metadata.is_empty());
407    }
408
409    #[test]
410    fn gunmetal_key_roundtrip() {
411        let now = Utc::now();
412        let original = GunmetalKey {
413            id: Uuid::new_v4(),
414            name: "test-key".to_owned(),
415            prefix: "gm_test".to_owned(),
416            state: KeyState::Active,
417            scopes: vec![KeyScope::Inference, KeyScope::ModelsRead],
418            allowed_providers: vec![ProviderKind::Codex, ProviderKind::Custom("edge".to_owned())],
419            expires_at: Some(now + Duration::hours(1)),
420            created_at: now,
421            updated_at: now,
422            last_used_at: None,
423        };
424        let json = serde_json::to_string(&original).unwrap();
425        let deserialized: GunmetalKey = serde_json::from_str(&json).unwrap();
426        assert_eq!(original, deserialized);
427    }
428
429    #[test]
430    fn provider_profile_roundtrip() {
431        let now = Utc::now();
432        let original = ProviderProfile {
433            id: Uuid::new_v4(),
434            provider: ProviderKind::OpenAi,
435            name: "openai".to_owned(),
436            base_url: Some("https://api.openai.com".to_owned()),
437            enabled: true,
438            credentials: Some(serde_json::json!({"key": "secret"})),
439            created_at: now,
440            updated_at: now,
441        };
442        let json = serde_json::to_string(&original).unwrap();
443        let deserialized: ProviderProfile = serde_json::from_str(&json).unwrap();
444        assert_eq!(original, deserialized);
445    }
446
447    #[test]
448    fn model_descriptor_roundtrip() {
449        let original = ModelDescriptor {
450            id: "openai/gpt-4".to_owned(),
451            provider: ProviderKind::OpenAi,
452            profile_id: Some(Uuid::new_v4()),
453            upstream_name: "gpt-4".to_owned(),
454            display_name: "GPT-4".to_owned(),
455            metadata: Some(ModelMetadata {
456                family: Some("gpt".to_owned()),
457                release_date: Some("2023-03-14".to_owned()),
458                last_updated: None,
459                input_modalities: vec!["text".to_owned()],
460                output_modalities: vec!["text".to_owned()],
461                context_window: Some(8192),
462                max_output_tokens: Some(4096),
463                supports_attachments: Some(false),
464                supports_reasoning: Some(true),
465                supports_tools: Some(true),
466                open_weights: Some(false),
467            }),
468        };
469        let json = serde_json::to_string(&original).unwrap();
470        let deserialized: ModelDescriptor = serde_json::from_str(&json).unwrap();
471        assert_eq!(original, deserialized);
472    }
473
474    #[test]
475    fn token_usage_roundtrip() {
476        let original = TokenUsage {
477            input_tokens: Some(10),
478            output_tokens: Some(20),
479            total_tokens: Some(30),
480        };
481        let json = serde_json::to_string(&original).unwrap();
482        let deserialized: TokenUsage = serde_json::from_str(&json).unwrap();
483        assert_eq!(original, deserialized);
484    }
485
486    #[test]
487    fn request_options_roundtrip() {
488        let mut metadata = Map::new();
489        metadata.insert(
490            "user".to_owned(),
491            serde_json::Value::String("alice".to_owned()),
492        );
493        let original = RequestOptions {
494            temperature: Some(0.7),
495            top_p: Some(0.9),
496            max_output_tokens: Some(256),
497            stop: vec!["STOP".to_owned()],
498            metadata,
499            provider_options: Map::new(),
500            mode: RequestMode::Passthrough,
501        };
502        let json = serde_json::to_string(&original).unwrap();
503        let deserialized: RequestOptions = serde_json::from_str(&json).unwrap();
504        assert_eq!(original, deserialized);
505    }
506
507    #[test]
508    fn chat_completion_request_roundtrip() {
509        let original = ChatCompletionRequest {
510            model: "gpt-4".to_owned(),
511            messages: vec![
512                ChatMessage {
513                    role: ChatRole::System,
514                    content: "You are helpful.".to_owned(),
515                },
516                ChatMessage {
517                    role: ChatRole::User,
518                    content: "Hello".to_owned(),
519                },
520            ],
521            stream: true,
522            options: RequestOptions::default(),
523        };
524        let json = serde_json::to_string(&original).unwrap();
525        let deserialized: ChatCompletionRequest = serde_json::from_str(&json).unwrap();
526        assert_eq!(original, deserialized);
527    }
528
529    #[test]
530    fn chat_completion_result_roundtrip() {
531        let original = ChatCompletionResult {
532            model: "gpt-4".to_owned(),
533            message: ChatMessage {
534                role: ChatRole::Assistant,
535                content: "Hi there!".to_owned(),
536            },
537            finish_reason: "stop".to_owned(),
538            usage: TokenUsage {
539                input_tokens: Some(1),
540                output_tokens: Some(2),
541                total_tokens: Some(3),
542            },
543        };
544        let json = serde_json::to_string(&original).unwrap();
545        let deserialized: ChatCompletionResult = serde_json::from_str(&json).unwrap();
546        assert_eq!(original, deserialized);
547    }
548
549    #[test]
550    fn chat_message_roundtrip() {
551        let original = ChatMessage {
552            role: ChatRole::User,
553            content: "test".to_owned(),
554        };
555        let json = serde_json::to_string(&original).unwrap();
556        let deserialized: ChatMessage = serde_json::from_str(&json).unwrap();
557        assert_eq!(original, deserialized);
558    }
559
560    #[test]
561    fn provider_auth_status_roundtrip() {
562        let original = ProviderAuthStatus {
563            state: ProviderAuthState::Connected,
564            label: "Connected to OpenAI".to_owned(),
565        };
566        let json = serde_json::to_string(&original).unwrap();
567        let deserialized: ProviderAuthStatus = serde_json::from_str(&json).unwrap();
568        assert_eq!(original, deserialized);
569    }
570
571    #[test]
572    fn chat_message_empty_content_roundtrip() {
573        let original = ChatMessage {
574            role: ChatRole::Assistant,
575            content: "".to_owned(),
576        };
577        let json = serde_json::to_string(&original).unwrap();
578        let deserialized: ChatMessage = serde_json::from_str(&json).unwrap();
579        assert_eq!(original, deserialized);
580    }
581
582    #[test]
583    fn token_usage_missing_fields_deserialize() {
584        let json = r#"{"input_tokens":10}"#;
585        let deserialized: TokenUsage = serde_json::from_str(json).unwrap();
586        assert_eq!(deserialized.input_tokens, Some(10));
587        assert_eq!(deserialized.output_tokens, None);
588        assert_eq!(deserialized.total_tokens, None);
589    }
590
591    #[test]
592    fn request_options_defaults_when_missing() {
593        let json = r#"{"temperature":0.5}"#;
594        let deserialized: RequestOptions = serde_json::from_str(json).unwrap();
595        assert_eq!(deserialized.temperature, Some(0.5));
596        assert!(deserialized.stop.is_empty());
597        assert!(deserialized.metadata.is_empty());
598        assert!(deserialized.provider_options.is_empty());
599        assert_eq!(deserialized.mode, RequestMode::Normalized);
600    }
601
602    #[test]
603    fn model_descriptor_null_metadata() {
604        let json = r#"{
605            "id": "openai/gpt-4",
606            "provider": {"kind":"open_ai","value":null},
607            "profile_id": null,
608            "upstream_name": "gpt-4",
609            "display_name": "GPT-4",
610            "metadata": null
611        }"#;
612        let deserialized: ModelDescriptor = serde_json::from_str(json).unwrap();
613        assert_eq!(deserialized.metadata, None);
614    }
615
616    #[test]
617    fn provider_auth_state_enum_variants() {
618        for state in [
619            ProviderAuthState::SignedOut,
620            ProviderAuthState::SigningIn,
621            ProviderAuthState::Connected,
622            ProviderAuthState::Expired,
623            ProviderAuthState::Error,
624        ] {
625            let status = ProviderAuthStatus {
626                state,
627                label: "test".to_owned(),
628            };
629            let json = serde_json::to_string(&status).unwrap();
630            let deserialized: ProviderAuthStatus = serde_json::from_str(&json).unwrap();
631            assert_eq!(status, deserialized);
632        }
633    }
634
635    #[test]
636    fn provider_kind_enum_variants() {
637        for kind in [
638            ProviderKind::Codex,
639            ProviderKind::Copilot,
640            ProviderKind::OpenRouter,
641            ProviderKind::Zen,
642            ProviderKind::OpenAi,
643            ProviderKind::Azure,
644            ProviderKind::Nvidia,
645            ProviderKind::Custom("x".to_owned()),
646        ] {
647            let json = serde_json::to_string(&kind).unwrap();
648            let deserialized: ProviderKind = serde_json::from_str(&json).unwrap();
649            assert_eq!(kind, deserialized);
650        }
651    }
652}