1use std::collections::HashMap;
85
86use ferrotorch_core::grad_fns::arithmetic::{add, mul};
87use ferrotorch_core::{
88 FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
89};
90use ferrotorch_nn::module::{Module, StateDict};
91use ferrotorch_nn::parameter::Parameter;
92use ferrotorch_nn::{
93 Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
94 transpose_heads_to_2d,
95};
96
97#[derive(Debug, Clone)]
104pub struct ClipTextConfig {
105 pub hidden_size: usize,
107 pub intermediate_size: usize,
109 pub num_attention_heads: usize,
112 pub num_hidden_layers: usize,
114 pub max_position_embeddings: usize,
116 pub vocab_size: usize,
118 pub layer_norm_eps: f64,
120}
121
122impl Default for ClipTextConfig {
123 fn default() -> Self {
124 Self::sd_v1_5()
125 }
126}
127
128impl ClipTextConfig {
129 pub fn sd_v1_5() -> Self {
131 Self {
132 hidden_size: 768,
133 intermediate_size: 3072,
134 num_attention_heads: 12,
135 num_hidden_layers: 12,
136 max_position_embeddings: 77,
137 vocab_size: 49408,
138 layer_norm_eps: 1e-5,
139 }
140 }
141
142 #[inline]
144 #[must_use]
145 pub fn head_dim(&self) -> usize {
146 self.hidden_size / self.num_attention_heads
147 }
148
149 pub fn validate(&self) -> FerrotorchResult<()> {
156 if self.hidden_size == 0
157 || self.intermediate_size == 0
158 || self.num_attention_heads == 0
159 || self.num_hidden_layers == 0
160 || self.max_position_embeddings == 0
161 || self.vocab_size == 0
162 {
163 return Err(FerrotorchError::InvalidArgument {
164 message: "ClipTextConfig: all size fields must be > 0".into(),
165 });
166 }
167 if self.hidden_size % self.num_attention_heads != 0 {
168 return Err(FerrotorchError::InvalidArgument {
169 message: format!(
170 "ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
171 self.hidden_size, self.num_attention_heads,
172 ),
173 });
174 }
175 if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
176 return Err(FerrotorchError::InvalidArgument {
177 message: format!(
178 "ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
179 self.layer_norm_eps,
180 ),
181 });
182 }
183 Ok(())
184 }
185
186 pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
198 let v: serde_json::Value =
199 serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
200 message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
201 })?;
202 let mut cfg = Self::default();
203 if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
204 cfg.hidden_size = x as usize;
205 }
206 if let Some(x) = v
207 .get("intermediate_size")
208 .and_then(serde_json::Value::as_u64)
209 {
210 cfg.intermediate_size = x as usize;
211 }
212 if let Some(x) = v
213 .get("num_attention_heads")
214 .and_then(serde_json::Value::as_u64)
215 {
216 cfg.num_attention_heads = x as usize;
217 }
218 if let Some(x) = v
219 .get("num_hidden_layers")
220 .and_then(serde_json::Value::as_u64)
221 {
222 cfg.num_hidden_layers = x as usize;
223 }
224 if let Some(x) = v
225 .get("max_position_embeddings")
226 .and_then(serde_json::Value::as_u64)
227 {
228 cfg.max_position_embeddings = x as usize;
229 }
230 if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
231 cfg.vocab_size = x as usize;
232 }
233 if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
234 cfg.layer_norm_eps = x;
235 }
236 cfg.validate()?;
237 Ok(cfg)
238 }
239
240 pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
247 let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
248 message: format!(
249 "ClipTextConfig::from_file: failed to read {}: {e}",
250 path.display(),
251 ),
252 })?;
253 Self::from_json_str(&s)
254 }
255}
256
257fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
263 let prod: usize = shape.iter().product();
264 if prod != t.numel() {
265 return Err(FerrotorchError::ShapeMismatch {
266 message: format!(
267 "ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
268 match source numel {}",
269 t.numel()
270 ),
271 });
272 }
273 let data = t.data_vec()?;
274 Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
275}
276
277fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
280 let data: Vec<T> = ids
281 .iter()
282 .map(|&i| numeric_cast::cast::<u32, T>(i))
283 .collect::<FerrotorchResult<Vec<T>>>()?;
284 let n = data.len();
285 Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
286}
287
288#[derive(Debug)]
298pub struct ClipTextEmbeddings<T: Float> {
299 pub token_embedding: Embedding<T>,
301 pub position_embedding: Embedding<T>,
303 hidden_size: usize,
304 max_position_embeddings: usize,
305 training: bool,
306}
307
308impl<T: Float> ClipTextEmbeddings<T> {
309 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
316 cfg.validate()?;
317 Ok(Self {
318 token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
319 position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
320 hidden_size: cfg.hidden_size,
321 max_position_embeddings: cfg.max_position_embeddings,
322 training: false,
323 })
324 }
325
326 pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
337 if input_ids.is_empty() {
338 return Err(FerrotorchError::InvalidArgument {
339 message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
340 });
341 }
342 let seq_len = input_ids.len();
343 if seq_len > self.max_position_embeddings {
344 return Err(FerrotorchError::InvalidArgument {
345 message: format!(
346 "ClipTextEmbeddings: sequence length {seq_len} exceeds \
347 max_position_embeddings {}",
348 self.max_position_embeddings,
349 ),
350 });
351 }
352
353 let word_idx = float_index_tensor::<T>(input_ids)?;
354 let word_2d = self.token_embedding.forward(&word_idx)?; let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
357 let pos_idx = float_index_tensor::<T>(&pos_ids)?;
358 let pos_2d = self.position_embedding.forward(&pos_idx)?; let summed = add(&word_2d, &pos_2d)?;
361 reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
363 }
364}
365
366impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
367 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
372 let word_2d = self.token_embedding.forward(input)?;
373 let seq_len = input.numel();
374 let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
375 let pos_idx = float_index_tensor::<T>(&pos_ids)?;
376 let pos_2d = self.position_embedding.forward(&pos_idx)?;
377 let summed = add(&word_2d, &pos_2d)?;
378 reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
379 }
380
381 fn parameters(&self) -> Vec<&Parameter<T>> {
382 let mut out = Vec::new();
383 out.extend(self.token_embedding.parameters());
384 out.extend(self.position_embedding.parameters());
385 out
386 }
387
388 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
389 let mut out = Vec::new();
390 out.extend(self.token_embedding.parameters_mut());
391 out.extend(self.position_embedding.parameters_mut());
392 out
393 }
394
395 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
396 let mut out = Vec::new();
397 for (n, p) in self.token_embedding.named_parameters() {
398 out.push((format!("token_embedding.{n}"), p));
399 }
400 for (n, p) in self.position_embedding.named_parameters() {
401 out.push((format!("position_embedding.{n}"), p));
402 }
403 out
404 }
405
406 fn train(&mut self) {
407 self.training = true;
408 }
409
410 fn eval(&mut self) {
411 self.training = false;
412 }
413
414 fn is_training(&self) -> bool {
415 self.training
416 }
417
418 fn state_dict(&self) -> StateDict<T> {
419 self.named_parameters()
420 .into_iter()
421 .map(|(n, p)| (n, p.tensor().clone()))
422 .collect()
423 }
424
425 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
426 let extract = |prefix: &str| -> StateDict<T> {
427 let p = format!("{prefix}.");
428 state
429 .iter()
430 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
431 .collect()
432 };
433 if strict {
434 let prefixes = ["token_embedding", "position_embedding"];
435 for k in state.keys() {
436 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
437 return Err(FerrotorchError::InvalidArgument {
438 message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
439 });
440 }
441 }
442 }
443 self.token_embedding
444 .load_state_dict(&extract("token_embedding"), strict)?;
445 self.position_embedding
446 .load_state_dict(&extract("position_embedding"), strict)?;
447 Ok(())
448 }
449}
450
451#[derive(Debug)]
470pub struct ClipSelfAttention<T: Float> {
471 pub q_proj: Linear<T>,
473 pub k_proj: Linear<T>,
475 pub v_proj: Linear<T>,
477 pub out_proj: Linear<T>,
479 num_heads: usize,
480 head_dim: usize,
481 hidden: usize,
482 training: bool,
483}
484
485impl<T: Float> ClipSelfAttention<T> {
486 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
492 cfg.validate()?;
493 Ok(Self {
494 q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
495 k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
496 v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
497 out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
498 num_heads: cfg.num_attention_heads,
499 head_dim: cfg.head_dim(),
500 hidden: cfg.hidden_size,
501 training: false,
502 })
503 }
504}
505
506impl<T: Float> Module<T> for ClipSelfAttention<T> {
507 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
513 let shape = input.shape();
514 if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
515 return Err(FerrotorchError::ShapeMismatch {
516 message: format!(
517 "ClipSelfAttention expects [1, S, {}], got {:?}",
518 self.hidden, shape,
519 ),
520 });
521 }
522 let seq_len = shape[1];
523
524 let q = self.q_proj.forward(input)?;
526 let k = self.k_proj.forward(input)?;
527 let v = self.v_proj.forward(input)?;
528
529 let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
532 let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
533 let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
534
535 let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
537 let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
538 let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
539
540 let ctx = standard_attention(&q_h, &k_h, &v_h, true)?;
544
545 let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
547 let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
548
549 self.out_proj.forward(&ctx3)
551 }
552
553 fn parameters(&self) -> Vec<&Parameter<T>> {
554 let mut out = Vec::new();
555 out.extend(self.q_proj.parameters());
556 out.extend(self.k_proj.parameters());
557 out.extend(self.v_proj.parameters());
558 out.extend(self.out_proj.parameters());
559 out
560 }
561
562 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
563 let mut out = Vec::new();
564 out.extend(self.q_proj.parameters_mut());
565 out.extend(self.k_proj.parameters_mut());
566 out.extend(self.v_proj.parameters_mut());
567 out.extend(self.out_proj.parameters_mut());
568 out
569 }
570
571 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
572 let mut out = Vec::new();
573 for (n, p) in self.q_proj.named_parameters() {
574 out.push((format!("q_proj.{n}"), p));
575 }
576 for (n, p) in self.k_proj.named_parameters() {
577 out.push((format!("k_proj.{n}"), p));
578 }
579 for (n, p) in self.v_proj.named_parameters() {
580 out.push((format!("v_proj.{n}"), p));
581 }
582 for (n, p) in self.out_proj.named_parameters() {
583 out.push((format!("out_proj.{n}"), p));
584 }
585 out
586 }
587
588 fn train(&mut self) {
589 self.training = true;
590 }
591
592 fn eval(&mut self) {
593 self.training = false;
594 }
595
596 fn is_training(&self) -> bool {
597 self.training
598 }
599
600 fn state_dict(&self) -> StateDict<T> {
601 self.named_parameters()
602 .into_iter()
603 .map(|(n, p)| (n, p.tensor().clone()))
604 .collect()
605 }
606
607 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
608 let extract = |prefix: &str| -> StateDict<T> {
609 let p = format!("{prefix}.");
610 state
611 .iter()
612 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
613 .collect()
614 };
615 if strict {
616 let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
617 for k in state.keys() {
618 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
619 return Err(FerrotorchError::InvalidArgument {
620 message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
621 });
622 }
623 }
624 }
625 self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
626 self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
627 self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
628 self.out_proj
629 .load_state_dict(&extract("out_proj"), strict)?;
630 Ok(())
631 }
632}
633
634#[derive(Debug)]
645pub struct ClipMlp<T: Float> {
646 pub fc1: Linear<T>,
648 pub fc2: Linear<T>,
650 activation: GELU,
651 training: bool,
652}
653
654impl<T: Float> ClipMlp<T> {
655 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
661 cfg.validate()?;
662 Ok(Self {
663 fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
664 fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
665 activation: GELU::with_approximate(GeluApproximate::Sigmoid),
667 training: false,
668 })
669 }
670}
671
672impl<T: Float> Module<T> for ClipMlp<T> {
673 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
674 let h = self.fc1.forward(input)?;
675 let h = self.activation.forward(&h)?;
676 self.fc2.forward(&h)
677 }
678
679 fn parameters(&self) -> Vec<&Parameter<T>> {
680 let mut out = Vec::new();
681 out.extend(self.fc1.parameters());
682 out.extend(self.fc2.parameters());
683 out
684 }
685
686 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
687 let mut out = Vec::new();
688 out.extend(self.fc1.parameters_mut());
689 out.extend(self.fc2.parameters_mut());
690 out
691 }
692
693 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
694 let mut out = Vec::new();
695 for (n, p) in self.fc1.named_parameters() {
696 out.push((format!("fc1.{n}"), p));
697 }
698 for (n, p) in self.fc2.named_parameters() {
699 out.push((format!("fc2.{n}"), p));
700 }
701 out
702 }
703
704 fn train(&mut self) {
705 self.training = true;
706 }
707
708 fn eval(&mut self) {
709 self.training = false;
710 }
711
712 fn is_training(&self) -> bool {
713 self.training
714 }
715
716 fn state_dict(&self) -> StateDict<T> {
717 self.named_parameters()
718 .into_iter()
719 .map(|(n, p)| (n, p.tensor().clone()))
720 .collect()
721 }
722
723 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
724 let extract = |prefix: &str| -> StateDict<T> {
725 let p = format!("{prefix}.");
726 state
727 .iter()
728 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
729 .collect()
730 };
731 if strict {
732 for k in state.keys() {
733 if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
734 return Err(FerrotorchError::InvalidArgument {
735 message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
736 });
737 }
738 }
739 }
740 self.fc1.load_state_dict(&extract("fc1"), strict)?;
741 self.fc2.load_state_dict(&extract("fc2"), strict)?;
742 Ok(())
743 }
744}
745
746#[derive(Debug)]
759pub struct ClipEncoderLayer<T: Float> {
760 pub layer_norm1: LayerNorm<T>,
762 pub self_attn: ClipSelfAttention<T>,
764 pub layer_norm2: LayerNorm<T>,
766 pub mlp: ClipMlp<T>,
768 training: bool,
769}
770
771impl<T: Float> ClipEncoderLayer<T> {
772 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
778 Ok(Self {
779 layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
780 self_attn: ClipSelfAttention::new(cfg)?,
781 layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
782 mlp: ClipMlp::new(cfg)?,
783 training: false,
784 })
785 }
786}
787
788impl<T: Float> Module<T> for ClipEncoderLayer<T> {
789 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
791 let normed = self.layer_norm1.forward(input)?;
793 let attn_out = self.self_attn.forward(&normed)?;
794 let after_attn = add(input, &attn_out)?;
795
796 let normed_ffn = self.layer_norm2.forward(&after_attn)?;
798 let mlp_out = self.mlp.forward(&normed_ffn)?;
799 add(&after_attn, &mlp_out)
800 }
801
802 fn parameters(&self) -> Vec<&Parameter<T>> {
803 let mut out = Vec::new();
804 out.extend(self.layer_norm1.parameters());
805 out.extend(self.self_attn.parameters());
806 out.extend(self.layer_norm2.parameters());
807 out.extend(self.mlp.parameters());
808 out
809 }
810
811 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
812 let mut out = Vec::new();
813 out.extend(self.layer_norm1.parameters_mut());
814 out.extend(self.self_attn.parameters_mut());
815 out.extend(self.layer_norm2.parameters_mut());
816 out.extend(self.mlp.parameters_mut());
817 out
818 }
819
820 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
821 let mut out = Vec::new();
822 for (n, p) in self.layer_norm1.named_parameters() {
823 out.push((format!("layer_norm1.{n}"), p));
824 }
825 for (n, p) in self.self_attn.named_parameters() {
826 out.push((format!("self_attn.{n}"), p));
827 }
828 for (n, p) in self.layer_norm2.named_parameters() {
829 out.push((format!("layer_norm2.{n}"), p));
830 }
831 for (n, p) in self.mlp.named_parameters() {
832 out.push((format!("mlp.{n}"), p));
833 }
834 out
835 }
836
837 fn train(&mut self) {
838 self.training = true;
839 }
840
841 fn eval(&mut self) {
842 self.training = false;
843 }
844
845 fn is_training(&self) -> bool {
846 self.training
847 }
848
849 fn state_dict(&self) -> StateDict<T> {
850 self.named_parameters()
851 .into_iter()
852 .map(|(n, p)| (n, p.tensor().clone()))
853 .collect()
854 }
855
856 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
857 let extract = |prefix: &str| -> StateDict<T> {
858 let p = format!("{prefix}.");
859 state
860 .iter()
861 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
862 .collect()
863 };
864 if strict {
865 let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
866 for k in state.keys() {
867 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
868 return Err(FerrotorchError::InvalidArgument {
869 message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
870 });
871 }
872 }
873 }
874 self.layer_norm1
875 .load_state_dict(&extract("layer_norm1"), strict)?;
876 self.self_attn
877 .load_state_dict(&extract("self_attn"), strict)?;
878 self.layer_norm2
879 .load_state_dict(&extract("layer_norm2"), strict)?;
880 self.mlp.load_state_dict(&extract("mlp"), strict)?;
881 Ok(())
882 }
883}
884
885#[derive(Debug)]
891pub struct ClipEncoder<T: Float> {
892 pub layers: Vec<ClipEncoderLayer<T>>,
894 training: bool,
895}
896
897impl<T: Float> ClipEncoder<T> {
898 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
904 cfg.validate()?;
905 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
906 for _ in 0..cfg.num_hidden_layers {
907 layers.push(ClipEncoderLayer::new(cfg)?);
908 }
909 Ok(Self {
910 layers,
911 training: false,
912 })
913 }
914}
915
916impl<T: Float> Module<T> for ClipEncoder<T> {
917 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
918 let mut h = input.clone();
919 for l in &self.layers {
920 h = l.forward(&h)?;
921 }
922 Ok(h)
923 }
924
925 fn parameters(&self) -> Vec<&Parameter<T>> {
926 let mut out = Vec::new();
927 for l in &self.layers {
928 out.extend(l.parameters());
929 }
930 out
931 }
932
933 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
934 let mut out = Vec::new();
935 for l in &mut self.layers {
936 out.extend(l.parameters_mut());
937 }
938 out
939 }
940
941 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
942 let mut out = Vec::new();
943 for (i, l) in self.layers.iter().enumerate() {
944 for (n, p) in l.named_parameters() {
945 out.push((format!("layers.{i}.{n}"), p));
946 }
947 }
948 out
949 }
950
951 fn train(&mut self) {
952 self.training = true;
953 for l in &mut self.layers {
954 l.train();
955 }
956 }
957
958 fn eval(&mut self) {
959 self.training = false;
960 for l in &mut self.layers {
961 l.eval();
962 }
963 }
964
965 fn is_training(&self) -> bool {
966 self.training
967 }
968
969 fn state_dict(&self) -> StateDict<T> {
970 self.named_parameters()
971 .into_iter()
972 .map(|(n, p)| (n, p.tensor().clone()))
973 .collect()
974 }
975
976 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
977 let extract = |prefix: &str| -> StateDict<T> {
978 let p = format!("{prefix}.");
979 state
980 .iter()
981 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
982 .collect()
983 };
984 if strict {
985 for k in state.keys() {
986 if !k.starts_with("layers.") {
987 return Err(FerrotorchError::InvalidArgument {
988 message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
989 });
990 }
991 }
992 }
993 for (i, l) in self.layers.iter_mut().enumerate() {
994 l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
995 }
996 Ok(())
997 }
998}
999
1000#[derive(Debug)]
1018pub struct ClipTextEncoder<T: Float> {
1019 pub embeddings: ClipTextEmbeddings<T>,
1021 pub encoder: ClipEncoder<T>,
1023 pub final_layer_norm: LayerNorm<T>,
1025 pub config: ClipTextConfig,
1027 training: bool,
1028}
1029
1030impl<T: Float> ClipTextEncoder<T> {
1031 pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
1038 cfg.validate()?;
1039 let embeddings = ClipTextEmbeddings::new(&cfg)?;
1040 let encoder = ClipEncoder::new(&cfg)?;
1041 let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
1042 Ok(Self {
1043 embeddings,
1044 encoder,
1045 final_layer_norm,
1046 config: cfg,
1047 training: false,
1048 })
1049 }
1050
1051 pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
1064 let h = self.embeddings.forward_from_ids(input_ids)?;
1065 let h = self.encoder.forward(&h)?;
1066 self.final_layer_norm.forward(&h)
1067 }
1068
1069 pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1082 if ids.ndim() != 1 {
1085 return Err(FerrotorchError::ShapeMismatch {
1086 message: format!(
1087 "ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
1088 ids.shape()
1089 ),
1090 });
1091 }
1092 let data = ids.data_vec()?;
1093 let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
1094 for (i, v) in data.iter().enumerate() {
1095 let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
1096 FerrotorchError::InvalidArgument {
1097 message: format!(
1098 "ClipTextEncoder::forward_from_id_tensor: id at {i} \
1099 not representable as f64"
1100 ),
1101 }
1102 })?;
1103 if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
1104 return Err(FerrotorchError::InvalidArgument {
1105 message: format!(
1106 "ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
1107 is not a non-negative integer"
1108 ),
1109 });
1110 }
1111 u32_ids.push(f as u32);
1112 }
1113 self.forward_from_ids(&u32_ids)
1114 }
1115
1116 pub fn load_hf_state_dict(
1139 &mut self,
1140 hf_state: &StateDict<T>,
1141 strict: bool,
1142 ) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
1143 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
1144 let mut dropped: Vec<String> = Vec::new();
1145 for (k, v) in hf_state {
1146 let after = k
1148 .strip_prefix("text_model.")
1149 .map_or_else(|| k.clone(), str::to_owned);
1150
1151 if after == "embeddings.position_ids" {
1154 dropped.push(k.clone());
1155 continue;
1156 }
1157
1158 let is_known = after.starts_with("embeddings.token_embedding.")
1159 || after.starts_with("embeddings.position_embedding.")
1160 || after.starts_with("encoder.")
1161 || after.starts_with("final_layer_norm.");
1162 if is_known {
1163 remapped.insert(after, v.clone());
1164 continue;
1165 }
1166
1167 if strict {
1168 return Err(FerrotorchError::InvalidArgument {
1169 message: format!(
1170 "ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
1171 known CLIP text-tower parameter and strict mode is on. \
1172 Pass strict=false to drop unknown keys."
1173 ),
1174 });
1175 }
1176 dropped.push(k.clone());
1177 }
1178 dropped.sort();
1179 self.load_state_dict(&remapped, strict)?;
1180 Ok(crate::safetensors_loader::DropReport { dropped })
1181 }
1182}
1183
1184impl<T: Float> Module<T> for ClipTextEncoder<T> {
1185 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1190 let h = self.encoder.forward(input)?;
1191 self.final_layer_norm.forward(&h)
1192 }
1193
1194 fn parameters(&self) -> Vec<&Parameter<T>> {
1195 let mut out = Vec::new();
1196 out.extend(self.embeddings.parameters());
1197 out.extend(self.encoder.parameters());
1198 out.extend(self.final_layer_norm.parameters());
1199 out
1200 }
1201
1202 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1203 let mut out = Vec::new();
1204 out.extend(self.embeddings.parameters_mut());
1205 out.extend(self.encoder.parameters_mut());
1206 out.extend(self.final_layer_norm.parameters_mut());
1207 out
1208 }
1209
1210 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1211 let mut out = Vec::new();
1212 for (n, p) in self.embeddings.named_parameters() {
1213 out.push((format!("embeddings.{n}"), p));
1214 }
1215 for (n, p) in self.encoder.named_parameters() {
1216 out.push((format!("encoder.{n}"), p));
1217 }
1218 for (n, p) in self.final_layer_norm.named_parameters() {
1219 out.push((format!("final_layer_norm.{n}"), p));
1220 }
1221 out
1222 }
1223
1224 fn train(&mut self) {
1225 self.training = true;
1226 self.embeddings.train();
1227 self.encoder.train();
1228 self.final_layer_norm.train();
1229 }
1230
1231 fn eval(&mut self) {
1232 self.training = false;
1233 self.embeddings.eval();
1234 self.encoder.eval();
1235 self.final_layer_norm.eval();
1236 }
1237
1238 fn is_training(&self) -> bool {
1239 self.training
1240 }
1241
1242 fn state_dict(&self) -> StateDict<T> {
1243 self.named_parameters()
1244 .into_iter()
1245 .map(|(n, p)| (n, p.tensor().clone()))
1246 .collect()
1247 }
1248
1249 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1250 let extract = |prefix: &str| -> StateDict<T> {
1251 let p = format!("{prefix}.");
1252 state
1253 .iter()
1254 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1255 .collect()
1256 };
1257 if strict {
1258 for k in state.keys() {
1259 if !(k.starts_with("embeddings.")
1260 || k.starts_with("encoder.")
1261 || k.starts_with("final_layer_norm."))
1262 {
1263 return Err(FerrotorchError::InvalidArgument {
1264 message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
1265 });
1266 }
1267 }
1268 }
1269 self.embeddings
1270 .load_state_dict(&extract("embeddings"), strict)?;
1271 self.encoder.load_state_dict(&extract("encoder"), strict)?;
1272 self.final_layer_norm
1273 .load_state_dict(&extract("final_layer_norm"), strict)?;
1274 Ok(())
1275 }
1276}
1277
1278#[allow(dead_code)]
1282fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1283 mul(a, b)
1284}
1285
1286#[cfg(test)]
1291mod tests {
1292 use super::*;
1293
1294 fn tiny_cfg() -> ClipTextConfig {
1295 ClipTextConfig {
1298 hidden_size: 8,
1299 intermediate_size: 16,
1300 num_attention_heads: 2,
1301 num_hidden_layers: 1,
1302 max_position_embeddings: 6,
1303 vocab_size: 32,
1304 layer_norm_eps: 1e-5,
1305 }
1306 }
1307
1308 #[test]
1309 fn sd_v1_5_config_is_canonical() {
1310 let c = ClipTextConfig::sd_v1_5();
1311 assert_eq!(c.hidden_size, 768);
1312 assert_eq!(c.intermediate_size, 3072);
1313 assert_eq!(c.num_attention_heads, 12);
1314 assert_eq!(c.num_hidden_layers, 12);
1315 assert_eq!(c.max_position_embeddings, 77);
1316 assert_eq!(c.vocab_size, 49408);
1317 assert_eq!(c.head_dim(), 64);
1318 c.validate().unwrap();
1319 }
1320
1321 #[test]
1322 fn validate_catches_bad_head_count() {
1323 let mut c = tiny_cfg();
1324 c.num_attention_heads = 3; assert!(c.validate().is_err());
1326 }
1327
1328 #[test]
1329 fn from_json_str_round_trip() {
1330 let json = r#"{
1331 "hidden_size": 768,
1332 "intermediate_size": 3072,
1333 "num_attention_heads": 12,
1334 "num_hidden_layers": 12,
1335 "max_position_embeddings": 77,
1336 "vocab_size": 49408,
1337 "layer_norm_eps": 1e-5,
1338 "hidden_act": "quick_gelu"
1339 }"#;
1340 let c = ClipTextConfig::from_json_str(json).unwrap();
1341 assert_eq!(c.hidden_size, 768);
1342 assert_eq!(c.intermediate_size, 3072);
1343 assert_eq!(c.num_attention_heads, 12);
1344 assert_eq!(c.num_hidden_layers, 12);
1345 assert_eq!(c.max_position_embeddings, 77);
1346 }
1347
1348 #[test]
1349 fn embeddings_forward_shape() {
1350 let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1351 let ids = [1u32, 5, 7, 9];
1352 let out = emb.forward_from_ids(&ids).unwrap();
1353 assert_eq!(out.shape(), &[1, 4, 8]);
1354 for &v in out.data().unwrap() {
1355 assert!(v.is_finite(), "embedding non-finite: {v}");
1356 }
1357 }
1358
1359 #[test]
1360 fn embeddings_reject_too_long_sequence() {
1361 let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1362 let ids: Vec<u32> = (0..7).collect(); assert!(emb.forward_from_ids(&ids).is_err());
1364 }
1365
1366 #[test]
1367 fn self_attention_forward_shape() {
1368 let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1369 let x = Tensor::from_storage(
1370 TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1371 vec![1, 5, 8],
1372 false,
1373 )
1374 .unwrap();
1375 let out = attn.forward(&x).unwrap();
1376 assert_eq!(out.shape(), &[1, 5, 8]);
1377 for &v in out.data().unwrap() {
1378 assert!(v.is_finite());
1379 }
1380 }
1381
1382 #[test]
1383 fn self_attention_is_actually_causal() {
1384 let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1390 let mut a = vec![0.1f32; 4 * 8];
1391 for i in 0..2 * 8 {
1392 a[i] = ((i + 1) as f32).sin();
1393 }
1394 let mut b = a.clone();
1395 for i in (2 * 8)..(4 * 8) {
1397 b[i] = ((i + 11) as f32).sin();
1398 }
1399 let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
1400 let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
1401 let oa = attn.forward(&xa).unwrap();
1402 let ob = attn.forward(&xb).unwrap();
1403 let da = oa.data().unwrap();
1404 let db = ob.data().unwrap();
1405 for i in 0..2 * 8 {
1406 assert!(
1407 (da[i] - db[i]).abs() < 1e-5,
1408 "row {} ({}) differs between runs: {} vs {}",
1409 i / 8,
1410 i % 8,
1411 da[i],
1412 db[i]
1413 );
1414 }
1415 }
1416
1417 #[test]
1418 fn mlp_uses_quick_gelu() {
1419 let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
1424 let x = Tensor::from_storage(
1425 TensorStorage::cpu(vec![0.0f32; 3 * 8]),
1426 vec![1, 3, 8],
1427 false,
1428 )
1429 .unwrap();
1430 let out = mlp.forward(&x).unwrap();
1431 assert_eq!(out.shape(), &[1, 3, 8]);
1432 for &v in out.data().unwrap() {
1433 assert!(v.is_finite());
1434 }
1435 }
1436
1437 #[test]
1438 fn encoder_layer_forward_shape() {
1439 let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1440 let x = Tensor::from_storage(
1441 TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1442 vec![1, 5, 8],
1443 false,
1444 )
1445 .unwrap();
1446 let out = layer.forward(&x).unwrap();
1447 assert_eq!(out.shape(), &[1, 5, 8]);
1448 for &v in out.data().unwrap() {
1449 assert!(v.is_finite());
1450 }
1451 }
1452
1453 #[test]
1454 fn encoder_layer_named_parameters_use_hf_layout() {
1455 let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1456 let names: Vec<String> = layer
1457 .named_parameters()
1458 .into_iter()
1459 .map(|(n, _)| n)
1460 .collect();
1461 for k in [
1462 "layer_norm1.weight",
1463 "layer_norm1.bias",
1464 "self_attn.q_proj.weight",
1465 "self_attn.q_proj.bias",
1466 "self_attn.k_proj.weight",
1467 "self_attn.v_proj.weight",
1468 "self_attn.out_proj.weight",
1469 "self_attn.out_proj.bias",
1470 "layer_norm2.weight",
1471 "mlp.fc1.weight",
1472 "mlp.fc1.bias",
1473 "mlp.fc2.weight",
1474 "mlp.fc2.bias",
1475 ] {
1476 assert!(
1477 names.iter().any(|n| n == k),
1478 "missing parameter key {k:?} in {names:?}"
1479 );
1480 }
1481 }
1482
1483 #[test]
1484 fn tiny_encoder_forward_from_ids_shape() {
1485 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1486 let ids = vec![1u32, 5, 7];
1487 let out = enc.forward_from_ids(&ids).unwrap();
1488 assert_eq!(out.shape(), &[1, 3, 8]);
1489 for &v in out.data().unwrap() {
1490 assert!(v.is_finite());
1491 }
1492 }
1493
1494 #[test]
1495 fn tiny_named_parameters_use_hf_layout() {
1496 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1497 let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
1498 for k in [
1499 "embeddings.token_embedding.weight",
1500 "embeddings.position_embedding.weight",
1501 "encoder.layers.0.layer_norm1.weight",
1502 "encoder.layers.0.self_attn.q_proj.weight",
1503 "encoder.layers.0.self_attn.out_proj.bias",
1504 "encoder.layers.0.layer_norm2.bias",
1505 "encoder.layers.0.mlp.fc1.weight",
1506 "encoder.layers.0.mlp.fc2.bias",
1507 "final_layer_norm.weight",
1508 "final_layer_norm.bias",
1509 ] {
1510 assert!(
1511 names.iter().any(|n| n == k),
1512 "missing parameter key {k:?} in {names:?}"
1513 );
1514 }
1515 }
1516
1517 #[test]
1518 fn round_trip_state_dict() {
1519 let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1520 let sd = src.state_dict();
1521 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1522 dst.load_state_dict(&sd, true).unwrap();
1523 let ids = vec![2u32, 4, 6];
1524 let a = src.forward_from_ids(&ids).unwrap();
1525 let b = dst.forward_from_ids(&ids).unwrap();
1526 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1527 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
1528 }
1529 }
1530
1531 #[test]
1532 fn load_hf_state_dict_strips_text_model_prefix() {
1533 let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1534 let bare = src.state_dict();
1535 let mut prefixed: StateDict<f32> = HashMap::new();
1536 for (k, v) in bare {
1537 prefixed.insert(format!("text_model.{k}"), v);
1538 }
1539 prefixed.insert(
1541 "text_model.embeddings.position_ids".into(),
1542 ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
1543 );
1544 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1545 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
1546 assert_eq!(
1547 rep.dropped,
1548 vec!["text_model.embeddings.position_ids".to_string()]
1549 );
1550 let ids = vec![1u32, 2, 3];
1551 let a = src.forward_from_ids(&ids).unwrap();
1552 let b = dst.forward_from_ids(&ids).unwrap();
1553 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1554 assert!((x - y).abs() < 1e-5);
1555 }
1556 }
1557
1558 #[test]
1559 fn load_hf_state_dict_strict_rejects_unknown_key() {
1560 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1561 let mut sd: StateDict<f32> = HashMap::new();
1562 sd.insert(
1563 "mystery.key".into(),
1564 ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
1565 );
1566 assert!(dst.load_hf_state_dict(&sd, true).is_err());
1567 }
1568
1569 #[test]
1570 fn forward_from_id_tensor_matches_forward_from_ids() {
1571 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1572 let ids = vec![1u32, 5, 7];
1573 let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
1574 let a = enc.forward_from_ids(&ids).unwrap();
1575 let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
1576 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1577 assert!((x - y).abs() < 1e-5);
1578 }
1579 }
1580}