use anyhow::Result;
use candle_core::Tensor;
use candle_transformers::models::mmdit::model::MMDiT;
use super::quantized_mmdit::QuantizedMMDiT;
#[allow(clippy::large_enum_variant)]
pub(crate) enum SD3Transformer {
BF16(MMDiT),
Offloaded(Box<super::offload::OffloadedMMDiT>),
Quantized(QuantizedMMDiT),
}
unsafe impl Send for SD3Transformer {}
unsafe impl Sync for SD3Transformer {}
impl SD3Transformer {
pub fn forward(
&self,
x: &Tensor,
t: &Tensor,
y: &Tensor,
context: &Tensor,
skip_layers: Option<&[usize]>,
) -> Result<Tensor> {
match self {
Self::BF16(m) => Ok(m.forward(x, t, y, context, skip_layers)?),
Self::Offloaded(m) => m.forward(x, t, y, context, skip_layers),
Self::Quantized(m) => m.forward(x, t, y, context, skip_layers),
}
}
}