Skip to main content

hanzo_quant/
pending_layer.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    sync::{
5        atomic::AtomicUsize,
6        mpsc::{self, Receiver},
7        Arc, Mutex,
8    },
9};
10
11use hanzo_ml::{DType, Device, Result, Tensor};
12
13use crate::{
14    DistributedKind, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde,
15};
16
17enum PendingState {
18    Pending(Receiver<Result<Arc<dyn QuantMethod>>>),
19    Ready(Arc<dyn QuantMethod>),
20    /// Transitional state used during resolve to avoid holding both the old
21    /// receiver and the new layer simultaneously.
22    Taken,
23}
24
25/// A wrapper around a `QuantMethod` that resolves lazily from a background
26/// quantization task. Created by `apply_immediate_isq` when a thread pool is
27/// available for parallel immediate ISQ.
28pub struct PendingIsqLayer {
29    inner: Mutex<PendingState>,
30}
31
32impl PendingIsqLayer {
33    pub fn new(rx: Receiver<Result<Arc<dyn QuantMethod>>>) -> Self {
34        Self {
35            inner: Mutex::new(PendingState::Pending(rx)),
36        }
37    }
38
39    /// Block until the background quantization task completes and return the
40    /// resolved layer. Subsequent calls return the cached result immediately.
41    fn resolve(&self) -> Result<Arc<dyn QuantMethod>> {
42        let mut state = self.inner.lock().expect("PendingIsqLayer lock poisoned");
43        match &*state {
44            PendingState::Ready(layer) => Ok(layer.clone()),
45            PendingState::Taken => {
46                hanzo_ml::bail!("PendingIsqLayer is in an invalid transitional state")
47            }
48            PendingState::Pending(_) => {
49                // Take the receiver out so we can receive without holding the
50                // lock on the enum variant (swap to Taken first).
51                let old = std::mem::replace(&mut *state, PendingState::Taken);
52                if let PendingState::Pending(rx) = old {
53                    let result = rx
54                        .recv()
55                        .map_err(|e| hanzo_ml::Error::Msg(format!("ISQ channel error: {e}")))?;
56                    match result {
57                        Ok(layer) => {
58                            *state = PendingState::Ready(layer.clone());
59                            Ok(layer)
60                        }
61                        Err(e) => {
62                            // Leave in Taken state; the error is propagated.
63                            Err(e)
64                        }
65                    }
66                } else {
67                    unreachable!()
68                }
69            }
70        }
71    }
72}
73
74impl Debug for PendingIsqLayer {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        let state_str = match &*self.inner.lock().unwrap() {
77            PendingState::Pending(_) => "Pending",
78            PendingState::Ready(_) => "Ready",
79            PendingState::Taken => "Taken",
80        };
81        write!(f, "PendingIsqLayer({state_str})")
82    }
83}
84
85impl QuantizedSerde for PendingIsqLayer {
86    fn name(&self) -> &'static str {
87        "pending-isq"
88    }
89
90    fn isq_serde_supported(&self) -> bool {
91        match self.resolve() {
92            Ok(layer) => layer.isq_serde_supported(),
93            Err(_) => false,
94        }
95    }
96
97    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
98        // We must return owned data since we can't borrow from the resolved layer
99        let layer = self.resolve()?;
100        let data = layer.serialize()?;
101        Ok(Cow::Owned(data.into_owned()))
102    }
103
104    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
105        let layer = self.resolve()?;
106        let data = layer.serialize_with_bias(bias)?;
107        Ok(Cow::Owned(data.into_owned()))
108    }
109}
110
111impl QuantMethod for PendingIsqLayer {
112    fn new(_method: QuantMethodConfig) -> Result<Self>
113    where
114        Self: Sized,
115    {
116        hanzo_ml::bail!("PendingIsqLayer cannot be created via QuantMethodConfig")
117    }
118
119    fn dequantize_w(&self) -> Result<Tensor> {
120        self.resolve()?.dequantize_w()
121    }
122
123    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
124        self.resolve()?.forward_raw(a)
125    }
126
127    fn forward(&self, a: &Tensor) -> Result<Tensor> {
128        self.resolve()?.forward(a)
129    }
130
131    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
132        self.resolve()?.gather_forward(a, indices)
133    }
134
135    fn gather_forward_raw(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
136        self.resolve()?.gather_forward_raw(a, indices)
137    }
138
139    fn quantized_act_type(&self) -> Option<DType> {
140        self.resolve().ok()?.quantized_act_type()
141    }
142
143    fn dtype_and_device(&self) -> (DType, Device) {
144        self.resolve()
145            .expect("PendingIsqLayer failed to resolve for dtype_and_device")
146            .dtype_and_device()
147    }
148
149    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
150        self.resolve()?.add_delta_w(delta)
151    }
152
153    fn apply_isq(
154        self: Arc<Self>,
155        dtype: Option<IsqType>,
156        device: Device,
157        n_quantized: &AtomicUsize,
158        imatrix_weight: Option<Vec<f32>>,
159        guard: QuantizeOntoGuard,
160    ) -> Result<Arc<dyn QuantMethod>> {
161        self.resolve()?
162            .clone()
163            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
164    }
165
166    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
167        self.resolve().ok()?.unquant_weight_bias()
168    }
169
170    fn begin_track_stats(&mut self) -> Result<()> {
171        // Immediate ISQ is the no-imatrix path, so stats tracking is never used.
172        hanzo_ml::bail!("`PendingIsqLayer` does not support tracking stats.")
173    }
174
175    fn end_track_stats(&self) -> Result<Tensor> {
176        hanzo_ml::bail!("`PendingIsqLayer` does not support tracking stats.")
177    }
178
179    fn is_distributed(&self) -> Option<DistributedKind> {
180        self.resolve().ok()?.is_distributed()
181    }
182}
183
184pub type IsqSender = mpsc::SyncSender<Result<Arc<dyn QuantMethod>>>;
185pub type IsqReceiver = Receiver<Result<Arc<dyn QuantMethod>>>;
186
187/// Create a channel pair for use with `PendingIsqLayer`. The sender should be
188/// given to a thread pool task; the receiver is passed to `PendingIsqLayer::new`.
189pub fn pending_isq_channel() -> (IsqSender, IsqReceiver) {
190    mpsc::sync_channel(1)
191}