use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::Arc;
use uuid::Uuid;
use crate::{PricingSnapshot, UsageAccumulator};
#[derive(Debug, Clone)]
pub struct RequestContext {
pub request_id: Uuid,
pub user_id: Uuid,
pub tenant_id: Uuid,
pub produce_ai_key_id: Uuid,
pub model: String,
pub provider: Option<String>,
pub messages: Vec<Message>,
pub stream: bool,
pub pricing_snapshot: PricingSnapshot, usage: Arc<UsageAccumulator>, pub started_at: DateTime<Utc>,
}
impl RequestContext {
pub fn new(
user_id: Uuid,
tenant_id: Uuid,
produce_ai_key_id: Uuid,
model: impl Into<String>,
messages: Vec<Message>,
stream: bool,
pricing_snapshot: PricingSnapshot,
) -> Self {
Self {
request_id: Uuid::new_v4(),
user_id,
tenant_id,
produce_ai_key_id,
model: model.into(),
provider: None,
messages,
stream,
pricing_snapshot,
usage: Arc::new(UsageAccumulator::new()),
started_at: Utc::now(),
}
}
pub fn set_provider(&mut self, provider: impl Into<String>) {
self.provider = Some(provider.into());
}
pub fn update_pricing(&mut self, pricing: PricingSnapshot) {
self.pricing_snapshot = pricing;
}
pub fn duration(&self) -> chrono::Duration {
Utc::now() - self.started_at
}
pub fn usage_snapshot(&self) -> (u32, u32) {
self.usage.snapshot()
}
pub fn add_output_tokens(&self, tokens: u32) {
self.usage.add_output(tokens);
}
pub fn set_input_tokens(&self, tokens: u32) {
self.usage.set_input(tokens);
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
#[default]
User,
Assistant,
Tool,
}
impl MessageRole {
pub fn as_str(&self) -> &'static str {
match self {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
}
}
}
impl fmt::Display for MessageRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: MessageRole,
pub content: String,
}
impl Message {
pub fn new(role: MessageRole, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(MessageRole::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(MessageRole::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(MessageRole::Assistant, content)
}
pub fn tool(content: impl Into<String>) -> Self {
Self::new(MessageRole::Tool, content)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
impl ChatCompletionRequest {
pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
Self {
model: model.into(),
messages,
stream: None,
max_tokens: None,
temperature: None,
top_p: None,
n: None,
stop: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_role_as_str() {
assert_eq!(MessageRole::System.as_str(), "system");
assert_eq!(MessageRole::User.as_str(), "user");
assert_eq!(MessageRole::Assistant.as_str(), "assistant");
assert_eq!(MessageRole::Tool.as_str(), "tool");
}
#[test]
fn test_message_role_all_variants() {
let roles = vec![
(MessageRole::System, "system"),
(MessageRole::User, "user"),
(MessageRole::Assistant, "assistant"),
(MessageRole::Tool, "tool"),
];
for (role, expected) in roles {
assert_eq!(role.as_str(), expected);
assert_eq!(format!("{}", role), expected);
}
}
#[test]
fn test_message_role_display() {
assert_eq!(format!("{}", MessageRole::System), "system");
assert_eq!(format!("{}", MessageRole::User), "user");
}
#[test]
fn test_message_role_default() {
assert_eq!(MessageRole::default(), MessageRole::User);
}
#[test]
fn test_message_role_serialize() {
let role = MessageRole::Assistant;
let json = serde_json::to_string(&role).unwrap();
assert_eq!(json, "\"assistant\"");
}
#[test]
fn test_message_role_deserialize() {
let json = "\"system\"";
let role: MessageRole = serde_json::from_str(json).unwrap();
assert_eq!(role, MessageRole::System);
}
#[test]
fn test_message_role_deserialize_invalid() {
let json = "\"invalid_role\"";
let result: Result<MessageRole, _> = serde_json::from_str(json);
assert!(result.is_err());
}
#[test]
fn test_message_creation() {
let msg = Message::new(MessageRole::User, "Hello");
assert_eq!(msg.role, MessageRole::User);
assert_eq!(msg.content, "Hello");
}
#[test]
fn test_message_convenience_constructors() {
let system_msg = Message::system("You are a helpful assistant");
assert_eq!(system_msg.role, MessageRole::System);
let user_msg = Message::user("Hello");
assert_eq!(user_msg.role, MessageRole::User);
let assistant_msg = Message::assistant("Hi there!");
assert_eq!(assistant_msg.role, MessageRole::Assistant);
let tool_msg = Message::tool("Tool result");
assert_eq!(tool_msg.role, MessageRole::Tool);
}
#[test]
fn test_message_serialize() {
let msg = Message::user("Hello");
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"content\":\"Hello\""));
}
#[test]
fn test_message_deserialize() {
let json = r#"{"role":"assistant","content":"Hello!"}"#;
let msg: Message = serde_json::from_str(json).unwrap();
assert_eq!(msg.role, MessageRole::Assistant);
assert_eq!(msg.content, "Hello!");
}
#[test]
fn test_request_context_new() {
let ctx = RequestContext::new(
Uuid::new_v4(),
Uuid::new_v4(),
Uuid::new_v4(),
"gpt-4",
vec![Message::user("Hello")],
false,
PricingSnapshot::default(),
);
assert_eq!(ctx.model, "gpt-4");
assert!(!ctx.stream);
}
#[test]
fn test_request_context_usage_shared() {
let ctx = RequestContext::new(
Uuid::new_v4(),
Uuid::new_v4(),
Uuid::new_v4(),
"gpt-4",
vec![Message::user("Hello")],
false,
PricingSnapshot::default(),
);
ctx.add_output_tokens(100);
ctx.set_input_tokens(50);
let (input, output) = ctx.usage_snapshot();
assert_eq!(input, 50);
assert_eq!(output, 100);
let ctx2 = ctx.clone();
ctx2.add_output_tokens(50);
let (_, output2) = ctx.usage_snapshot();
assert_eq!(output2, 150);
}
#[test]
fn test_chat_completion_request_new() {
let req = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")]);
assert_eq!(req.model, "gpt-4");
assert_eq!(req.messages.len(), 1);
assert!(req.stream.is_none());
}
#[test]
fn test_chat_completion_request_serialize() {
let req = ChatCompletionRequest::new("gpt-4", vec![Message::user("Hello")]);
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"model\":\"gpt-4\""));
assert!(json.contains("\"role\":\"user\""));
}
}