use crate::apr_transformer::{AprTransformer, AprTransformerConfig, AprTransformerLayer};
use crate::error::{RealizarError, Result};
use crate::safetensors::validation::ValidatedAprTransformer;
#[cfg(not(target_arch = "wasm32"))]
use crate::safetensors::ShardedSafeTensorsModel;
use crate::safetensors::{MappedSafeTensorsModel, SafetensorsConfig};
use std::path::Path;
pub(crate) trait TensorSource {
fn get_tensor_auto(&self, name: &str) -> Result<Vec<f32>>;
fn has_tensor(&self, name: &str) -> bool;
fn tensor_names(&self) -> Vec<&str>;
fn tensor_shape(&self, name: &str) -> Option<Vec<usize>>;
}
#[cfg(not(target_arch = "wasm32"))]
impl TensorSource for MappedSafeTensorsModel {
fn get_tensor_auto(&self, name: &str) -> Result<Vec<f32>> {
MappedSafeTensorsModel::get_tensor_auto(self, name)
}
fn has_tensor(&self, name: &str) -> bool {
MappedSafeTensorsModel::has_tensor(self, name)
}
fn tensor_names(&self) -> Vec<&str> {
MappedSafeTensorsModel::tensor_names(self)
}
fn tensor_shape(&self, name: &str) -> Option<Vec<usize>> {
MappedSafeTensorsModel::get_tensor_info(self, name).map(|i| i.shape.clone())
}
}
#[cfg(not(target_arch = "wasm32"))]
impl TensorSource for ShardedSafeTensorsModel {
fn get_tensor_auto(&self, name: &str) -> Result<Vec<f32>> {
ShardedSafeTensorsModel::get_tensor_auto(self, name)
}
fn has_tensor(&self, name: &str) -> bool {
ShardedSafeTensorsModel::has_tensor(self, name)
}
fn tensor_names(&self) -> Vec<&str> {
ShardedSafeTensorsModel::tensor_names(self)
}
fn tensor_shape(&self, name: &str) -> Option<Vec<usize>> {
ShardedSafeTensorsModel::get_tensor_info(self, name).map(|i| i.shape.clone())
}
}
pub struct SafetensorsToAprConverter;
include!("safetensors_infer_convert.rs");
include!("safetensors_infer_convert_02.rs");