1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use std::sync::Arc;
5use uuid::Uuid;
6
7use crate::{PricingSnapshot, UsageAccumulator};
8
9#[derive(Debug, Clone)]
17pub struct RequestContext {
18 pub request_id: Uuid,
19 pub user_id: Uuid,
20 pub tenant_id: Uuid,
21 pub produce_ai_key_id: Uuid,
22 pub model: String,
23 pub provider: Option<String>,
25 pub messages: Vec<Message>,
26 pub stream: bool,
27 pub pricing_snapshot: PricingSnapshot, usage: Arc<UsageAccumulator>, pub started_at: DateTime<Utc>,
30}
31
32impl RequestContext {
33 pub fn new(
34 user_id: Uuid,
35 tenant_id: Uuid,
36 produce_ai_key_id: Uuid,
37 model: impl Into<String>,
38 messages: Vec<Message>,
39 stream: bool,
40 pricing_snapshot: PricingSnapshot,
41 ) -> Self {
42 Self {
43 request_id: Uuid::new_v4(),
44 user_id,
45 tenant_id,
46 produce_ai_key_id,
47 model: model.into(),
48 provider: None,
49 messages,
50 stream,
51 pricing_snapshot,
52 usage: Arc::new(UsageAccumulator::new()),
53 started_at: Utc::now(),
54 }
55 }
56
57 pub fn set_provider(&mut self, provider: impl Into<String>) {
59 self.provider = Some(provider.into());
60 }
61
62 pub fn update_pricing(&mut self, pricing: PricingSnapshot) {
64 self.pricing_snapshot = pricing;
65 }
66
67 pub fn duration(&self) -> chrono::Duration {
69 Utc::now() - self.started_at
70 }
71
72 pub fn usage_snapshot(&self) -> (u32, u32) {
74 self.usage.snapshot()
75 }
76
77 pub fn add_output_tokens(&self, tokens: u32) {
79 self.usage.add_output(tokens);
80 }
81
82 pub fn set_input_tokens(&self, tokens: u32) {
84 self.usage.set_input(tokens);
85 }
86}
87
88#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
90#[serde(rename_all = "lowercase")]
91pub enum MessageRole {
92 System,
93 #[default]
94 User,
95 Assistant,
96 Tool,
97}
98
99impl MessageRole {
100 pub fn as_str(&self) -> &'static str {
102 match self {
103 MessageRole::System => "system",
104 MessageRole::User => "user",
105 MessageRole::Assistant => "assistant",
106 MessageRole::Tool => "tool",
107 }
108 }
109}
110
111impl fmt::Display for MessageRole {
112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
113 write!(f, "{}", self.as_str())
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Message {
120 pub role: MessageRole,
121 pub content: String,
122}
123
124impl Message {
125 pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
126 Self {
127 role,
128 content: content.into(),
129 }
130 }
131
132 pub fn system(content: impl Into<String>) -> Self {
133 Self::new(MessageRole::System, content)
134 }
135
136 pub fn user(content: impl Into<String>) -> Self {
137 Self::new(MessageRole::User, content)
138 }
139
140 pub fn assistant(content: impl Into<String>) -> Self {
141 Self::new(MessageRole::Assistant, content)
142 }
143
144 pub fn tool(content: impl Into<String>) -> Self {
145 Self::new(MessageRole::Tool, content)
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ChatCompletionRequest {
152 pub model: String,
153 pub messages: Vec<Message>,
154 #[serde(skip_serializing_if = "Option::is_none")]
155 pub stream: Option<bool>,
156 #[serde(skip_serializing_if = "Option::is_none")]
157 pub max_tokens: Option<u32>,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub temperature: Option<f32>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 pub top_p: Option<f32>,
162 #[serde(skip_serializing_if = "Option::is_none")]
163 pub n: Option<u32>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 pub stop: Option<Vec<String>>,
166}
167
168impl ChatCompletionRequest {
169 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
170 Self {
171 model: model.into(),
172 messages,
173 stream: None,
174 max_tokens: None,
175 temperature: None,
176 top_p: None,
177 n: None,
178 stop: None,
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_message_role_as_str() {
189 assert_eq!(MessageRole::System.as_str(), "system");
190 assert_eq!(MessageRole::User.as_str(), "user");
191 assert_eq!(MessageRole::Assistant.as_str(), "assistant");
192 assert_eq!(MessageRole::Tool.as_str(), "tool");
193 }
194
195 #[test]
196 fn test_message_role_all_variants() {
197 let roles = vec![
199 (MessageRole::System, "system"),
200 (MessageRole::User, "user"),
201 (MessageRole::Assistant, "assistant"),
202 (MessageRole::Tool, "tool"),
203 ];
204 for (role, expected) in roles {
205 assert_eq!(role.as_str(), expected);
206 assert_eq!(format!("{}", role), expected);
207 }
208 }
209
210 #[test]
211 fn test_message_role_display() {
212 assert_eq!(format!("{}", MessageRole::System), "system");
213 assert_eq!(format!("{}", MessageRole::User), "user");
214 }
215
216 #[test]
217 fn test_message_role_default() {
218 assert_eq!(MessageRole::default(), MessageRole::User);
219 }
220
221 #[test]
222 fn test_message_role_serialize() {
223 let role = MessageRole::Assistant;
224 let json = serde_json::to_string(&role).unwrap();
225 assert_eq!(json, "\"assistant\"");
226 }
227
228 #[test]
229 fn test_message_role_deserialize() {
230 let json = "\"system\"";
231 let role: MessageRole = serde_json::from_str(json).unwrap();
232 assert_eq!(role, MessageRole::System);
233 }
234
235 #[test]
236 fn test_message_role_deserialize_invalid() {
237 let json = "\"invalid_role\"";
238 let result: Result<MessageRole, _> = serde_json::from_str(json);
239 assert!(result.is_err());
240 }
241
242 #[test]
243 fn test_message_creation() {
244 let msg = Message::new(MessageRole::User, "Hello");
245 assert_eq!(msg.role, MessageRole::User);
246 assert_eq!(msg.content, "Hello");
247 }
248
249 #[test]
250 fn test_message_convenience_constructors() {
251 let system_msg = Message::system("You are a helpful assistant");
252 assert_eq!(system_msg.role, MessageRole::System);
253
254 let user_msg = Message::user("Hello");
255 assert_eq!(user_msg.role, MessageRole::User);
256
257 let assistant_msg = Message::assistant("Hi there!");
258 assert_eq!(assistant_msg.role, MessageRole::Assistant);
259
260 let tool_msg = Message::tool("Tool result");
261 assert_eq!(tool_msg.role, MessageRole::Tool);
262 }
263
264 #[test]
265 fn test_message_serialize() {
266 let msg = Message::user("Hello");
267 let json = serde_json::to_string(&msg).unwrap();
268 assert!(json.contains("\"role\":\"user\""));
269 assert!(json.contains("\"content\":\"Hello\""));
270 }
271
272 #[test]
273 fn test_message_deserialize() {
274 let json = r#"{"role":"assistant","content":"Hello!"}"#;
275 let msg: Message = serde_json::from_str(json).unwrap();
276 assert_eq!(msg.role, MessageRole::Assistant);
277 assert_eq!(msg.content, "Hello!");
278 }
279
280 #[test]
281 fn test_request_context_new() {
282 let ctx = RequestContext::new(
283 Uuid::new_v4(),
284 Uuid::new_v4(),
285 Uuid::new_v4(),
286 "gpt-4",
287 vec![Message::user("Hello")],
288 false,
289 PricingSnapshot::default(),
290 );
291 assert_eq!(ctx.model, "gpt-4");
292 assert!(!ctx.stream);
293 }
294
295 #[test]
296 fn test_request_context_usage_shared() {
297 let ctx = RequestContext::new(
298 Uuid::new_v4(),
299 Uuid::new_v4(),
300 Uuid::new_v4(),
301 "gpt-4",
302 vec![Message::user("Hello")],
303 false,
304 PricingSnapshot::default(),
305 );
306
307 ctx.add_output_tokens(100);
309 ctx.set_input_tokens(50);
310
311 let (input, output) = ctx.usage_snapshot();
313 assert_eq!(input, 50);
314 assert_eq!(output, 100);
315
316 let ctx2 = ctx.clone();
318 ctx2.add_output_tokens(50);
319
320 let (_, output2) = ctx.usage_snapshot();
322 assert_eq!(output2, 150);
323 }
324
325 #[test]
326 fn test_chat_completion_request_new() {
327 let req = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")]);
328 assert_eq!(req.model, "gpt-4");
329 assert_eq!(req.messages.len(), 1);
330 assert!(req.stream.is_none());
331 }
332
333 #[test]
334 fn test_chat_completion_request_serialize() {
335 let req = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")]);
336 let json = serde_json::to_string(&req).unwrap();
337 assert!(json.contains("\"model\":\"gpt-4\""));
338 assert!(json.contains("\"role\":\"user\""));
339 }
340}