use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use crate::llm::{CallOptions, ToolDefinition};
use crate::state::Message;
#[derive(Debug, Clone)]
pub struct CacheKeyInput {
pub model: String,
pub messages: Vec<Message>,
pub tools: Vec<ToolDefinition>,
pub config: Option<CallOptions>,
}
impl CacheKeyInput {
pub fn new(
model: impl Into<String>,
messages: Vec<Message>,
tools: Vec<ToolDefinition>,
config: Option<CallOptions>,
) -> Self {
Self {
model: model.into(),
messages,
tools,
config,
}
}
#[must_use]
pub fn hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.model.hash(&mut hasher);
for msg in &self.messages {
msg.role.hash(&mut hasher);
match &msg.content {
crate::state::Content::Text(text) => {
text.hash(&mut hasher);
}
crate::state::Content::MultiPart(parts) => {
for part in parts {
match part {
crate::state::ContentPart::Text { text } => {
text.hash(&mut hasher);
}
crate::state::ContentPart::Image(data) => {
data.media_type.hash(&mut hasher);
match &data.source {
crate::state::ImageSource::Base64(data) => {
data.hash(&mut hasher);
}
crate::state::ImageSource::Url(url) => {
url.hash(&mut hasher);
}
}
}
crate::state::ContentPart::Thinking { text, signature } => {
text.hash(&mut hasher);
signature.hash(&mut hasher);
}
}
}
}
}
for call in &msg.tool_calls {
call.id.hash(&mut hasher);
call.name.hash(&mut hasher);
if let Ok(s) = serde_json::to_string(&call.arguments) {
s.hash(&mut hasher);
}
}
}
for tool in &self.tools {
tool.name.hash(&mut hasher);
if let Ok(s) = serde_json::to_string(&tool.parameters) {
s.hash(&mut hasher);
}
}
if let Some(config) = &self.config {
if let Some(temp) = config.temperature {
(temp.to_bits()).hash(&mut hasher);
}
if let Some(max_tokens) = config.max_tokens {
max_tokens.hash(&mut hasher);
}
if let Some(top_p) = config.top_p {
(top_p.to_bits()).hash(&mut hasher);
}
}
hasher.finish()
}
}
#[derive(Default)]
#[allow(
missing_debug_implementations,
clippy::type_complexity,
reason = "Contains Arc<dyn Fn> which doesn't implement Debug. Complex trait object type is required for dynamic tool configuration."
)]
#[derive(Clone)]
pub struct CachePolicy {
pub key_func: Option<Arc<dyn Fn(&CacheKeyInput) -> String + Send + Sync>>,
}
impl CachePolicy {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_key_func(mut self, f: Arc<dyn Fn(&CacheKeyInput) -> String + Send + Sync>) -> Self {
self.key_func = Some(f);
self
}
#[must_use]
pub fn generate_key(&self, input: &CacheKeyInput) -> String {
self.key_func.as_ref().map_or_else(
|| format!("{}:{}", input.model, input.hash()),
|func| func(input),
)
}
}
pub trait MetricsCollector: Send + Sync + 'static {
fn inc_counter(&self, name: &str, value: u64);
fn record_histogram(&self, name: &str, value: f64);
fn set_gauge(&self, name: &str, value: u64);
}
pub trait GraphLifecycleCallback: Send + Sync + 'static {
fn on_node_start(&self, node: &str, task_id: &str) {
let _ = (node, task_id);
}
fn on_node_end(&self, node: &str, task_id: &str, duration_ms: u64) {
let _ = (node, task_id, duration_ms);
}
fn on_node_error(&self, node: &str, error: &crate::JunctureError) {
let _ = (node, error);
}
fn on_graph_end(&self, result: &Result<(), crate::JunctureError>) {
let _ = result;
}
fn on_checkpoint_saved(&self, checkpoint_id: &str, step: usize) {
let _ = (checkpoint_id, step);
}
}