Skip to main content

ferrum_models/
weights.rs

1//! Weight loading - MVP stub implementation
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use ferrum_interfaces::{
8    backend::{
9        TensorSpec, TransformationType, WeightFormat, WeightLoaderCapabilities, WeightMetadata,
10        WeightSource, WeightSourceType,
11    },
12    TensorFactory, TensorRef, WeightLoader,
13};
14use ferrum_types::{DataType, Result};
15use tracing::debug;
16
17/// Weight loader handle wrapping a trait object
18#[derive(Clone)]
19pub struct WeightLoaderHandle(pub Arc<dyn WeightLoader + Send + Sync>);
20
21impl std::fmt::Debug for WeightLoaderHandle {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.debug_struct("WeightLoaderHandle").finish()
24    }
25}
26
27/// Create default stub weight loader
28pub fn default_weight_loader() -> Result<WeightLoaderHandle> {
29    debug!("Creating default stub weight loader");
30    Ok(WeightLoaderHandle(Arc::new(StubWeightLoader::new())))
31}
32
33/// Stub weight loader - MVP implementation
34pub struct StubWeightLoader {
35    factory: Option<Arc<dyn TensorFactory>>,
36}
37
38impl StubWeightLoader {
39    pub fn new() -> Self {
40        Self { factory: None }
41    }
42
43    pub fn with_factory(factory: Arc<dyn TensorFactory>) -> Self {
44        Self {
45            factory: Some(factory),
46        }
47    }
48}
49
50impl Default for StubWeightLoader {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl std::fmt::Debug for StubWeightLoader {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("StubWeightLoader")
59            .field("has_factory", &self.factory.is_some())
60            .finish()
61    }
62}
63
64#[async_trait]
65impl WeightLoader for StubWeightLoader {
66    async fn load_tensor(&self, spec: &TensorSpec) -> Result<TensorRef> {
67        debug!(
68            "StubWeightLoader: creating zeros for '{}' {:?}",
69            spec.name, spec.shape
70        );
71
72        if let Some(factory) = &self.factory {
73            factory.zeros(&spec.shape, spec.dtype, &spec.device)
74        } else {
75            Err(ferrum_types::FerrumError::model(
76                "No tensor factory configured in stub weight loader",
77            ))
78        }
79    }
80
81    async fn load_tensors(&self, specs: &[TensorSpec]) -> Result<Vec<TensorRef>> {
82        let mut tensors = Vec::with_capacity(specs.len());
83        for spec in specs {
84            tensors.push(self.load_tensor(spec).await?);
85        }
86        Ok(tensors)
87    }
88
89    async fn is_available(&self, _source: &WeightSource) -> bool {
90        true
91    }
92
93    async fn get_metadata(&self, _source: &WeightSource) -> Result<WeightMetadata> {
94        Ok(WeightMetadata {
95            tensors: HashMap::new(),
96            format: WeightFormat::SafeTensors,
97            total_size_bytes: 1024 * 1024,
98            dtypes: vec![DataType::FP16],
99            extra: HashMap::new(),
100        })
101    }
102
103    async fn preload(&self, _source: &WeightSource) -> Result<()> {
104        Ok(())
105    }
106
107    fn capabilities(&self) -> WeightLoaderCapabilities {
108        WeightLoaderCapabilities {
109            supported_formats: vec![WeightFormat::SafeTensors],
110            supported_sources: vec![WeightSourceType::File, WeightSourceType::HuggingFace],
111            max_tensor_size: 10 * 1024 * 1024 * 1024, // 10GB
112            supports_streaming: false,
113            supports_concurrent: false,
114            supported_transformations: vec![
115                TransformationType::Transpose,
116                TransformationType::Reshape,
117                TransformationType::Cast,
118            ],
119        }
120    }
121}
122
123/// Placeholder SafeTensors loader
124pub struct SafeTensorsLoader;
125
126/// Placeholder GGUF loader
127pub struct GGUFLoader;