Skip to main content

katu_llm/
model.rs

1//! # katu_llm::model
2//!
3//! ## 职责
4//! 定义模型引用 (`ModelRef`) 及其组成类型:能力描述、限制、定价、思考配置等。
5//!
6//! ## 对外接口
7//! - `ModelRef` — 可执行的模型引用(身份 + 连接 + 能力 + 默认参数)
8//! - `ModelLimits` — token 上限
9//! - `ModelPricing` — 费率定义
10//! - `ModelCapabilities` — 功能标志
11//! - `InputModality` — 支持的输入模态
12//! - `ThinkingMode` — 思考控制模式
13//! - `ThinkingConfig` — 思考能力配置
14//! - `ReasoningEffort` — 推理强度级别
15
16use std::collections::HashMap;
17
18use serde::{Deserialize, Serialize};
19
20use katu_core::{ModelId, ProviderId, RouteId};
21
22use crate::cache::CachePolicy;
23use katu_core::GenerationOptions;
24use crate::http::HttpOptions;
25
26// ---------------------------------------------------------------------------
27// InputModality
28// ---------------------------------------------------------------------------
29
30/// 模型支持的输入模态。
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum InputModality {
34    /// 文本输入
35    Text,
36    /// 图像输入
37    Image,
38    /// 音频输入
39    Audio,
40    /// 视频输入
41    Video,
42}
43
44// ---------------------------------------------------------------------------
45// ReasoningEffort
46// ---------------------------------------------------------------------------
47
48/// 推理强度级别。
49///
50/// 对应 OpenAI `reasoning_effort` 和 Anthropic `thinking` 级别的统一抽象。
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
52#[serde(rename_all = "snake_case")]
53pub enum ReasoningEffort {
54    None,
55    Low,
56    Medium,
57    High,
58    XHigh,
59    Max
60}
61
62// ---------------------------------------------------------------------------
63// ThinkingMode
64// ---------------------------------------------------------------------------
65
66/// 思考/推理的控制模式。
67///
68/// 不同 provider 使用不同机制控制推理行为:
69/// - Anthropic: adaptive 或 budget(指定 token 预算)
70/// - OpenAI: effort 级别(low/medium/high)
71/// - 其他 provider: 可能不支持
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74pub enum ThinkingMode {
75    /// 自适应思考(Anthropic adaptive thinking)
76    Adaptive,
77    /// Budget 模式(指定 token 预算上限)
78    Budget,
79    /// Effort 级别模式(OpenAI reasoning_effort)
80    Effort,
81}
82
83// ---------------------------------------------------------------------------
84// ThinkingConfig
85// ---------------------------------------------------------------------------
86
87/// 模型的思考/推理能力配置。
88///
89/// 描述模型如何支持推理,以及推理的控制参数。
90/// 仅在 `ModelCapabilities::thinking` 为 `Some` 时有意义。
91///
92/// # Examples
93/// ```
94/// use katu_llm::model::{ThinkingConfig, ThinkingMode, ReasoningEffort};
95///
96/// // Anthropic adaptive
97/// let config = ThinkingConfig {
98///     mode: ThinkingMode::Adaptive,
99///     default_budget: None,
100///     min_effort: None,
101///     max_effort: None,
102/// };
103///
104/// // OpenAI effort-based
105/// let config = ThinkingConfig {
106///     mode: ThinkingMode::Effort,
107///     default_budget: None,
108///     min_effort: Some(ReasoningEffort::Low),
109///     max_effort: Some(ReasoningEffort::High),
110/// };
111/// ```
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
113pub struct ThinkingConfig {
114    /// 思考控制模式
115    pub mode: ThinkingMode,
116    /// 默认思考 token 预算(仅 Budget 模式有意义)
117    pub default_budget: Option<u32>,
118    /// 支持的最低 effort 级别
119    pub min_effort: Option<ReasoningEffort>,
120    /// 支持的最高 effort 级别
121    pub max_effort: Option<ReasoningEffort>,
122}
123
124// ---------------------------------------------------------------------------
125// ModelCapabilities
126// ---------------------------------------------------------------------------
127
128/// 模型功能标志。
129///
130/// 描述模型支持哪些特性,供 Agent loop 和 Provider 适配层参考。
131///
132/// # Examples
133/// ```
134/// use katu_llm::model::{ModelCapabilities, InputModality};
135///
136/// let caps = ModelCapabilities {
137///     input_modalities: vec![InputModality::Text, InputModality::Image],
138///     tool_calls: true,
139///     streaming_tool_input: true,
140///     structured_output: true,
141///     prompt_caching: true,
142///     thinking: None,
143/// };
144/// assert!(caps.supports_modality(InputModality::Image));
145/// ```
146#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
147pub struct ModelCapabilities {
148    /// 支持的输入模态列表
149    pub input_modalities: Vec<InputModality>,
150    /// 是否支持工具调用
151    pub tool_calls: bool,
152    /// 是否支持流式工具参数输入
153    pub streaming_tool_input: bool,
154    /// 是否支持结构化输出(JSON mode / response_format)
155    pub structured_output: bool,
156    /// 是否支持 prompt caching
157    pub prompt_caching: bool,
158    /// 思考/推理能力配置,`None` 表示不支持
159    pub thinking: Option<ThinkingConfig>,
160}
161
162impl ModelCapabilities {
163    /// 检查是否支持指定输入模态。
164    pub fn supports_modality(&self, modality: InputModality) -> bool {
165        self.input_modalities.contains(&modality)
166    }
167
168    /// 检查是否支持推理/思考。
169    pub fn supports_thinking(&self) -> bool {
170        self.thinking.is_some()
171    }
172}
173
174impl Default for ModelCapabilities {
175    fn default() -> Self {
176        Self {
177            input_modalities: vec![InputModality::Text],
178            tool_calls: true,
179            streaming_tool_input: false,
180            structured_output: false,
181            prompt_caching: false,
182            thinking: None,
183        }
184    }
185}
186
187// ---------------------------------------------------------------------------
188// ModelLimits
189// ---------------------------------------------------------------------------
190
191/// 模型 token 上限。
192///
193/// # Examples
194/// ```
195/// use katu_llm::ModelLimits;
196///
197/// let limits = ModelLimits {
198///     context_window: 200_000,
199///     max_output_tokens: 8192,
200/// };
201/// ```
202#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
203pub struct ModelLimits {
204    /// 上下文窗口大小(input + output 总量上限)
205    pub context_window: u32,
206    /// 最大输出 token 数
207    pub max_output_tokens: u32,
208}
209
210// ---------------------------------------------------------------------------
211// ModelPricing
212// ---------------------------------------------------------------------------
213
214/// 模型费率定义(单位:美元 / 百万 token)。
215///
216/// 用于从 `Usage` 计算 `Cost`。
217///
218/// # Examples
219/// ```
220/// use katu_llm::ModelPricing;
221///
222/// let pricing = ModelPricing {
223///     input: 3.0,        // $3 / M input tokens
224///     output: 15.0,      // $15 / M output tokens
225///     cache_read: 0.30,  // $0.30 / M cache read tokens
226///     cache_write: 3.75, // $3.75 / M cache write tokens
227/// };
228/// ```
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
230pub struct ModelPricing {
231    /// 输入 token 费率($/M tokens)
232    pub input: f64,
233    /// 输出 token 费率($/M tokens)
234    pub output: f64,
235    /// 缓存读取费率($/M tokens)
236    pub cache_read: f64,
237    /// 缓存写入费率($/M tokens)
238    pub cache_write: f64,
239}
240
241// ---------------------------------------------------------------------------
242// ModelRef
243// ---------------------------------------------------------------------------
244
245/// 可执行的模型引用。
246///
247/// `ModelRef` 是从"选择模型"到"发出请求"所需全部信息的聚合体:
248/// - **身份**:provider/model/route 三元组 + 可选显示名
249/// - **连接**:base_url, api_key, 额外 headers/query_params
250/// - **能力**:token 限制、功能标志、输入模态、思考配置
251/// - **默认参数**:生成选项、缓存策略(被 Request 级覆盖)
252/// - **定价**:用于 Usage → Cost 计算
253/// - **Provider 私有**:非标选项(如 Bedrock region, Vertex project_id)
254///
255/// # 参数合并链
256/// ```text
257/// LlmRequest.generation > Agent 配置 > ModelRef.generation > Route defaults
258/// ```
259///
260/// # Examples
261/// ```
262/// use katu_core::{ModelId, ProviderId, RouteId};
263/// use katu_llm::model::*;
264/// use katu_llm::GenerationOptions;
265///
266/// let model = ModelRef::new(
267///     ModelId::new("claude-sonnet-4-20250514"),
268///     ProviderId::new("anthropic"),
269///     RouteId::new("anthropic-messages"),
270///     "https://api.anthropic.com/v1",
271///     ModelLimits {
272///         context_window: 200_000,
273///         max_output_tokens: 8192,
274///     },
275/// )
276/// .with_display_name("Claude Sonnet 4")
277/// .with_api_key("sk-ant-xxx")
278/// .with_capabilities(ModelCapabilities {
279///     input_modalities: vec![InputModality::Text, InputModality::Image],
280///     tool_calls: true,
281///     streaming_tool_input: true,
282///     structured_output: false,
283///     prompt_caching: true,
284///     thinking: Some(ThinkingConfig {
285///         mode: ThinkingMode::Adaptive,
286///         default_budget: None,
287///         min_effort: None,
288///         max_effort: None,
289///     }),
290/// })
291/// .with_pricing(ModelPricing {
292///     input: 3.0,
293///     output: 15.0,
294///     cache_read: 0.30,
295///     cache_write: 3.75,
296/// })
297/// .with_generation(GenerationOptions::new().with_max_tokens(4096));
298///
299/// assert_eq!(model.id.as_str(), "claude-sonnet-4-20250514");
300/// assert!(model.capabilities.supports_thinking());
301/// ```
302#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
303pub struct ModelRef {
304    // ─── 身份标识 ───
305
306    /// 模型 ID(发送给 API 的 wire 值)
307    pub id: ModelId,
308    /// Provider 标识
309    pub provider: ProviderId,
310    /// 路由标识(决定使用哪个 Protocol 转换器)
311    pub route: RouteId,
312    /// 人类可读名称
313    #[serde(skip_serializing_if = "Option::is_none")]
314    pub display_name: Option<String>,
315
316    // ─── 连接信息 ───
317
318    /// API Base URL
319    pub base_url: String,
320    /// API Key
321    #[serde(skip_serializing_if = "Option::is_none")]
322    pub api_key: Option<String>,
323    /// 额外的固定请求头
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub headers: Option<HashMap<String, String>>,
326    /// URL 查询参数(如 Azure api-version)
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub query_params: Option<HashMap<String, String>>,
329
330    // ─── 能力与限制 ───
331
332    /// Token 上限
333    pub limits: ModelLimits,
334    /// 功能标志
335    pub capabilities: ModelCapabilities,
336
337    // ─── 默认参数 ───
338
339    /// 模型级默认生成参数(被 Request 级覆盖)
340    #[serde(skip_serializing_if = "Option::is_none")]
341    pub generation: Option<GenerationOptions>,
342    /// 默认缓存策略
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub cache_policy: Option<CachePolicy>,
345
346    // ─── 定价 ───
347
348    /// 费率(用于 Usage → Cost 计算)
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub pricing: Option<ModelPricing>,
351
352    // ─── Provider 私有 ───
353
354    /// Provider 特有的非标选项
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub provider_options: Option<serde_json::Value>,
357    /// HTTP 传输层覆写
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub http: Option<HttpOptions>,
360}
361
362impl ModelRef {
363    /// 创建一个 ModelRef,仅包含必需字段。
364    pub fn new(
365        id: ModelId,
366        provider: ProviderId,
367        route: RouteId,
368        base_url: impl Into<String>,
369        limits: ModelLimits,
370    ) -> Self {
371        Self {
372            id,
373            provider,
374            route,
375            display_name: None,
376            base_url: base_url.into(),
377            api_key: None,
378            headers: None,
379            query_params: None,
380            limits,
381            capabilities: ModelCapabilities::default(),
382            generation: None,
383            cache_policy: None,
384            pricing: None,
385            provider_options: None,
386            http: None,
387        }
388    }
389
390    /// 设置显示名称。
391    pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
392        self.display_name = Some(name.into());
393        self
394    }
395
396    /// 设置 API key。
397    pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
398        self.api_key = Some(key.into());
399        self
400    }
401
402    /// 添加一个额外请求头。
403    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
404        self.headers
405            .get_or_insert_with(HashMap::new)
406            .insert(key.into(), value.into());
407        self
408    }
409
410    /// 添加一个查询参数。
411    pub fn with_query_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
412        self.query_params
413            .get_or_insert_with(HashMap::new)
414            .insert(key.into(), value.into());
415        self
416    }
417
418    /// 设置模型能力。
419    pub fn with_capabilities(mut self, capabilities: ModelCapabilities) -> Self {
420        self.capabilities = capabilities;
421        self
422    }
423
424    /// 设置默认生成参数。
425    pub fn with_generation(mut self, generation: GenerationOptions) -> Self {
426        self.generation = Some(generation);
427        self
428    }
429
430    /// 设置缓存策略。
431    pub fn with_cache_policy(mut self, policy: CachePolicy) -> Self {
432        self.cache_policy = Some(policy);
433        self
434    }
435
436    /// 设置定价。
437    pub fn with_pricing(mut self, pricing: ModelPricing) -> Self {
438        self.pricing = Some(pricing);
439        self
440    }
441
442    /// 设置 provider 私有选项。
443    pub fn with_provider_options(mut self, options: serde_json::Value) -> Self {
444        self.provider_options = Some(options);
445        self
446    }
447
448    /// 设置 HTTP 覆写选项。
449    pub fn with_http(mut self, http: HttpOptions) -> Self {
450        self.http = Some(http);
451        self
452    }
453}
454
455// ---------------------------------------------------------------------------
456// Tests
457// ---------------------------------------------------------------------------
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    fn sample_model() -> ModelRef {
464        ModelRef::new(
465            ModelId::new("claude-sonnet-4-20250514"),
466            ProviderId::new("anthropic"),
467            RouteId::new("anthropic-messages"),
468            "https://api.anthropic.com/v1",
469            ModelLimits {
470                context_window: 200_000,
471                max_output_tokens: 8192,
472            },
473        )
474    }
475
476    #[test]
477    fn test_new_has_required_fields() {
478        let m = sample_model();
479        assert_eq!(m.id.as_str(), "claude-sonnet-4-20250514");
480        assert_eq!(m.provider.as_str(), "anthropic");
481        assert_eq!(m.route.as_str(), "anthropic-messages");
482        assert_eq!(m.base_url, "https://api.anthropic.com/v1");
483        assert_eq!(m.limits.context_window, 200_000);
484        assert_eq!(m.limits.max_output_tokens, 8192);
485    }
486
487    #[test]
488    fn test_new_optional_fields_are_none() {
489        let m = sample_model();
490        assert_eq!(m.display_name, None);
491        assert_eq!(m.api_key, None);
492        assert_eq!(m.headers, None);
493        assert_eq!(m.generation, None);
494        assert_eq!(m.pricing, None);
495        assert_eq!(m.provider_options, None);
496        assert_eq!(m.http, None);
497    }
498
499    #[test]
500    fn test_builder_chain() {
501        let m = sample_model()
502            .with_display_name("Claude Sonnet 4")
503            .with_api_key("sk-ant-xxx")
504            .with_header("x-custom", "value")
505            .with_query_param("version", "1")
506            .with_generation(GenerationOptions::new().with_max_tokens(4096))
507            .with_cache_policy(CachePolicy::Auto)
508            .with_pricing(ModelPricing {
509                input: 3.0,
510                output: 15.0,
511                cache_read: 0.30,
512                cache_write: 3.75,
513            });
514
515        assert_eq!(m.display_name.as_deref(), Some("Claude Sonnet 4"));
516        assert_eq!(m.api_key.as_deref(), Some("sk-ant-xxx"));
517        assert_eq!(
518            m.headers.as_ref().unwrap().get("x-custom").unwrap(),
519            "value"
520        );
521        assert_eq!(
522            m.generation.as_ref().unwrap().max_tokens,
523            Some(4096)
524        );
525        assert_eq!(m.pricing.as_ref().unwrap().input, 3.0);
526    }
527
528    #[test]
529    fn test_capabilities_default() {
530        let m = sample_model();
531        assert!(m.capabilities.supports_modality(InputModality::Text));
532        assert!(!m.capabilities.supports_modality(InputModality::Image));
533        assert!(m.capabilities.tool_calls);
534        assert!(!m.capabilities.supports_thinking());
535    }
536
537    #[test]
538    fn test_capabilities_with_thinking() {
539        let m = sample_model().with_capabilities(ModelCapabilities {
540            input_modalities: vec![InputModality::Text, InputModality::Image],
541            tool_calls: true,
542            streaming_tool_input: true,
543            structured_output: false,
544            prompt_caching: true,
545            thinking: Some(ThinkingConfig {
546                mode: ThinkingMode::Adaptive,
547                default_budget: None,
548                min_effort: None,
549                max_effort: None,
550            }),
551        });
552
553        assert!(m.capabilities.supports_thinking());
554        assert!(m.capabilities.supports_modality(InputModality::Image));
555        assert!(m.capabilities.streaming_tool_input);
556    }
557
558    #[test]
559    fn test_serde_roundtrip_minimal() {
560        let m = sample_model();
561        let json = serde_json::to_string(&m).unwrap();
562        let restored: ModelRef = serde_json::from_str(&json).unwrap();
563        assert_eq!(m.id, restored.id);
564        assert_eq!(m.provider, restored.provider);
565        assert_eq!(m.limits, restored.limits);
566    }
567
568    #[test]
569    fn test_serde_roundtrip_full() {
570        let m = sample_model()
571            .with_display_name("Claude Sonnet 4")
572            .with_api_key("sk-test")
573            .with_capabilities(ModelCapabilities {
574                input_modalities: vec![InputModality::Text, InputModality::Image],
575                tool_calls: true,
576                streaming_tool_input: true,
577                structured_output: true,
578                prompt_caching: true,
579                thinking: Some(ThinkingConfig {
580                    mode: ThinkingMode::Budget,
581                    default_budget: Some(10000),
582                    min_effort: Some(ReasoningEffort::Low),
583                    max_effort: Some(ReasoningEffort::High),
584                }),
585            })
586            .with_generation(GenerationOptions::new().with_max_tokens(4096).with_temperature(0.7))
587            .with_cache_policy(CachePolicy::Auto)
588            .with_pricing(ModelPricing {
589                input: 3.0,
590                output: 15.0,
591                cache_read: 0.30,
592                cache_write: 3.75,
593            })
594            .with_provider_options(serde_json::json!({"region": "us-east-1"}))
595            .with_http(HttpOptions::new().with_header("x-extra", "val"));
596
597        let json = serde_json::to_string_pretty(&m).unwrap();
598        let restored: ModelRef = serde_json::from_str(&json).unwrap();
599        assert_eq!(m, restored);
600    }
601
602    #[test]
603    fn test_serde_skips_none_fields() {
604        let m = sample_model();
605        let json = serde_json::to_string(&m).unwrap();
606        // None fields with skip_serializing_if should not appear
607        assert!(!json.contains("display_name"));
608        assert!(!json.contains("api_key"));
609        assert!(!json.contains("pricing"));
610        assert!(!json.contains("provider_options"));
611    }
612}