use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerInfo {
pub assistant_id: Option<String>,
pub graph_id: Option<String>,
pub user: Option<String>,
pub deployment: Option<String>,
pub version: Option<String>,
pub instance_id: Option<String>,
}
impl ServerInfo {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_assistant_id(mut self, id: impl Into<String>) -> Self {
self.assistant_id = Some(id.into());
self
}
#[must_use]
pub fn with_graph_id(mut self, id: impl Into<String>) -> Self {
self.graph_id = Some(id.into());
self
}
#[must_use]
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.user = Some(user.into());
self
}
#[must_use]
pub fn with_deployment(mut self, deployment: impl Into<String>) -> Self {
self.deployment = Some(deployment.into());
self
}
#[must_use]
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.version = Some(version.into());
self
}
#[must_use]
pub fn with_instance_id(mut self, id: impl Into<String>) -> Self {
self.instance_id = Some(id.into());
self
}
}
#[derive(Default)]
pub struct LlmCachePolicy {
pub key_func: Option<LlmCacheKeyFn>,
}
type LlmCacheKeyFn = std::sync::Arc<dyn Fn(&LlmCacheKeyInput) -> String + Send + Sync>;
impl fmt::Debug for LlmCachePolicy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LlmCachePolicy")
.field(
"key_func",
if self.key_func.is_some() {
&"Some(custom function)"
} else {
&"None"
},
)
.finish()
}
}
impl LlmCachePolicy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
#[allow(
clippy::type_complexity,
reason = "Function pointer type is necessary for the cache key API"
)]
pub fn with_key_func<F>(mut self, f: F) -> Self
where
F: Fn(&LlmCacheKeyInput) -> String + Send + Sync + 'static,
{
self.key_func = Some(std::sync::Arc::new(f));
self
}
}
#[derive(Clone, Debug)]
pub struct LlmCacheKeyInput {
pub model: String,
pub messages: Vec<serde_json::Value>,
pub tools: Vec<serde_json::Value>,
pub config: Option<()>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_server_info_default() {
let info = ServerInfo::default();
assert!(info.assistant_id.is_none());
assert!(info.graph_id.is_none());
assert!(info.user.is_none());
assert!(info.deployment.is_none());
assert!(info.version.is_none());
assert!(info.instance_id.is_none());
}
#[test]
fn test_server_info_builder() {
let info = ServerInfo::new()
.with_assistant_id("asst_123")
.with_graph_id("graph_456")
.with_user("user@example.com")
.with_deployment("production")
.with_version("1.0.0")
.with_instance_id("pod-abc123");
assert_eq!(info.assistant_id, Some("asst_123".to_string()));
assert_eq!(info.graph_id, Some("graph_456".to_string()));
assert_eq!(info.user, Some("user@example.com".to_string()));
assert_eq!(info.deployment, Some("production".to_string()));
assert_eq!(info.version, Some("1.0.0".to_string()));
assert_eq!(info.instance_id, Some("pod-abc123".to_string()));
}
#[test]
fn test_server_info_serialization() {
let info = ServerInfo {
assistant_id: Some("asst_123".to_string()),
deployment: Some("production".to_string()),
..Default::default()
};
let json = serde_json::to_string(&info).unwrap();
let deserialized: ServerInfo = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.assistant_id, info.assistant_id);
assert_eq!(deserialized.deployment, info.deployment);
}
#[test]
fn test_llm_cache_policy_default() {
let policy = LlmCachePolicy::default();
assert!(policy.key_func.is_none());
}
#[test]
fn test_llm_cache_policy_with_custom_func() {
let policy = LlmCachePolicy::new()
.with_key_func(|input| format!("custom:{}:{}", input.model, input.messages.len()));
assert!(policy.key_func.is_some());
let input = LlmCacheKeyInput {
model: "gpt-4".to_string(),
messages: vec![json!({}), json!({})],
tools: vec![],
config: None,
};
let key = policy.key_func.as_ref().unwrap()(&input);
assert_eq!(key, "custom:gpt-4:2");
}
#[test]
fn test_llm_cache_policy_debug() {
let policy_without = LlmCachePolicy::default();
let debug_str = format!("{policy_without:?}");
assert!(debug_str.contains("None"));
let policy_with = LlmCachePolicy::new().with_key_func(|_| "key".to_string());
let debug_str = format!("{policy_with:?}");
assert!(debug_str.contains("Some"));
}
#[test]
fn test_llm_cache_key_input() {
let input = LlmCacheKeyInput {
model: "claude-3".to_string(),
messages: vec![json!({"role": "user"})],
tools: vec![],
config: None,
};
assert_eq!(input.model, "claude-3");
assert_eq!(input.messages.len(), 1);
assert!(input.tools.is_empty());
assert!(input.config.is_none());
}
}