use crate::runtime::inference::context_message::ContextMessage;
use crate::runtime::inference::transform::InferenceRequestTransform;
use crate::runtime::tool_call::ToolDescriptor;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ContextCompactionMode {
#[default]
KeepRecentRawSuffix,
CompactToSafeFrontier,
}
const fn default_compaction_raw_suffix_messages() -> usize {
2
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextWindowPolicy {
pub max_context_tokens: usize,
pub max_output_tokens: usize,
pub min_recent_messages: usize,
pub enable_prompt_cache: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub autocompact_threshold: Option<usize>,
#[serde(default)]
pub compaction_mode: ContextCompactionMode,
#[serde(default = "default_compaction_raw_suffix_messages")]
pub compaction_raw_suffix_messages: usize,
}
impl Default for ContextWindowPolicy {
fn default() -> Self {
Self {
max_context_tokens: 200_000,
max_output_tokens: 16_384,
min_recent_messages: 10,
enable_prompt_cache: true,
autocompact_threshold: None,
compaction_mode: ContextCompactionMode::KeepRecentRawSuffix,
compaction_raw_suffix_messages: default_compaction_raw_suffix_messages(),
}
}
}
#[derive(Debug, Clone)]
pub struct InferenceModelOverride {
pub model: String,
pub fallback_models: Vec<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ReasoningEffort {
None,
Low,
Medium,
High,
Max,
Budget(u32),
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct InferenceOverride {
pub model: Option<String>,
pub fallback_models: Option<Vec<String>>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub top_p: Option<f64>,
pub reasoning_effort: Option<ReasoningEffort>,
}
impl InferenceOverride {
pub fn merge(&mut self, other: InferenceOverride) {
if other.model.is_some() {
self.model = other.model;
}
if other.fallback_models.is_some() {
self.fallback_models = other.fallback_models;
}
if other.temperature.is_some() {
self.temperature = other.temperature;
}
if other.max_tokens.is_some() {
self.max_tokens = other.max_tokens;
}
if other.top_p.is_some() {
self.top_p = other.top_p;
}
if other.reasoning_effort.is_some() {
self.reasoning_effort = other.reasoning_effort;
}
}
}
impl From<InferenceModelOverride> for InferenceOverride {
fn from(m: InferenceModelOverride) -> Self {
Self {
model: Some(m.model),
fallback_models: if m.fallback_models.is_empty() {
None
} else {
Some(m.fallback_models)
},
..Default::default()
}
}
}
#[derive(Default, Clone)]
pub struct InferenceContext {
pub tools: Vec<ToolDescriptor>,
pub request_transforms: Vec<Arc<dyn InferenceRequestTransform>>,
pub inference_override: Option<InferenceOverride>,
pub context_messages: Vec<ContextMessage>,
}
impl std::fmt::Debug for InferenceContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InferenceContext")
.field("tools", &self.tools)
.field("request_transforms", &self.request_transforms.len())
.field("inference_override", &self.inference_override)
.field("context_messages", &self.context_messages.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn default_policy_uses_suffix_compaction_defaults() {
let policy = ContextWindowPolicy::default();
assert_eq!(
policy.compaction_mode,
ContextCompactionMode::KeepRecentRawSuffix
);
assert_eq!(policy.compaction_raw_suffix_messages, 2);
}
#[test]
fn policy_deserialization_backfills_new_compaction_fields() {
let value = json!({
"max_context_tokens": 4096,
"max_output_tokens": 512,
"min_recent_messages": 4,
"enable_prompt_cache": false,
"autocompact_threshold": 2048
});
let policy: ContextWindowPolicy = serde_json::from_value(value).unwrap();
assert_eq!(
policy.compaction_mode,
ContextCompactionMode::KeepRecentRawSuffix
);
assert_eq!(policy.compaction_raw_suffix_messages, 2);
}
#[test]
fn policy_serialization_roundtrip_preserves_frontier_mode() {
let policy = ContextWindowPolicy {
max_context_tokens: 8192,
max_output_tokens: 1024,
min_recent_messages: 6,
enable_prompt_cache: false,
autocompact_threshold: Some(4096),
compaction_mode: ContextCompactionMode::CompactToSafeFrontier,
compaction_raw_suffix_messages: 5,
};
let encoded = serde_json::to_value(&policy).unwrap();
assert_eq!(encoded["compaction_mode"], "compact_to_safe_frontier");
assert_eq!(encoded["compaction_raw_suffix_messages"], 5);
let restored: ContextWindowPolicy = serde_json::from_value(encoded).unwrap();
assert_eq!(
restored.compaction_mode,
ContextCompactionMode::CompactToSafeFrontier
);
assert_eq!(restored.compaction_raw_suffix_messages, 5);
}
#[test]
fn inference_override_merge_last_wins() {
let mut base = InferenceOverride {
model: Some("model-a".into()),
temperature: Some(0.5),
..Default::default()
};
base.merge(InferenceOverride {
model: Some("model-b".into()),
reasoning_effort: Some(ReasoningEffort::High),
..Default::default()
});
assert_eq!(base.model.as_deref(), Some("model-b"));
assert_eq!(base.temperature, Some(0.5)); assert_eq!(base.reasoning_effort, Some(ReasoningEffort::High));
}
#[test]
fn inference_override_merge_none_preserves_existing() {
let mut base = InferenceOverride {
max_tokens: Some(1024),
top_p: Some(0.9),
..Default::default()
};
base.merge(InferenceOverride::default());
assert_eq!(base.max_tokens, Some(1024));
assert_eq!(base.top_p, Some(0.9));
}
#[test]
fn from_model_override_converts_correctly() {
let model_ovr = InferenceModelOverride {
model: "claude-sonnet".into(),
fallback_models: vec!["claude-haiku".into()],
};
let ovr: InferenceOverride = model_ovr.into();
assert_eq!(ovr.model.as_deref(), Some("claude-sonnet"));
assert_eq!(ovr.fallback_models, Some(vec!["claude-haiku".into()]));
assert!(ovr.temperature.is_none());
}
#[test]
fn from_model_override_empty_fallbacks() {
let model_ovr = InferenceModelOverride {
model: "gpt-4o".into(),
fallback_models: vec![],
};
let ovr: InferenceOverride = model_ovr.into();
assert!(ovr.fallback_models.is_none());
}
#[test]
fn reasoning_effort_serde_roundtrip() {
let effort = ReasoningEffort::Budget(4096);
let json = serde_json::to_string(&effort).unwrap();
let restored: ReasoningEffort = serde_json::from_str(&json).unwrap();
assert_eq!(restored, effort);
}
}