use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelCapacity {
pub context_window: u32,
pub max_output_tokens: u32,
}
impl ModelCapacity {
pub const UNKNOWN: Self = Self {
context_window: 0,
max_output_tokens: 0,
};
pub fn new(context_window: u32, max_output_tokens: u32) -> Self {
Self {
context_window,
max_output_tokens,
}
}
pub fn from_metadata_model_entry(value: &serde_json::Value) -> Self {
let context_window = value
.get("context_window")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let max_output_tokens = value
.get("max_output_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
Self::new(context_window, max_output_tokens)
}
pub fn context_window_is_unknown(&self) -> bool {
self.context_window == 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct ContextBudget {
pub max_input_tokens: u32,
pub reserve_output_tokens: u32,
pub min_tail_messages: usize,
}
impl ContextBudget {
pub const DEFAULT_FALLBACK_INPUT: u32 = 24_000;
pub const DEFAULT_FALLBACK_OUTPUT_RESERVE: u32 = 4_096;
pub fn new(
max_input_tokens: u32,
reserve_output_tokens: u32,
min_tail_messages: usize,
) -> Self {
Self {
max_input_tokens,
reserve_output_tokens,
min_tail_messages,
}
}
pub fn from_capacity(capacity: ModelCapacity, min_tail_messages: usize) -> Self {
let reserve = if capacity.max_output_tokens > 0 {
capacity.max_output_tokens
} else {
Self::DEFAULT_FALLBACK_OUTPUT_RESERVE
};
let max_input = if capacity.context_window > reserve {
capacity.context_window - reserve
} else if capacity.context_window > 0 {
capacity.context_window
} else {
Self::DEFAULT_FALLBACK_INPUT
};
Self {
max_input_tokens: max_input,
reserve_output_tokens: reserve,
min_tail_messages,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn from_metadata_model_entry_parses_pt075_fixture() {
let entry = json!({
"context_window": 128000,
"max_output_tokens": 16384,
"status": "active"
});
let cap = ModelCapacity::from_metadata_model_entry(&entry);
assert_eq!(cap.context_window, 128_000);
assert_eq!(cap.max_output_tokens, 16_384);
let budget = ContextBudget::from_capacity(cap, 2);
assert_eq!(budget.max_input_tokens, 128_000 - 16_384);
}
}