1use std::collections::HashMap;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use ferrum_interfaces::{
8 backend::{
9 TensorSpec, TransformationType, WeightFormat, WeightLoaderCapabilities, WeightMetadata,
10 WeightSource, WeightSourceType,
11 },
12 TensorFactory, TensorRef, WeightLoader,
13};
14use ferrum_types::{DataType, Result};
15use tracing::debug;
16
17#[derive(Clone)]
19pub struct WeightLoaderHandle(pub Arc<dyn WeightLoader + Send + Sync>);
20
21impl std::fmt::Debug for WeightLoaderHandle {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("WeightLoaderHandle").finish()
24 }
25}
26
27pub fn default_weight_loader() -> Result<WeightLoaderHandle> {
29 debug!("Creating default stub weight loader");
30 Ok(WeightLoaderHandle(Arc::new(StubWeightLoader::new())))
31}
32
33pub struct StubWeightLoader {
35 factory: Option<Arc<dyn TensorFactory>>,
36}
37
38impl StubWeightLoader {
39 pub fn new() -> Self {
40 Self { factory: None }
41 }
42
43 pub fn with_factory(factory: Arc<dyn TensorFactory>) -> Self {
44 Self {
45 factory: Some(factory),
46 }
47 }
48}
49
50impl Default for StubWeightLoader {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl std::fmt::Debug for StubWeightLoader {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("StubWeightLoader")
59 .field("has_factory", &self.factory.is_some())
60 .finish()
61 }
62}
63
64#[async_trait]
65impl WeightLoader for StubWeightLoader {
66 async fn load_tensor(&self, spec: &TensorSpec) -> Result<TensorRef> {
67 debug!(
68 "StubWeightLoader: creating zeros for '{}' {:?}",
69 spec.name, spec.shape
70 );
71
72 if let Some(factory) = &self.factory {
73 factory.zeros(&spec.shape, spec.dtype, &spec.device)
74 } else {
75 Err(ferrum_types::FerrumError::model(
76 "No tensor factory configured in stub weight loader",
77 ))
78 }
79 }
80
81 async fn load_tensors(&self, specs: &[TensorSpec]) -> Result<Vec<TensorRef>> {
82 let mut tensors = Vec::with_capacity(specs.len());
83 for spec in specs {
84 tensors.push(self.load_tensor(spec).await?);
85 }
86 Ok(tensors)
87 }
88
89 async fn is_available(&self, _source: &WeightSource) -> bool {
90 true
91 }
92
93 async fn get_metadata(&self, _source: &WeightSource) -> Result<WeightMetadata> {
94 Ok(WeightMetadata {
95 tensors: HashMap::new(),
96 format: WeightFormat::SafeTensors,
97 total_size_bytes: 1024 * 1024,
98 dtypes: vec![DataType::FP16],
99 extra: HashMap::new(),
100 })
101 }
102
103 async fn preload(&self, _source: &WeightSource) -> Result<()> {
104 Ok(())
105 }
106
107 fn capabilities(&self) -> WeightLoaderCapabilities {
108 WeightLoaderCapabilities {
109 supported_formats: vec![WeightFormat::SafeTensors],
110 supported_sources: vec![WeightSourceType::File, WeightSourceType::HuggingFace],
111 max_tensor_size: 10 * 1024 * 1024 * 1024, supports_streaming: false,
113 supports_concurrent: false,
114 supported_transformations: vec![
115 TransformationType::Transpose,
116 TransformationType::Reshape,
117 TransformationType::Cast,
118 ],
119 }
120 }
121}
122
123pub struct SafeTensorsLoader;
125
126pub struct GGUFLoader;