Skip to main content

katu_core/
generation.rs

1//! # katu_core::generation
2//!
3//! ## 职责
4//! 定义 provider 无关的 LLM 生成参数。
5//!
6//! ## 对外接口
7//! - `GenerationOptions` — 生成控制参数
8
9use serde::{Deserialize, Serialize};
10
11// ---------------------------------------------------------------------------
12// GenerationOptions
13// ---------------------------------------------------------------------------
14
15/// Provider 无关的 LLM 生成参数。
16///
17/// 所有字段均为 `Option`,`None` 表示使用 provider/model 默认值。
18/// 支持多级合并:Request 级 > Agent 级 > Model 级 > Provider 默认。
19///
20/// # Examples
21/// ```
22/// use katu_core::GenerationOptions;
23///
24/// let opts = GenerationOptions::new()
25///     .with_max_tokens(4096)
26///     .with_temperature(0.7);
27/// assert_eq!(opts.max_tokens, Some(4096));
28/// assert_eq!(opts.temperature, Some(0.7));
29/// ```
30#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
31pub struct GenerationOptions {
32    /// 最大输出 token 数
33    pub max_tokens: Option<u32>,
34    /// 采样温度 (0.0 = 确定性, 2.0 = 最大随机)
35    pub temperature: Option<f32>,
36    /// Top-p (nucleus) 采样阈值
37    pub top_p: Option<f32>,
38    /// Top-k 采样(部分 provider 支持)
39    pub top_k: Option<u32>,
40    /// 频率惩罚 (-2.0 ~ 2.0)
41    pub frequency_penalty: Option<f32>,
42    /// 存在惩罚 (-2.0 ~ 2.0)
43    pub presence_penalty: Option<f32>,
44    /// 停止序列
45    pub stop: Option<Vec<String>>,
46    /// 随机种子(可复现生成)
47    pub seed: Option<u64>,
48}
49
50impl GenerationOptions {
51    /// 创建空的 GenerationOptions(所有字段为 None)。
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    /// 设置 max_tokens。
57    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
58        self.max_tokens = Some(max_tokens);
59        self
60    }
61
62    /// 设置 temperature。
63    pub fn with_temperature(mut self, temperature: f32) -> Self {
64        self.temperature = Some(temperature);
65        self
66    }
67
68    /// 设置 top_p。
69    pub fn with_top_p(mut self, top_p: f32) -> Self {
70        self.top_p = Some(top_p);
71        self
72    }
73
74    /// 设置 top_k。
75    pub fn with_top_k(mut self, top_k: u32) -> Self {
76        self.top_k = Some(top_k);
77        self
78    }
79
80    /// 设置 frequency_penalty。
81    pub fn with_frequency_penalty(mut self, penalty: f32) -> Self {
82        self.frequency_penalty = Some(penalty);
83        self
84    }
85
86    /// 设置 presence_penalty。
87    pub fn with_presence_penalty(mut self, penalty: f32) -> Self {
88        self.presence_penalty = Some(penalty);
89        self
90    }
91
92    /// 设置停止序列。
93    pub fn with_stop(mut self, stop: Vec<String>) -> Self {
94        self.stop = Some(stop);
95        self
96    }
97
98    /// 设置随机种子。
99    pub fn with_seed(mut self, seed: u64) -> Self {
100        self.seed = Some(seed);
101        self
102    }
103
104    /// 合并两个 GenerationOptions,`other` 的非 None 值覆盖 `self`。
105    ///
106    /// 用于实现参数合并链:Request > Agent > Model > Provider。
107    pub fn merge(&self, other: &GenerationOptions) -> GenerationOptions {
108        GenerationOptions {
109            max_tokens: other.max_tokens.or(self.max_tokens),
110            temperature: other.temperature.or(self.temperature),
111            top_p: other.top_p.or(self.top_p),
112            top_k: other.top_k.or(self.top_k),
113            frequency_penalty: other.frequency_penalty.or(self.frequency_penalty),
114            presence_penalty: other.presence_penalty.or(self.presence_penalty),
115            stop: other.stop.clone().or_else(|| self.stop.clone()),
116            seed: other.seed.or(self.seed),
117        }
118    }
119}
120
121// ---------------------------------------------------------------------------
122// Tests
123// ---------------------------------------------------------------------------
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_default_is_all_none() {
131        let opts = GenerationOptions::new();
132        assert_eq!(opts.max_tokens, None);
133        assert_eq!(opts.temperature, None);
134        assert_eq!(opts.top_p, None);
135        assert_eq!(opts.top_k, None);
136        assert_eq!(opts.frequency_penalty, None);
137        assert_eq!(opts.presence_penalty, None);
138        assert_eq!(opts.stop, None);
139        assert_eq!(opts.seed, None);
140    }
141
142    #[test]
143    fn test_builder_methods() {
144        let opts = GenerationOptions::new()
145            .with_max_tokens(4096)
146            .with_temperature(0.7)
147            .with_top_p(0.9)
148            .with_top_k(40)
149            .with_frequency_penalty(0.5)
150            .with_presence_penalty(-0.5)
151            .with_stop(vec!["END".into()])
152            .with_seed(42);
153
154        assert_eq!(opts.max_tokens, Some(4096));
155        assert_eq!(opts.temperature, Some(0.7));
156        assert_eq!(opts.top_p, Some(0.9));
157        assert_eq!(opts.top_k, Some(40));
158        assert_eq!(opts.frequency_penalty, Some(0.5));
159        assert_eq!(opts.presence_penalty, Some(-0.5));
160        assert_eq!(opts.stop, Some(vec!["END".to_string()]));
161        assert_eq!(opts.seed, Some(42));
162    }
163
164    #[test]
165    fn test_merge_other_overrides_self() {
166        let base = GenerationOptions::new()
167            .with_max_tokens(1024)
168            .with_temperature(0.5);
169
170        let override_opts = GenerationOptions::new()
171            .with_max_tokens(4096)
172            .with_top_p(0.9);
173
174        let merged = base.merge(&override_opts);
175        assert_eq!(merged.max_tokens, Some(4096)); // overridden
176        assert_eq!(merged.temperature, Some(0.5)); // kept from base
177        assert_eq!(merged.top_p, Some(0.9)); // new from override
178    }
179
180    #[test]
181    fn test_merge_none_does_not_override() {
182        let base = GenerationOptions::new()
183            .with_max_tokens(2048)
184            .with_seed(123);
185
186        let empty = GenerationOptions::new();
187        let merged = base.merge(&empty);
188        assert_eq!(merged.max_tokens, Some(2048));
189        assert_eq!(merged.seed, Some(123));
190    }
191
192    #[test]
193    fn test_serde_roundtrip() {
194        let opts = GenerationOptions::new()
195            .with_max_tokens(4096)
196            .with_temperature(0.8)
197            .with_stop(vec!["<|end|>".into(), "STOP".into()]);
198
199        let json = serde_json::to_string(&opts).unwrap();
200        let restored: GenerationOptions = serde_json::from_str(&json).unwrap();
201        assert_eq!(opts, restored);
202    }
203
204    #[test]
205    fn test_serde_skips_none_fields() {
206        let opts = GenerationOptions::new().with_max_tokens(100);
207        let json = serde_json::to_string(&opts).unwrap();
208        let restored: GenerationOptions = serde_json::from_str(&json).unwrap();
209        assert_eq!(restored.max_tokens, Some(100));
210        assert_eq!(restored.temperature, None);
211    }
212}