1use std::{borrow::Cow, path::Path};
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::{Map, Value};
6use uuid::Uuid;
7
8pub trait ProviderContext: Send + Sync {
9 fn helpers_dir(&self) -> &Path;
10 fn empty_workspace_dir(&self) -> &Path;
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
15pub enum ProviderKind {
16 Codex,
17 Copilot,
18 OpenRouter,
19 Zen,
20 OpenAi,
21 Azure,
22 Nvidia,
23 Custom(String),
24}
25
26impl ProviderKind {
27 pub fn slug(&self) -> Cow<'_, str> {
28 match self {
29 Self::Codex => Cow::Borrowed("codex"),
30 Self::Copilot => Cow::Borrowed("copilot"),
31 Self::OpenRouter => Cow::Borrowed("openrouter"),
32 Self::Zen => Cow::Borrowed("zen"),
33 Self::OpenAi => Cow::Borrowed("openai"),
34 Self::Azure => Cow::Borrowed("azure"),
35 Self::Nvidia => Cow::Borrowed("nvidia"),
36 Self::Custom(value) => Cow::Borrowed(value.as_str()),
37 }
38 }
39}
40
41impl std::fmt::Display for ProviderKind {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "{}", self.slug())
44 }
45}
46
47impl std::str::FromStr for ProviderKind {
48 type Err = String;
49
50 fn from_str(value: &str) -> Result<Self, Self::Err> {
51 match value {
52 "codex" => Ok(Self::Codex),
53 "copilot" => Ok(Self::Copilot),
54 "openrouter" => Ok(Self::OpenRouter),
55 "zen" => Ok(Self::Zen),
56 "openai" => Ok(Self::OpenAi),
57 "azure" => Ok(Self::Azure),
58 "nvidia" => Ok(Self::Nvidia),
59 value if !value.trim().is_empty() => Ok(Self::Custom(value.to_owned())),
60 _ => Err("provider kind cannot be empty".to_owned()),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum KeyScope {
68 Inference,
69 ModelsRead,
70 LogsRead,
71}
72
73impl std::fmt::Display for KeyScope {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 let value = match self {
76 Self::Inference => "inference",
77 Self::ModelsRead => "models_read",
78 Self::LogsRead => "logs_read",
79 };
80
81 write!(f, "{value}")
82 }
83}
84
85impl std::str::FromStr for KeyScope {
86 type Err = String;
87
88 fn from_str(value: &str) -> Result<Self, Self::Err> {
89 match value {
90 "inference" => Ok(Self::Inference),
91 "models_read" => Ok(Self::ModelsRead),
92 "logs_read" => Ok(Self::LogsRead),
93 _ => Err(format!("unknown key scope: {value}")),
94 }
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
99#[serde(rename_all = "snake_case")]
100pub enum KeyState {
101 Active,
102 Disabled,
103 Revoked,
104}
105
106impl std::fmt::Display for KeyState {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 let value = match self {
109 Self::Active => "active",
110 Self::Disabled => "disabled",
111 Self::Revoked => "revoked",
112 };
113
114 write!(f, "{value}")
115 }
116}
117
118impl std::str::FromStr for KeyState {
119 type Err = String;
120
121 fn from_str(value: &str) -> Result<Self, Self::Err> {
122 match value {
123 "active" => Ok(Self::Active),
124 "disabled" => Ok(Self::Disabled),
125 "revoked" => Ok(Self::Revoked),
126 _ => Err(format!("unknown key state: {value}")),
127 }
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
132pub struct GunmetalKey {
133 pub id: Uuid,
134 pub name: String,
135 pub prefix: String,
136 pub state: KeyState,
137 pub scopes: Vec<KeyScope>,
138 pub allowed_providers: Vec<ProviderKind>,
139 pub expires_at: Option<DateTime<Utc>>,
140 pub created_at: DateTime<Utc>,
141 pub updated_at: DateTime<Utc>,
142 pub last_used_at: Option<DateTime<Utc>>,
143}
144
145impl GunmetalKey {
146 pub fn can_access_provider(&self, provider: &ProviderKind) -> bool {
147 self.allowed_providers.is_empty()
148 || self.allowed_providers.iter().any(|item| item == provider)
149 }
150
151 pub fn is_usable_at(&self, now: DateTime<Utc>) -> bool {
152 self.state == KeyState::Active && self.expires_at.is_none_or(|value| value > now)
153 }
154}
155
156#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
157pub struct NewGunmetalKey {
158 pub name: String,
159 pub scopes: Vec<KeyScope>,
160 pub allowed_providers: Vec<ProviderKind>,
161 pub expires_at: Option<DateTime<Utc>>,
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
165pub struct CreatedGunmetalKey {
166 pub record: GunmetalKey,
167 pub secret: String,
168}
169
170#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
171pub struct ProviderProfile {
172 pub id: Uuid,
173 pub provider: ProviderKind,
174 pub name: String,
175 pub base_url: Option<String>,
176 pub enabled: bool,
177 pub credentials: Option<Value>,
178 pub created_at: DateTime<Utc>,
179 pub updated_at: DateTime<Utc>,
180}
181
182#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
183pub struct NewProviderProfile {
184 pub provider: ProviderKind,
185 pub name: String,
186 pub base_url: Option<String>,
187 pub enabled: bool,
188 pub credentials: Option<Value>,
189}
190
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct ModelDescriptor {
193 pub id: String,
194 pub provider: ProviderKind,
195 pub profile_id: Option<Uuid>,
196 pub upstream_name: String,
197 pub display_name: String,
198 pub metadata: Option<ModelMetadata>,
199}
200
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
202pub struct ModelMetadata {
203 pub family: Option<String>,
204 pub release_date: Option<String>,
205 pub last_updated: Option<String>,
206 #[serde(default)]
207 pub input_modalities: Vec<String>,
208 #[serde(default)]
209 pub output_modalities: Vec<String>,
210 pub context_window: Option<u32>,
211 pub max_output_tokens: Option<u32>,
212 pub supports_attachments: Option<bool>,
213 pub supports_reasoning: Option<bool>,
214 pub supports_tools: Option<bool>,
215 pub open_weights: Option<bool>,
216}
217
218#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
219#[serde(rename_all = "snake_case")]
220pub enum ChatRole {
221 System,
222 User,
223 Assistant,
224}
225
226impl std::fmt::Display for ChatRole {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 let value = match self {
229 Self::System => "system",
230 Self::User => "user",
231 Self::Assistant => "assistant",
232 };
233 write!(f, "{value}")
234 }
235}
236
237impl std::str::FromStr for ChatRole {
238 type Err = String;
239
240 fn from_str(value: &str) -> Result<Self, Self::Err> {
241 match value {
242 "system" => Ok(Self::System),
243 "user" => Ok(Self::User),
244 "assistant" => Ok(Self::Assistant),
245 _ => Err(format!("unknown chat role: {value}")),
246 }
247 }
248}
249
250#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
251pub struct ChatMessage {
252 pub role: ChatRole,
253 pub content: String,
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
257pub struct TokenUsage {
258 pub input_tokens: Option<u32>,
259 pub output_tokens: Option<u32>,
260 pub total_tokens: Option<u32>,
261}
262
263#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
264#[serde(rename_all = "snake_case")]
265pub enum RequestMode {
266 #[default]
267 Normalized,
268 Passthrough,
269}
270
271#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
272pub struct RequestOptions {
273 pub temperature: Option<f32>,
274 pub top_p: Option<f32>,
275 pub max_output_tokens: Option<u32>,
276 #[serde(default)]
277 pub stop: Vec<String>,
278 #[serde(default)]
279 pub metadata: Map<String, Value>,
280 #[serde(default)]
281 pub provider_options: Map<String, Value>,
282 #[serde(default)]
283 pub mode: RequestMode,
284}
285
286#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
287pub struct ChatCompletionRequest {
288 pub model: String,
289 pub messages: Vec<ChatMessage>,
290 pub stream: bool,
291 #[serde(default)]
292 pub options: RequestOptions,
293}
294
295#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
296pub struct ChatCompletionResult {
297 pub model: String,
298 pub message: ChatMessage,
299 pub finish_reason: String,
300 pub usage: TokenUsage,
301}
302
303#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
304#[serde(rename_all = "snake_case")]
305pub enum ProviderAuthState {
306 SignedOut,
307 SigningIn,
308 Connected,
309 Expired,
310 Error,
311}
312
313#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
314pub struct ProviderAuthStatus {
315 pub state: ProviderAuthState,
316 pub label: String,
317}
318
319#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
320pub struct ProviderLoginSession {
321 pub login_id: String,
322 pub auth_url: String,
323 pub user_code: Option<String>,
324 pub interval_seconds: Option<u64>,
325}
326
327#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
328pub struct RequestLogEntry {
329 pub id: Uuid,
330 pub started_at: DateTime<Utc>,
331 pub key_id: Option<Uuid>,
332 pub profile_id: Option<Uuid>,
333 pub provider: ProviderKind,
334 pub model: String,
335 pub endpoint: String,
336 pub status_code: Option<u16>,
337 pub duration_ms: u64,
338 pub usage: TokenUsage,
339 pub error_message: Option<String>,
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
343pub struct NewRequestLogEntry {
344 pub key_id: Option<Uuid>,
345 pub profile_id: Option<Uuid>,
346 pub provider: ProviderKind,
347 pub model: String,
348 pub endpoint: String,
349 pub status_code: Option<u16>,
350 pub duration_ms: u64,
351 pub usage: TokenUsage,
352 pub error_message: Option<String>,
353}
354
355#[cfg(test)]
356mod tests {
357 use chrono::Duration;
358
359 use super::*;
360
361 #[test]
362 fn provider_parses_known_and_custom_variants() {
363 assert_eq!(
364 "codex".parse::<ProviderKind>().unwrap(),
365 ProviderKind::Codex
366 );
367 assert_eq!(
368 "edgebox".parse::<ProviderKind>().unwrap(),
369 ProviderKind::Custom("edgebox".to_owned())
370 );
371 }
372
373 #[test]
374 fn active_key_checks_state_expiry_and_provider() {
375 let now = Utc::now();
376 let key = GunmetalKey {
377 id: Uuid::new_v4(),
378 name: "default".to_owned(),
379 prefix: "gm_test".to_owned(),
380 state: KeyState::Active,
381 scopes: vec![KeyScope::Inference],
382 allowed_providers: vec![ProviderKind::Codex],
383 expires_at: Some(now + Duration::hours(1)),
384 created_at: now,
385 updated_at: now,
386 last_used_at: None,
387 };
388
389 assert!(key.can_access_provider(&ProviderKind::Codex));
390 assert!(!key.can_access_provider(&ProviderKind::Copilot));
391 assert!(key.is_usable_at(now));
392 assert!(!key.is_usable_at(now + Duration::hours(2)));
393 }
394
395 #[test]
396 fn chat_role_parses_known_values() {
397 assert_eq!("user".parse::<ChatRole>().unwrap(), ChatRole::User);
398 assert!("tool".parse::<ChatRole>().is_err());
399 }
400
401 #[test]
402 fn request_options_default_to_normalized_mode() {
403 let options = RequestOptions::default();
404 assert_eq!(options.mode, RequestMode::Normalized);
405 assert!(options.provider_options.is_empty());
406 assert!(options.metadata.is_empty());
407 }
408
409 #[test]
410 fn gunmetal_key_roundtrip() {
411 let now = Utc::now();
412 let original = GunmetalKey {
413 id: Uuid::new_v4(),
414 name: "test-key".to_owned(),
415 prefix: "gm_test".to_owned(),
416 state: KeyState::Active,
417 scopes: vec![KeyScope::Inference, KeyScope::ModelsRead],
418 allowed_providers: vec![ProviderKind::Codex, ProviderKind::Custom("edge".to_owned())],
419 expires_at: Some(now + Duration::hours(1)),
420 created_at: now,
421 updated_at: now,
422 last_used_at: None,
423 };
424 let json = serde_json::to_string(&original).unwrap();
425 let deserialized: GunmetalKey = serde_json::from_str(&json).unwrap();
426 assert_eq!(original, deserialized);
427 }
428
429 #[test]
430 fn provider_profile_roundtrip() {
431 let now = Utc::now();
432 let original = ProviderProfile {
433 id: Uuid::new_v4(),
434 provider: ProviderKind::OpenAi,
435 name: "openai".to_owned(),
436 base_url: Some("https://api.openai.com".to_owned()),
437 enabled: true,
438 credentials: Some(serde_json::json!({"key": "secret"})),
439 created_at: now,
440 updated_at: now,
441 };
442 let json = serde_json::to_string(&original).unwrap();
443 let deserialized: ProviderProfile = serde_json::from_str(&json).unwrap();
444 assert_eq!(original, deserialized);
445 }
446
447 #[test]
448 fn model_descriptor_roundtrip() {
449 let original = ModelDescriptor {
450 id: "openai/gpt-4".to_owned(),
451 provider: ProviderKind::OpenAi,
452 profile_id: Some(Uuid::new_v4()),
453 upstream_name: "gpt-4".to_owned(),
454 display_name: "GPT-4".to_owned(),
455 metadata: Some(ModelMetadata {
456 family: Some("gpt".to_owned()),
457 release_date: Some("2023-03-14".to_owned()),
458 last_updated: None,
459 input_modalities: vec!["text".to_owned()],
460 output_modalities: vec!["text".to_owned()],
461 context_window: Some(8192),
462 max_output_tokens: Some(4096),
463 supports_attachments: Some(false),
464 supports_reasoning: Some(true),
465 supports_tools: Some(true),
466 open_weights: Some(false),
467 }),
468 };
469 let json = serde_json::to_string(&original).unwrap();
470 let deserialized: ModelDescriptor = serde_json::from_str(&json).unwrap();
471 assert_eq!(original, deserialized);
472 }
473
474 #[test]
475 fn token_usage_roundtrip() {
476 let original = TokenUsage {
477 input_tokens: Some(10),
478 output_tokens: Some(20),
479 total_tokens: Some(30),
480 };
481 let json = serde_json::to_string(&original).unwrap();
482 let deserialized: TokenUsage = serde_json::from_str(&json).unwrap();
483 assert_eq!(original, deserialized);
484 }
485
486 #[test]
487 fn request_options_roundtrip() {
488 let mut metadata = Map::new();
489 metadata.insert(
490 "user".to_owned(),
491 serde_json::Value::String("alice".to_owned()),
492 );
493 let original = RequestOptions {
494 temperature: Some(0.7),
495 top_p: Some(0.9),
496 max_output_tokens: Some(256),
497 stop: vec!["STOP".to_owned()],
498 metadata,
499 provider_options: Map::new(),
500 mode: RequestMode::Passthrough,
501 };
502 let json = serde_json::to_string(&original).unwrap();
503 let deserialized: RequestOptions = serde_json::from_str(&json).unwrap();
504 assert_eq!(original, deserialized);
505 }
506
507 #[test]
508 fn chat_completion_request_roundtrip() {
509 let original = ChatCompletionRequest {
510 model: "gpt-4".to_owned(),
511 messages: vec![
512 ChatMessage {
513 role: ChatRole::System,
514 content: "You are helpful.".to_owned(),
515 },
516 ChatMessage {
517 role: ChatRole::User,
518 content: "Hello".to_owned(),
519 },
520 ],
521 stream: true,
522 options: RequestOptions::default(),
523 };
524 let json = serde_json::to_string(&original).unwrap();
525 let deserialized: ChatCompletionRequest = serde_json::from_str(&json).unwrap();
526 assert_eq!(original, deserialized);
527 }
528
529 #[test]
530 fn chat_completion_result_roundtrip() {
531 let original = ChatCompletionResult {
532 model: "gpt-4".to_owned(),
533 message: ChatMessage {
534 role: ChatRole::Assistant,
535 content: "Hi there!".to_owned(),
536 },
537 finish_reason: "stop".to_owned(),
538 usage: TokenUsage {
539 input_tokens: Some(1),
540 output_tokens: Some(2),
541 total_tokens: Some(3),
542 },
543 };
544 let json = serde_json::to_string(&original).unwrap();
545 let deserialized: ChatCompletionResult = serde_json::from_str(&json).unwrap();
546 assert_eq!(original, deserialized);
547 }
548
549 #[test]
550 fn chat_message_roundtrip() {
551 let original = ChatMessage {
552 role: ChatRole::User,
553 content: "test".to_owned(),
554 };
555 let json = serde_json::to_string(&original).unwrap();
556 let deserialized: ChatMessage = serde_json::from_str(&json).unwrap();
557 assert_eq!(original, deserialized);
558 }
559
560 #[test]
561 fn provider_auth_status_roundtrip() {
562 let original = ProviderAuthStatus {
563 state: ProviderAuthState::Connected,
564 label: "Connected to OpenAI".to_owned(),
565 };
566 let json = serde_json::to_string(&original).unwrap();
567 let deserialized: ProviderAuthStatus = serde_json::from_str(&json).unwrap();
568 assert_eq!(original, deserialized);
569 }
570
571 #[test]
572 fn chat_message_empty_content_roundtrip() {
573 let original = ChatMessage {
574 role: ChatRole::Assistant,
575 content: "".to_owned(),
576 };
577 let json = serde_json::to_string(&original).unwrap();
578 let deserialized: ChatMessage = serde_json::from_str(&json).unwrap();
579 assert_eq!(original, deserialized);
580 }
581
582 #[test]
583 fn token_usage_missing_fields_deserialize() {
584 let json = r#"{"input_tokens":10}"#;
585 let deserialized: TokenUsage = serde_json::from_str(json).unwrap();
586 assert_eq!(deserialized.input_tokens, Some(10));
587 assert_eq!(deserialized.output_tokens, None);
588 assert_eq!(deserialized.total_tokens, None);
589 }
590
591 #[test]
592 fn request_options_defaults_when_missing() {
593 let json = r#"{"temperature":0.5}"#;
594 let deserialized: RequestOptions = serde_json::from_str(json).unwrap();
595 assert_eq!(deserialized.temperature, Some(0.5));
596 assert!(deserialized.stop.is_empty());
597 assert!(deserialized.metadata.is_empty());
598 assert!(deserialized.provider_options.is_empty());
599 assert_eq!(deserialized.mode, RequestMode::Normalized);
600 }
601
602 #[test]
603 fn model_descriptor_null_metadata() {
604 let json = r#"{
605 "id": "openai/gpt-4",
606 "provider": {"kind":"open_ai","value":null},
607 "profile_id": null,
608 "upstream_name": "gpt-4",
609 "display_name": "GPT-4",
610 "metadata": null
611 }"#;
612 let deserialized: ModelDescriptor = serde_json::from_str(json).unwrap();
613 assert_eq!(deserialized.metadata, None);
614 }
615
616 #[test]
617 fn provider_auth_state_enum_variants() {
618 for state in [
619 ProviderAuthState::SignedOut,
620 ProviderAuthState::SigningIn,
621 ProviderAuthState::Connected,
622 ProviderAuthState::Expired,
623 ProviderAuthState::Error,
624 ] {
625 let status = ProviderAuthStatus {
626 state,
627 label: "test".to_owned(),
628 };
629 let json = serde_json::to_string(&status).unwrap();
630 let deserialized: ProviderAuthStatus = serde_json::from_str(&json).unwrap();
631 assert_eq!(status, deserialized);
632 }
633 }
634
635 #[test]
636 fn provider_kind_enum_variants() {
637 for kind in [
638 ProviderKind::Codex,
639 ProviderKind::Copilot,
640 ProviderKind::OpenRouter,
641 ProviderKind::Zen,
642 ProviderKind::OpenAi,
643 ProviderKind::Azure,
644 ProviderKind::Nvidia,
645 ProviderKind::Custom("x".to_owned()),
646 ] {
647 let json = serde_json::to_string(&kind).unwrap();
648 let deserialized: ProviderKind = serde_json::from_str(&json).unwrap();
649 assert_eq!(kind, deserialized);
650 }
651 }
652}