use std::path::Path;
use trustformers_core::errors::{ErrorKind, Result, TrustformersError};
use super::config::{DistributedConfig, WeightLoadingConfig};
use super::distributed::DistributedWeightLoader;
use super::gguf::GGUFLoader;
use super::huggingface::{HuggingFaceLoader, WeightLoader};
use super::memory_mapped::MemoryMappedLoader;
pub fn create_huggingface_loader(
model_dir: impl AsRef<Path>,
config: Option<WeightLoadingConfig>,
) -> Result<Box<dyn WeightLoader>> {
let config = config.unwrap_or_default();
let loader = HuggingFaceLoader::new(model_dir, config)?;
Ok(Box::new(loader))
}
pub fn create_memory_mapped_loader(path: impl AsRef<Path>) -> Result<Box<dyn WeightLoader>> {
let loader = MemoryMappedLoader::new(path)?;
Ok(Box::new(loader))
}
pub fn create_gguf_loader(path: impl AsRef<Path>) -> Result<Box<dyn WeightLoader>> {
let loader = GGUFLoader::new(path)?;
Ok(Box::new(loader))
}
pub fn create_distributed_loader(
config: WeightLoadingConfig,
distributed_config: DistributedConfig,
) -> Result<Box<dyn WeightLoader>> {
let loader = DistributedWeightLoader::new(config, distributed_config)?;
Ok(Box::new(loader))
}
pub fn auto_create_loader(
path: impl AsRef<Path>,
config: Option<WeightLoadingConfig>,
) -> Result<Box<dyn WeightLoader>> {
let path = path.as_ref();
let config = config.unwrap_or_default();
if let Some(distributed_config) = config.distributed.clone() {
return create_distributed_loader(config, distributed_config);
}
if path.is_dir() {
create_huggingface_loader(path, Some(config))
} else if path.extension().and_then(|s| s.to_str()) == Some("safetensors") {
if config.memory_mapped {
create_memory_mapped_loader(path)
} else {
create_huggingface_loader(path.parent().unwrap_or(path), Some(config))
}
} else if path.extension().and_then(|s| s.to_str()) == Some("gguf") {
create_gguf_loader(path)
} else {
Err(TrustformersError::new(ErrorKind::InvalidFormat {
expected: "HuggingFace directory, .safetensors, or .gguf".to_string(),
actual: "Unknown weight format".to_string(),
}))
}
}