use serde::{Deserialize, Serialize};
use crate::schema::ModelCapability;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TaskHint {
Chat,
Classify,
Reasoning,
Code,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IntentHint {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub task: Option<TaskHint>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub require: Vec<ModelCapability>,
#[serde(default, skip_serializing_if = "is_false")]
pub prefer_local: bool,
#[serde(default, skip_serializing_if = "is_false")]
pub prefer_fast: bool,
}
fn is_false(b: &bool) -> bool {
!*b
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UseCaseRole {
Generative,
Retrieval,
Audio,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UseCase {
Assistant,
Coding,
Summarize,
Vision,
Transcription,
Search,
}
impl UseCase {
pub fn role(self) -> UseCaseRole {
match self {
UseCase::Assistant
| UseCase::Coding
| UseCase::Summarize
| UseCase::Vision => UseCaseRole::Generative,
UseCase::Search => UseCaseRole::Retrieval,
UseCase::Transcription => UseCaseRole::Audio,
}
}
pub fn required_capabilities(self) -> &'static [ModelCapability] {
use ModelCapability::*;
match self {
UseCase::Assistant => &[Generate],
UseCase::Coding => &[Generate, Code],
UseCase::Summarize => &[Generate],
UseCase::Vision => &[Vision, Generate],
UseCase::Transcription => &[SpeechToText],
UseCase::Search => &[Embed],
}
}
pub fn preferred_capabilities(self) -> &'static [ModelCapability] {
use ModelCapability::*;
match self {
UseCase::Assistant => &[ToolUse],
UseCase::Coding => &[ToolUse, Reasoning],
UseCase::Summarize => &[Summarize],
UseCase::Vision => &[],
UseCase::Transcription => &[],
UseCase::Search => &[Rerank],
}
}
}
impl Default for UseCase {
fn default() -> Self {
UseCase::Assistant
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum QualityTier {
Fastest,
Balanced,
MostCapable,
}
impl Default for QualityTier {
fn default() -> Self {
QualityTier::Balanced
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TierWeights {
pub quality: f32,
pub latency: f32,
pub memory_pressure: f32,
}
impl QualityTier {
pub fn weights(self) -> TierWeights {
match self {
QualityTier::Fastest => TierWeights {
quality: 0.2,
latency: 0.6,
memory_pressure: 0.2,
},
QualityTier::Balanced => TierWeights {
quality: 0.5,
latency: 0.2,
memory_pressure: 0.3,
},
QualityTier::MostCapable => TierWeights {
quality: 0.8,
latency: 0.0,
memory_pressure: 0.2,
},
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Privacy {
OnDevice,
CloudOk,
}
impl Default for Privacy {
fn default() -> Self {
Privacy::OnDevice
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_intent_serializes_compactly() {
let hint = IntentHint::default();
let json = serde_json::to_string(&hint).unwrap();
assert_eq!(json, "{}");
}
#[test]
fn round_trip_with_capability_require() {
let hint = IntentHint {
task: Some(TaskHint::Code),
require: vec![ModelCapability::Code, ModelCapability::ToolUse],
prefer_local: true,
prefer_fast: false,
};
let json = serde_json::to_string(&hint).unwrap();
let back: IntentHint = serde_json::from_str(&json).unwrap();
assert_eq!(back.task, Some(TaskHint::Code));
assert_eq!(
back.require,
vec![ModelCapability::Code, ModelCapability::ToolUse]
);
assert!(back.prefer_local);
assert!(!back.prefer_fast);
}
#[test]
fn missing_fields_default_cleanly() {
let hint: IntentHint = serde_json::from_str("{}").unwrap();
assert_eq!(hint.task, None);
assert!(hint.require.is_empty());
assert!(!hint.prefer_local);
assert!(!hint.prefer_fast);
}
#[test]
fn prefer_fast_round_trips_and_skips_when_false() {
let off = IntentHint::default();
assert_eq!(serde_json::to_string(&off).unwrap(), "{}");
let on = IntentHint {
prefer_fast: true,
..IntentHint::default()
};
let json = serde_json::to_string(&on).unwrap();
assert!(json.contains("prefer_fast"));
let back: IntentHint = serde_json::from_str(&json).unwrap();
assert!(back.prefer_fast);
}
#[test]
fn use_case_defaults_to_assistant_and_balanced_on_device() {
assert_eq!(UseCase::default(), UseCase::Assistant);
assert_eq!(QualityTier::default(), QualityTier::Balanced);
assert_eq!(Privacy::default(), Privacy::OnDevice);
}
#[test]
fn coding_requires_both_generate_and_code() {
let req = UseCase::Coding.required_capabilities();
assert!(req.contains(&ModelCapability::Generate));
assert!(req.contains(&ModelCapability::Code));
}
#[test]
fn search_is_a_retrieval_role_not_generative() {
assert_eq!(UseCase::Search.role(), UseCaseRole::Retrieval);
assert_eq!(UseCase::Assistant.role(), UseCaseRole::Generative);
assert_eq!(UseCase::Transcription.role(), UseCaseRole::Audio);
assert_eq!(
UseCase::Search.required_capabilities(),
&[ModelCapability::Embed]
);
}
#[test]
fn required_and_preferred_are_disjoint() {
for uc in [
UseCase::Assistant,
UseCase::Coding,
UseCase::Summarize,
UseCase::Vision,
UseCase::Transcription,
UseCase::Search,
] {
for p in uc.preferred_capabilities() {
assert!(
!uc.required_capabilities().contains(p),
"{uc:?}: {p:?} is both required and preferred"
);
}
}
}
#[test]
fn tier_weights_match_documented_table() {
let b = QualityTier::Balanced.weights();
assert_eq!((b.quality, b.latency, b.memory_pressure), (0.5, 0.2, 0.3));
let f = QualityTier::Fastest.weights();
assert!(f.latency > f.quality, "Fastest must favor latency");
let c = QualityTier::MostCapable.weights();
assert!(c.quality > c.latency, "MostCapable must favor quality");
}
#[test]
fn tier_weights_are_non_negative_and_sum_to_one() {
for tier in [
QualityTier::Fastest,
QualityTier::Balanced,
QualityTier::MostCapable,
] {
let w = tier.weights();
for axis in [w.quality, w.latency, w.memory_pressure] {
assert!(axis >= 0.0, "{tier:?}: negative weight {axis}");
}
let sum = w.quality + w.latency + w.memory_pressure;
assert!(
(sum - 1.0).abs() < 1e-6,
"{tier:?}: weights sum to {sum}, expected 1.0"
);
}
}
#[test]
fn use_case_round_trips_snake_case() {
let json = serde_json::to_string(&UseCase::Coding).unwrap();
assert_eq!(json, "\"coding\"");
let back: UseCase = serde_json::from_str("\"search\"").unwrap();
assert_eq!(back, UseCase::Search);
}
}