use std::collections::HashMap;
use llm_tokenizer::TokenizerTrait;
use serde_json::Value;
use thiserror::Error;
use crate::{
types::{FieldLayout, ImageSize, Modality, PromptReplacement, TokenId},
vision::image_processor::PreprocessedImages,
};
#[derive(Debug, Error)]
pub enum ModelRegistryError {
#[error("unsupported model: {0}")]
UnsupportedModel(String),
#[error("token '{token}' not found in tokenizer vocabulary")]
TokenNotFound { token: String },
#[error("missing config field '{field}'")]
MissingConfigField { field: String },
}
pub type RegistryResult<T> = Result<T, ModelRegistryError>;
pub struct ModelMetadata<'a> {
pub model_id: &'a str,
pub tokenizer: &'a dyn TokenizerTrait,
pub config: &'a Value,
}
impl<'a> ModelMetadata<'a> {
pub fn token_id(&self, token: &str) -> RegistryResult<TokenId> {
self.tokenizer
.token_to_id(token)
.map(|id| id as TokenId)
.ok_or_else(|| ModelRegistryError::TokenNotFound {
token: token.to_string(),
})
}
pub fn config_u32(&self, path: &[&str]) -> Option<u32> {
Self::find_value(self.config, path).and_then(|value| value.as_u64().map(|v| v as u32))
}
fn find_value<'v>(value: &'v Value, path: &[&str]) -> Option<&'v Value> {
let mut current = value;
for key in path {
current = current.get(*key)?;
}
Some(current)
}
}
pub trait ModelProcessorSpec: Send + Sync {
fn name(&self) -> &'static str;
fn matches(&self, metadata: &ModelMetadata) -> bool;
fn placeholder_token(&self, metadata: &ModelMetadata) -> RegistryResult<String>;
fn placeholder_token_id(&self, metadata: &ModelMetadata) -> RegistryResult<TokenId>;
fn modality_limits(&self, metadata: &ModelMetadata)
-> RegistryResult<HashMap<Modality, usize>>;
fn processor_kwargs(&self, metadata: &ModelMetadata) -> RegistryResult<Value>;
fn prompt_replacements(
&self,
metadata: &ModelMetadata,
preprocessed: &PreprocessedImages,
) -> RegistryResult<Vec<PromptReplacement>>;
fn field_layouts(&self) -> HashMap<String, FieldLayout> {
HashMap::from([("pixel_values".to_string(), FieldLayout::Batched)])
}
fn keep_on_cpu_keys(&self) -> Vec<String> {
vec![]
}
}
pub fn image_sizes_hw(preprocessed: &PreprocessedImages) -> Vec<ImageSize> {
preprocessed
.image_sizes
.iter()
.map(|&(h, w)| ImageSize {
width: w,
height: h,
})
.collect()
}