use super::types::CacheKey;
use crate::core::models::openai::{ChatCompletionRequest, EmbeddingRequest};
use serde::Serialize;
use std::collections::BTreeMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use tracing::warn;
pub const CHAT_KEY_PREFIX: &str = "chat";
pub const EMBEDDING_KEY_PREFIX: &str = "embed";
pub const COMPLETION_KEY_PREFIX: &str = "completion";
pub fn generate_chat_key(request: &ChatCompletionRequest) -> CacheKey {
generate_chat_key_with_user(request, None)
}
pub fn generate_chat_key_with_user(
request: &ChatCompletionRequest,
user_id: Option<&str>,
) -> CacheKey {
let mut hasher = DefaultHasher::new();
request.model.hash(&mut hasher);
for message in &request.messages {
message.role.hash(&mut hasher);
if let Some(ref content) = message.content {
content.hash(&mut hasher);
}
if let Some(ref name) = message.name {
name.hash(&mut hasher);
}
}
hash_optional_f32(&mut hasher, request.temperature);
hash_optional_u32(&mut hasher, request.max_tokens);
hash_optional_u32(&mut hasher, request.max_completion_tokens);
hash_optional_f32(&mut hasher, request.top_p);
hash_optional_u32(&mut hasher, request.n);
hash_optional_f32(&mut hasher, request.presence_penalty);
hash_optional_f32(&mut hasher, request.frequency_penalty);
hash_optional_u32(&mut hasher, request.seed);
if let Some(ref stops) = request.stop {
let mut sorted_stops = stops.clone();
sorted_stops.sort();
for stop in sorted_stops {
stop.hash(&mut hasher);
}
}
if let Some(ref format) = request.response_format {
format.format_type.hash(&mut hasher);
}
if let Some(ref tools) = request.tools {
for tool in tools {
tool.function.name.hash(&mut hasher);
}
}
if let Some(uid) = user_id {
uid.hash(&mut hasher);
}
let hash = hasher.finish();
let key = format!("{}:{}:{:016x}", CHAT_KEY_PREFIX, request.model, hash);
CacheKey::new(key)
}
pub fn generate_embedding_key(request: &EmbeddingRequest) -> CacheKey {
generate_embedding_key_with_user(request, None)
}
pub fn generate_embedding_key_with_user(
request: &EmbeddingRequest,
user_id: Option<&str>,
) -> CacheKey {
let mut hasher = DefaultHasher::new();
request.model.hash(&mut hasher);
match &request.input {
serde_json::Value::String(s) => {
"string".hash(&mut hasher);
s.hash(&mut hasher);
}
serde_json::Value::Array(arr) => {
"array".hash(&mut hasher);
let mut texts: Vec<String> = arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
texts.sort();
for text in texts {
text.hash(&mut hasher);
}
}
_ => {
"other".hash(&mut hasher);
request.input.to_string().hash(&mut hasher);
}
}
if let Some(uid) = user_id {
uid.hash(&mut hasher);
}
let hash = hasher.finish();
let key = format!("{}:{}:{:016x}", EMBEDDING_KEY_PREFIX, request.model, hash);
CacheKey::new(key)
}
pub fn generate_key_from_json<T: Serialize>(prefix: &str, request: &T) -> CacheKey {
let json = serde_json::to_string(request).unwrap_or_else(|e| {
warn!(
"Failed to serialize request for cache key generation: {}",
e
);
String::new()
});
let normalized = normalize_json_string(&json);
let mut hasher = DefaultHasher::new();
normalized.hash(&mut hasher);
let hash = hasher.finish();
CacheKey::new(format!("{}:{:016x}", prefix, hash))
}
pub fn generate_key_from_content(prefix: &str, content: &str) -> CacheKey {
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
let hash = hasher.finish();
CacheKey::new(format!("{}:{:016x}", prefix, hash))
}
pub fn generate_key_from_parts(prefix: &str, parts: &[&str]) -> CacheKey {
let mut hasher = DefaultHasher::new();
for part in parts {
part.hash(&mut hasher);
}
let hash = hasher.finish();
CacheKey::new(format!("{}:{:016x}", prefix, hash))
}
fn hash_optional_f32<H: Hasher>(hasher: &mut H, value: Option<f32>) {
if let Some(v) = value {
1u8.hash(hasher);
v.to_bits().hash(hasher);
} else {
0u8.hash(hasher);
}
}
fn hash_optional_u32<H: Hasher>(hasher: &mut H, value: Option<u32>) {
if let Some(v) = value {
1u8.hash(hasher);
v.hash(hasher);
} else {
0u8.hash(hasher);
}
}
fn normalize_json_string(json: &str) -> String {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(json) {
normalize_json_value(&value)
} else {
json.to_string()
}
}
fn normalize_json_value(value: &serde_json::Value) -> String {
match value {
serde_json::Value::Object(map) => {
let sorted: BTreeMap<_, _> = map
.iter()
.filter(|(k, _)| !is_non_deterministic_field(k))
.map(|(k, v)| (k.clone(), normalize_json_value(v)))
.collect();
serde_json::to_string(&sorted).unwrap_or_default()
}
serde_json::Value::Array(arr) => {
let normalized: Vec<String> = arr.iter().map(normalize_json_value).collect();
format!("[{}]", normalized.join(","))
}
_ => value.to_string(),
}
}
fn is_non_deterministic_field(field: &str) -> bool {
matches!(
field,
"timestamp"
| "request_id"
| "trace_id"
| "span_id"
| "created_at"
| "updated_at"
| "id"
| "stream"
| "stream_options"
)
}
pub struct CacheKeyBuilder {
parts: Vec<String>,
prefix: String,
}
impl CacheKeyBuilder {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
parts: Vec::new(),
prefix: prefix.into(),
}
}
pub fn with_part(mut self, part: impl Into<String>) -> Self {
self.parts.push(part.into());
self
}
pub fn add_optional(mut self, part: Option<impl Into<String>>) -> Self {
if let Some(p) = part {
self.parts.push(p.into());
}
self
}
pub fn add_num<N: std::fmt::Display>(mut self, num: N) -> Self {
self.parts.push(num.to_string());
self
}
pub fn build(self) -> CacheKey {
let mut hasher = DefaultHasher::new();
for part in &self.parts {
part.hash(&mut hasher);
}
let hash = hasher.finish();
CacheKey::new(format!("{}:{:016x}", self.prefix, hash))
}
pub fn build_explicit(self) -> CacheKey {
let key = std::iter::once(self.prefix)
.chain(self.parts)
.collect::<Vec<_>>()
.join(":");
CacheKey::new(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::models::openai::messages::{ChatMessage, MessageContent, MessageRole};
fn create_user_message(content: &str) -> ChatMessage {
ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text(content.to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
}
}
#[test]
fn test_generate_chat_key_basic() {
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![create_user_message("Hello")],
..Default::default()
};
let key = generate_chat_key(&request);
assert!(key.as_str().starts_with("chat:gpt-4:"));
}
#[test]
fn test_generate_chat_key_consistency() {
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![create_user_message("Hello")],
temperature: Some(0.7),
..Default::default()
};
let key1 = generate_chat_key(&request);
let key2 = generate_chat_key(&request);
assert_eq!(key1, key2);
}
#[test]
fn test_generate_chat_key_different_messages() {
let request1 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![create_user_message("Hello")],
..Default::default()
};
let request2 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![create_user_message("Goodbye")],
..Default::default()
};
let key1 = generate_chat_key(&request1);
let key2 = generate_chat_key(&request2);
assert_ne!(key1, key2);
}
#[test]
fn test_generate_chat_key_different_models() {
let request1 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
..Default::default()
};
let request2 = ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![],
..Default::default()
};
let key1 = generate_chat_key(&request1);
let key2 = generate_chat_key(&request2);
assert_ne!(key1, key2);
}
#[test]
fn test_generate_chat_key_with_user() {
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
..Default::default()
};
let key1 = generate_chat_key_with_user(&request, Some("user-123"));
let key2 = generate_chat_key_with_user(&request, Some("user-456"));
let key3 = generate_chat_key_with_user(&request, None);
assert_ne!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_generate_chat_key_with_parameters() {
let request1 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: Some(0.7),
..Default::default()
};
let request2 = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
temperature: Some(0.9),
..Default::default()
};
let key1 = generate_chat_key(&request1);
let key2 = generate_chat_key(&request2);
assert_ne!(key1, key2);
}
#[test]
fn test_generate_embedding_key_string() {
let request = EmbeddingRequest {
model: "text-embedding-ada-002".to_string(),
input: serde_json::json!("Hello world"),
user: None,
};
let key = generate_embedding_key(&request);
assert!(key.as_str().starts_with("embed:text-embedding-ada-002:"));
}
#[test]
fn test_generate_embedding_key_array() {
let request = EmbeddingRequest {
model: "text-embedding-ada-002".to_string(),
input: serde_json::json!(["Hello", "World"]),
user: None,
};
let key = generate_embedding_key(&request);
assert!(key.as_str().starts_with("embed:text-embedding-ada-002:"));
}
#[test]
fn test_generate_embedding_key_consistency() {
let request = EmbeddingRequest {
model: "text-embedding-3-small".to_string(),
input: serde_json::json!("Test input"),
user: None,
};
let key1 = generate_embedding_key(&request);
let key2 = generate_embedding_key(&request);
assert_eq!(key1, key2);
}
#[test]
fn test_generate_embedding_key_array_order_normalized() {
let request1 = EmbeddingRequest {
model: "text-embedding-ada-002".to_string(),
input: serde_json::json!(["Alpha", "Beta"]),
user: None,
};
let request2 = EmbeddingRequest {
model: "text-embedding-ada-002".to_string(),
input: serde_json::json!(["Beta", "Alpha"]),
user: None,
};
let key1 = generate_embedding_key(&request1);
let key2 = generate_embedding_key(&request2);
assert_eq!(key1, key2);
}
#[test]
fn test_generate_key_from_content() {
let key = generate_key_from_content("test", "some content");
assert!(key.as_str().starts_with("test:"));
}
#[test]
fn test_generate_key_from_parts() {
let key = generate_key_from_parts("prefix", &["part1", "part2", "part3"]);
assert!(key.as_str().starts_with("prefix:"));
}
#[test]
fn test_generate_key_from_json() {
#[derive(Serialize)]
struct TestRequest {
field1: String,
field2: i32,
}
let request = TestRequest {
field1: "value".to_string(),
field2: 42,
};
let key = generate_key_from_json("test", &request);
assert!(key.as_str().starts_with("test:"));
}
#[test]
fn test_cache_key_builder_basic() {
let key = CacheKeyBuilder::new("chat")
.with_part("gpt-4")
.with_part("user-123")
.build();
assert!(key.as_str().starts_with("chat:"));
}
#[test]
fn test_cache_key_builder_with_nums() {
let key = CacheKeyBuilder::new("session")
.with_part("user")
.add_num(123)
.add_num(456)
.build();
assert!(key.as_str().starts_with("session:"));
}
#[test]
fn test_cache_key_builder_with_optional() {
let key1 = CacheKeyBuilder::new("test")
.with_part("base")
.add_optional(Some("optional"))
.build();
let key2 = CacheKeyBuilder::new("test")
.with_part("base")
.add_optional(None::<String>)
.build();
assert_ne!(key1, key2);
}
#[test]
fn test_cache_key_builder_explicit() {
let key = CacheKeyBuilder::new("chat")
.with_part("gpt-4")
.with_part("conversation-1")
.build_explicit();
assert_eq!(key.as_str(), "chat:gpt-4:conversation-1");
}
#[test]
fn test_normalize_json_filters_timestamp() {
let json1 = r#"{"message": "hello", "timestamp": "2024-01-01"}"#;
let json2 = r#"{"message": "hello", "timestamp": "2024-12-31"}"#;
let norm1 = normalize_json_string(json1);
let norm2 = normalize_json_string(json2);
assert_eq!(norm1, norm2);
}
#[test]
fn test_normalize_json_sorts_keys() {
let json1 = r#"{"b": 2, "a": 1}"#;
let json2 = r#"{"a": 1, "b": 2}"#;
let norm1 = normalize_json_string(json1);
let norm2 = normalize_json_string(json2);
assert_eq!(norm1, norm2);
}
#[test]
fn test_is_non_deterministic_field() {
assert!(is_non_deterministic_field("timestamp"));
assert!(is_non_deterministic_field("request_id"));
assert!(is_non_deterministic_field("stream"));
assert!(!is_non_deterministic_field("model"));
assert!(!is_non_deterministic_field("messages"));
}
}