Skip to main content

key_token/
request.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::sync::Arc;
5use uuid::Uuid;
6
7use crate::{PricingSnapshot, UsageAccumulator};
8
9/// 请求上下文:贯穿全链路的唯一状态载体
10///
11/// # 设计说明
12/// - `usage` 字段使用 `Arc<UsageAccumulator>` 实现共享状态,Clone 时会共享同一个用量累积器
13/// - 通过 `add_output_tokens()` 和 `set_input_tokens()` 方法安全地更新用量
14/// - 使用 `usage_snapshot()` 获取当前用量快照
15/// - `provider` 字段在路由确定后被设置,用于精确的定价查询
16#[derive(Debug, Clone)]
17pub struct RequestContext {
18    pub request_id: Uuid,
19    pub user_id: Uuid,
20    pub tenant_id: Uuid,
21    pub produce_ai_key_id: Uuid,
22    pub model: String,
23    /// Provider 名称(路由确定后设置)
24    pub provider: Option<String>,
25    pub messages: Vec<Message>,
26    pub stream: bool,
27    pub pricing_snapshot: PricingSnapshot, // 请求开始时固化
28    usage: Arc<UsageAccumulator>,          // streaming 中累积(共享状态)
29    pub started_at: DateTime<Utc>,
30}
31
32impl RequestContext {
33    pub fn new(
34        user_id: Uuid,
35        tenant_id: Uuid,
36        produce_ai_key_id: Uuid,
37        model: impl Into<String>,
38        messages: Vec<Message>,
39        stream: bool,
40        pricing_snapshot: PricingSnapshot,
41    ) -> Self {
42        Self {
43            request_id: Uuid::new_v4(),
44            user_id,
45            tenant_id,
46            produce_ai_key_id,
47            model: model.into(),
48            provider: None,
49            messages,
50            stream,
51            pricing_snapshot,
52            usage: Arc::new(UsageAccumulator::new()),
53            started_at: Utc::now(),
54        }
55    }
56
57    /// 设置 Provider(路由确定后调用)
58    pub fn set_provider(&mut self, provider: impl Into<String>) {
59        self.provider = Some(provider.into());
60    }
61
62    /// 更新定价快照(路由后根据实际 provider 更新)
63    pub fn update_pricing(&mut self, pricing: PricingSnapshot) {
64        self.pricing_snapshot = pricing;
65    }
66
67    /// 获取请求持续时间
68    pub fn duration(&self) -> chrono::Duration {
69        Utc::now() - self.started_at
70    }
71
72    /// 获取当前用量快照
73    pub fn usage_snapshot(&self) -> (u32, u32) {
74        self.usage.snapshot()
75    }
76
77    /// 添加输出 token(原子更新)
78    pub fn add_output_tokens(&self, tokens: u32) {
79        self.usage.add_output(tokens);
80    }
81
82    /// 设置输入 token(原子更新)
83    pub fn set_input_tokens(&self, tokens: u32) {
84        self.usage.set_input(tokens);
85    }
86}
87
88/// 消息角色枚举
89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
90#[serde(rename_all = "lowercase")]
91pub enum MessageRole {
92    System,
93    #[default]
94    User,
95    Assistant,
96    Tool,
97}
98
99impl MessageRole {
100    /// 获取角色字符串表示
101    pub fn as_str(&self) -> &'static str {
102        match self {
103            MessageRole::System => "system",
104            MessageRole::User => "user",
105            MessageRole::Assistant => "assistant",
106            MessageRole::Tool => "tool",
107        }
108    }
109}
110
111impl fmt::Display for MessageRole {
112    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113        write!(f, "{}", self.as_str())
114    }
115}
116
117/// 消息结构
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Message {
120    pub role: MessageRole,
121    pub content: String,
122}
123
124impl Message {
125    pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
126        Self {
127            role,
128            content: content.into(),
129        }
130    }
131
132    pub fn system(content: impl Into<String>) -> Self {
133        Self::new(MessageRole::System, content)
134    }
135
136    pub fn user(content: impl Into<String>) -> Self {
137        Self::new(MessageRole::User, content)
138    }
139
140    pub fn assistant(content: impl Into<String>) -> Self {
141        Self::new(MessageRole::Assistant, content)
142    }
143
144    pub fn tool(content: impl Into<String>) -> Self {
145        Self::new(MessageRole::Tool, content)
146    }
147}
148
149/// OpenAI 兼容的请求体
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ChatCompletionRequest {
152    pub model: String,
153    pub messages: Vec<Message>,
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub stream: Option<bool>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub max_tokens: Option<u32>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub temperature: Option<f32>,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub top_p: Option<f32>,
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub n: Option<u32>,
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub stop: Option<Vec<String>>,
166}
167
168impl ChatCompletionRequest {
169    pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
170        Self {
171            model: model.into(),
172            messages,
173            stream: None,
174            max_tokens: None,
175            temperature: None,
176            top_p: None,
177            n: None,
178            stop: None,
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_message_role_as_str() {
189        assert_eq!(MessageRole::System.as_str(), "system");
190        assert_eq!(MessageRole::User.as_str(), "user");
191        assert_eq!(MessageRole::Assistant.as_str(), "assistant");
192        assert_eq!(MessageRole::Tool.as_str(), "tool");
193    }
194
195    #[test]
196    fn test_message_role_all_variants() {
197        // 测试所有变体的字符串表示
198        let roles = vec![
199            (MessageRole::System, "system"),
200            (MessageRole::User, "user"),
201            (MessageRole::Assistant, "assistant"),
202            (MessageRole::Tool, "tool"),
203        ];
204        for (role, expected) in roles {
205            assert_eq!(role.as_str(), expected);
206            assert_eq!(format!("{}", role), expected);
207        }
208    }
209
210    #[test]
211    fn test_message_role_display() {
212        assert_eq!(format!("{}", MessageRole::System), "system");
213        assert_eq!(format!("{}", MessageRole::User), "user");
214    }
215
216    #[test]
217    fn test_message_role_default() {
218        assert_eq!(MessageRole::default(), MessageRole::User);
219    }
220
221    #[test]
222    fn test_message_role_serialize() {
223        let role = MessageRole::Assistant;
224        let json = serde_json::to_string(&role).unwrap();
225        assert_eq!(json, "\"assistant\"");
226    }
227
228    #[test]
229    fn test_message_role_deserialize() {
230        let json = "\"system\"";
231        let role: MessageRole = serde_json::from_str(json).unwrap();
232        assert_eq!(role, MessageRole::System);
233    }
234
235    #[test]
236    fn test_message_role_deserialize_invalid() {
237        let json = "\"invalid_role\"";
238        let result: Result<MessageRole, _> = serde_json::from_str(json);
239        assert!(result.is_err());
240    }
241
242    #[test]
243    fn test_message_creation() {
244        let msg = Message::new(MessageRole::User, "Hello");
245        assert_eq!(msg.role, MessageRole::User);
246        assert_eq!(msg.content, "Hello");
247    }
248
249    #[test]
250    fn test_message_convenience_constructors() {
251        let system_msg = Message::system("You are a helpful assistant");
252        assert_eq!(system_msg.role, MessageRole::System);
253
254        let user_msg = Message::user("Hello");
255        assert_eq!(user_msg.role, MessageRole::User);
256
257        let assistant_msg = Message::assistant("Hi there!");
258        assert_eq!(assistant_msg.role, MessageRole::Assistant);
259
260        let tool_msg = Message::tool("Tool result");
261        assert_eq!(tool_msg.role, MessageRole::Tool);
262    }
263
264    #[test]
265    fn test_message_serialize() {
266        let msg = Message::user("Hello");
267        let json = serde_json::to_string(&msg).unwrap();
268        assert!(json.contains("\"role\":\"user\""));
269        assert!(json.contains("\"content\":\"Hello\""));
270    }
271
272    #[test]
273    fn test_message_deserialize() {
274        let json = r#"{"role":"assistant","content":"Hello!"}"#;
275        let msg: Message = serde_json::from_str(json).unwrap();
276        assert_eq!(msg.role, MessageRole::Assistant);
277        assert_eq!(msg.content, "Hello!");
278    }
279
280    #[test]
281    fn test_request_context_new() {
282        let ctx = RequestContext::new(
283            Uuid::new_v4(),
284            Uuid::new_v4(),
285            Uuid::new_v4(),
286            "gpt-4",
287            vec![Message::user("Hello")],
288            false,
289            PricingSnapshot::default(),
290        );
291        assert_eq!(ctx.model, "gpt-4");
292        assert!(!ctx.stream);
293    }
294
295    #[test]
296    fn test_request_context_usage_shared() {
297        let ctx = RequestContext::new(
298            Uuid::new_v4(),
299            Uuid::new_v4(),
300            Uuid::new_v4(),
301            "gpt-4",
302            vec![Message::user("Hello")],
303            false,
304            PricingSnapshot::default(),
305        );
306
307        // 添加 token
308        ctx.add_output_tokens(100);
309        ctx.set_input_tokens(50);
310
311        // 验证用量
312        let (input, output) = ctx.usage_snapshot();
313        assert_eq!(input, 50);
314        assert_eq!(output, 100);
315
316        // Clone 后共享同一个 usage
317        let ctx2 = ctx.clone();
318        ctx2.add_output_tokens(50);
319
320        // ctx 也能看到更新
321        let (_, output2) = ctx.usage_snapshot();
322        assert_eq!(output2, 150);
323    }
324
325    #[test]
326    fn test_chat_completion_request_new() {
327        let req = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")]);
328        assert_eq!(req.model, "gpt-4");
329        assert_eq!(req.messages.len(), 1);
330        assert!(req.stream.is_none());
331    }
332
333    #[test]
334    fn test_chat_completion_request_serialize() {
335        let req = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")]);
336        let json = serde_json::to_string(&req).unwrap();
337        assert!(json.contains("\"model\":\"gpt-4\""));
338        assert!(json.contains("\"role\":\"user\""));
339    }
340}