use std::sync::Arc;
use crate::types::{Conversation, MessageRole, ToolDefinition};
use async_trait::async_trait;
use domain::entities::ToolCallingResult;
use futures::StreamExt;
use tracing::{debug, instrument, warn};
use application::error::ApplicationError;
use application::ports::{
InferenceOverrides, InferencePort, InferenceResult, InferenceStream, StreamingChunk,
};
use crate::optimizer::TokenOptimizer;
use crate::ports::InferencePortSummarizer;
use crate::stream::repetition::RepetitionState;
pub struct TokenOptimizedInferencePort {
inner: Arc<dyn InferencePort>,
optimizer: Arc<TokenOptimizer>,
tool_tracker: std::sync::Mutex<crate::tools::progressive::ToolUsageTracker>,
}
impl std::fmt::Debug for TokenOptimizedInferencePort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenOptimizedInferencePort")
.finish_non_exhaustive()
}
}
impl TokenOptimizedInferencePort {
#[must_use]
pub fn new(inner: Arc<dyn InferencePort>, optimizer: Arc<TokenOptimizer>) -> Self {
Self {
inner,
optimizer,
tool_tracker: std::sync::Mutex::new(crate::tools::progressive::ToolUsageTracker::new()),
}
}
fn build_overrides(&self, recommended_max_tokens: Option<u32>) -> Option<InferenceOverrides> {
let config = self.optimizer.config();
let mut options = serde_json::Map::new();
if let Some(max_tokens) = recommended_max_tokens {
options.insert("num_predict".into(), serde_json::json!(max_tokens));
}
if let Some(freq) = config.frequency_penalty {
options.insert("repeat_penalty".into(), serde_json::json!(freq));
}
if let Some(pres) = config.presence_penalty {
options.insert("presence_penalty".into(), serde_json::json!(pres));
}
if options.is_empty() {
None
} else {
Some(InferenceOverrides {
model: None,
options: Some(serde_json::Value::Object(options)),
})
}
}
fn wrap_stream_with_monitor(&self, inner: InferenceStream) -> InferenceStream {
let monitor = self.optimizer.create_stream_monitor();
let Some(monitor) = monitor else {
return inner;
};
let stream = futures::stream::unfold(
RepetitionStreamState {
inner,
monitor,
done: false,
},
|mut state| async move {
if state.done {
return None;
}
match state.inner.next().await {
Some(Ok(chunk)) => {
if chunk.done {
state.done = true;
return Some((Ok(chunk), state));
}
let rep_state = state.monitor.feed(&chunk.content);
match rep_state {
RepetitionState::Degenerate => {
warn!("Degenerate repetition detected, terminating stream early");
state.done = true;
Some((
Ok(StreamingChunk {
content: chunk.content,
done: true,
model: chunk.model,
}),
state,
))
},
RepetitionState::Warning(ratio) => {
debug!(
repetition_ratio = ratio,
"Elevated repetition in output stream"
);
Some((Ok(chunk), state))
},
RepetitionState::Normal => Some((Ok(chunk), state)),
}
},
Some(Err(e)) => {
state.done = true;
Some((Err(e), state))
},
None => None,
}
},
);
Box::pin(stream)
}
}
struct RepetitionStreamState {
inner: InferenceStream,
monitor: crate::stream::repetition::RepetitionDetector,
done: bool,
}
#[async_trait]
impl InferencePort for TokenOptimizedInferencePort {
#[instrument(skip(self, message), fields(optimized = true))]
async fn generate(&self, message: &str) -> Result<InferenceResult, ApplicationError> {
self.inner.generate(message).await
}
#[instrument(skip(self, conversation), fields(optimized = true))]
async fn generate_with_context(
&self,
conversation: &Conversation,
) -> Result<InferenceResult, ApplicationError> {
if !self.optimizer.is_enabled() {
return self.inner.generate_with_context(conversation).await;
}
let mut optimized = conversation.clone();
let summarizer = InferencePortSummarizer(self.inner.as_ref());
match self
.optimizer
.optimize_conversation(&mut optimized, Some(&summarizer))
.await
{
Ok(result) => {
let inference_result =
if let Some(overrides) = self.build_overrides(result.recommended_max_tokens) {
debug!(overrides = ?overrides.options, "Applying inference overrides");
self.inner
.generate_with_context_and_overrides(&optimized, &overrides)
.await?
} else {
self.inner.generate_with_context(&optimized).await?
};
if let Some(actual) = inference_result.tokens_used {
self.optimizer.report_actual_tokens(
&inference_result.model,
result.estimate_after.total,
actual,
);
}
Ok(inference_result)
},
Err(e) => {
warn!(error = %e, "Conversation optimization failed, using original");
self.inner.generate_with_context(conversation).await
},
}
}
#[instrument(skip(self, system_prompt, message), fields(optimized = true))]
async fn generate_with_system(
&self,
system_prompt: &str,
message: &str,
) -> Result<InferenceResult, ApplicationError> {
if !self.optimizer.is_enabled() {
return self
.inner
.generate_with_system(system_prompt, message)
.await;
}
self.inner
.generate_with_system(system_prompt, message)
.await
}
#[instrument(skip(self, message), fields(optimized = true))]
async fn generate_stream(&self, message: &str) -> Result<InferenceStream, ApplicationError> {
let stream = self.inner.generate_stream(message).await?;
if self.optimizer.is_enabled() {
Ok(self.wrap_stream_with_monitor(stream))
} else {
Ok(stream)
}
}
#[instrument(skip(self, system_prompt, message), fields(optimized = true))]
async fn generate_stream_with_system(
&self,
system_prompt: &str,
message: &str,
) -> Result<InferenceStream, ApplicationError> {
let stream = self
.inner
.generate_stream_with_system(system_prompt, message)
.await?;
if self.optimizer.is_enabled() {
Ok(self.wrap_stream_with_monitor(stream))
} else {
Ok(stream)
}
}
async fn is_healthy(&self) -> bool {
self.inner.is_healthy().await
}
fn current_model(&self) -> String {
self.inner.current_model()
}
async fn list_available_models(&self) -> Result<Vec<String>, ApplicationError> {
self.inner.list_available_models().await
}
async fn switch_model(&self, model_name: &str) -> Result<(), ApplicationError> {
self.inner.switch_model(model_name).await
}
async fn generate_with_tools(
&self,
conversation: &Conversation,
tools: &[ToolDefinition],
) -> Result<ToolCallingResult, ApplicationError> {
if !self.optimizer.is_enabled() {
return self.inner.generate_with_tools(conversation, tools).await;
}
let mut optimized_conv = conversation.clone();
let last_user_message = conversation
.messages
.iter()
.rev()
.find(|m| m.role == MessageRole::User)
.map_or("", |m| m.content.as_str());
let optimized_tools = if self.optimizer.config().progressive_tool_compression {
let tracker = self
.tool_tracker
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
self.optimizer
.optimize_tools_progressive(last_user_message, tools, &tracker)
} else {
self.optimizer.optimize_tools(last_user_message, tools)
};
let summarizer = InferencePortSummarizer(self.inner.as_ref());
match self
.optimizer
.optimize_conversation_with_tools(
&mut optimized_conv,
&optimized_tools,
Some(&summarizer),
)
.await
{
Ok(result) => {
if self.optimizer.config().progressive_tool_compression {
if let Ok(mut tracker) = self.tool_tracker.lock() {
tracker.mark_seen(&optimized_tools);
}
}
if let Some(overrides) = self.build_overrides(result.recommended_max_tokens) {
debug!(overrides = ?overrides.options, "Applying inference overrides (tools)");
self.inner
.generate_with_tools_and_overrides(
&optimized_conv,
&optimized_tools,
&overrides,
)
.await
} else {
self.inner
.generate_with_tools(&optimized_conv, &optimized_tools)
.await
}
},
Err(e) => {
warn!(error = %e, "Tool optimization failed, using original");
self.inner.generate_with_tools(conversation, tools).await
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use application::ports::InferenceResult;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::config::TokenOptimizationConfig;
struct MockInference {
model: String,
healthy: AtomicBool,
}
impl MockInference {
fn new() -> Self {
Self {
model: "test-model".to_string(),
healthy: AtomicBool::new(true),
}
}
}
#[async_trait]
impl InferencePort for MockInference {
async fn generate(&self, message: &str) -> Result<InferenceResult, ApplicationError> {
Ok(InferenceResult {
content: format!("Response to: {message}"),
model: self.model.clone(),
tokens_used: Some(10),
latency_ms: 100,
})
}
async fn generate_with_context(
&self,
conversation: &Conversation,
) -> Result<InferenceResult, ApplicationError> {
let msg_count = conversation.messages.len();
Ok(InferenceResult {
content: format!("Context response ({msg_count} messages)"),
model: self.model.clone(),
tokens_used: Some(20),
latency_ms: 200,
})
}
async fn generate_with_system(
&self,
_system_prompt: &str,
message: &str,
) -> Result<InferenceResult, ApplicationError> {
Ok(InferenceResult {
content: format!("System response to: {message}"),
model: self.model.clone(),
tokens_used: Some(15),
latency_ms: 150,
})
}
async fn generate_stream(
&self,
_message: &str,
) -> Result<InferenceStream, ApplicationError> {
let chunks = vec![
Ok(StreamingChunk {
content: "Hello ".to_string(),
done: false,
model: None,
}),
Ok(StreamingChunk {
content: "world!".to_string(),
done: true,
model: Some(self.model.clone()),
}),
];
Ok(Box::pin(futures::stream::iter(chunks)))
}
async fn generate_stream_with_system(
&self,
_system_prompt: &str,
_message: &str,
) -> Result<InferenceStream, ApplicationError> {
self.generate_stream("").await
}
async fn is_healthy(&self) -> bool {
self.healthy.load(Ordering::Relaxed)
}
fn current_model(&self) -> String {
self.model.clone()
}
async fn list_available_models(&self) -> Result<Vec<String>, ApplicationError> {
Ok(vec![self.model.clone()])
}
async fn switch_model(&self, _model_name: &str) -> Result<(), ApplicationError> {
Ok(())
}
}
fn create_decorator() -> TokenOptimizedInferencePort {
let inner = Arc::new(MockInference::new());
let optimizer = Arc::new(TokenOptimizer::new(TokenOptimizationConfig::default()));
TokenOptimizedInferencePort::new(inner, optimizer)
}
fn create_disabled_decorator() -> TokenOptimizedInferencePort {
let inner = Arc::new(MockInference::new());
let config = TokenOptimizationConfig {
enabled: false,
..TokenOptimizationConfig::default()
};
let optimizer = Arc::new(TokenOptimizer::new(config));
TokenOptimizedInferencePort::new(inner, optimizer)
}
#[tokio::test]
async fn generate_passes_through() {
let decorator = create_decorator();
let result = decorator.generate("test").await;
assert!(result.is_ok());
assert!(result.expect("should succeed").content.contains("test"));
}
#[tokio::test]
async fn generate_with_context_optimizes() {
let decorator = create_decorator();
let mut conv = Conversation::new();
conv.add_user_message("Hello");
conv.add_assistant_message("Hi!");
let result = decorator.generate_with_context(&conv).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn disabled_optimizer_passes_through() {
let decorator = create_disabled_decorator();
let mut conv = Conversation::new();
conv.add_user_message("Hello");
let result = decorator.generate_with_context(&conv).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn stream_wraps_with_monitor() {
let decorator = create_decorator();
let stream = decorator.generate_stream("test").await;
assert!(stream.is_ok());
let mut stream = stream.expect("should succeed");
let mut chunks = Vec::new();
while let Some(chunk) = stream.next().await {
chunks.push(chunk.expect("chunk should be ok"));
}
assert!(!chunks.is_empty());
}
#[tokio::test]
async fn passthrough_methods_work() {
let decorator = create_decorator();
assert!(decorator.is_healthy().await);
assert_eq!(decorator.current_model(), "test-model");
let models = decorator.list_available_models().await;
assert!(models.is_ok());
let switch = decorator.switch_model("other").await;
assert!(switch.is_ok());
}
struct OverrideTrackingMock {
overrides_called: AtomicBool,
last_num_predict: std::sync::Mutex<Option<u64>>,
last_options: std::sync::Mutex<Option<serde_json::Value>>,
}
impl OverrideTrackingMock {
fn new() -> Self {
Self {
overrides_called: AtomicBool::new(false),
last_num_predict: std::sync::Mutex::new(None),
last_options: std::sync::Mutex::new(None),
}
}
}
#[async_trait]
impl InferencePort for OverrideTrackingMock {
async fn generate(&self, _: &str) -> Result<InferenceResult, ApplicationError> {
Ok(InferenceResult {
content: String::new(),
model: "test".into(),
tokens_used: Some(10),
latency_ms: 50,
})
}
async fn generate_with_context(
&self,
_: &Conversation,
) -> Result<InferenceResult, ApplicationError> {
Ok(InferenceResult {
content: "no-override".into(),
model: "test".into(),
tokens_used: Some(20),
latency_ms: 100,
})
}
async fn generate_with_context_and_overrides(
&self,
_conversation: &Conversation,
overrides: &InferenceOverrides,
) -> Result<InferenceResult, ApplicationError> {
self.overrides_called.store(true, Ordering::SeqCst);
if let Some(ref opts) = overrides.options {
*self
.last_options
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(opts.clone());
if let Some(np) = opts.get("num_predict").and_then(serde_json::Value::as_u64) {
*self
.last_num_predict
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner) = Some(np);
}
}
Ok(InferenceResult {
content: "with-override".into(),
model: "test".into(),
tokens_used: Some(20),
latency_ms: 100,
})
}
async fn generate_with_system(
&self,
_: &str,
_: &str,
) -> Result<InferenceResult, ApplicationError> {
Ok(InferenceResult {
content: String::new(),
model: "test".into(),
tokens_used: Some(10),
latency_ms: 50,
})
}
async fn generate_stream(&self, _: &str) -> Result<InferenceStream, ApplicationError> {
Ok(Box::pin(futures::stream::empty()))
}
async fn generate_stream_with_system(
&self,
_: &str,
_: &str,
) -> Result<InferenceStream, ApplicationError> {
Ok(Box::pin(futures::stream::empty()))
}
async fn is_healthy(&self) -> bool {
true
}
fn current_model(&self) -> String {
"test".into()
}
async fn list_available_models(&self) -> Result<Vec<String>, ApplicationError> {
Ok(vec!["test".into()])
}
async fn switch_model(&self, _: &str) -> Result<(), ApplicationError> {
Ok(())
}
}
#[tokio::test]
async fn generate_with_context_no_output_budget_by_default() {
let mock = Arc::new(OverrideTrackingMock::new());
let optimizer = Arc::new(TokenOptimizer::new(TokenOptimizationConfig::default()));
let decorator = TokenOptimizedInferencePort::new(mock.clone(), optimizer);
let mut conv = Conversation::new();
conv.add_user_message("Tell me about Rust programming language.");
let result = decorator.generate_with_context(&conv).await;
assert!(result.is_ok());
let np = mock
.last_num_predict
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
assert!(np.is_none(), "Expected no num_predict override by default");
}
#[tokio::test]
async fn output_budget_respects_config_cap() {
let mock = Arc::new(OverrideTrackingMock::new());
let config = TokenOptimizationConfig {
output_max_tokens: Some(128),
..TokenOptimizationConfig::default()
};
let optimizer = Arc::new(TokenOptimizer::new(config));
let decorator = TokenOptimizedInferencePort::new(mock.clone(), optimizer);
let mut conv = Conversation::new();
conv.add_user_message("Explain quantum entanglement in great detail.");
let result = decorator.generate_with_context(&conv).await;
assert!(result.is_ok());
assert!(mock.overrides_called.load(Ordering::SeqCst));
let np = mock
.last_num_predict
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
assert!(np.is_some());
assert!(
np.expect("checked above") <= 128,
"num_predict should be capped by output_max_tokens"
);
}
#[tokio::test]
async fn sampling_params_included_in_overrides() {
let mock = Arc::new(OverrideTrackingMock::new());
let config = TokenOptimizationConfig {
frequency_penalty: Some(1.2),
presence_penalty: Some(0.6),
..TokenOptimizationConfig::default()
};
let optimizer = Arc::new(TokenOptimizer::new(config));
let decorator = TokenOptimizedInferencePort::new(mock.clone(), optimizer);
let mut conv = Conversation::new();
conv.add_user_message("Hello");
let result = decorator.generate_with_context(&conv).await;
assert!(result.is_ok());
assert!(
mock.overrides_called.load(Ordering::SeqCst),
"Expected overrides to be called when sampling params are set"
);
let opts = mock
.last_options
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let opts = opts.as_ref().expect("options should be present");
let repeat_penalty = opts
.get("repeat_penalty")
.and_then(serde_json::Value::as_f64)
.expect("repeat_penalty should be set");
assert!(
(repeat_penalty - 1.2).abs() < 0.001,
"repeat_penalty should be ~1.2, got {repeat_penalty}"
);
let presence_penalty = opts
.get("presence_penalty")
.and_then(serde_json::Value::as_f64)
.expect("presence_penalty should be set");
assert!(
(presence_penalty - 0.6).abs() < 0.001,
"presence_penalty should be ~0.6, got {presence_penalty}"
);
}
#[tokio::test]
async fn no_overrides_when_nothing_configured() {
let mock = Arc::new(OverrideTrackingMock::new());
let config = TokenOptimizationConfig {
enabled: false,
..TokenOptimizationConfig::default()
};
let optimizer = Arc::new(TokenOptimizer::new(config));
let decorator = TokenOptimizedInferencePort::new(mock.clone(), optimizer);
let mut conv = Conversation::new();
conv.add_user_message("Hello");
let result = decorator.generate_with_context(&conv).await;
assert!(result.is_ok());
assert!(
!mock.overrides_called.load(Ordering::SeqCst),
"Disabled optimizer should not use overrides"
);
}
}