use crate::sd3::{
config::{Sd3Config, Sd3ConfigError},
model::{Sd3Error, Sd3TextEmbeddings, Sd3TextEncoderPipeline},
};
#[derive(Debug, thiserror::Error)]
pub enum Sd3TaskError {
#[error("Configuration error: {0}")]
Config(#[from] Sd3ConfigError),
#[error("Encoder error: {0}")]
Encoder(#[from] Sd3Error),
#[error("Empty token sequence provided")]
EmptyTokens,
#[error("Invalid output embeddings: expected pooled_dim={expected}, got {got}")]
InvalidEmbeddingDim { expected: usize, got: usize },
}
pub struct Sd3TextEncoder {
pipeline: Sd3TextEncoderPipeline,
}
impl Sd3TextEncoder {
pub fn new(config: Sd3Config) -> Result<Self, Sd3TaskError> {
config.validate()?;
let pipeline = Sd3TextEncoderPipeline::new(config)?;
Ok(Self { pipeline })
}
pub fn encode(&self, token_ids: &[u32]) -> Result<Sd3TextEmbeddings, Sd3TaskError> {
if token_ids.is_empty() {
return Err(Sd3TaskError::EmptyTokens);
}
let seq_len = token_ids.len();
let embeddings = self.pipeline.encode_text(token_ids, seq_len)?;
let expected_pooled = self.pipeline.config().pooled_embedding_dim;
if embeddings.pooled_embeddings.len() != expected_pooled {
return Err(Sd3TaskError::InvalidEmbeddingDim {
expected: expected_pooled,
got: embeddings.pooled_embeddings.len(),
});
}
Ok(embeddings)
}
pub fn pipeline(&self) -> &Sd3TextEncoderPipeline {
&self.pipeline
}
pub fn config(&self) -> &Sd3Config {
self.pipeline.config()
}
}