use crate::runtime::tool_call::ToolDescriptor;
use crate::thread::Message;
use std::sync::Arc;
pub struct InferenceTransformOutput {
pub messages: Vec<Message>,
pub enable_prompt_cache: bool,
}
pub trait InferenceRequestTransform: Send + Sync {
fn transform(
&self,
messages: Vec<Message>,
tool_descriptors: &[ToolDescriptor],
) -> InferenceTransformOutput;
}
pub fn apply_request_transforms(
mut messages: Vec<Message>,
tool_descriptors: &[ToolDescriptor],
transforms: &[Arc<dyn InferenceRequestTransform>],
) -> InferenceTransformOutput {
let mut enable_prompt_cache = false;
for transform in transforms {
let output = transform.transform(messages, tool_descriptors);
messages = output.messages;
enable_prompt_cache |= output.enable_prompt_cache;
}
InferenceTransformOutput {
messages,
enable_prompt_cache,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::thread::Message;
struct PrependSystem(String);
impl InferenceRequestTransform for PrependSystem {
fn transform(
&self,
mut messages: Vec<Message>,
_tool_descriptors: &[ToolDescriptor],
) -> InferenceTransformOutput {
messages.insert(0, Message::system(&self.0));
InferenceTransformOutput {
messages,
enable_prompt_cache: false,
}
}
}
struct EnableCache;
impl InferenceRequestTransform for EnableCache {
fn transform(
&self,
messages: Vec<Message>,
_tool_descriptors: &[ToolDescriptor],
) -> InferenceTransformOutput {
InferenceTransformOutput {
messages,
enable_prompt_cache: true,
}
}
}
struct LimitMessages(usize);
impl InferenceRequestTransform for LimitMessages {
fn transform(
&self,
messages: Vec<Message>,
_tool_descriptors: &[ToolDescriptor],
) -> InferenceTransformOutput {
let kept: Vec<Message> = messages.into_iter().take(self.0).collect();
InferenceTransformOutput {
messages: kept,
enable_prompt_cache: false,
}
}
}
#[test]
fn empty_transforms_is_passthrough() {
let messages = vec![Message::user("Hello"), Message::assistant("Hi")];
let output = apply_request_transforms(messages.clone(), &[], &[]);
assert_eq!(output.messages.len(), 2);
assert_eq!(output.messages[0].content, "Hello");
assert_eq!(output.messages[1].content, "Hi");
assert!(!output.enable_prompt_cache);
}
#[test]
fn single_transform_applied() {
let messages = vec![Message::user("Hello")];
let transforms: Vec<Arc<dyn InferenceRequestTransform>> =
vec![Arc::new(PrependSystem("System prompt".into()))];
let output = apply_request_transforms(messages, &[], &transforms);
assert_eq!(output.messages.len(), 2);
assert_eq!(output.messages[0].content, "System prompt");
assert_eq!(output.messages[1].content, "Hello");
}
#[test]
fn transforms_chain_pipes_output_to_next() {
let messages = vec![Message::user("Hello")];
let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
Arc::new(PrependSystem("First".into())),
Arc::new(PrependSystem("Second".into())),
];
let output = apply_request_transforms(messages, &[], &transforms);
assert_eq!(output.messages.len(), 3);
assert_eq!(output.messages[0].content, "Second");
assert_eq!(output.messages[1].content, "First");
assert_eq!(output.messages[2].content, "Hello");
}
#[test]
fn enable_prompt_cache_or_aggregated() {
let messages = vec![Message::user("Hello")];
let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
Arc::new(PrependSystem("sys".into())), Arc::new(EnableCache), ];
let output = apply_request_transforms(messages, &[], &transforms);
assert!(
output.enable_prompt_cache,
"should be true via OR aggregation"
);
}
#[test]
fn enable_prompt_cache_stays_false_when_none_request() {
let messages = vec![Message::user("Hello")];
let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
Arc::new(PrependSystem("a".into())),
Arc::new(PrependSystem("b".into())),
];
let output = apply_request_transforms(messages, &[], &transforms);
assert!(!output.enable_prompt_cache);
}
#[test]
fn chain_with_limiting_transform() {
let messages = vec![Message::user("1"), Message::user("2"), Message::user("3")];
let transforms: Vec<Arc<dyn InferenceRequestTransform>> = vec![
Arc::new(PrependSystem("sys".into())), Arc::new(LimitMessages(2)), ];
let output = apply_request_transforms(messages, &[], &transforms);
assert_eq!(output.messages.len(), 2);
assert_eq!(output.messages[0].content, "sys");
assert_eq!(output.messages[1].content, "1");
}
}