use std::fmt;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
use std::sync::Arc;
use super::core_interop::CoreInterop;
use super::model_variant::ModelVariant;
use crate::error::{FoundryLocalError, Result};
use crate::openai::AudioClient;
use crate::openai::ChatClient;
use crate::openai::EmbeddingClient;
use crate::types::ModelInfo;
pub struct Model {
inner: ModelKind,
}
#[allow(clippy::large_enum_variant)]
enum ModelKind {
ModelVariant(ModelVariant),
Model {
alias: String,
core: Arc<CoreInterop>,
variants: Vec<ModelVariant>,
selected: AtomicUsize,
},
}
impl Clone for Model {
fn clone(&self) -> Self {
Self {
inner: match &self.inner {
ModelKind::ModelVariant(v) => ModelKind::ModelVariant(v.clone()),
ModelKind::Model {
alias,
core,
variants,
selected,
} => ModelKind::Model {
alias: alias.clone(),
core: Arc::clone(core),
variants: variants.clone(),
selected: AtomicUsize::new(selected.load(Relaxed)),
},
},
}
}
}
impl fmt::Debug for Model {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.inner {
ModelKind::ModelVariant(v) => f
.debug_struct("Model::ModelVariant")
.field("id", &v.id())
.field("alias", &v.alias())
.finish(),
ModelKind::Model {
alias,
variants,
selected,
..
} => f
.debug_struct("Model::Model")
.field("alias", alias)
.field("id", &variants[selected.load(Relaxed)].id())
.field("variants_count", &variants.len())
.field("selected_index", &selected.load(Relaxed))
.finish(),
}
}
}
impl Model {
pub(crate) fn from_variant(variant: ModelVariant) -> Self {
Self {
inner: ModelKind::ModelVariant(variant),
}
}
pub(crate) fn from_group(alias: String, core: Arc<CoreInterop>) -> Self {
Self {
inner: ModelKind::Model {
alias,
core,
variants: Vec::new(),
selected: AtomicUsize::new(0),
},
}
}
pub(crate) fn add_variant(&mut self, variant: ModelVariant) {
match &mut self.inner {
ModelKind::Model {
variants, selected, ..
} => {
variants.push(variant);
let new_idx = variants.len() - 1;
let current = selected.load(Relaxed);
if variants[new_idx].info_ref().cached && !variants[current].info_ref().cached {
selected.store(new_idx, Relaxed);
}
}
ModelKind::ModelVariant(_) => {
panic!("add_variant called on a single-variant Model");
}
}
}
}
impl Model {
fn selected_variant(&self) -> &ModelVariant {
match &self.inner {
ModelKind::ModelVariant(v) => v,
ModelKind::Model {
variants, selected, ..
} => &variants[selected.load(Relaxed)],
}
}
}
impl Model {
pub fn id(&self) -> &str {
self.selected_variant().id()
}
pub fn alias(&self) -> &str {
match &self.inner {
ModelKind::ModelVariant(v) => v.alias(),
ModelKind::Model { alias, .. } => alias,
}
}
pub fn info(&self) -> &ModelInfo {
self.selected_variant().info()
}
pub fn context_length(&self) -> Option<u64> {
self.selected_variant().info().context_length
}
pub fn input_modalities(&self) -> Option<&str> {
self.selected_variant().info().input_modalities.as_deref()
}
pub fn output_modalities(&self) -> Option<&str> {
self.selected_variant().info().output_modalities.as_deref()
}
pub fn capabilities(&self) -> Option<&str> {
self.selected_variant().info().capabilities.as_deref()
}
pub fn supports_tool_calling(&self) -> Option<bool> {
self.selected_variant().info().supports_tool_calling
}
pub async fn is_cached(&self) -> Result<bool> {
self.selected_variant().is_cached().await
}
pub async fn is_loaded(&self) -> Result<bool> {
self.selected_variant().is_loaded().await
}
pub async fn download<F>(&self, progress: Option<F>) -> Result<()>
where
F: FnMut(f64) + Send + 'static,
{
self.selected_variant().download(progress).await
}
pub async fn path(&self) -> Result<PathBuf> {
self.selected_variant().path().await
}
pub async fn load(&self) -> Result<()> {
self.selected_variant().load().await
}
pub async fn unload(&self) -> Result<String> {
self.selected_variant().unload().await
}
pub async fn remove_from_cache(&self) -> Result<String> {
self.selected_variant().remove_from_cache().await
}
pub fn create_chat_client(&self) -> ChatClient {
self.selected_variant().create_chat_client()
}
pub fn create_audio_client(&self) -> AudioClient {
self.selected_variant().create_audio_client()
}
pub fn create_embedding_client(&self) -> EmbeddingClient {
self.selected_variant().create_embedding_client()
}
pub fn variants(&self) -> Vec<Arc<Model>> {
match &self.inner {
ModelKind::ModelVariant(v) => {
vec![Arc::new(Model::from_variant(v.clone()))]
}
ModelKind::Model { variants, .. } => variants
.iter()
.map(|v| Arc::new(Model::from_variant(v.clone())))
.collect(),
}
}
pub fn select_variant(&self, variant: &Model) -> Result<()> {
self.select_variant_by_id(variant.id())
}
pub fn select_variant_by_id(&self, id: &str) -> Result<()> {
match &self.inner {
ModelKind::ModelVariant(v) => Err(FoundryLocalError::ModelOperation {
reason: format!(
"select_variant is not supported on a single variant. \
Call Catalog::get_model(\"{}\") to get a model with all variants available.",
v.alias()
),
}),
ModelKind::Model {
variants,
selected,
alias,
..
} => match variants.iter().position(|v| v.id() == id) {
Some(pos) => {
selected.store(pos, Relaxed);
Ok(())
}
None => {
let available: Vec<&str> = variants.iter().map(|v| v.id()).collect();
Err(FoundryLocalError::ModelOperation {
reason: format!(
"Variant '{id}' not found for model '{alias}'. Available: {available:?}",
),
})
}
},
}
}
}