use alloc::{string::String, vec::Vec};
use schemars::Schema;
#[derive(Debug, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Parameters {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub repetition_penalty: Option<f32>,
pub min_p: Option<f32>,
pub top_a: Option<f32>,
pub seed: Option<u32>,
pub max_tokens: Option<u32>,
pub logit_bias: Option<Vec<(String, f32)>>,
pub logprobs: Option<bool>,
pub top_logprobs: Option<u8>,
pub stop: Option<Vec<String>>,
pub tool_choice: Option<Vec<String>>,
pub structured_outputs: bool,
pub response_format: Option<Schema>,
}
macro_rules! impl_with_methods {
(
impl $ty:ty {
$($field:ident : $field_ty:ty),* $(,)?
}
) => {
impl $ty {
$(
#[allow(clippy::missing_const_for_fn)]
#[must_use] pub fn $field(mut self, value: $field_ty) -> Self {
self.$field = Some(value);
self
}
)*
}
};
}
impl_with_methods! {
impl Parameters {
temperature: f32,
top_p: f32,
top_k: u32,
frequency_penalty: f32,
presence_penalty: f32,
repetition_penalty: f32,
min_p: f32,
top_a: f32,
seed: u32,
max_tokens: u32,
logit_bias: Vec<(String, f32)>,
logprobs: bool,
top_logprobs: u8,
stop: Vec<String>,
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Profile {
pub name: String,
pub author: String,
pub slug: String,
pub description: String,
pub abilities: Vec<Ability>,
pub context_length: u32,
pub pricing: Option<Pricing>,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct Pricing {
pub prompt: f64,
pub completion: f64,
pub request: f64,
pub image: f64,
pub web_search: f64,
pub internal_reasoning: f64,
pub input_cache_read: f64,
pub input_cache_write: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[allow(clippy::struct_excessive_bools)]
#[non_exhaustive]
pub struct SupportedParameters {
pub max_tokens: bool,
pub temperature: bool,
pub top_p: bool,
pub reasoning: bool,
pub include_reasoning: bool,
pub structured_outputs: bool,
pub response_format: bool,
pub stop: bool,
pub frequency_penalty: bool,
pub presence_penalty: bool,
pub seed: bool,
}
impl Profile {
pub fn new(
name: impl Into<String>,
author: impl Into<String>,
slug: impl Into<String>,
description: impl Into<String>,
context_length: u32,
) -> Self {
Self {
name: name.into(),
author: author.into(),
slug: slug.into(),
description: description.into(),
abilities: Vec::new(),
context_length,
pricing: None,
}
}
#[must_use]
pub fn with_ability(self, ability: Ability) -> Self {
self.with_abilities([ability])
}
#[must_use]
pub fn with_abilities(mut self, abilities: impl IntoIterator<Item = Ability>) -> Self {
self.abilities.extend(abilities);
self
}
#[must_use]
pub const fn with_pricing(mut self, pricing: Pricing) -> Self {
self.pricing = Some(pricing);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Ability {
ToolUse,
Vision,
Audio,
WebSearch,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn profile_creation() {
let profile = Profile::new("Test model", "test", "test-model", "A test model", 4096);
assert_eq!(profile.name, "test-model");
assert_eq!(profile.description, "A test model");
assert_eq!(profile.context_length, 4096);
assert!(profile.abilities.is_empty());
assert!(profile.pricing.is_none());
}
#[test]
fn profile_with_single_ability() {
let profile = Profile::new(
"Test vision model",
"test",
"vision-model",
"A vision model",
8192,
)
.with_ability(Ability::Vision);
assert_eq!(profile.abilities.len(), 1);
assert_eq!(profile.abilities[0], Ability::Vision);
}
#[test]
fn profile_with_multiple_abilities() {
let abilities = [Ability::ToolUse, Ability::Vision, Ability::Audio];
let profile = Profile::new(
"Test",
"test",
"multimodal-model",
"A multimodal model",
16384,
)
.with_abilities(abilities);
assert_eq!(profile.abilities.len(), 3);
assert_eq!(profile.abilities, abilities);
}
#[test]
#[allow(clippy::float_cmp)]
fn profile_with_pricing() {
let pricing = Pricing {
prompt: 0.0001,
completion: 0.0002,
request: 0.001,
image: 0.01,
web_search: 0.005,
internal_reasoning: 0.0003,
input_cache_read: 0.00005,
input_cache_write: 0.0001,
};
let profile = Profile::new(
"Test paid model",
"test",
"paid-model",
"A paid model",
2048,
)
.with_pricing(pricing);
assert!(profile.pricing.is_some());
let profile_pricing = profile.pricing.unwrap();
assert_eq!(profile_pricing.prompt, 0.0001);
assert_eq!(profile_pricing.completion, 0.0002);
assert_eq!(profile_pricing.request, 0.001);
assert_eq!(profile_pricing.image, 0.01);
assert_eq!(profile_pricing.web_search, 0.005);
assert_eq!(profile_pricing.internal_reasoning, 0.0003);
assert_eq!(profile_pricing.input_cache_read, 0.00005);
assert_eq!(profile_pricing.input_cache_write, 0.0001);
}
#[test]
fn profile_builder_pattern() {
let pricing = Pricing {
prompt: 0.001,
completion: 0.002,
request: 0.01,
image: 0.1,
web_search: 0.05,
internal_reasoning: 0.003,
input_cache_read: 0.0005,
input_cache_write: 0.001,
};
let profile = Profile::new("Test", "test", "full-model", "A full-featured model", 32768)
.with_ability(Ability::ToolUse)
.with_ability(Ability::Vision)
.with_abilities([Ability::Audio, Ability::WebSearch])
.with_pricing(pricing);
assert_eq!(profile.name, "full-model");
assert_eq!(profile.description, "A full-featured model");
assert_eq!(profile.context_length, 32768);
assert_eq!(profile.abilities.len(), 4);
assert!(profile.abilities.contains(&Ability::ToolUse));
assert!(profile.abilities.contains(&Ability::Vision));
assert!(profile.abilities.contains(&Ability::Audio));
assert!(profile.abilities.contains(&Ability::WebSearch));
assert!(profile.pricing.is_some());
}
#[test]
fn ability_equality() {
assert_eq!(Ability::ToolUse, Ability::ToolUse);
assert_eq!(Ability::Vision, Ability::Vision);
assert_eq!(Ability::Audio, Ability::Audio);
assert_eq!(Ability::WebSearch, Ability::WebSearch);
assert_ne!(Ability::ToolUse, Ability::Vision);
assert_ne!(Ability::Audio, Ability::WebSearch);
}
#[test]
fn ability_debug() {
let ability = Ability::ToolUse;
let debug_str = alloc::format!("{ability:?}");
assert!(debug_str.contains("ToolUse"));
}
#[test]
fn profile_debug() {
let profile = Profile::new("Test model", "test", "debug-model", "A debug model", 1024);
let debug_str = alloc::format!("{profile:?}");
assert!(debug_str.contains("debug-model"));
assert!(debug_str.contains("A debug model"));
assert!(debug_str.contains("1024"));
}
#[test]
fn profile_clone() {
let original = Profile::new("Test model", "test", "original", "Original model", 2048)
.with_ability(Ability::Vision);
let cloned = original.clone();
assert_eq!(original.name, cloned.name);
assert_eq!(original.description, cloned.description);
assert_eq!(original.context_length, cloned.context_length);
assert_eq!(original.abilities, cloned.abilities);
}
#[test]
fn pricing_debug() {
let pricing = Pricing {
prompt: 0.001,
completion: 0.002,
request: 0.01,
image: 0.1,
web_search: 0.05,
internal_reasoning: 0.003,
input_cache_read: 0.0005,
input_cache_write: 0.001,
};
let debug_str = alloc::format!("{pricing:?}");
assert!(debug_str.contains("0.001"));
assert!(debug_str.contains("0.002"));
}
#[test]
#[allow(clippy::float_cmp)]
fn pricing_clone() {
let original = Pricing {
prompt: 0.001,
completion: 0.002,
request: 0.01,
image: 0.1,
web_search: 0.05,
internal_reasoning: 0.003,
input_cache_read: 0.0005,
input_cache_write: 0.001,
};
let cloned = original.clone();
assert_eq!(original.prompt, cloned.prompt);
assert_eq!(original.completion, cloned.completion);
assert_eq!(original.request, cloned.request);
assert_eq!(original.image, cloned.image);
assert_eq!(original.web_search, cloned.web_search);
assert_eq!(original.internal_reasoning, cloned.internal_reasoning);
assert_eq!(original.input_cache_read, cloned.input_cache_read);
assert_eq!(original.input_cache_write, cloned.input_cache_write);
}
#[test]
fn pricing_equality() {
let pricing1 = Pricing {
prompt: 0.001,
completion: 0.002,
request: 0.01,
image: 0.1,
web_search: 0.05,
internal_reasoning: 0.003,
input_cache_read: 0.0005,
input_cache_write: 0.001,
};
let pricing2 = Pricing {
prompt: 0.001,
completion: 0.002,
request: 0.01,
image: 0.1,
web_search: 0.05,
internal_reasoning: 0.003,
input_cache_read: 0.0005,
input_cache_write: 0.001,
};
let pricing3 = Pricing {
prompt: 0.002, completion: 0.002,
request: 0.01,
image: 0.1,
web_search: 0.05,
internal_reasoning: 0.003,
input_cache_read: 0.0005,
input_cache_write: 0.001,
};
assert_eq!(pricing1, pricing2);
assert_ne!(pricing1, pricing3);
}
#[test]
fn supported_parameters() {
let params = SupportedParameters {
max_tokens: true,
temperature: true,
top_p: false,
structured_outputs: true,
stop: true,
presence_penalty: true,
..Default::default()
};
assert!(params.max_tokens);
assert!(params.temperature);
assert!(!params.top_p);
}
#[test]
fn parameters_debug() {
let params = Parameters::default()
.temperature(0.7)
.top_p(0.9)
.top_k(40)
.seed(42)
.max_tokens(1000);
let debug_str = alloc::format!("{params:?}");
assert!(debug_str.contains("0.7"));
assert!(debug_str.contains("42"));
assert!(debug_str.contains("1000"));
}
}