1use std::collections::HashMap;
85
86use ferrotorch_core::grad_fns::arithmetic::{add, mul};
87use ferrotorch_core::{
88 FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage, numeric_cast,
89};
90use ferrotorch_nn::module::{Module, StateDict};
91use ferrotorch_nn::parameter::Parameter;
92use ferrotorch_nn::{
93 Embedding, GELU, GeluApproximate, LayerNorm, Linear, reshape_to_heads, standard_attention,
94 transpose_heads_to_2d,
95};
96
97#[derive(Debug, Clone)]
104pub struct ClipTextConfig {
105 pub hidden_size: usize,
107 pub intermediate_size: usize,
109 pub num_attention_heads: usize,
112 pub num_hidden_layers: usize,
114 pub max_position_embeddings: usize,
116 pub vocab_size: usize,
118 pub layer_norm_eps: f64,
120}
121
122impl Default for ClipTextConfig {
123 fn default() -> Self {
124 Self::sd_v1_5()
125 }
126}
127
128impl ClipTextConfig {
129 pub fn sd_v1_5() -> Self {
131 Self {
132 hidden_size: 768,
133 intermediate_size: 3072,
134 num_attention_heads: 12,
135 num_hidden_layers: 12,
136 max_position_embeddings: 77,
137 vocab_size: 49408,
138 layer_norm_eps: 1e-5,
139 }
140 }
141
142 #[inline]
144 #[must_use]
145 pub fn head_dim(&self) -> usize {
146 self.hidden_size / self.num_attention_heads
147 }
148
149 pub fn validate(&self) -> FerrotorchResult<()> {
156 if self.hidden_size == 0
157 || self.intermediate_size == 0
158 || self.num_attention_heads == 0
159 || self.num_hidden_layers == 0
160 || self.max_position_embeddings == 0
161 || self.vocab_size == 0
162 {
163 return Err(FerrotorchError::InvalidArgument {
164 message: "ClipTextConfig: all size fields must be > 0".into(),
165 });
166 }
167 if self.hidden_size % self.num_attention_heads != 0 {
168 return Err(FerrotorchError::InvalidArgument {
169 message: format!(
170 "ClipTextConfig: hidden_size {} not divisible by num_attention_heads {}",
171 self.hidden_size, self.num_attention_heads,
172 ),
173 });
174 }
175 if !self.layer_norm_eps.is_finite() || self.layer_norm_eps <= 0.0 {
176 return Err(FerrotorchError::InvalidArgument {
177 message: format!(
178 "ClipTextConfig: layer_norm_eps must be finite and > 0, got {}",
179 self.layer_norm_eps,
180 ),
181 });
182 }
183 Ok(())
184 }
185
186 pub fn from_json_str(s: &str) -> FerrotorchResult<Self> {
198 let v: serde_json::Value =
199 serde_json::from_str(s).map_err(|e| FerrotorchError::InvalidArgument {
200 message: format!("ClipTextConfig::from_json_str: bad JSON: {e}"),
201 })?;
202 let mut cfg = Self::default();
203 if let Some(x) = v.get("hidden_size").and_then(serde_json::Value::as_u64) {
204 cfg.hidden_size = x as usize;
205 }
206 if let Some(x) = v.get("intermediate_size").and_then(serde_json::Value::as_u64) {
207 cfg.intermediate_size = x as usize;
208 }
209 if let Some(x) = v.get("num_attention_heads").and_then(serde_json::Value::as_u64) {
210 cfg.num_attention_heads = x as usize;
211 }
212 if let Some(x) = v.get("num_hidden_layers").and_then(serde_json::Value::as_u64) {
213 cfg.num_hidden_layers = x as usize;
214 }
215 if let Some(x) = v
216 .get("max_position_embeddings")
217 .and_then(serde_json::Value::as_u64)
218 {
219 cfg.max_position_embeddings = x as usize;
220 }
221 if let Some(x) = v.get("vocab_size").and_then(serde_json::Value::as_u64) {
222 cfg.vocab_size = x as usize;
223 }
224 if let Some(x) = v.get("layer_norm_eps").and_then(serde_json::Value::as_f64) {
225 cfg.layer_norm_eps = x;
226 }
227 cfg.validate()?;
228 Ok(cfg)
229 }
230
231 pub fn from_file(path: &std::path::Path) -> FerrotorchResult<Self> {
238 let s = std::fs::read_to_string(path).map_err(|e| FerrotorchError::InvalidArgument {
239 message: format!(
240 "ClipTextConfig::from_file: failed to read {}: {e}",
241 path.display(),
242 ),
243 })?;
244 Self::from_json_str(&s)
245 }
246}
247
248fn reshape_owned<T: Float>(t: &Tensor<T>, shape: Vec<usize>) -> FerrotorchResult<Tensor<T>> {
254 let prod: usize = shape.iter().product();
255 if prod != t.numel() {
256 return Err(FerrotorchError::ShapeMismatch {
257 message: format!(
258 "ClipTextEncoder reshape: target {shape:?} (= {prod} elements) does not \
259 match source numel {}",
260 t.numel()
261 ),
262 });
263 }
264 let data = t.data_vec()?;
265 Tensor::from_storage(TensorStorage::cpu(data), shape, t.requires_grad())
266}
267
268fn float_index_tensor<T: Float>(ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
271 let data: Vec<T> = ids
272 .iter()
273 .map(|&i| numeric_cast::cast::<u32, T>(i))
274 .collect::<FerrotorchResult<Vec<T>>>()?;
275 let n = data.len();
276 Tensor::from_storage(TensorStorage::cpu(data), vec![n], false)
277}
278
279#[derive(Debug)]
289pub struct ClipTextEmbeddings<T: Float> {
290 pub token_embedding: Embedding<T>,
292 pub position_embedding: Embedding<T>,
294 hidden_size: usize,
295 max_position_embeddings: usize,
296 training: bool,
297}
298
299impl<T: Float> ClipTextEmbeddings<T> {
300 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
307 cfg.validate()?;
308 Ok(Self {
309 token_embedding: Embedding::new(cfg.vocab_size, cfg.hidden_size, None)?,
310 position_embedding: Embedding::new(cfg.max_position_embeddings, cfg.hidden_size, None)?,
311 hidden_size: cfg.hidden_size,
312 max_position_embeddings: cfg.max_position_embeddings,
313 training: false,
314 })
315 }
316
317 pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
328 if input_ids.is_empty() {
329 return Err(FerrotorchError::InvalidArgument {
330 message: "ClipTextEmbeddings::forward_from_ids needs at least one token".into(),
331 });
332 }
333 let seq_len = input_ids.len();
334 if seq_len > self.max_position_embeddings {
335 return Err(FerrotorchError::InvalidArgument {
336 message: format!(
337 "ClipTextEmbeddings: sequence length {seq_len} exceeds \
338 max_position_embeddings {}",
339 self.max_position_embeddings,
340 ),
341 });
342 }
343
344 let word_idx = float_index_tensor::<T>(input_ids)?;
345 let word_2d = self.token_embedding.forward(&word_idx)?; let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
348 let pos_idx = float_index_tensor::<T>(&pos_ids)?;
349 let pos_2d = self.position_embedding.forward(&pos_idx)?; let summed = add(&word_2d, &pos_2d)?;
352 reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
354 }
355}
356
357impl<T: Float> Module<T> for ClipTextEmbeddings<T> {
358 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
363 let word_2d = self.token_embedding.forward(input)?;
364 let seq_len = input.numel();
365 let pos_ids: Vec<u32> = (0..seq_len as u32).collect();
366 let pos_idx = float_index_tensor::<T>(&pos_ids)?;
367 let pos_2d = self.position_embedding.forward(&pos_idx)?;
368 let summed = add(&word_2d, &pos_2d)?;
369 reshape_owned(&summed, vec![1, seq_len, self.hidden_size])
370 }
371
372 fn parameters(&self) -> Vec<&Parameter<T>> {
373 let mut out = Vec::new();
374 out.extend(self.token_embedding.parameters());
375 out.extend(self.position_embedding.parameters());
376 out
377 }
378
379 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
380 let mut out = Vec::new();
381 out.extend(self.token_embedding.parameters_mut());
382 out.extend(self.position_embedding.parameters_mut());
383 out
384 }
385
386 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
387 let mut out = Vec::new();
388 for (n, p) in self.token_embedding.named_parameters() {
389 out.push((format!("token_embedding.{n}"), p));
390 }
391 for (n, p) in self.position_embedding.named_parameters() {
392 out.push((format!("position_embedding.{n}"), p));
393 }
394 out
395 }
396
397 fn train(&mut self) {
398 self.training = true;
399 }
400
401 fn eval(&mut self) {
402 self.training = false;
403 }
404
405 fn is_training(&self) -> bool {
406 self.training
407 }
408
409 fn state_dict(&self) -> StateDict<T> {
410 self.named_parameters()
411 .into_iter()
412 .map(|(n, p)| (n, p.tensor().clone()))
413 .collect()
414 }
415
416 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
417 let extract = |prefix: &str| -> StateDict<T> {
418 let p = format!("{prefix}.");
419 state
420 .iter()
421 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
422 .collect()
423 };
424 if strict {
425 let prefixes = ["token_embedding", "position_embedding"];
426 for k in state.keys() {
427 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
428 return Err(FerrotorchError::InvalidArgument {
429 message: format!("unexpected key in ClipTextEmbeddings state_dict: {k:?}"),
430 });
431 }
432 }
433 }
434 self.token_embedding
435 .load_state_dict(&extract("token_embedding"), strict)?;
436 self.position_embedding
437 .load_state_dict(&extract("position_embedding"), strict)?;
438 Ok(())
439 }
440}
441
442#[derive(Debug)]
461pub struct ClipSelfAttention<T: Float> {
462 pub q_proj: Linear<T>,
464 pub k_proj: Linear<T>,
466 pub v_proj: Linear<T>,
468 pub out_proj: Linear<T>,
470 num_heads: usize,
471 head_dim: usize,
472 hidden: usize,
473 training: bool,
474}
475
476impl<T: Float> ClipSelfAttention<T> {
477 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
483 cfg.validate()?;
484 Ok(Self {
485 q_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
486 k_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
487 v_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
488 out_proj: Linear::new(cfg.hidden_size, cfg.hidden_size, true)?,
489 num_heads: cfg.num_attention_heads,
490 head_dim: cfg.head_dim(),
491 hidden: cfg.hidden_size,
492 training: false,
493 })
494 }
495}
496
497impl<T: Float> Module<T> for ClipSelfAttention<T> {
498 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
504 let shape = input.shape();
505 if shape.len() != 3 || shape[0] != 1 || shape[2] != self.hidden {
506 return Err(FerrotorchError::ShapeMismatch {
507 message: format!(
508 "ClipSelfAttention expects [1, S, {}], got {:?}",
509 self.hidden, shape,
510 ),
511 });
512 }
513 let seq_len = shape[1];
514
515 let q = self.q_proj.forward(input)?;
517 let k = self.k_proj.forward(input)?;
518 let v = self.v_proj.forward(input)?;
519
520 let q2 = reshape_owned(&q, vec![seq_len, self.hidden])?;
523 let k2 = reshape_owned(&k, vec![seq_len, self.hidden])?;
524 let v2 = reshape_owned(&v, vec![seq_len, self.hidden])?;
525
526 let q_h = reshape_to_heads(&q2, self.num_heads, seq_len, self.head_dim)?;
528 let k_h = reshape_to_heads(&k2, self.num_heads, seq_len, self.head_dim)?;
529 let v_h = reshape_to_heads(&v2, self.num_heads, seq_len, self.head_dim)?;
530
531 let ctx = standard_attention(&q_h, &k_h, &v_h, true)?;
535
536 let ctx2 = transpose_heads_to_2d(&ctx, self.num_heads, seq_len, self.head_dim)?;
538 let ctx3 = reshape_owned(&ctx2, vec![1, seq_len, self.hidden])?;
539
540 self.out_proj.forward(&ctx3)
542 }
543
544 fn parameters(&self) -> Vec<&Parameter<T>> {
545 let mut out = Vec::new();
546 out.extend(self.q_proj.parameters());
547 out.extend(self.k_proj.parameters());
548 out.extend(self.v_proj.parameters());
549 out.extend(self.out_proj.parameters());
550 out
551 }
552
553 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
554 let mut out = Vec::new();
555 out.extend(self.q_proj.parameters_mut());
556 out.extend(self.k_proj.parameters_mut());
557 out.extend(self.v_proj.parameters_mut());
558 out.extend(self.out_proj.parameters_mut());
559 out
560 }
561
562 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
563 let mut out = Vec::new();
564 for (n, p) in self.q_proj.named_parameters() {
565 out.push((format!("q_proj.{n}"), p));
566 }
567 for (n, p) in self.k_proj.named_parameters() {
568 out.push((format!("k_proj.{n}"), p));
569 }
570 for (n, p) in self.v_proj.named_parameters() {
571 out.push((format!("v_proj.{n}"), p));
572 }
573 for (n, p) in self.out_proj.named_parameters() {
574 out.push((format!("out_proj.{n}"), p));
575 }
576 out
577 }
578
579 fn train(&mut self) {
580 self.training = true;
581 }
582
583 fn eval(&mut self) {
584 self.training = false;
585 }
586
587 fn is_training(&self) -> bool {
588 self.training
589 }
590
591 fn state_dict(&self) -> StateDict<T> {
592 self.named_parameters()
593 .into_iter()
594 .map(|(n, p)| (n, p.tensor().clone()))
595 .collect()
596 }
597
598 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
599 let extract = |prefix: &str| -> StateDict<T> {
600 let p = format!("{prefix}.");
601 state
602 .iter()
603 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
604 .collect()
605 };
606 if strict {
607 let prefixes = ["q_proj", "k_proj", "v_proj", "out_proj"];
608 for k in state.keys() {
609 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
610 return Err(FerrotorchError::InvalidArgument {
611 message: format!("unexpected key in ClipSelfAttention state_dict: {k:?}"),
612 });
613 }
614 }
615 }
616 self.q_proj.load_state_dict(&extract("q_proj"), strict)?;
617 self.k_proj.load_state_dict(&extract("k_proj"), strict)?;
618 self.v_proj.load_state_dict(&extract("v_proj"), strict)?;
619 self.out_proj
620 .load_state_dict(&extract("out_proj"), strict)?;
621 Ok(())
622 }
623}
624
625#[derive(Debug)]
636pub struct ClipMlp<T: Float> {
637 pub fc1: Linear<T>,
639 pub fc2: Linear<T>,
641 activation: GELU,
642 training: bool,
643}
644
645impl<T: Float> ClipMlp<T> {
646 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
652 cfg.validate()?;
653 Ok(Self {
654 fc1: Linear::new(cfg.hidden_size, cfg.intermediate_size, true)?,
655 fc2: Linear::new(cfg.intermediate_size, cfg.hidden_size, true)?,
656 activation: GELU::with_approximate(GeluApproximate::Sigmoid),
658 training: false,
659 })
660 }
661}
662
663impl<T: Float> Module<T> for ClipMlp<T> {
664 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
665 let h = self.fc1.forward(input)?;
666 let h = self.activation.forward(&h)?;
667 self.fc2.forward(&h)
668 }
669
670 fn parameters(&self) -> Vec<&Parameter<T>> {
671 let mut out = Vec::new();
672 out.extend(self.fc1.parameters());
673 out.extend(self.fc2.parameters());
674 out
675 }
676
677 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
678 let mut out = Vec::new();
679 out.extend(self.fc1.parameters_mut());
680 out.extend(self.fc2.parameters_mut());
681 out
682 }
683
684 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
685 let mut out = Vec::new();
686 for (n, p) in self.fc1.named_parameters() {
687 out.push((format!("fc1.{n}"), p));
688 }
689 for (n, p) in self.fc2.named_parameters() {
690 out.push((format!("fc2.{n}"), p));
691 }
692 out
693 }
694
695 fn train(&mut self) {
696 self.training = true;
697 }
698
699 fn eval(&mut self) {
700 self.training = false;
701 }
702
703 fn is_training(&self) -> bool {
704 self.training
705 }
706
707 fn state_dict(&self) -> StateDict<T> {
708 self.named_parameters()
709 .into_iter()
710 .map(|(n, p)| (n, p.tensor().clone()))
711 .collect()
712 }
713
714 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
715 let extract = |prefix: &str| -> StateDict<T> {
716 let p = format!("{prefix}.");
717 state
718 .iter()
719 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
720 .collect()
721 };
722 if strict {
723 for k in state.keys() {
724 if !(k.starts_with("fc1.") || k.starts_with("fc2.")) {
725 return Err(FerrotorchError::InvalidArgument {
726 message: format!("unexpected key in ClipMlp state_dict: {k:?}"),
727 });
728 }
729 }
730 }
731 self.fc1.load_state_dict(&extract("fc1"), strict)?;
732 self.fc2.load_state_dict(&extract("fc2"), strict)?;
733 Ok(())
734 }
735}
736
737#[derive(Debug)]
750pub struct ClipEncoderLayer<T: Float> {
751 pub layer_norm1: LayerNorm<T>,
753 pub self_attn: ClipSelfAttention<T>,
755 pub layer_norm2: LayerNorm<T>,
757 pub mlp: ClipMlp<T>,
759 training: bool,
760}
761
762impl<T: Float> ClipEncoderLayer<T> {
763 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
769 Ok(Self {
770 layer_norm1: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
771 self_attn: ClipSelfAttention::new(cfg)?,
772 layer_norm2: LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?,
773 mlp: ClipMlp::new(cfg)?,
774 training: false,
775 })
776 }
777}
778
779impl<T: Float> Module<T> for ClipEncoderLayer<T> {
780 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
782 let normed = self.layer_norm1.forward(input)?;
784 let attn_out = self.self_attn.forward(&normed)?;
785 let after_attn = add(input, &attn_out)?;
786
787 let normed_ffn = self.layer_norm2.forward(&after_attn)?;
789 let mlp_out = self.mlp.forward(&normed_ffn)?;
790 add(&after_attn, &mlp_out)
791 }
792
793 fn parameters(&self) -> Vec<&Parameter<T>> {
794 let mut out = Vec::new();
795 out.extend(self.layer_norm1.parameters());
796 out.extend(self.self_attn.parameters());
797 out.extend(self.layer_norm2.parameters());
798 out.extend(self.mlp.parameters());
799 out
800 }
801
802 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
803 let mut out = Vec::new();
804 out.extend(self.layer_norm1.parameters_mut());
805 out.extend(self.self_attn.parameters_mut());
806 out.extend(self.layer_norm2.parameters_mut());
807 out.extend(self.mlp.parameters_mut());
808 out
809 }
810
811 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
812 let mut out = Vec::new();
813 for (n, p) in self.layer_norm1.named_parameters() {
814 out.push((format!("layer_norm1.{n}"), p));
815 }
816 for (n, p) in self.self_attn.named_parameters() {
817 out.push((format!("self_attn.{n}"), p));
818 }
819 for (n, p) in self.layer_norm2.named_parameters() {
820 out.push((format!("layer_norm2.{n}"), p));
821 }
822 for (n, p) in self.mlp.named_parameters() {
823 out.push((format!("mlp.{n}"), p));
824 }
825 out
826 }
827
828 fn train(&mut self) {
829 self.training = true;
830 }
831
832 fn eval(&mut self) {
833 self.training = false;
834 }
835
836 fn is_training(&self) -> bool {
837 self.training
838 }
839
840 fn state_dict(&self) -> StateDict<T> {
841 self.named_parameters()
842 .into_iter()
843 .map(|(n, p)| (n, p.tensor().clone()))
844 .collect()
845 }
846
847 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
848 let extract = |prefix: &str| -> StateDict<T> {
849 let p = format!("{prefix}.");
850 state
851 .iter()
852 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
853 .collect()
854 };
855 if strict {
856 let prefixes = ["layer_norm1", "self_attn", "layer_norm2", "mlp"];
857 for k in state.keys() {
858 if !prefixes.iter().any(|p| k.starts_with(&format!("{p}."))) {
859 return Err(FerrotorchError::InvalidArgument {
860 message: format!("unexpected key in ClipEncoderLayer state_dict: {k:?}"),
861 });
862 }
863 }
864 }
865 self.layer_norm1
866 .load_state_dict(&extract("layer_norm1"), strict)?;
867 self.self_attn
868 .load_state_dict(&extract("self_attn"), strict)?;
869 self.layer_norm2
870 .load_state_dict(&extract("layer_norm2"), strict)?;
871 self.mlp.load_state_dict(&extract("mlp"), strict)?;
872 Ok(())
873 }
874}
875
876#[derive(Debug)]
882pub struct ClipEncoder<T: Float> {
883 pub layers: Vec<ClipEncoderLayer<T>>,
885 training: bool,
886}
887
888impl<T: Float> ClipEncoder<T> {
889 pub fn new(cfg: &ClipTextConfig) -> FerrotorchResult<Self> {
895 cfg.validate()?;
896 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
897 for _ in 0..cfg.num_hidden_layers {
898 layers.push(ClipEncoderLayer::new(cfg)?);
899 }
900 Ok(Self {
901 layers,
902 training: false,
903 })
904 }
905}
906
907impl<T: Float> Module<T> for ClipEncoder<T> {
908 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
909 let mut h = input.clone();
910 for l in &self.layers {
911 h = l.forward(&h)?;
912 }
913 Ok(h)
914 }
915
916 fn parameters(&self) -> Vec<&Parameter<T>> {
917 let mut out = Vec::new();
918 for l in &self.layers {
919 out.extend(l.parameters());
920 }
921 out
922 }
923
924 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
925 let mut out = Vec::new();
926 for l in &mut self.layers {
927 out.extend(l.parameters_mut());
928 }
929 out
930 }
931
932 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
933 let mut out = Vec::new();
934 for (i, l) in self.layers.iter().enumerate() {
935 for (n, p) in l.named_parameters() {
936 out.push((format!("layers.{i}.{n}"), p));
937 }
938 }
939 out
940 }
941
942 fn train(&mut self) {
943 self.training = true;
944 for l in &mut self.layers {
945 l.train();
946 }
947 }
948
949 fn eval(&mut self) {
950 self.training = false;
951 for l in &mut self.layers {
952 l.eval();
953 }
954 }
955
956 fn is_training(&self) -> bool {
957 self.training
958 }
959
960 fn state_dict(&self) -> StateDict<T> {
961 self.named_parameters()
962 .into_iter()
963 .map(|(n, p)| (n, p.tensor().clone()))
964 .collect()
965 }
966
967 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
968 let extract = |prefix: &str| -> StateDict<T> {
969 let p = format!("{prefix}.");
970 state
971 .iter()
972 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
973 .collect()
974 };
975 if strict {
976 for k in state.keys() {
977 if !k.starts_with("layers.") {
978 return Err(FerrotorchError::InvalidArgument {
979 message: format!("unexpected key in ClipEncoder state_dict: {k:?}"),
980 });
981 }
982 }
983 }
984 for (i, l) in self.layers.iter_mut().enumerate() {
985 l.load_state_dict(&extract(&format!("layers.{i}")), strict)?;
986 }
987 Ok(())
988 }
989}
990
991#[derive(Debug)]
1009pub struct ClipTextEncoder<T: Float> {
1010 pub embeddings: ClipTextEmbeddings<T>,
1012 pub encoder: ClipEncoder<T>,
1014 pub final_layer_norm: LayerNorm<T>,
1016 pub config: ClipTextConfig,
1018 training: bool,
1019}
1020
1021impl<T: Float> ClipTextEncoder<T> {
1022 pub fn new(cfg: ClipTextConfig) -> FerrotorchResult<Self> {
1029 cfg.validate()?;
1030 let embeddings = ClipTextEmbeddings::new(&cfg)?;
1031 let encoder = ClipEncoder::new(&cfg)?;
1032 let final_layer_norm = LayerNorm::new(vec![cfg.hidden_size], cfg.layer_norm_eps, true)?;
1033 Ok(Self {
1034 embeddings,
1035 encoder,
1036 final_layer_norm,
1037 config: cfg,
1038 training: false,
1039 })
1040 }
1041
1042 pub fn forward_from_ids(&self, input_ids: &[u32]) -> FerrotorchResult<Tensor<T>> {
1055 let h = self.embeddings.forward_from_ids(input_ids)?;
1056 let h = self.encoder.forward(&h)?;
1057 self.final_layer_norm.forward(&h)
1058 }
1059
1060 pub fn forward_from_id_tensor(&self, ids: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1073 if ids.ndim() != 1 {
1076 return Err(FerrotorchError::ShapeMismatch {
1077 message: format!(
1078 "ClipTextEncoder::forward_from_id_tensor expects 1-D ids, got {:?}",
1079 ids.shape()
1080 ),
1081 });
1082 }
1083 let data = ids.data_vec()?;
1084 let mut u32_ids: Vec<u32> = Vec::with_capacity(data.len());
1085 for (i, v) in data.iter().enumerate() {
1086 let f = num_traits::ToPrimitive::to_f64(v).ok_or_else(|| {
1087 FerrotorchError::InvalidArgument {
1088 message: format!(
1089 "ClipTextEncoder::forward_from_id_tensor: id at {i} \
1090 not representable as f64"
1091 ),
1092 }
1093 })?;
1094 if !f.is_finite() || f < 0.0 || f > u32::MAX as f64 || f.fract() != 0.0 {
1095 return Err(FerrotorchError::InvalidArgument {
1096 message: format!(
1097 "ClipTextEncoder::forward_from_id_tensor: id at {i} ({f}) \
1098 is not a non-negative integer"
1099 ),
1100 });
1101 }
1102 u32_ids.push(f as u32);
1103 }
1104 self.forward_from_ids(&u32_ids)
1105 }
1106
1107 pub fn load_hf_state_dict(
1130 &mut self,
1131 hf_state: &StateDict<T>,
1132 strict: bool,
1133 ) -> FerrotorchResult<crate::safetensors_loader::DropReport> {
1134 let mut remapped: StateDict<T> = HashMap::with_capacity(hf_state.len());
1135 let mut dropped: Vec<String> = Vec::new();
1136 for (k, v) in hf_state {
1137 let after = k.strip_prefix("text_model.").map_or_else(|| k.clone(), str::to_owned);
1139
1140 if after == "embeddings.position_ids" {
1143 dropped.push(k.clone());
1144 continue;
1145 }
1146
1147 let is_known = after.starts_with("embeddings.token_embedding.")
1148 || after.starts_with("embeddings.position_embedding.")
1149 || after.starts_with("encoder.")
1150 || after.starts_with("final_layer_norm.");
1151 if is_known {
1152 remapped.insert(after, v.clone());
1153 continue;
1154 }
1155
1156 if strict {
1157 return Err(FerrotorchError::InvalidArgument {
1158 message: format!(
1159 "ClipTextEncoder::load_hf_state_dict: key {k:?} is not a \
1160 known CLIP text-tower parameter and strict mode is on. \
1161 Pass strict=false to drop unknown keys."
1162 ),
1163 });
1164 }
1165 dropped.push(k.clone());
1166 }
1167 dropped.sort();
1168 self.load_state_dict(&remapped, strict)?;
1169 Ok(crate::safetensors_loader::DropReport { dropped })
1170 }
1171}
1172
1173impl<T: Float> Module<T> for ClipTextEncoder<T> {
1174 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1179 let h = self.encoder.forward(input)?;
1180 self.final_layer_norm.forward(&h)
1181 }
1182
1183 fn parameters(&self) -> Vec<&Parameter<T>> {
1184 let mut out = Vec::new();
1185 out.extend(self.embeddings.parameters());
1186 out.extend(self.encoder.parameters());
1187 out.extend(self.final_layer_norm.parameters());
1188 out
1189 }
1190
1191 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1192 let mut out = Vec::new();
1193 out.extend(self.embeddings.parameters_mut());
1194 out.extend(self.encoder.parameters_mut());
1195 out.extend(self.final_layer_norm.parameters_mut());
1196 out
1197 }
1198
1199 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1200 let mut out = Vec::new();
1201 for (n, p) in self.embeddings.named_parameters() {
1202 out.push((format!("embeddings.{n}"), p));
1203 }
1204 for (n, p) in self.encoder.named_parameters() {
1205 out.push((format!("encoder.{n}"), p));
1206 }
1207 for (n, p) in self.final_layer_norm.named_parameters() {
1208 out.push((format!("final_layer_norm.{n}"), p));
1209 }
1210 out
1211 }
1212
1213 fn train(&mut self) {
1214 self.training = true;
1215 self.embeddings.train();
1216 self.encoder.train();
1217 self.final_layer_norm.train();
1218 }
1219
1220 fn eval(&mut self) {
1221 self.training = false;
1222 self.embeddings.eval();
1223 self.encoder.eval();
1224 self.final_layer_norm.eval();
1225 }
1226
1227 fn is_training(&self) -> bool {
1228 self.training
1229 }
1230
1231 fn state_dict(&self) -> StateDict<T> {
1232 self.named_parameters()
1233 .into_iter()
1234 .map(|(n, p)| (n, p.tensor().clone()))
1235 .collect()
1236 }
1237
1238 fn load_state_dict(&mut self, state: &StateDict<T>, strict: bool) -> FerrotorchResult<()> {
1239 let extract = |prefix: &str| -> StateDict<T> {
1240 let p = format!("{prefix}.");
1241 state
1242 .iter()
1243 .filter_map(|(k, v)| k.strip_prefix(&p).map(|r| (r.to_string(), v.clone())))
1244 .collect()
1245 };
1246 if strict {
1247 for k in state.keys() {
1248 if !(k.starts_with("embeddings.")
1249 || k.starts_with("encoder.")
1250 || k.starts_with("final_layer_norm."))
1251 {
1252 return Err(FerrotorchError::InvalidArgument {
1253 message: format!("unexpected key in ClipTextEncoder state_dict: {k:?}"),
1254 });
1255 }
1256 }
1257 }
1258 self.embeddings
1259 .load_state_dict(&extract("embeddings"), strict)?;
1260 self.encoder
1261 .load_state_dict(&extract("encoder"), strict)?;
1262 self.final_layer_norm
1263 .load_state_dict(&extract("final_layer_norm"), strict)?;
1264 Ok(())
1265 }
1266}
1267
1268#[allow(dead_code)]
1272fn _unused_mul_ref<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1273 mul(a, b)
1274}
1275
1276#[cfg(test)]
1281mod tests {
1282 use super::*;
1283
1284 fn tiny_cfg() -> ClipTextConfig {
1285 ClipTextConfig {
1288 hidden_size: 8,
1289 intermediate_size: 16,
1290 num_attention_heads: 2,
1291 num_hidden_layers: 1,
1292 max_position_embeddings: 6,
1293 vocab_size: 32,
1294 layer_norm_eps: 1e-5,
1295 }
1296 }
1297
1298 #[test]
1299 fn sd_v1_5_config_is_canonical() {
1300 let c = ClipTextConfig::sd_v1_5();
1301 assert_eq!(c.hidden_size, 768);
1302 assert_eq!(c.intermediate_size, 3072);
1303 assert_eq!(c.num_attention_heads, 12);
1304 assert_eq!(c.num_hidden_layers, 12);
1305 assert_eq!(c.max_position_embeddings, 77);
1306 assert_eq!(c.vocab_size, 49408);
1307 assert_eq!(c.head_dim(), 64);
1308 c.validate().unwrap();
1309 }
1310
1311 #[test]
1312 fn validate_catches_bad_head_count() {
1313 let mut c = tiny_cfg();
1314 c.num_attention_heads = 3; assert!(c.validate().is_err());
1316 }
1317
1318 #[test]
1319 fn from_json_str_round_trip() {
1320 let json = r#"{
1321 "hidden_size": 768,
1322 "intermediate_size": 3072,
1323 "num_attention_heads": 12,
1324 "num_hidden_layers": 12,
1325 "max_position_embeddings": 77,
1326 "vocab_size": 49408,
1327 "layer_norm_eps": 1e-5,
1328 "hidden_act": "quick_gelu"
1329 }"#;
1330 let c = ClipTextConfig::from_json_str(json).unwrap();
1331 assert_eq!(c.hidden_size, 768);
1332 assert_eq!(c.intermediate_size, 3072);
1333 assert_eq!(c.num_attention_heads, 12);
1334 assert_eq!(c.num_hidden_layers, 12);
1335 assert_eq!(c.max_position_embeddings, 77);
1336 }
1337
1338 #[test]
1339 fn embeddings_forward_shape() {
1340 let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1341 let ids = [1u32, 5, 7, 9];
1342 let out = emb.forward_from_ids(&ids).unwrap();
1343 assert_eq!(out.shape(), &[1, 4, 8]);
1344 for &v in out.data().unwrap() {
1345 assert!(v.is_finite(), "embedding non-finite: {v}");
1346 }
1347 }
1348
1349 #[test]
1350 fn embeddings_reject_too_long_sequence() {
1351 let emb = ClipTextEmbeddings::<f32>::new(&tiny_cfg()).unwrap();
1352 let ids: Vec<u32> = (0..7).collect(); assert!(emb.forward_from_ids(&ids).is_err());
1354 }
1355
1356 #[test]
1357 fn self_attention_forward_shape() {
1358 let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1359 let x = Tensor::from_storage(
1360 TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1361 vec![1, 5, 8],
1362 false,
1363 )
1364 .unwrap();
1365 let out = attn.forward(&x).unwrap();
1366 assert_eq!(out.shape(), &[1, 5, 8]);
1367 for &v in out.data().unwrap() {
1368 assert!(v.is_finite());
1369 }
1370 }
1371
1372 #[test]
1373 fn self_attention_is_actually_causal() {
1374 let attn = ClipSelfAttention::<f32>::new(&tiny_cfg()).unwrap();
1380 let mut a = vec![0.1f32; 4 * 8];
1381 for i in 0..2 * 8 {
1382 a[i] = ((i + 1) as f32).sin();
1383 }
1384 let mut b = a.clone();
1385 for i in (2 * 8)..(4 * 8) {
1387 b[i] = ((i + 11) as f32).sin();
1388 }
1389 let xa = Tensor::from_storage(TensorStorage::cpu(a), vec![1, 4, 8], false).unwrap();
1390 let xb = Tensor::from_storage(TensorStorage::cpu(b), vec![1, 4, 8], false).unwrap();
1391 let oa = attn.forward(&xa).unwrap();
1392 let ob = attn.forward(&xb).unwrap();
1393 let da = oa.data().unwrap();
1394 let db = ob.data().unwrap();
1395 for i in 0..2 * 8 {
1396 assert!(
1397 (da[i] - db[i]).abs() < 1e-5,
1398 "row {} ({}) differs between runs: {} vs {}",
1399 i / 8,
1400 i % 8,
1401 da[i],
1402 db[i]
1403 );
1404 }
1405 }
1406
1407 #[test]
1408 fn mlp_uses_quick_gelu() {
1409 let mlp = ClipMlp::<f32>::new(&tiny_cfg()).unwrap();
1414 let x = Tensor::from_storage(
1415 TensorStorage::cpu(vec![0.0f32; 3 * 8]),
1416 vec![1, 3, 8],
1417 false,
1418 )
1419 .unwrap();
1420 let out = mlp.forward(&x).unwrap();
1421 assert_eq!(out.shape(), &[1, 3, 8]);
1422 for &v in out.data().unwrap() {
1423 assert!(v.is_finite());
1424 }
1425 }
1426
1427 #[test]
1428 fn encoder_layer_forward_shape() {
1429 let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1430 let x = Tensor::from_storage(
1431 TensorStorage::cpu(vec![0.1f32; 5 * 8]),
1432 vec![1, 5, 8],
1433 false,
1434 )
1435 .unwrap();
1436 let out = layer.forward(&x).unwrap();
1437 assert_eq!(out.shape(), &[1, 5, 8]);
1438 for &v in out.data().unwrap() {
1439 assert!(v.is_finite());
1440 }
1441 }
1442
1443 #[test]
1444 fn encoder_layer_named_parameters_use_hf_layout() {
1445 let layer = ClipEncoderLayer::<f32>::new(&tiny_cfg()).unwrap();
1446 let names: Vec<String> = layer.named_parameters().into_iter().map(|(n, _)| n).collect();
1447 for k in [
1448 "layer_norm1.weight",
1449 "layer_norm1.bias",
1450 "self_attn.q_proj.weight",
1451 "self_attn.q_proj.bias",
1452 "self_attn.k_proj.weight",
1453 "self_attn.v_proj.weight",
1454 "self_attn.out_proj.weight",
1455 "self_attn.out_proj.bias",
1456 "layer_norm2.weight",
1457 "mlp.fc1.weight",
1458 "mlp.fc1.bias",
1459 "mlp.fc2.weight",
1460 "mlp.fc2.bias",
1461 ] {
1462 assert!(
1463 names.iter().any(|n| n == k),
1464 "missing parameter key {k:?} in {names:?}"
1465 );
1466 }
1467 }
1468
1469 #[test]
1470 fn tiny_encoder_forward_from_ids_shape() {
1471 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1472 let ids = vec![1u32, 5, 7];
1473 let out = enc.forward_from_ids(&ids).unwrap();
1474 assert_eq!(out.shape(), &[1, 3, 8]);
1475 for &v in out.data().unwrap() {
1476 assert!(v.is_finite());
1477 }
1478 }
1479
1480 #[test]
1481 fn tiny_named_parameters_use_hf_layout() {
1482 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1483 let names: Vec<String> = enc.named_parameters().into_iter().map(|(n, _)| n).collect();
1484 for k in [
1485 "embeddings.token_embedding.weight",
1486 "embeddings.position_embedding.weight",
1487 "encoder.layers.0.layer_norm1.weight",
1488 "encoder.layers.0.self_attn.q_proj.weight",
1489 "encoder.layers.0.self_attn.out_proj.bias",
1490 "encoder.layers.0.layer_norm2.bias",
1491 "encoder.layers.0.mlp.fc1.weight",
1492 "encoder.layers.0.mlp.fc2.bias",
1493 "final_layer_norm.weight",
1494 "final_layer_norm.bias",
1495 ] {
1496 assert!(
1497 names.iter().any(|n| n == k),
1498 "missing parameter key {k:?} in {names:?}"
1499 );
1500 }
1501 }
1502
1503 #[test]
1504 fn round_trip_state_dict() {
1505 let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1506 let sd = src.state_dict();
1507 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1508 dst.load_state_dict(&sd, true).unwrap();
1509 let ids = vec![2u32, 4, 6];
1510 let a = src.forward_from_ids(&ids).unwrap();
1511 let b = dst.forward_from_ids(&ids).unwrap();
1512 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1513 assert!((x - y).abs() < 1e-5, "round-trip differs: {x} vs {y}");
1514 }
1515 }
1516
1517 #[test]
1518 fn load_hf_state_dict_strips_text_model_prefix() {
1519 let src = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1520 let bare = src.state_dict();
1521 let mut prefixed: StateDict<f32> = HashMap::new();
1522 for (k, v) in bare {
1523 prefixed.insert(format!("text_model.{k}"), v);
1524 }
1525 prefixed.insert(
1527 "text_model.embeddings.position_ids".into(),
1528 ferrotorch_core::zeros::<f32>(&[1, 6]).unwrap(),
1529 );
1530 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1531 let rep = dst.load_hf_state_dict(&prefixed, false).unwrap();
1532 assert_eq!(rep.dropped, vec!["text_model.embeddings.position_ids".to_string()]);
1533 let ids = vec![1u32, 2, 3];
1534 let a = src.forward_from_ids(&ids).unwrap();
1535 let b = dst.forward_from_ids(&ids).unwrap();
1536 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1537 assert!((x - y).abs() < 1e-5);
1538 }
1539 }
1540
1541 #[test]
1542 fn load_hf_state_dict_strict_rejects_unknown_key() {
1543 let mut dst = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1544 let mut sd: StateDict<f32> = HashMap::new();
1545 sd.insert(
1546 "mystery.key".into(),
1547 ferrotorch_core::zeros::<f32>(&[1]).unwrap(),
1548 );
1549 assert!(dst.load_hf_state_dict(&sd, true).is_err());
1550 }
1551
1552 #[test]
1553 fn forward_from_id_tensor_matches_forward_from_ids() {
1554 let enc = ClipTextEncoder::<f32>::new(tiny_cfg()).unwrap();
1555 let ids = vec![1u32, 5, 7];
1556 let id_tensor = float_index_tensor::<f32>(&ids).unwrap();
1557 let a = enc.forward_from_ids(&ids).unwrap();
1558 let b = enc.forward_from_id_tensor(&id_tensor).unwrap();
1559 for (x, y) in a.data().unwrap().iter().zip(b.data().unwrap().iter()) {
1560 assert!((x - y).abs() < 1e-5);
1561 }
1562 }
1563}