1use std::collections::HashMap;
98
99use ferrotorch_core::grad_fns::arithmetic::{add, mul};
100use ferrotorch_core::{
101 FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
102};
103use ferrotorch_nn::module::{Module, StateDict};
104use ferrotorch_nn::parameter::Parameter;
105use ferrotorch_nn::{
106 Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
107 transpose_heads_to_2d,
108};
109
110#[derive(Debug, Clone)]
117pub struct ClipTextConfig {
118 pub hidden_size: usize,
120 pub intermediate_size: usize,
122 pub num_attention_heads: usize,
125 pub num_hidden_layers: usize,
127 pub max_position_embeddings: usize,
129 pub vocab_size: usize,
131 pub layer_norm_eps: f64,
133}
134
135impl Default for ClipTextConfig {
136 fn default() -> Self {
137 Self::sd_v1_5()
138 }
139}
140
141impl ClipTextConfig {
142 pub fn sd_v1_5() -> Self {
144 Self {
145 hidden_size: 768,
146 intermediate_size: 3072,
147 num_attention_heads: 12,
148 num_hidden_layers: 12,
149 max_position_embeddings: 77,
150 vocab_size: 49408,
151 layer_norm_eps: 1e-5,
152 }
153 }
154
155 #[inline]
157 #[must_use]
158 pub fn head_dim(&self) -> usize {
159 self.hidden_size / self.num_attention_heads
160 }
161
162 pub fn validate(&self) -> FerrotorchResult<()> {
169 if self.hidden_size == 0
170 || self.intermediate_size == 0
171 || self.num_attention_heads == 0
172 || self.num_hidden_layers == 0
173 || self.max_position_embeddings == 0
174 || self.vocab_size == 0
175 {
176 return Err(FerrotorchError::InvalidArgument {
177 message: "ClipTextConfig: all size fields must be > 0".into(),
178 });
179 }
180 if self.hidden_size % self.num_attention_heads != 0 {
181 return Err(FerrotorchError::InvalidArgument {
182 message: format!(
183 "ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
184 self.hidden_size, self.num_attention_heads,
185 ),
186 });
187 }
188 if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
189 return Err(FerrotorchError::InvalidArgument {
190 message: format!(
191 "ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
192 self.layer_norm_eps,
193 ),
194 });
195 }
196 Ok(())
197 }
198
199 pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
211 let v: serde_json::Value =
212 serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
213 message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
214 })?;
215 let mut cfg = Self::default();
216 if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
217 cfg.hidden_size = x as usize;
218 }
219 if let Some(x) = v
220 .get("intermediate_size")
221 .and_then(serde_json::Value::as_u64)
222 {
223 cfg.intermediate_size = x as usize;
224 }
225 if let Some(x) = v
226 .get("num_attention_heads")
227 .and_then(serde_json::Value::as_u64)
228 {
229 cfg.num_attention_heads = x as usize;
230 }
231 if let Some(x) = v
232 .get("num_hidden_layers")
233 .and_then(serde_json::Value::as_u64)
234 {
235 cfg.num_hidden_layers = x as usize;
236 }
237 if let Some(x) = v
238 .get("max_position_embeddings")
239 .and_then(serde_json::Value::as_u64)
240 {
241 cfg.max_position_embeddings = x as usize;
242 }
243 if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
244 cfg.vocab_size = x as usize;
245 }
246 if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
247 cfg.layer_norm_eps = x;
248 }
249 cfg.validate()?;
250 Ok(cfg)
251 }
252
253 pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
260 let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
261 message: format!(
262 "ClipTextConfig::from_file: failed to read {}: {e}",
263 path.display(),
264 ),
265 })?;
266 Self::from_json_str(&s)
267 }
268}
269
270fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
276 let prod: usize = shape.iter().product();
277 if prod != t.numel() {
278 return Err(FerrotorchError::ShapeMismatch {
279 message: format!(
280 "ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
281 match source numel {}",
282 t.numel()
283 ),
284 });
285 }
286 let data = t.data_vec()?;
287 Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
288}
289
290fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
293 let data: Vec<T> = ids
294 .iter()
295 .map(|&i| numeric_cast::cast::<u32, T>(i))
296 .collect::<FerrotorchResult<Vec<T>>>()?;
297 let n = data.len();
298 Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
299}
300
301#[derive(Debug)]
311pub struct ClipTextEmbeddings<T: Float> {
312 pub token_embedding: Embedding<T>,
314 pub position_embedding: Embedding<T>,
316 hidden_size: usize,
317 max_position_embeddings: usize,
318 training: bool,
319}
320
321impl<T: Float> ClipTextEmbeddings<T> {
322 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
329 cfg.validate()?;
330 Ok(Self {
331 token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
332 position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
333 hidden_size: cfg.hidden_size,
334 max_position_embeddings: cfg.max_position_embeddings,
335 training: false,
336 })
337 }
338
339 pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
350 if input_ids.is_empty() {
351 return Err(FerrotorchError::InvalidArgument {
352 message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
353 });
354 }
355 let seq_len = input_ids.len();
356 if seq_len > self.max_position_embeddings {
357 return Err(FerrotorchError::InvalidArgument {
358 message: format!(
359 "ClipTextEmbeddings: sequence length {seq_len} exceeds \
360 max_position_embeddings {}",
361 self.max_position_embeddings,
362 ),
363 });
364 }
365
366 let word_idx = float_index_tensor::<T>(input_ids)?;
367 let word_2d = self.token_embedding.forward(&word_idx)?; let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
370 let pos_idx = float_index_tensor::<T>(&pos_ids)?;
371 let pos_2d = self.position_embedding.forward(&pos_idx)?; let summed = add(&word_2d, &pos_2d)?;
374 reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
376 }
377}
378
379impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
380 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
385 let word_2d = self.token_embedding.forward(input)?;
386 let seq_len = input.numel();
387 let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
388 let pos_idx = float_index_tensor::<T>(&pos_ids)?;
389 let pos_2d = self.position_embedding.forward(&pos_idx)?;
390 let summed = add(&word_2d, &pos_2d)?;
391 reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
392 }
393
394 fn parameters(&self) -> Vec<&Parameter<T>> {
395 let mut out = Vec::new();
396 out.extend(self.token_embedding.parameters());
397 out.extend(self.position_embedding.parameters());
398 out
399 }
400
401 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
402 let mut out = Vec::new();
403 out.extend(self.token_embedding.parameters_mut());
404 out.extend(self.position_embedding.parameters_mut());
405 out
406 }
407
408 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
409 let mut out = Vec::new();
410 for (n, p) in self.token_embedding.named_parameters() {
411 out.push((format!("token_embedding.{n}"), p));
412 }
413 for (n, p) in self.position_embedding.named_parameters() {
414 out.push((format!("position_embedding.{n}"), p));
415 }
416 out
417 }
418
419 fn train(&mut self) {
420 self.training = true;
421 }
422
423 fn eval(&mut self) {
424 self.training = false;
425 }
426
427 fn is_training(&self) -> bool {
428 self.training
429 }
430
431 fn state_dict(&self) -> StateDict<T> {
432 self.named_parameters()
433 .into_iter()
434 .map(|(n, p)| (n, p.tensor().clone()))
435 .collect()
436 }
437
438 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
439 let extract = |prefix: &str| -> StateDict<T> {
440 let p = format!("{prefix}.");
441 state
442 .iter()
443 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
444 .collect()
445 };
446 if strict {
447 let prefixes = ["token_embedding", "position_embedding"];
448 for k in state.keys() {
449 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
450 return Err(FerrotorchError::InvalidArgument {
451 message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
452 });
453 }
454 }
455 }
456 self.token_embedding
457 .load_state_dict(&extract("token_embedding"), strict)?;
458 self.position_embedding
459 .load_state_dict(&extract("position_embedding"), strict)?;
460 Ok(())
461 }
462}
463
464#[derive(Debug)]
483pub struct ClipSelfAttention<T: Float> {
484 pub q_proj: Linear<T>,
486 pub k_proj: Linear<T>,
488 pub v_proj: Linear<T>,
490 pub out_proj: Linear<T>,
492 num_heads: usize,
493 head_dim: usize,
494 hidden: usize,
495 training: bool,
496}
497
498impl<T: Float> ClipSelfAttention<T> {
499 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
505 cfg.validate()?;
506 Ok(Self {
507 q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
508 k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
509 v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
510 out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
511 num_heads: cfg.num_attention_heads,
512 head_dim: cfg.head_dim(),
513 hidden: cfg.hidden_size,
514 training: false,
515 })
516 }
517}
518
519impl<T: Float> Module<T> for ClipSelfAttention<T> {
520 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
526 let shape = input.shape();
527 if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
528 return Err(FerrotorchError::ShapeMismatch {
529 message: format!(
530 "ClipSelfAttention expects [1, S, {}], got {:?}",
531 self.hidden, shape,
532 ),
533 });
534 }
535 let seq_len = shape[1];
536
537 let q = self.q_proj.forward(input)?;
539 let k = self.k_proj.forward(input)?;
540 let v = self.v_proj.forward(input)?;
541
542 let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
545 let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
546 let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
547
548 let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
550 let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
551 let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
552
553 let ctx = standard_attention(&q_h, &k_h, &v_h, true)?;
557
558 let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
560 let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
561
562 self.out_proj.forward(&ctx3)
564 }
565
566 fn parameters(&self) -> Vec<&Parameter<T>> {
567 let mut out = Vec::new();
568 out.extend(self.q_proj.parameters());
569 out.extend(self.k_proj.parameters());
570 out.extend(self.v_proj.parameters());
571 out.extend(self.out_proj.parameters());
572 out
573 }
574
575 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
576 let mut out = Vec::new();
577 out.extend(self.q_proj.parameters_mut());
578 out.extend(self.k_proj.parameters_mut());
579 out.extend(self.v_proj.parameters_mut());
580 out.extend(self.out_proj.parameters_mut());
581 out
582 }
583
584 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
585 let mut out = Vec::new();
586 for (n, p) in self.q_proj.named_parameters() {
587 out.push((format!("q_proj.{n}"), p));
588 }
589 for (n, p) in self.k_proj.named_parameters() {
590 out.push((format!("k_proj.{n}"), p));
591 }
592 for (n, p) in self.v_proj.named_parameters() {
593 out.push((format!("v_proj.{n}"), p));
594 }
595 for (n, p) in self.out_proj.named_parameters() {
596 out.push((format!("out_proj.{n}"), p));
597 }
598 out
599 }
600
601 fn train(&mut self) {
602 self.training = true;
603 }
604
605 fn eval(&mut self) {
606 self.training = false;
607 }
608
609 fn is_training(&self) -> bool {
610 self.training
611 }
612
613 fn state_dict(&self) -> StateDict<T> {
614 self.named_parameters()
615 .into_iter()
616 .map(|(n, p)| (n, p.tensor().clone()))
617 .collect()
618 }
619
620 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
621 let extract = |prefix: &str| -> StateDict<T> {
622 let p = format!("{prefix}.");
623 state
624 .iter()
625 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
626 .collect()
627 };
628 if strict {
629 let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
630 for k in state.keys() {
631 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
632 return Err(FerrotorchError::InvalidArgument {
633 message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
634 });
635 }
636 }
637 }
638 self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
639 self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
640 self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
641 self.out_proj
642 .load_state_dict(&extract("out_proj"), strict)?;
643 Ok(())
644 }
645}
646
647#[derive(Debug)]
658pub struct ClipMlp<T: Float> {
659 pub fc1: Linear<T>,
661 pub fc2: Linear<T>,
663 activation: GELU,
664 training: bool,
665}
666
667impl<T: Float> ClipMlp<T> {
668 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
674 cfg.validate()?;
675 Ok(Self {
676 fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
677 fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
678 activation: GELU::with_approximate(GeluApproximate::Sigmoid),
680 training: false,
681 })
682 }
683}
684
685impl<T: Float> Module<T> for ClipMlp<T> {
686 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
687 let h = self.fc1.forward(input)?;
688 let h = self.activation.forward(&h)?;
689 self.fc2.forward(&h)
690 }
691
692 fn parameters(&self) -> Vec<&Parameter<T>> {
693 let mut out = Vec::new();
694 out.extend(self.fc1.parameters());
695 out.extend(self.fc2.parameters());
696 out
697 }
698
699 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
700 let mut out = Vec::new();
701 out.extend(self.fc1.parameters_mut());
702 out.extend(self.fc2.parameters_mut());
703 out
704 }
705
706 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
707 let mut out = Vec::new();
708 for (n, p) in self.fc1.named_parameters() {
709 out.push((format!("fc1.{n}"), p));
710 }
711 for (n, p) in self.fc2.named_parameters() {
712 out.push((format!("fc2.{n}"), p));
713 }
714 out
715 }
716
717 fn train(&mut self) {
718 self.training = true;
719 }
720
721 fn eval(&mut self) {
722 self.training = false;
723 }
724
725 fn is_training(&self) -> bool {
726 self.training
727 }
728
729 fn state_dict(&self) -> StateDict<T> {
730 self.named_parameters()
731 .into_iter()
732 .map(|(n, p)| (n, p.tensor().clone()))
733 .collect()
734 }
735
736 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
737 let extract = |prefix: &str| -> StateDict<T> {
738 let p = format!("{prefix}.");
739 state
740 .iter()
741 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
742 .collect()
743 };
744 if strict {
745 for k in state.keys() {
746 if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
747 return Err(FerrotorchError::InvalidArgument {
748 message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
749 });
750 }
751 }
752 }
753 self.fc1.load_state_dict(&extract("fc1"), strict)?;
754 self.fc2.load_state_dict(&extract("fc2"), strict)?;
755 Ok(())
756 }
757}
758
759#[derive(Debug)]
772pub struct ClipEncoderLayer<T: Float> {
773 pub layer_norm1: LayerNorm<T>,
775 pub self_attn: ClipSelfAttention<T>,
777 pub layer_norm2: LayerNorm<T>,
779 pub mlp: ClipMlp<T>,
781 training: bool,
782}
783
784impl<T: Float> ClipEncoderLayer<T> {
785 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
791 Ok(Self {
792 layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
793 self_attn: ClipSelfAttention::new(cfg)?,
794 layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
795 mlp: ClipMlp::new(cfg)?,
796 training: false,
797 })
798 }
799}
800
801impl<T: Float> Module<T> for ClipEncoderLayer<T> {
802 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
804 let normed = self.layer_norm1.forward(input)?;
806 let attn_out = self.self_attn.forward(&normed)?;
807 let after_attn = add(input, &attn_out)?;
808
809 let normed_ffn = self.layer_norm2.forward(&after_attn)?;
811 let mlp_out = self.mlp.forward(&normed_ffn)?;
812 add(&after_attn, &mlp_out)
813 }
814
815 fn parameters(&self) -> Vec<&Parameter<T>> {
816 let mut out = Vec::new();
817 out.extend(self.layer_norm1.parameters());
818 out.extend(self.self_attn.parameters());
819 out.extend(self.layer_norm2.parameters());
820 out.extend(self.mlp.parameters());
821 out
822 }
823
824 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
825 let mut out = Vec::new();
826 out.extend(self.layer_norm1.parameters_mut());
827 out.extend(self.self_attn.parameters_mut());
828 out.extend(self.layer_norm2.parameters_mut());
829 out.extend(self.mlp.parameters_mut());
830 out
831 }
832
833 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
834 let mut out = Vec::new();
835 for (n, p) in self.layer_norm1.named_parameters() {
836 out.push((format!("layer_norm1.{n}"), p));
837 }
838 for (n, p) in self.self_attn.named_parameters() {
839 out.push((format!("self_attn.{n}"), p));
840 }
841 for (n, p) in self.layer_norm2.named_parameters() {
842 out.push((format!("layer_norm2.{n}"), p));
843 }
844 for (n, p) in self.mlp.named_parameters() {
845 out.push((format!("mlp.{n}"), p));
846 }
847 out
848 }
849
850 fn train(&mut self) {
851 self.training = true;
852 }
853
854 fn eval(&mut self) {
855 self.training = false;
856 }
857
858 fn is_training(&self) -> bool {
859 self.training
860 }
861
862 fn state_dict(&self) -> StateDict<T> {
863 self.named_parameters()
864 .into_iter()
865 .map(|(n, p)| (n, p.tensor().clone()))
866 .collect()
867 }
868
869 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
870 let extract = |prefix: &str| -> StateDict<T> {
871 let p = format!("{prefix}.");
872 state
873 .iter()
874 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
875 .collect()
876 };
877 if strict {
878 let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
879 for k in state.keys() {
880 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
881 return Err(FerrotorchError::InvalidArgument {
882 message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
883 });
884 }
885 }
886 }
887 self.layer_norm1
888 .load_state_dict(&extract("layer_norm1"), strict)?;
889 self.self_attn
890 .load_state_dict(&extract("self_attn"), strict)?;
891 self.layer_norm2
892 .load_state_dict(&extract("layer_norm2"), strict)?;
893 self.mlp.load_state_dict(&extract("mlp"), strict)?;
894 Ok(())
895 }
896}
897
898#[derive(Debug)]
904pub struct ClipEncoder<T: Float> {
905 pub layers: Vec<ClipEncoderLayer<T>>,
907 training: bool,
908}
909
910impl<T: Float> ClipEncoder<T> {
911 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
917 cfg.validate()?;
918 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
919 for _ in 0..cfg.num_hidden_layers {
920 layers.push(ClipEncoderLayer::new(cfg)?);
921 }
922 Ok(Self {
923 layers,
924 training: false,
925 })
926 }
927}
928
929impl<T: Float> Module<T> for ClipEncoder<T> {
930 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
931 let mut h = input.clone();
932 for l in &self.layers {
933 h = l.forward(&h)?;
934 }
935 Ok(h)
936 }
937
938 fn parameters(&self) -> Vec<&Parameter<T>> {
939 let mut out = Vec::new();
940 for l in &self.layers {
941 out.extend(l.parameters());
942 }
943 out
944 }
945
946 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
947 let mut out = Vec::new();
948 for l in &mut self.layers {
949 out.extend(l.parameters_mut());
950 }
951 out
952 }
953
954 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
955 let mut out = Vec::new();
956 for (i, l) in self.layers.iter().enumerate() {
957 for (n, p) in l.named_parameters() {
958 out.push((format!("layers.{i}.{n}"), p));
959 }
960 }
961 out
962 }
963
964 fn train(&mut self) {
965 self.training = true;
966 for l in &mut self.layers {
967 l.train();
968 }
969 }
970
971 fn eval(&mut self) {
972 self.training = false;
973 for l in &mut self.layers {
974 l.eval();
975 }
976 }
977
978 fn is_training(&self) -> bool {
979 self.training
980 }
981
982 fn state_dict(&self) -> StateDict<T> {
983 self.named_parameters()
984 .into_iter()
985 .map(|(n, p)| (n, p.tensor().clone()))
986 .collect()
987 }
988
989 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
990 let extract = |prefix: &str| -> StateDict<T> {
991 let p = format!("{prefix}.");
992 state
993 .iter()
994 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
995 .collect()
996 };
997 if strict {
998 for k in state.keys() {
999 if !k.starts_with("layers.") {
1000 return Err(FerrotorchError::InvalidArgument {
1001 message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
1002 });
1003 }
1004 }
1005 }
1006 for (i, l) in self.layers.iter_mut().enumerate() {
1007 l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
1008 }
1009 Ok(())
1010 }
1011}
1012
1013#[derive(Debug)]
1031pub struct ClipTextEncoder<T: Float> {
1032 pub embeddings: ClipTextEmbeddings<T>,
1034 pub encoder: ClipEncoder<T>,
1036 pub final_layer_norm: LayerNorm<T>,
1038 pub config: ClipTextConfig,
1040 training: bool,
1041}
1042
1043impl<T: Float> ClipTextEncoder<T> {
1044 pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
1051 cfg.validate()?;
1052 let embeddings = ClipTextEmbeddings::new(&cfg)?;
1053 let encoder = ClipEncoder::new(&cfg)?;
1054 let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
1055 Ok(Self {
1056 embeddings,
1057 encoder,
1058 final_layer_norm,
1059 config: cfg,
1060 training: false,
1061 })
1062 }
1063
1064 pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
1077 let h = self.embeddings.forward_from_ids(input_ids)?;
1078 let h = self.encoder.forward(&h)?;
1079 self.final_layer_norm.forward(&h)
1080 }
1081
1082 pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1095 if ids.ndim() != 1 {
1098 return Err(FerrotorchError::ShapeMismatch {
1099 message: format!(
1100 "ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
1101 ids.shape()
1102 ),
1103 });
1104 }
1105 let data = ids.data_vec()?;
1106 let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
1107 for (i, v) in data.iter().enumerate() {
1108 let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
1109 FerrotorchError::InvalidArgument {
1110 message: format!(
1111 "ClipTextEncoder::forward_from_id_tensor: id at {i} \
1112 not representable as f64"
1113 ),
1114 }
1115 })?;
1116 if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
1117 return Err(FerrotorchError::InvalidArgument {
1118 message: format!(
1119 "ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
1120 is not a non-negative integer"
1121 ),
1122 });
1123 }
1124 u32_ids.push(f as u32);
1125 }
1126 self.forward_from_ids(&u32_ids)
1127 }
1128
1129 pub fn load_hf_state_dict(
1152 &mut self,
1153 hf_state: &StateDict<T>,
1154 strict: bool,
1155 ) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
1156 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
1157 let mut dropped: Vec<String> = Vec::new();
1158 for (k, v) in hf_state {
1159 let after = k
1161 .strip_prefix("text_model.")
1162 .map_or_else(|| k.clone(), str::to_owned);
1163
1164 if after == "embeddings.position_ids" {
1167 dropped.push(k.clone());
1168 continue;
1169 }
1170
1171 let is_known = after.starts_with("embeddings.token_embedding.")
1172 || after.starts_with("embeddings.position_embedding.")
1173 || after.starts_with("encoder.")
1174 || after.starts_with("final_layer_norm.");
1175 if is_known {
1176 remapped.insert(after, v.clone());
1177 continue;
1178 }
1179
1180 if strict {
1181 return Err(FerrotorchError::InvalidArgument {
1182 message: format!(
1183 "ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
1184 known CLIP text-tower parameter and strict mode is on. \
1185 Pass strict=false to drop unknown keys."
1186 ),
1187 });
1188 }
1189 dropped.push(k.clone());
1190 }
1191 dropped.sort();
1192 self.load_state_dict(&remapped, strict)?;
1193 Ok(crate::safetensors_loader::DropReport { dropped })
1194 }
1195}
1196
1197impl<T: Float> Module<T> for ClipTextEncoder<T> {
1198 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1203 let h = self.encoder.forward(input)?;
1204 self.final_layer_norm.forward(&h)
1205 }
1206
1207 fn parameters(&self) -> Vec<&Parameter<T>> {
1208 let mut out = Vec::new();
1209 out.extend(self.embeddings.parameters());
1210 out.extend(self.encoder.parameters());
1211 out.extend(self.final_layer_norm.parameters());
1212 out
1213 }
1214
1215 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1216 let mut out = Vec::new();
1217 out.extend(self.embeddings.parameters_mut());
1218 out.extend(self.encoder.parameters_mut());
1219 out.extend(self.final_layer_norm.parameters_mut());
1220 out
1221 }
1222
1223 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1224 let mut out = Vec::new();
1225 for (n, p) in self.embeddings.named_parameters() {
1226 out.push((format!("embeddings.{n}"), p));
1227 }
1228 for (n, p) in self.encoder.named_parameters() {
1229 out.push((format!("encoder.{n}"), p));
1230 }
1231 for (n, p) in self.final_layer_norm.named_parameters() {
1232 out.push((format!("final_layer_norm.{n}"), p));
1233 }
1234 out
1235 }
1236
1237 fn train(&mut self) {
1238 self.training = true;
1239 self.embeddings.train();
1240 self.encoder.train();
1241 self.final_layer_norm.train();
1242 }
1243
1244 fn eval(&mut self) {
1245 self.training = false;
1246 self.embeddings.eval();
1247 self.encoder.eval();
1248 self.final_layer_norm.eval();
1249 }
1250
1251 fn is_training(&self) -> bool {
1252 self.training
1253 }
1254
1255 fn state_dict(&self) -> StateDict<T> {
1256 self.named_parameters()
1257 .into_iter()
1258 .map(|(n, p)| (n, p.tensor().clone()))
1259 .collect()
1260 }
1261
1262 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1263 let extract = |prefix: &str| -> StateDict<T> {
1264 let p = format!("{prefix}.");
1265 state
1266 .iter()
1267 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1268 .collect()
1269 };
1270 if strict {
1271 for k in state.keys() {
1272 if !(k.starts_with("embeddings.")
1273 || k.starts_with("encoder.")
1274 || k.starts_with("final_layer_norm."))
1275 {
1276 return Err(FerrotorchError::InvalidArgument {
1277 message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
1278 });
1279 }
1280 }
1281 }
1282 self.embeddings
1283 .load_state_dict(&extract("embeddings"), strict)?;
1284 self.encoder.load_state_dict(&extract("encoder"), strict)?;
1285 self.final_layer_norm
1286 .load_state_dict(&extract("final_layer_norm"), strict)?;
1287 Ok(())
1288 }
1289}
1290
1291#[allow(dead_code)]
1295fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1296 mul(a, b)
1297}
1298
1299#[cfg(test)]
1304mod tests {
1305 use super::*;
1306
1307 fn tiny_cfg() -> ClipTextConfig {
1308 ClipTextConfig {
1311 hidden_size: 8,
1312 intermediate_size: 16,
1313 num_attention_heads: 2,
1314 num_hidden_layers: 1,
1315 max_position_embeddings: 6,
1316 vocab_size: 32,
1317 layer_norm_eps: 1e-5,
1318 }
1319 }
1320
1321 #[test]
1322 fn sd_v1_5_config_is_canonical() {
1323 let c = ClipTextConfig::sd_v1_5();
1324 assert_eq!(c.hidden_size, 768);
1325 assert_eq!(c.intermediate_size, 3072);
1326 assert_eq!(c.num_attention_heads, 12);
1327 assert_eq!(c.num_hidden_layers, 12);
1328 assert_eq!(c.max_position_embeddings, 77);
1329 assert_eq!(c.vocab_size, 49408);
1330 assert_eq!(c.head_dim(), 64);
1331 c.validate().unwrap();
1332 }
1333
1334 #[test]
1335 fn validate_catches_bad_head_count() {
1336 let mut c = tiny_cfg();
1337 c.num_attention_heads = 3; assert!(c.validate().is_err());
1339 }
1340
1341 #[test]
1342 fn from_json_str_round_trip() {
1343 let json = r#"{
1344 "hidden_size": 768,
1345 "intermediate_size": 3072,
1346 "num_attention_heads": 12,
1347 "num_hidden_layers": 12,
1348 "max_position_embeddings": 77,
1349 "vocab_size": 49408,
1350 "layer_norm_eps": 1e-5,
1351 "hidden_act": "quick_gelu"
1352 }"#;
1353 let c = ClipTextConfig::from_json_str(json).unwrap();
1354 assert_eq!(c.hidden_size, 768);
1355 assert_eq!(c.intermediate_size, 3072);
1356 assert_eq!(c.num_attention_heads, 12);
1357 assert_eq!(c.num_hidden_layers, 12);
1358 assert_eq!(c.max_position_embeddings, 77);
1359 }
1360
1361 #[test]
1362 fn embeddings_forward_shape() {
1363 let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1364 let ids = [1u32, 5, 7, 9];
1365 let out = emb.forward_from_ids(&ids).unwrap();
1366 assert_eq!(out.shape(), &[1, 4, 8]);
1367 for &v in out.data().unwrap() {
1368 assert!(v.is_finite(), "embedding non-finite: {v}");
1369 }
1370 }
1371
1372 #[test]
1373 fn embeddings_reject_too_long_sequence() {
1374 let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1375 let ids: Vec<u32> = (0..7).collect(); assert!(emb.forward_from_ids(&ids).is_err());
1377 }
1378
1379 #[test]
1380 fn self_attention_forward_shape() {
1381 let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1382 let x = Tensor::from_storage(
1383 TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1384 vec![1, 5, 8],
1385 false,
1386 )
1387 .unwrap();
1388 let out = attn.forward(&x).unwrap();
1389 assert_eq!(out.shape(), &[1, 5, 8]);
1390 for &v in out.data().unwrap() {
1391 assert!(v.is_finite());
1392 }
1393 }
1394
1395 #[test]
1396 fn self_attention_is_actually_causal() {
1397 let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1403 let mut a = vec![0.1f32; 4 * 8];
1404 for i in 0..2 * 8 {
1405 a[i] = ((i + 1) as f32).sin();
1406 }
1407 let mut b = a.clone();
1408 for i in (2 * 8)..(4 * 8) {
1410 b[i] = ((i + 11) as f32).sin();
1411 }
1412 let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
1413 let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
1414 let oa = attn.forward(&xa).unwrap();
1415 let ob = attn.forward(&xb).unwrap();
1416 let da = oa.data().unwrap();
1417 let db = ob.data().unwrap();
1418 for i in 0..2 * 8 {
1419 assert!(
1420 (da[i] - db[i]).abs() < 1e-5,
1421 "row {} ({}) differs between runs: {} vs {}",
1422 i / 8,
1423 i % 8,
1424 da[i],
1425 db[i]
1426 );
1427 }
1428 }
1429
1430 #[test]
1431 fn mlp_uses_quick_gelu() {
1432 let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
1437 let x = Tensor::from_storage(
1438 TensorStorage::cpu(vec![0.0f32; 3 * 8]),
1439 vec![1, 3, 8],
1440 false,
1441 )
1442 .unwrap();
1443 let out = mlp.forward(&x).unwrap();
1444 assert_eq!(out.shape(), &[1, 3, 8]);
1445 for &v in out.data().unwrap() {
1446 assert!(v.is_finite());
1447 }
1448 }
1449
1450 #[test]
1451 fn encoder_layer_forward_shape() {
1452 let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1453 let x = Tensor::from_storage(
1454 TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1455 vec![1, 5, 8],
1456 false,
1457 )
1458 .unwrap();
1459 let out = layer.forward(&x).unwrap();
1460 assert_eq!(out.shape(), &[1, 5, 8]);
1461 for &v in out.data().unwrap() {
1462 assert!(v.is_finite());
1463 }
1464 }
1465
1466 #[test]
1467 fn encoder_layer_named_parameters_use_hf_layout() {
1468 let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1469 let names: Vec<String> = layer
1470 .named_parameters()
1471 .into_iter()
1472 .map(|(n, _)| n)
1473 .collect();
1474 for k in [
1475 "layer_norm1.weight",
1476 "layer_norm1.bias",
1477 "self_attn.q_proj.weight",
1478 "self_attn.q_proj.bias",
1479 "self_attn.k_proj.weight",
1480 "self_attn.v_proj.weight",
1481 "self_attn.out_proj.weight",
1482 "self_attn.out_proj.bias",
1483 "layer_norm2.weight",
1484 "mlp.fc1.weight",
1485 "mlp.fc1.bias",
1486 "mlp.fc2.weight",
1487 "mlp.fc2.bias",
1488 ] {
1489 assert!(
1490 names.iter().any(|n| n == k),
1491 "missing parameter key {k:?} in {names:?}"
1492 );
1493 }
1494 }
1495
1496 #[test]
1497 fn tiny_encoder_forward_from_ids_shape() {
1498 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1499 let ids = vec![1u32, 5, 7];
1500 let out = enc.forward_from_ids(&ids).unwrap();
1501 assert_eq!(out.shape(), &[1, 3, 8]);
1502 for &v in out.data().unwrap() {
1503 assert!(v.is_finite());
1504 }
1505 }
1506
1507 #[test]
1508 fn tiny_named_parameters_use_hf_layout() {
1509 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1510 let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
1511 for k in [
1512 "embeddings.token_embedding.weight",
1513 "embeddings.position_embedding.weight",
1514 "encoder.layers.0.layer_norm1.weight",
1515 "encoder.layers.0.self_attn.q_proj.weight",
1516 "encoder.layers.0.self_attn.out_proj.bias",
1517 "encoder.layers.0.layer_norm2.bias",
1518 "encoder.layers.0.mlp.fc1.weight",
1519 "encoder.layers.0.mlp.fc2.bias",
1520 "final_layer_norm.weight",
1521 "final_layer_norm.bias",
1522 ] {
1523 assert!(
1524 names.iter().any(|n| n == k),
1525 "missing parameter key {k:?} in {names:?}"
1526 );
1527 }
1528 }
1529
1530 #[test]
1531 fn round_trip_state_dict() {
1532 let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1533 let sd = src.state_dict();
1534 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1535 dst.load_state_dict(&sd, true).unwrap();
1536 let ids = vec![2u32, 4, 6];
1537 let a = src.forward_from_ids(&ids).unwrap();
1538 let b = dst.forward_from_ids(&ids).unwrap();
1539 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1540 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
1541 }
1542 }
1543
1544 #[test]
1545 fn load_hf_state_dict_strips_text_model_prefix() {
1546 let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1547 let bare = src.state_dict();
1548 let mut prefixed: StateDict<f32> = HashMap::new();
1549 for (k, v) in bare {
1550 prefixed.insert(format!("text_model.{k}"), v);
1551 }
1552 prefixed.insert(
1554 "text_model.embeddings.position_ids".into(),
1555 ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
1556 );
1557 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1558 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
1559 assert_eq!(
1560 rep.dropped,
1561 vec!["text_model.embeddings.position_ids".to_string()]
1562 );
1563 let ids = vec![1u32, 2, 3];
1564 let a = src.forward_from_ids(&ids).unwrap();
1565 let b = dst.forward_from_ids(&ids).unwrap();
1566 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1567 assert!((x - y).abs() < 1e-5);
1568 }
1569 }
1570
1571 #[test]
1572 fn load_hf_state_dict_strict_rejects_unknown_key() {
1573 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1574 let mut sd: StateDict<f32> = HashMap::new();
1575 sd.insert(
1576 "mystery.key".into(),
1577 ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
1578 );
1579 assert!(dst.load_hf_state_dict(&sd, true).is_err());
1580 }
1581
1582 #[test]
1583 fn forward_from_id_tensor_matches_forward_from_ids() {
1584 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1585 let ids = vec![1u32, 5, 7];
1586 let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
1587 let a = enc.forward_from_ids(&ids).unwrap();
1588 let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
1589 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1590 assert!((x - y).abs() < 1e-5);
1591 }
1592 }
1593}