1use std::collections::HashMap;
17
18use serde::{Deserialize, Serialize};
19
20use katu_core::{ModelId, ProviderId, RouteId};
21
22use crate::cache::CachePolicy;
23use katu_core::GenerationOptions;
24use crate::http::HttpOptions;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum InputModality {
34 Text,
36 Image,
38 Audio,
40 Video,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
52#[serde(rename_all = "snake_case")]
53pub enum ReasoningEffort {
54 None,
55 Low,
56 Medium,
57 High,
58 XHigh,
59 Max
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74pub enum ThinkingMode {
75 Adaptive,
77 Budget,
79 Effort,
81}
82
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113pub struct ThinkingConfig {
114 pub mode: ThinkingMode,
116 pub default_budget: Option<u32>,
118 pub min_effort: Option<ReasoningEffort>,
120 pub max_effort: Option<ReasoningEffort>,
122}
123
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
147pub struct ModelCapabilities {
148 pub input_modalities: Vec<InputModality>,
150 pub tool_calls: bool,
152 pub streaming_tool_input: bool,
154 pub structured_output: bool,
156 pub prompt_caching: bool,
158 pub thinking: Option<ThinkingConfig>,
160}
161
162impl ModelCapabilities {
163 pub fn supports_modality(&self, modality: InputModality) -> bool {
165 self.input_modalities.contains(&modality)
166 }
167
168 pub fn supports_thinking(&self) -> bool {
170 self.thinking.is_some()
171 }
172}
173
174impl Default for ModelCapabilities {
175 fn default() -> Self {
176 Self {
177 input_modalities: vec![InputModality::Text],
178 tool_calls: true,
179 streaming_tool_input: false,
180 structured_output: false,
181 prompt_caching: false,
182 thinking: None,
183 }
184 }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
203pub struct ModelLimits {
204 pub context_window: u32,
206 pub max_output_tokens: u32,
208}
209
210#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
230pub struct ModelPricing {
231 pub input: f64,
233 pub output: f64,
235 pub cache_read: f64,
237 pub cache_write: f64,
239}
240
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
303pub struct ModelRef {
304 pub id: ModelId,
308 pub provider: ProviderId,
310 pub route: RouteId,
312 #[serde(skip_serializing_if = "Option::is_none")]
314 pub display_name: Option<String>,
315
316 pub base_url: String,
320 #[serde(skip_serializing_if = "Option::is_none")]
322 pub api_key: Option<String>,
323 #[serde(skip_serializing_if = "Option::is_none")]
325 pub headers: Option<HashMap<String, String>>,
326 #[serde(skip_serializing_if = "Option::is_none")]
328 pub query_params: Option<HashMap<String, String>>,
329
330 pub limits: ModelLimits,
334 pub capabilities: ModelCapabilities,
336
337 #[serde(skip_serializing_if = "Option::is_none")]
341 pub generation: Option<GenerationOptions>,
342 #[serde(skip_serializing_if = "Option::is_none")]
344 pub cache_policy: Option<CachePolicy>,
345
346 #[serde(skip_serializing_if = "Option::is_none")]
350 pub pricing: Option<ModelPricing>,
351
352 #[serde(skip_serializing_if = "Option::is_none")]
356 pub provider_options: Option<serde_json::Value>,
357 #[serde(skip_serializing_if = "Option::is_none")]
359 pub http: Option<HttpOptions>,
360}
361
362impl ModelRef {
363 pub fn new(
365 id: ModelId,
366 provider: ProviderId,
367 route: RouteId,
368 base_url: impl Into<String>,
369 limits: ModelLimits,
370 ) -> Self {
371 Self {
372 id,
373 provider,
374 route,
375 display_name: None,
376 base_url: base_url.into(),
377 api_key: None,
378 headers: None,
379 query_params: None,
380 limits,
381 capabilities: ModelCapabilities::default(),
382 generation: None,
383 cache_policy: None,
384 pricing: None,
385 provider_options: None,
386 http: None,
387 }
388 }
389
390 pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
392 self.display_name = Some(name.into());
393 self
394 }
395
396 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
398 self.api_key = Some(key.into());
399 self
400 }
401
402 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
404 self.headers
405 .get_or_insert_with(HashMap::new)
406 .insert(key.into(), value.into());
407 self
408 }
409
410 pub fn with_query_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
412 self.query_params
413 .get_or_insert_with(HashMap::new)
414 .insert(key.into(), value.into());
415 self
416 }
417
418 pub fn with_capabilities(mut self, capabilities: ModelCapabilities) -> Self {
420 self.capabilities = capabilities;
421 self
422 }
423
424 pub fn with_generation(mut self, generation: GenerationOptions) -> Self {
426 self.generation = Some(generation);
427 self
428 }
429
430 pub fn with_cache_policy(mut self, policy: CachePolicy) -> Self {
432 self.cache_policy = Some(policy);
433 self
434 }
435
436 pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
438 self.pricing = Some(pricing);
439 self
440 }
441
442 pub fn with_provider_options(mut self, options: serde_json::Value) -> Self {
444 self.provider_options = Some(options);
445 self
446 }
447
448 pub fn with_http(mut self, http: HttpOptions) -> Self {
450 self.http = Some(http);
451 self
452 }
453}
454
455#[cfg(test)]
460mod tests {
461 use super::*;
462
463 fn sample_model() -> ModelRef {
464 ModelRef::new(
465 ModelId::new("claude-sonnet-4-20250514"),
466 ProviderId::new("anthropic"),
467 RouteId::new("anthropic-messages"),
468 "https://api.anthropic.com/v1",
469 ModelLimits {
470 context_window: 200_000,
471 max_output_tokens: 8192,
472 },
473 )
474 }
475
476 #[test]
477 fn test_new_has_required_fields() {
478 let m = sample_model();
479 assert_eq!(m.id.as_str(), "claude-sonnet-4-20250514");
480 assert_eq!(m.provider.as_str(), "anthropic");
481 assert_eq!(m.route.as_str(), "anthropic-messages");
482 assert_eq!(m.base_url, "https://api.anthropic.com/v1");
483 assert_eq!(m.limits.context_window, 200_000);
484 assert_eq!(m.limits.max_output_tokens, 8192);
485 }
486
487 #[test]
488 fn test_new_optional_fields_are_none() {
489 let m = sample_model();
490 assert_eq!(m.display_name, None);
491 assert_eq!(m.api_key, None);
492 assert_eq!(m.headers, None);
493 assert_eq!(m.generation, None);
494 assert_eq!(m.pricing, None);
495 assert_eq!(m.provider_options, None);
496 assert_eq!(m.http, None);
497 }
498
499 #[test]
500 fn test_builder_chain() {
501 let m = sample_model()
502 .with_display_name("Claude Sonnet 4")
503 .with_api_key("sk-ant-xxx")
504 .with_header("x-custom", "value")
505 .with_query_param("version", "1")
506 .with_generation(GenerationOptions::new().with_max_tokens(4096))
507 .with_cache_policy(CachePolicy::Auto)
508 .with_pricing(ModelPricing {
509 input: 3.0,
510 output: 15.0,
511 cache_read: 0.30,
512 cache_write: 3.75,
513 });
514
515 assert_eq!(m.display_name.as_deref(), Some("Claude Sonnet 4"));
516 assert_eq!(m.api_key.as_deref(), Some("sk-ant-xxx"));
517 assert_eq!(
518 m.headers.as_ref().unwrap().get("x-custom").unwrap(),
519 "value"
520 );
521 assert_eq!(
522 m.generation.as_ref().unwrap().max_tokens,
523 Some(4096)
524 );
525 assert_eq!(m.pricing.as_ref().unwrap().input, 3.0);
526 }
527
528 #[test]
529 fn test_capabilities_default() {
530 let m = sample_model();
531 assert!(m.capabilities.supports_modality(InputModality::Text));
532 assert!(!m.capabilities.supports_modality(InputModality::Image));
533 assert!(m.capabilities.tool_calls);
534 assert!(!m.capabilities.supports_thinking());
535 }
536
537 #[test]
538 fn test_capabilities_with_thinking() {
539 let m = sample_model().with_capabilities(ModelCapabilities {
540 input_modalities: vec![InputModality::Text, InputModality::Image],
541 tool_calls: true,
542 streaming_tool_input: true,
543 structured_output: false,
544 prompt_caching: true,
545 thinking: Some(ThinkingConfig {
546 mode: ThinkingMode::Adaptive,
547 default_budget: None,
548 min_effort: None,
549 max_effort: None,
550 }),
551 });
552
553 assert!(m.capabilities.supports_thinking());
554 assert!(m.capabilities.supports_modality(InputModality::Image));
555 assert!(m.capabilities.streaming_tool_input);
556 }
557
558 #[test]
559 fn test_serde_roundtrip_minimal() {
560 let m = sample_model();
561 let json = serde_json::to_string(&m).unwrap();
562 let restored: ModelRef = serde_json::from_str(&json).unwrap();
563 assert_eq!(m.id, restored.id);
564 assert_eq!(m.provider, restored.provider);
565 assert_eq!(m.limits, restored.limits);
566 }
567
568 #[test]
569 fn test_serde_roundtrip_full() {
570 let m = sample_model()
571 .with_display_name("Claude Sonnet 4")
572 .with_api_key("sk-test")
573 .with_capabilities(ModelCapabilities {
574 input_modalities: vec![InputModality::Text, InputModality::Image],
575 tool_calls: true,
576 streaming_tool_input: true,
577 structured_output: true,
578 prompt_caching: true,
579 thinking: Some(ThinkingConfig {
580 mode: ThinkingMode::Budget,
581 default_budget: Some(10000),
582 min_effort: Some(ReasoningEffort::Low),
583 max_effort: Some(ReasoningEffort::High),
584 }),
585 })
586 .with_generation(GenerationOptions::new().with_max_tokens(4096).with_temperature(0.7))
587 .with_cache_policy(CachePolicy::Auto)
588 .with_pricing(ModelPricing {
589 input: 3.0,
590 output: 15.0,
591 cache_read: 0.30,
592 cache_write: 3.75,
593 })
594 .with_provider_options(serde_json::json!({"region": "us-east-1"}))
595 .with_http(HttpOptions::new().with_header("x-extra", "val"));
596
597 let json = serde_json::to_string_pretty(&m).unwrap();
598 let restored: ModelRef = serde_json::from_str(&json).unwrap();
599 assert_eq!(m, restored);
600 }
601
602 #[test]
603 fn test_serde_skips_none_fields() {
604 let m = sample_model();
605 let json = serde_json::to_string(&m).unwrap();
606 assert!(!json.contains("display_name"));
608 assert!(!json.contains("api_key"));
609 assert!(!json.contains("pricing"));
610 assert!(!json.contains("provider_options"));
611 }
612}