use serde::{Deserialize, Serialize};
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ThinkingDisplay {
Summarized,
Omitted,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ThinkingConfig {
Enabled {
budget_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
display: Option<ThinkingDisplay>,
},
Adaptive {
#[serde(skip_serializing_if = "Option::is_none")]
display: Option<ThinkingDisplay>,
},
Disabled,
}
impl ThinkingConfig {
pub fn num_tokens(&self) -> u32 {
match self {
ThinkingConfig::Enabled { budget_tokens, .. } => *budget_tokens,
ThinkingConfig::Adaptive { .. } | ThinkingConfig::Disabled => 0,
}
}
pub fn enabled(budget_tokens: u32) -> Self {
Self::Enabled { budget_tokens, display: None }
}
pub fn adaptive() -> Self {
Self::Adaptive { display: None }
}
pub fn disabled() -> Self {
Self::Disabled
}
}
impl Default for ThinkingConfig {
fn default() -> Self {
Self::disabled()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{json, to_value};
#[test]
fn enabled_serialization() {
let config = ThinkingConfig::enabled(2048);
assert_eq!(to_value(config).unwrap(), json!({"type": "enabled", "budget_tokens": 2048}));
}
#[test]
fn enabled_with_display() {
let config = ThinkingConfig::Enabled {
budget_tokens: 4096,
display: Some(ThinkingDisplay::Omitted),
};
assert_eq!(
to_value(config).unwrap(),
json!({"type": "enabled", "budget_tokens": 4096, "display": "omitted"})
);
}
#[test]
fn adaptive_serialization() {
let config = ThinkingConfig::adaptive();
assert_eq!(to_value(config).unwrap(), json!({"type": "adaptive"}));
}
#[test]
fn adaptive_with_display() {
let config = ThinkingConfig::Adaptive { display: Some(ThinkingDisplay::Omitted) };
assert_eq!(to_value(config).unwrap(), json!({"type": "adaptive", "display": "omitted"}));
}
#[test]
fn disabled_serialization() {
assert_eq!(to_value(ThinkingConfig::disabled()).unwrap(), json!({"type": "disabled"}));
}
#[test]
fn enabled_deserialization() {
let config: ThinkingConfig =
serde_json::from_value(json!({"type": "enabled", "budget_tokens": 2048})).unwrap();
assert_eq!(config, ThinkingConfig::Enabled { budget_tokens: 2048, display: None });
}
#[test]
fn adaptive_deserialization() {
let config: ThinkingConfig =
serde_json::from_value(json!({"type": "adaptive", "display": "omitted"})).unwrap();
assert_eq!(config, ThinkingConfig::Adaptive { display: Some(ThinkingDisplay::Omitted) });
}
#[test]
fn disabled_deserialization() {
let config: ThinkingConfig = serde_json::from_value(json!({"type": "disabled"})).unwrap();
assert_eq!(config, ThinkingConfig::Disabled);
}
#[test]
fn num_tokens() {
assert_eq!(ThinkingConfig::enabled(8192).num_tokens(), 8192);
assert_eq!(ThinkingConfig::adaptive().num_tokens(), 0);
assert_eq!(ThinkingConfig::disabled().num_tokens(), 0);
}
}