hanzo_quant/
pending_layer.rs1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 sync::{
5 atomic::AtomicUsize,
6 mpsc::{self, Receiver},
7 Arc, Mutex,
8 },
9};
10
11use hanzo_ml::{DType, Device, Result, Tensor};
12
13use crate::{
14 DistributedKind, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde,
15};
16
17enum PendingState {
18 Pending(Receiver<Result<Arc<dyn QuantMethod>>>),
19 Ready(Arc<dyn QuantMethod>),
20 Taken,
23}
24
25pub struct PendingIsqLayer {
29 inner: Mutex<PendingState>,
30}
31
32impl PendingIsqLayer {
33 pub fn new(rx: Receiver<Result<Arc<dyn QuantMethod>>>) -> Self {
34 Self {
35 inner: Mutex::new(PendingState::Pending(rx)),
36 }
37 }
38
39 fn resolve(&self) -> Result<Arc<dyn QuantMethod>> {
42 let mut state = self.inner.lock().expect("PendingIsqLayer lock poisoned");
43 match &*state {
44 PendingState::Ready(layer) => Ok(layer.clone()),
45 PendingState::Taken => {
46 hanzo_ml::bail!("PendingIsqLayer is in an invalid transitional state")
47 }
48 PendingState::Pending(_) => {
49 let old = std::mem::replace(&mut *state, PendingState::Taken);
52 if let PendingState::Pending(rx) = old {
53 let result = rx
54 .recv()
55 .map_err(|e| hanzo_ml::Error::Msg(format!("ISQ channel error: {e}")))?;
56 match result {
57 Ok(layer) => {
58 *state = PendingState::Ready(layer.clone());
59 Ok(layer)
60 }
61 Err(e) => {
62 Err(e)
64 }
65 }
66 } else {
67 unreachable!()
68 }
69 }
70 }
71 }
72}
73
74impl Debug for PendingIsqLayer {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 let state_str = match &*self.inner.lock().unwrap() {
77 PendingState::Pending(_) => "Pending",
78 PendingState::Ready(_) => "Ready",
79 PendingState::Taken => "Taken",
80 };
81 write!(f, "PendingIsqLayer({state_str})")
82 }
83}
84
85impl QuantizedSerde for PendingIsqLayer {
86 fn name(&self) -> &'static str {
87 "pending-isq"
88 }
89
90 fn isq_serde_supported(&self) -> bool {
91 match self.resolve() {
92 Ok(layer) => layer.isq_serde_supported(),
93 Err(_) => false,
94 }
95 }
96
97 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
98 let layer = self.resolve()?;
100 let data = layer.serialize()?;
101 Ok(Cow::Owned(data.into_owned()))
102 }
103
104 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
105 let layer = self.resolve()?;
106 let data = layer.serialize_with_bias(bias)?;
107 Ok(Cow::Owned(data.into_owned()))
108 }
109}
110
111impl QuantMethod for PendingIsqLayer {
112 fn new(_method: QuantMethodConfig) -> Result<Self>
113 where
114 Self: Sized,
115 {
116 hanzo_ml::bail!("PendingIsqLayer cannot be created via QuantMethodConfig")
117 }
118
119 fn dequantize_w(&self) -> Result<Tensor> {
120 self.resolve()?.dequantize_w()
121 }
122
123 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
124 self.resolve()?.forward_raw(a)
125 }
126
127 fn forward(&self, a: &Tensor) -> Result<Tensor> {
128 self.resolve()?.forward(a)
129 }
130
131 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
132 self.resolve()?.gather_forward(a, indices)
133 }
134
135 fn gather_forward_raw(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
136 self.resolve()?.gather_forward_raw(a, indices)
137 }
138
139 fn quantized_act_type(&self) -> Option<DType> {
140 self.resolve().ok()?.quantized_act_type()
141 }
142
143 fn dtype_and_device(&self) -> (DType, Device) {
144 self.resolve()
145 .expect("PendingIsqLayer failed to resolve for dtype_and_device")
146 .dtype_and_device()
147 }
148
149 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
150 self.resolve()?.add_delta_w(delta)
151 }
152
153 fn apply_isq(
154 self: Arc<Self>,
155 dtype: Option<IsqType>,
156 device: Device,
157 n_quantized: &AtomicUsize,
158 imatrix_weight: Option<Vec<f32>>,
159 guard: QuantizeOntoGuard,
160 ) -> Result<Arc<dyn QuantMethod>> {
161 self.resolve()?
162 .clone()
163 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
164 }
165
166 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
167 self.resolve().ok()?.unquant_weight_bias()
168 }
169
170 fn begin_track_stats(&mut self) -> Result<()> {
171 hanzo_ml::bail!("`PendingIsqLayer` does not support tracking stats.")
173 }
174
175 fn end_track_stats(&self) -> Result<Tensor> {
176 hanzo_ml::bail!("`PendingIsqLayer` does not support tracking stats.")
177 }
178
179 fn is_distributed(&self) -> Option<DistributedKind> {
180 self.resolve().ok()?.is_distributed()
181 }
182}
183
184pub type IsqSender = mpsc::SyncSender<Result<Arc<dyn QuantMethod>>>;
185pub type IsqReceiver = Receiver<Result<Arc<dyn QuantMethod>>>;
186
187pub fn pending_isq_channel() -> (IsqSender, IsqReceiver) {
190 mpsc::sync_channel(1)
191}