#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
pub use ort::session::builder::GraphOptimizationLevel;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RequestOptions {
temperature: f32,
min_p: f32,
repetition_penalty: f32,
max_new_tokens: usize,
}
impl RequestOptions {
pub const fn new() -> Self {
Self {
temperature: 0.1,
min_p: 0.15,
repetition_penalty: 1.05,
max_new_tokens: 512,
}
}
pub const fn deterministic() -> Self {
Self {
temperature: 0.0,
min_p: 0.0,
repetition_penalty: 1.05,
max_new_tokens: 512,
}
}
pub const fn temperature(&self) -> f32 {
self.temperature
}
pub const fn min_p(&self) -> f32 {
self.min_p
}
pub const fn repetition_penalty(&self) -> f32 {
self.repetition_penalty
}
pub const fn max_new_tokens(&self) -> usize {
self.max_new_tokens
}
pub const fn with_temperature(mut self, v: f32) -> Self {
self.temperature = v;
self
}
pub const fn with_min_p(mut self, v: f32) -> Self {
self.min_p = v;
self
}
pub const fn with_repetition_penalty(mut self, v: f32) -> Self {
self.repetition_penalty = v;
self
}
pub const fn with_max_new_tokens(mut self, v: usize) -> Self {
self.max_new_tokens = v;
self
}
pub fn set_temperature(&mut self, v: f32) -> &mut Self {
self.temperature = v;
self
}
pub fn set_min_p(&mut self, v: f32) -> &mut Self {
self.min_p = v;
self
}
pub fn set_repetition_penalty(&mut self, v: f32) -> &mut Self {
self.repetition_penalty = v;
self
}
pub fn set_max_new_tokens(&mut self, v: usize) -> &mut Self {
self.max_new_tokens = v;
self
}
pub const fn validate(&self) -> Result<()> {
if self.temperature.is_nan() || self.temperature.is_infinite() {
return Err(Error::InvalidRequest("temperature must be finite"));
}
if self.temperature < 0.0 {
return Err(Error::InvalidRequest("temperature must be >= 0.0"));
}
const MIN_TEMPERATURE: f32 = 1e-3;
if self.temperature > 0.0 && self.temperature < MIN_TEMPERATURE {
return Err(Error::InvalidRequest(
"temperature must be either exactly 0.0 (greedy) or >= 1e-3 (1/temp would overflow for smaller positive values, poisoning softmax with NaN/inf)",
));
}
if self.min_p.is_nan() || self.min_p.is_infinite() {
return Err(Error::InvalidRequest("min_p must be finite"));
}
if self.min_p < 0.0 || self.min_p > 1.0 {
return Err(Error::InvalidRequest("min_p must be in [0.0, 1.0]"));
}
if self.repetition_penalty.is_nan() || self.repetition_penalty.is_infinite() {
return Err(Error::InvalidRequest("repetition_penalty must be finite"));
}
if self.repetition_penalty < 1.0 {
return Err(Error::InvalidRequest("repetition_penalty must be >= 1.0"));
}
if self.repetition_penalty > MAX_REPETITION_PENALTY {
return Err(Error::InvalidRequest(
"repetition_penalty must be <= 100.0 (penalty × negative logit could otherwise overflow to -inf and poison sampling)",
));
}
if self.max_new_tokens == 0 {
return Err(Error::InvalidRequest("max_new_tokens must be > 0"));
}
if self.max_new_tokens > MAX_NEW_TOKENS_CAP {
return Err(Error::InvalidRequest(
"max_new_tokens must be <= 32768 (model context is 128K; this leaves headroom for prompt + image tokens and prevents OOM from oversized preallocation)",
));
}
Ok(())
}
}
pub const MAX_NEW_TOKENS_CAP: usize = 32_768;
pub const MAX_REPETITION_PENALTY: f32 = 100.0;
pub const MODEL_CONTEXT_TOKENS: usize = 128_000;
impl Default for RequestOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ImageBudget {
min_image_tokens: usize,
max_image_tokens: usize,
min_tiles: usize,
max_tiles: usize,
use_thumbnail: bool,
max_pixels_tolerance: f32,
}
impl ImageBudget {
pub const fn new() -> Self {
Self {
min_image_tokens: 64,
max_image_tokens: 256,
min_tiles: 2,
max_tiles: 10,
use_thumbnail: true,
max_pixels_tolerance: 2.0,
}
}
pub const fn fast() -> Self {
Self {
min_image_tokens: 32,
max_image_tokens: 64,
min_tiles: 2,
max_tiles: 4,
use_thumbnail: false,
max_pixels_tolerance: 2.0,
}
}
pub const fn quality() -> Self {
Self::new()
}
pub const fn min_image_tokens(&self) -> usize {
self.min_image_tokens
}
pub const fn max_image_tokens(&self) -> usize {
self.max_image_tokens
}
pub const fn min_tiles(&self) -> usize {
self.min_tiles
}
pub const fn max_tiles(&self) -> usize {
self.max_tiles
}
pub const fn use_thumbnail(&self) -> bool {
self.use_thumbnail
}
pub const fn max_pixels_tolerance(&self) -> f32 {
self.max_pixels_tolerance
}
pub const fn max_tokens_per_image(&self) -> usize {
const TOKENS_PER_FULL_TILE: usize = 256;
self.max_tiles * TOKENS_PER_FULL_TILE + self.max_image_tokens
}
pub const fn with_min_image_tokens(mut self, v: usize) -> Self {
self.min_image_tokens = v;
self
}
pub const fn with_max_image_tokens(mut self, v: usize) -> Self {
self.max_image_tokens = v;
self
}
pub const fn with_min_tiles(mut self, v: usize) -> Self {
self.min_tiles = v;
self
}
pub const fn with_max_tiles(mut self, v: usize) -> Self {
self.max_tiles = v;
self
}
pub const fn with_use_thumbnail(mut self, v: bool) -> Self {
self.use_thumbnail = v;
self
}
pub fn with_max_pixels_tolerance(mut self, v: f32) -> Self {
self.max_pixels_tolerance = v;
self
}
pub fn set_min_image_tokens(&mut self, v: usize) -> &mut Self {
self.min_image_tokens = v;
self
}
pub fn set_max_image_tokens(&mut self, v: usize) -> &mut Self {
self.max_image_tokens = v;
self
}
pub fn set_min_tiles(&mut self, v: usize) -> &mut Self {
self.min_tiles = v;
self
}
pub fn set_max_tiles(&mut self, v: usize) -> &mut Self {
self.max_tiles = v;
self
}
pub fn set_use_thumbnail(&mut self, v: bool) -> &mut Self {
self.use_thumbnail = v;
self
}
pub fn set_max_pixels_tolerance(&mut self, v: f32) -> &mut Self {
self.max_pixels_tolerance = v;
self
}
pub const fn validate(&self) -> Result<()> {
if self.min_image_tokens == 0 {
return Err(Error::InvalidBudget("min_image_tokens must be > 0"));
}
if self.max_image_tokens < self.min_image_tokens {
return Err(Error::InvalidBudget(
"max_image_tokens must be >= min_image_tokens",
));
}
if self.min_tiles == 0 {
return Err(Error::InvalidBudget("min_tiles must be > 0"));
}
if self.max_tiles < self.min_tiles {
return Err(Error::InvalidBudget("max_tiles must be >= min_tiles"));
}
if self.max_tiles > MAX_TOKENIZER_TILE_DIM {
return Err(Error::InvalidBudget(
"max_tiles must be <= 10 (bundled tokenizer's row/col marker grid is 10x10)",
));
}
if self.max_image_tokens > MAX_IMAGE_TOKENS_CAP {
return Err(Error::InvalidBudget(
"max_image_tokens must be <= 1024 (4× the model default; protects against unbounded smart_resize / pixel_values allocation)",
));
}
if !self.max_pixels_tolerance.is_finite() || self.max_pixels_tolerance <= 0.0 {
return Err(Error::InvalidBudget(
"max_pixels_tolerance must be a finite, positive f32 (NaN/Inf/<=0 reject)",
));
}
Ok(())
}
}
pub const MAX_TOKENIZER_TILE_DIM: usize = 10;
pub const MAX_IMAGE_TOKENS_CAP: usize = 1024;
impl Default for ImageBudget {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ThreadOptions {
intra_threads: Option<usize>,
inter_threads: Option<usize>,
}
impl ThreadOptions {
pub const fn new() -> Self {
Self {
intra_threads: None,
inter_threads: None,
}
}
pub const fn deterministic() -> Self {
Self {
intra_threads: Some(1),
inter_threads: Some(1),
}
}
pub const fn intra_threads(&self) -> Option<usize> {
self.intra_threads
}
pub const fn inter_threads(&self) -> Option<usize> {
self.inter_threads
}
pub const fn with_intra_threads(mut self, v: usize) -> Self {
self.intra_threads = Some(v);
self
}
pub const fn with_inter_threads(mut self, v: usize) -> Self {
self.inter_threads = Some(v);
self
}
pub fn set_intra_threads(&mut self, v: usize) -> &mut Self {
self.intra_threads = Some(v);
self
}
pub fn set_inter_threads(&mut self, v: usize) -> &mut Self {
self.inter_threads = Some(v);
self
}
}
impl Default for ThreadOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Options {
request: RequestOptions,
image_budget: ImageBudget,
thread: ThreadOptions,
#[cfg(feature = "inference")]
optimization_level: GraphOptLevelMirror,
}
impl Options {
pub const fn new() -> Self {
Self {
request: RequestOptions::deterministic(),
image_budget: ImageBudget::new(),
thread: ThreadOptions::new(),
#[cfg(feature = "inference")]
optimization_level: GraphOptLevelMirror::Level1,
}
}
pub const fn request(&self) -> &RequestOptions {
&self.request
}
pub const fn image_budget(&self) -> &ImageBudget {
&self.image_budget
}
pub const fn thread(&self) -> &ThreadOptions {
&self.thread
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
pub fn optimization_level(&self) -> GraphOptimizationLevel {
self.optimization_level.into()
}
pub const fn with_request(mut self, r: RequestOptions) -> Self {
self.request = r;
self
}
pub const fn with_image_budget(mut self, b: ImageBudget) -> Self {
self.image_budget = b;
self
}
pub const fn with_thread(mut self, t: ThreadOptions) -> Self {
self.thread = t;
self
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
pub fn with_optimization_level(mut self, lvl: GraphOptimizationLevel) -> Self {
self.optimization_level = lvl.into();
self
}
pub fn set_request(&mut self, r: RequestOptions) -> &mut Self {
self.request = r;
self
}
pub fn set_image_budget(&mut self, b: ImageBudget) -> &mut Self {
self.image_budget = b;
self
}
pub fn set_thread(&mut self, t: ThreadOptions) -> &mut Self {
self.thread = t;
self
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
pub fn set_optimization_level(&mut self, lvl: GraphOptimizationLevel) -> &mut Self {
self.optimization_level = lvl.into();
self
}
}
impl Default for Options {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "inference")]
#[cfg_attr(docsrs, doc(cfg(feature = "inference")))]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "snake_case"))]
enum GraphOptLevelMirror {
Disable,
Level1,
Level2,
Level3,
All,
}
#[cfg(feature = "inference")]
impl From<GraphOptimizationLevel> for GraphOptLevelMirror {
fn from(v: GraphOptimizationLevel) -> Self {
match v {
GraphOptimizationLevel::Disable => Self::Disable,
GraphOptimizationLevel::Level1 => Self::Level1,
GraphOptimizationLevel::Level2 => Self::Level2,
GraphOptimizationLevel::Level3 => Self::Level3,
GraphOptimizationLevel::All => Self::All,
}
}
}
#[cfg(feature = "inference")]
impl From<GraphOptLevelMirror> for GraphOptimizationLevel {
fn from(v: GraphOptLevelMirror) -> Self {
match v {
GraphOptLevelMirror::Disable => Self::Disable,
GraphOptLevelMirror::Level1 => Self::Level1,
GraphOptLevelMirror::Level2 => Self::Level2,
GraphOptLevelMirror::Level3 => Self::Level3,
GraphOptLevelMirror::All => Self::All,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_options_new_matches_model_card() {
let r = RequestOptions::new();
assert_eq!(r.temperature(), 0.1);
assert_eq!(r.min_p(), 0.15);
assert_eq!(r.repetition_penalty(), 1.05);
assert_eq!(r.max_new_tokens(), 512);
}
#[test]
fn request_options_deterministic_is_greedy() {
let r = RequestOptions::deterministic();
assert_eq!(r.temperature(), 0.0);
assert_eq!(r.repetition_penalty(), 1.05);
}
#[test]
fn request_options_validate_rejects_bad_inputs() {
assert!(
RequestOptions::new()
.with_max_new_tokens(0)
.validate()
.is_err()
);
assert!(
RequestOptions::new()
.with_temperature(-1.0)
.validate()
.is_err()
);
assert!(RequestOptions::new().with_min_p(2.0).validate().is_err());
assert!(
RequestOptions::new()
.with_repetition_penalty(0.5)
.validate()
.is_err()
);
}
#[test]
fn request_options_validate_rejects_non_finite() {
for nan_temp in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
assert!(
RequestOptions::new()
.with_temperature(nan_temp)
.validate()
.is_err(),
"temperature {nan_temp:?} must be rejected"
);
}
for nan_min_p in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
assert!(
RequestOptions::new()
.with_min_p(nan_min_p)
.validate()
.is_err(),
"min_p {nan_min_p:?} must be rejected"
);
}
for nan_rep in [f32::NAN, f32::INFINITY, f32::NEG_INFINITY] {
assert!(
RequestOptions::new()
.with_repetition_penalty(nan_rep)
.validate()
.is_err(),
"repetition_penalty {nan_rep:?} must be rejected"
);
}
}
#[test]
fn request_options_with_chains() {
let r = RequestOptions::new()
.with_temperature(0.3)
.with_min_p(0.05)
.with_repetition_penalty(1.10)
.with_max_new_tokens(1024);
assert_eq!(r.temperature(), 0.3);
assert_eq!(r.max_new_tokens(), 1024);
}
#[test]
fn request_options_validate_rejects_subnormal_positive_temperature() {
assert!(
RequestOptions::new()
.with_temperature(1e-40)
.validate()
.is_err()
);
assert!(
RequestOptions::new()
.with_temperature(1e-6)
.validate()
.is_err()
);
assert!(
RequestOptions::new()
.with_temperature(1e-3)
.validate()
.is_ok()
);
assert!(
RequestOptions::new()
.with_temperature(0.0)
.validate()
.is_ok()
);
}
#[test]
fn request_options_validate_caps_repetition_penalty() {
let r = RequestOptions::new().with_repetition_penalty(MAX_REPETITION_PENALTY + 0.001);
assert!(matches!(r.validate(), Err(Error::InvalidRequest(_))));
let r_at = RequestOptions::new().with_repetition_penalty(MAX_REPETITION_PENALTY);
assert!(r_at.validate().is_ok());
let r_max = RequestOptions::new().with_repetition_penalty(f32::MAX);
assert!(matches!(r_max.validate(), Err(Error::InvalidRequest(_))));
}
#[test]
fn request_options_validate_caps_max_new_tokens() {
let r = RequestOptions::new().with_max_new_tokens(MAX_NEW_TOKENS_CAP + 1);
assert!(matches!(r.validate(), Err(Error::InvalidRequest(_))));
let r_ok = RequestOptions::new().with_max_new_tokens(MAX_NEW_TOKENS_CAP);
assert!(r_ok.validate().is_ok());
}
#[test]
fn image_budget_new_matches_preprocessor_config() {
let b = ImageBudget::new();
assert_eq!(b.min_image_tokens(), 64);
assert_eq!(b.max_image_tokens(), 256);
assert_eq!(b.min_tiles(), 2);
assert_eq!(b.max_tiles(), 10);
assert!(b.use_thumbnail());
}
#[test]
fn image_budget_fast_is_smaller() {
let f = ImageBudget::fast();
assert!(f.max_image_tokens() < ImageBudget::new().max_image_tokens());
assert!(!f.use_thumbnail());
}
#[test]
fn image_budget_validate_rejects_bad_inputs() {
let mut b = ImageBudget::new();
b.set_min_image_tokens(0);
assert!(b.validate().is_err());
let mut b2 = ImageBudget::new();
b2.set_max_image_tokens(b2.min_image_tokens() - 1);
assert!(b2.validate().is_err());
}
#[test]
fn image_budget_max_tokens_per_image_default() {
assert_eq!(ImageBudget::new().max_tokens_per_image(), 2816);
}
#[test]
fn image_budget_max_tokens_per_image_fast() {
assert_eq!(ImageBudget::fast().max_tokens_per_image(), 1088);
}
#[test]
fn image_budget_validate_caps_max_image_tokens() {
let mut b = ImageBudget::new();
b.set_max_image_tokens(MAX_IMAGE_TOKENS_CAP + 1);
assert!(matches!(b.validate(), Err(Error::InvalidBudget(_))));
let mut b_ok = ImageBudget::new();
b_ok.set_max_image_tokens(MAX_IMAGE_TOKENS_CAP);
assert!(b_ok.validate().is_ok());
let mut b_min = ImageBudget::new();
b_min.set_min_image_tokens(MAX_IMAGE_TOKENS_CAP + 1);
b_min.set_max_image_tokens(MAX_IMAGE_TOKENS_CAP + 1);
assert!(matches!(b_min.validate(), Err(Error::InvalidBudget(_))));
}
#[test]
fn image_budget_validate_rejects_max_tiles_above_tokenizer_grid() {
let mut b = ImageBudget::new();
b.set_max_tiles(MAX_TOKENIZER_TILE_DIM + 1);
assert!(b.validate().is_err());
let mut b2 = ImageBudget::new();
b2.set_max_tiles(MAX_TOKENIZER_TILE_DIM);
assert!(b2.validate().is_ok());
}
#[test]
fn image_budget_max_pixels_tolerance_round_trip() {
let b = ImageBudget::new().with_max_pixels_tolerance(2.5);
assert_eq!(b.max_pixels_tolerance(), 2.5);
let mut b2 = ImageBudget::new();
b2.set_max_pixels_tolerance(1.75);
assert_eq!(b2.max_pixels_tolerance(), 1.75);
assert_eq!(ImageBudget::new().max_pixels_tolerance(), 2.0);
}
#[test]
fn image_budget_max_pixels_tolerance_preserves_sub_hundredth_precision() {
let b = ImageBudget::new().with_max_pixels_tolerance(2.067);
assert_eq!(b.max_pixels_tolerance(), 2.067);
}
#[test]
fn image_budget_validate_rejects_non_finite_tolerance() {
let mut b = ImageBudget::new();
b.set_max_pixels_tolerance(f32::NAN);
assert!(b.validate().is_err());
b.set_max_pixels_tolerance(f32::INFINITY);
assert!(b.validate().is_err());
b.set_max_pixels_tolerance(0.0);
assert!(b.validate().is_err());
b.set_max_pixels_tolerance(-1.0);
assert!(b.validate().is_err());
}
#[test]
fn options_are_send_sync_copy_or_clone() {
fn req<T: Send + Sync>() {}
req::<RequestOptions>();
req::<ImageBudget>();
req::<ThreadOptions>();
req::<Options>();
}
}