use std::collections::BTreeSet;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProviderId(pub String);
impl ProviderId {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<&str> for ProviderId {
fn from(value: &str) -> Self {
Self(value.to_string())
}
}
impl From<String> for ProviderId {
fn from(value: String) -> Self {
Self(value)
}
}
impl std::fmt::Display for ProviderId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum Capability {
Completion,
Tools,
Vision,
Embedding,
StructuredOutput,
Thinking,
ImageGen,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
#[non_exhaustive]
pub enum Quantization {
Q4KM,
Q4KS,
Q5KM,
Q8_0,
Fp16,
Bf16,
Other(String),
}
impl Quantization {
pub fn parse(label: &str) -> Self {
match label.trim().to_ascii_uppercase().as_str() {
"Q4_K_M" | "Q4KM" => Self::Q4KM,
"Q4_K_S" | "Q4KS" => Self::Q4KS,
"Q5_K_M" | "Q5KM" => Self::Q5KM,
"Q8_0" | "Q8" => Self::Q8_0,
"F16" | "FP16" => Self::Fp16,
"BF16" => Self::Bf16,
_ => Self::Other(label.to_string()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ModelDescriptor {
pub provider: ProviderId,
pub model: String,
pub context_window: Option<u64>,
pub max_output_tokens: Option<u64>,
pub capabilities: BTreeSet<Capability>,
pub family: Option<String>,
pub parameter_count: Option<u64>,
pub quantization: Option<Quantization>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub raw: Option<serde_json::Value>,
}
impl ModelDescriptor {
pub fn new(provider: impl Into<ProviderId>, model: impl Into<String>) -> Self {
Self {
provider: provider.into(),
model: model.into(),
context_window: None,
max_output_tokens: None,
capabilities: BTreeSet::new(),
family: None,
parameter_count: None,
quantization: None,
raw: None,
}
}
pub fn builder(
provider: impl Into<ProviderId>,
model: impl Into<String>,
) -> ModelDescriptorBuilder {
ModelDescriptorBuilder {
inner: Self::new(provider, model),
}
}
pub fn context_used_fraction(&self, input_tokens: u64) -> Option<f64> {
match self.context_window {
Some(window) if window > 0 => Some(input_tokens as f64 / window as f64),
_ => None,
}
}
pub fn has_capability(&self, cap: Capability) -> bool {
self.capabilities.contains(&cap)
}
}
#[derive(Debug, Clone)]
pub struct ModelDescriptorBuilder {
inner: ModelDescriptor,
}
impl ModelDescriptorBuilder {
pub fn context_window(mut self, tokens: u64) -> Self {
self.inner.context_window = Some(tokens);
self
}
pub fn max_output_tokens(mut self, tokens: u64) -> Self {
self.inner.max_output_tokens = Some(tokens);
self
}
pub fn capability(mut self, cap: Capability) -> Self {
self.inner.capabilities.insert(cap);
self
}
pub fn capabilities(mut self, caps: impl IntoIterator<Item = Capability>) -> Self {
self.inner.capabilities = caps.into_iter().collect();
self
}
pub fn family(mut self, family: impl Into<String>) -> Self {
self.inner.family = Some(family.into());
self
}
pub fn parameter_count(mut self, count: u64) -> Self {
self.inner.parameter_count = Some(count);
self
}
pub fn quantization(mut self, q: Quantization) -> Self {
self.inner.quantization = Some(q);
self
}
pub fn raw(mut self, raw: serde_json::Value) -> Self {
self.inner.raw = Some(raw);
self
}
pub fn build(self) -> ModelDescriptor {
self.inner
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn quantization_parses_common_labels() {
assert_eq!(Quantization::parse("Q4_K_M"), Quantization::Q4KM);
assert_eq!(Quantization::parse("q4_k_s"), Quantization::Q4KS);
assert_eq!(Quantization::parse("F16"), Quantization::Fp16);
match Quantization::parse("Q3_K_S") {
Quantization::Other(label) => assert_eq!(label, "Q3_K_S"),
other => panic!("expected Other, got {other:?}"),
}
}
#[test]
fn builder_roundtrip() {
let d = ModelDescriptor::builder("ollama", "qwen3.5:9b")
.context_window(128_000)
.capability(Capability::Completion)
.capability(Capability::Tools)
.family("llama")
.quantization(Quantization::Q4KM)
.build();
assert_eq!(d.context_window, Some(128_000));
assert!(d.has_capability(Capability::Completion));
assert!(d.has_capability(Capability::Tools));
assert!(!d.has_capability(Capability::Vision));
assert_eq!(d.context_used_fraction(64_000), Some(0.5));
}
#[test]
fn context_used_fraction_handles_zero_and_unknown() {
let mut d = ModelDescriptor::new("ollama", "x");
assert_eq!(d.context_used_fraction(100), None);
d.context_window = Some(0);
assert_eq!(d.context_used_fraction(100), None);
d.context_window = Some(1000);
assert_eq!(d.context_used_fraction(250), Some(0.25));
}
#[test]
fn descriptor_serde_roundtrip() {
let d = ModelDescriptor::builder("openai", "gpt-4o")
.context_window(128_000)
.max_output_tokens(16_384)
.capability(Capability::Completion)
.capability(Capability::Tools)
.capability(Capability::Vision)
.build();
let json = serde_json::to_string(&d).unwrap();
let back: ModelDescriptor = serde_json::from_str(&json).unwrap();
assert_eq!(d, back);
}
}