candle_crf/
lib.rs

1use candle_core::{shape::Dim, DType, Device, Error, IndexOp, Result, Tensor, D};
2use candle_nn::{Init, VarBuilder};
3use std::fmt::Display;
4
5/// Reduction Type
6#[derive(Debug)]
7pub enum Reduction {
8    None,
9    Sum,
10    Mean,
11    TokenMean,
12}
13
14impl Default for Reduction {
15    fn default() -> Self {
16        Reduction::Sum
17    }
18}
19
20pub fn crf(num_tags: usize, batch_first: bool, vb: VarBuilder) -> Result<CRF> {
21    let start_transitions = vb.get_with_hints(
22        num_tags,
23        "start_transitions",
24        Init::Uniform {
25            lo: -0.1_f64,
26            up: 1.0_f64,
27        },
28    )?;
29
30    let end_transitions = vb.get_with_hints(
31        num_tags,
32        "end_transitions",
33        Init::Uniform {
34            lo: -0.1_f64,
35            up: 1.0_f64,
36        },
37    )?;
38
39    let transitions = vb.get_with_hints(
40        (num_tags, num_tags),
41        "transitions",
42        Init::Uniform {
43            lo: -0.1_f64,
44            up: 1.0_f64,
45        },
46    )?;
47
48    Ok(CRF {
49        num_tags,
50        batch_first,
51        start_transitions,
52        end_transitions,
53        transitions,
54    })
55}
56
57/// CRF
58/// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L9
59pub struct CRF {
60    pub(crate) num_tags: usize,
61    pub(crate) batch_first: bool,
62
63    pub(crate) start_transitions: Tensor,
64    pub(crate) end_transitions: Tensor,
65    pub(crate) transitions: Tensor,
66}
67
68impl Display for CRF {
69    /// Display
70    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L60
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(
73            f,
74            "CRF(num_tags: {}, batch_first: {})",
75            self.num_tags, self.batch_first
76        )
77    }
78}
79
80impl CRF {
81    /// Create a new CRF
82    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L38
83    pub fn new(num_tags: usize, batch_first: bool, device: &Device) -> Result<Self> {
84        Self::new_with_dtype(num_tags, batch_first, DType::F32, device)
85    }
86
87    pub fn new_with_dtype(
88        num_tags: usize,
89        batch_first: bool,
90        dtype: DType,
91        device: &Device,
92    ) -> Result<Self> {
93        {
94            use DType::*;
95            match dtype {
96                #[cfg(feature = "metal")]
97                F32 => {}
98                #[cfg(feature = "cuda")]
99                F32 | F64 => {}
100                #[cfg(not(any(feature = "cuda", feature = "metal")))]
101                BF16 | F16 | F32 | F64 => {}
102                _ => return Err(Error::UnsupportedDTypeForOp(dtype, "unsupported dtype")),
103            }
104        }
105
106        if num_tags == 0 {
107            return Err(Error::Msg("num_tags must be greater than 0".to_string()));
108        }
109
110        let start_transitions = Tensor::rand(-0.1_f32, 1.0, num_tags, device)?.to_dtype(dtype)?;
111        let end_transitions = Tensor::rand(-0.1_f32, 1.0, num_tags, device)?.to_dtype(dtype)?;
112        let transitions =
113            Tensor::rand(-0.1_f32, 1.0, (num_tags, num_tags), device)?.to_dtype(dtype)?;
114
115        Ok(Self {
116            num_tags,
117            batch_first,
118            start_transitions,
119            end_transitions,
120            transitions,
121        })
122    }
123
124    pub fn load(num_tags: usize, batch_first: bool, vb: VarBuilder) -> Result<Self> {
125        crf(num_tags, batch_first, vb)
126    }
127
128    /// validate
129    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L142
130    fn validate(
131        &self,
132        emissions: &Tensor,
133        tags: Option<&Tensor>,
134        mask: Option<&Tensor>,
135    ) -> Result<()> {
136        {
137            let dtype_transitions = self.transitions.dtype();
138            let dtype_emissions = emissions.dtype();
139            if dtype_transitions != dtype_emissions {
140                return Err(Error::Msg(format!(
141                    "emissions and CRF must have the same dtype, expected {:?}, got {:?}",
142                    dtype_transitions, dtype_emissions
143                )));
144            }
145        }
146
147        {
148            // check if the tensor has 3 dimensions
149            let dims = emissions.dims().len();
150            if dims != 3 {
151                return Err(Error::Msg(format!(
152                    "emissions must have 3 dimensions, got {}",
153                    dims
154                )));
155            }
156        }
157
158        let (d1, d2, d3) = emissions.dims3()?;
159
160        if d3 != self.num_tags {
161            // check if the last dimension of the tensor is equal to the number of tags
162            return Err(Error::Msg(format!(
163                "expected last dimension of emissions is {}, got {}",
164                self.num_tags, d3
165            )));
166        }
167
168        if let Some(tags) = tags {
169            #[cfg(feature = "metal")]
170            if tags.dtype() != DType::U32 {
171                return Err(Error::Msg("tags must be of type u32".to_string()));
172            }
173
174            #[cfg(not(feature = "metal"))]
175            if tags.dtype() != DType::I64 {
176                return Err(Error::Msg("tags must be of type i64".to_string()));
177            }
178
179            if tags.dims().len() != 2 {
180                // check if the tensor has 2 dimensions
181                return Err(Error::Msg(format!(
182                    "tags must have 2 dimensions, got {}",
183                    tags.dims().len()
184                )));
185            }
186
187            let (tag_d1, tag_d2) = tags.dims2()?;
188            if (d1, d2) != (tag_d1, tag_d2) {
189                return Err(Error::Msg(format!(
190                    "the first two dimensions of emissions and tags must match, got ({}, {}) and ({}, {})",
191                    d1, d2, d1, d2
192                )));
193            }
194        }
195
196        if let Some(mask) = mask {
197            if mask.dtype() != DType::U8 {
198                return Err(Error::Msg("mask must be of type u8".to_string()));
199            }
200
201            if mask.dims().len() != 2 {
202                // check if the tensor has 2 dimensions
203                return Err(Error::Msg(format!(
204                    "mask must have 2 dimensions, got {}",
205                    mask.dims().len()
206                )));
207            }
208
209            let (mask_d1, mask_d2) = mask.dims2()?;
210            if (d1, d2) != (mask_d1, mask_d2) {
211                return Err(Error::Msg(format!(
212                    "the first two dimensions of emissions and mask must match, got ({}, {}) and ({}, {})",
213                    d1, d2, mask_d1, mask_d2
214                )));
215            }
216
217            let no_empty_seq = !self.batch_first && all(&mask.i(0)?)?;
218            let no_empty_seq_bf = self.batch_first && all(&mask.i((.., 0))?)?;
219
220            if !no_empty_seq && !no_empty_seq_bf {
221                return Err(Error::Msg(
222                    "mask of the first timestep must all be on".to_string(),
223                ));
224            }
225        }
226
227        Ok(())
228    }
229
230    /// compute_score
231    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L172
232    fn compute_score(&self, emissions: &Tensor, tags: &Tensor, mask: &Tensor) -> Result<Tensor> {
233        let (d1, d2, d3) = emissions.dims3()?;
234        let (seq_length, batch_size) = tags.dims2()?;
235        assert_eq!(d1, seq_length);
236        assert_eq!(d2, batch_size);
237        assert_eq!(d3, self.num_tags);
238        assert_eq!(mask.shape(), tags.shape());
239        assert!(all(&mask.i(0)?)?);
240
241        let mask = mask.to_dtype(emissions.dtype())?;
242
243        let mut score = self.start_transitions.i(&tags.i(0)?)?;
244
245        let z = gather(&emissions.i((0, 0..batch_size))?, &tags.i(0)?)?;
246
247        score = score.broadcast_add(&z)?;
248
249        for i in 1..seq_length {
250            let z = gather(&self.transitions.i(&tags.i(i - 1)?)?, &tags.i(i)?)?;
251            score = score.broadcast_add(&z.broadcast_mul(&mask.i(i)?)?)?;
252
253            let z = gather(&emissions.i((i, 0..batch_size))?, &tags.i(i)?)?;
254            score = score.broadcast_add(&z.broadcast_mul(&mask.i(i)?)?)?;
255        }
256
257        let seq_ends = mask
258            .to_dtype(DType::I64)?
259            .sum(0)?
260            .broadcast_sub(&Tensor::ones(1, DType::I64, mask.device())?)?;
261
262        #[cfg(feature = "metal")]
263        let last_tags = {
264            let tags2 = tags.to_dtype(DType::F32)?;
265            gather(
266                &tags2.i(&seq_ends)?,
267                &Tensor::arange(0, batch_size as u32, mask.device())?,
268            )?
269        };
270
271        #[cfg(not(feature = "metal"))]
272        let last_tags = gather(
273            &tags.i(&seq_ends)?,
274            &Tensor::arange(0, batch_size as i64, mask.device())?,
275        )?;
276
277        #[cfg(feature = "metal")]
278        let last_tags = last_tags.to_dtype(DType::U32)?;
279
280        score.broadcast_add(&self.end_transitions.i(&last_tags)?)
281    }
282
283    /// compute_normalizer
284    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L211
285    fn compute_normalizer(&self, emissions: &Tensor, mask: &Tensor) -> Result<Tensor> {
286        let (d1, d2, d3) = emissions.dims3()?;
287        let (seq_length, batch_size) = mask.dims2()?;
288        assert_eq!(d1, seq_length);
289        assert_eq!(d2, batch_size);
290        assert_eq!(d3, self.num_tags);
291        assert!(all(&mask.i(0)?)?);
292
293        let mut score = self.start_transitions.broadcast_add(&emissions.i(0)?)?;
294
295        for i in 1..seq_length {
296            let broadcast_score = score.unsqueeze(2)?;
297
298            let broadcast_emissions = emissions.i(i)?.unsqueeze(1)?;
299            let next_score = broadcast_score
300                .broadcast_add(&self.transitions)?
301                .broadcast_add(&broadcast_emissions)?;
302
303            let next_score = next_score.log_sum_exp(1)?;
304            let z = mask.i(i)?.unsqueeze(1)?.broadcast_as(next_score.shape())?;
305            score = z.where_cond(&next_score, &score)?;
306        }
307
308        score = score.broadcast_add(&self.end_transitions)?;
309        score.log_sum_exp(1)
310    }
311
312    /// viterbi_decode
313    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L262
314    fn viterbi_decode(&self, emissions: &Tensor, mask: &Tensor) -> Result<Vec<Vec<u32>>> {
315        let (d1, d2, d3) = emissions.dims3()?;
316        let (seq_length, batch_size) = mask.dims2()?;
317        assert_eq!(d1, seq_length);
318        assert_eq!(d2, batch_size);
319        assert_eq!(d3, self.num_tags);
320        assert!(all(&mask.i(0)?)?);
321
322        let mut score = self.start_transitions.broadcast_add(&emissions.i(0)?)?;
323
324        let mut history = Vec::with_capacity(seq_length);
325        for i in 1..seq_length {
326            let broadcast_sore = score.unsqueeze(2)?;
327
328            let broadcast_emission = emissions.i(i)?.unsqueeze(1)?;
329
330            let next_score = broadcast_sore
331                .broadcast_add(&self.transitions)?
332                .broadcast_add(&broadcast_emission)?;
333
334            let (next_score, indices) = max_indices(&next_score, 1)?;
335
336            let z = mask.i(i)?.unsqueeze(1)?.broadcast_as(next_score.shape())?;
337            score = z.where_cond(&next_score, &score)?;
338            history.push(indices);
339        }
340
341        score = score.broadcast_add(&self.end_transitions)?;
342
343        let seq_ends = mask
344            .to_dtype(DType::I64)?
345            .sum(0)?
346            .broadcast_sub(&Tensor::ones(1, DType::I64, mask.device())?)?;
347
348        let mut best_tags_list = vec![];
349
350        for idx in 0..batch_size {
351            let best_last_tag = score.i(idx)?.argmax(0)?;
352
353            let mut best_tags = vec![best_last_tag.to_scalar::<u32>()?];
354
355            let z = seq_ends.i(idx)?.to_scalar::<i64>()? as usize;
356            let mut a = history[..z].to_vec();
357            a.reverse();
358            for hist in a.iter() {
359                let last_idx = *best_tags.last().unwrap() as usize;
360                let best_last_tag = hist.i(idx)?.i(last_idx)?;
361                best_tags.push(best_last_tag.to_scalar::<u32>()?);
362            }
363
364            best_tags.reverse();
365            best_tags_list.push(best_tags);
366        }
367
368        Ok(best_tags_list)
369    }
370
371    /// decode
372    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L118
373    pub fn decode(&self, emissions: &Tensor, mask: Option<&Tensor>) -> Result<Vec<Vec<u32>>> {
374        self.validate(emissions, None, mask)?;
375        let mask = if let Some(mask) = mask {
376            mask.clone()
377        } else {
378            let (d1, d2, _) = emissions.dims3()?;
379            Tensor::ones((d1, d2), DType::U8, emissions.device())?
380        };
381
382        let (emissions, mask) = if self.batch_first {
383            (emissions.transpose(0, 1)?, mask.transpose(0, 1)?)
384        } else {
385            (emissions.clone(), mask.clone())
386        };
387        self.viterbi_decode(&emissions, &mask)
388    }
389
390    /// Forward
391    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/torchcrf/__init__.py#L63
392    pub fn forward(
393        &self,
394        emissions: &Tensor,
395        tags: &Tensor,
396        mask: Option<&Tensor>,
397        reduction: Reduction,
398    ) -> Result<Tensor> {
399        self.validate(emissions, Some(tags), mask)?;
400        let mask = if let Some(mask) = mask {
401            mask.clone()
402        } else {
403            Tensor::ones_like(tags)?.to_dtype(DType::U8)?
404        };
405
406        let (emissions, tags, mask) = if self.batch_first {
407            (
408                emissions.transpose(0, 1)?,
409                tags.transpose(0, 1)?,
410                mask.transpose(0, 1)?,
411            )
412        } else {
413            (emissions.clone(), tags.clone(), mask.clone())
414        };
415
416        let numerator = self.compute_score(&emissions, &tags, &mask)?;
417        let denominator = self.compute_normalizer(&emissions, &mask)?;
418
419        let llh = numerator.broadcast_sub(&denominator)?;
420
421        match reduction {
422            Reduction::Sum => llh.sum_all(),
423            Reduction::Mean => llh.mean_all(),
424            Reduction::TokenMean => {
425                let mask = mask.to_dtype(llh.dtype())?;
426                let z = mask.sum_all()?;
427                llh.sum_all()?.broadcast_div(&z)
428            }
429            Reduction::None => Ok(llh),
430        }
431    }
432}
433
434// -----------------------------------------------------------------------------
435
436pub(crate) fn all(x: &Tensor) -> Result<bool> {
437    let zero = x.zeros_like()?;
438    Ok(x.broadcast_ne(&zero)?.min_all()?.to_scalar::<u8>()? != 0)
439}
440
441// -----------------------------------------------------------------------------
442
443pub(crate) fn gather(src: &Tensor, idx: &Tensor) -> Result<Tensor> {
444    let index = idx.reshape((idx.dim(0)?, 1))?;
445    src.gather(&index, D::Minus1)?.squeeze(D::Minus1)
446}
447
448// -----------------------------------------------------------------------------
449
450pub(crate) fn max_indices<D: Dim + Copy>(x: &Tensor, dim: D) -> Result<(Tensor, Tensor)> {
451    let max = x.max(dim)?;
452    let idx = x.argmax(dim)?;
453    Ok((max, idx))
454}
455
456// -----------------------------------------------------------------------------
457
458#[cfg(test)]
459mod tests {
460
461    /*
462    following tests correspond to the following PyTorch script.
463    https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L12
464
465    RANDOM_SEED = 1478754
466
467    random.seed(RANDOM_SEED)
468    torch.manual_seed(RANDOM_SEED)
469     */
470
471    use super::*;
472    use anyhow::Result;
473    use candle_core::{utils, DType, Device, IndexOp, Tensor};
474    use itertools::Itertools;
475
476    #[cfg(feature = "metal")]
477    const OK_TYPES: [DType; 1] = [DType::F32];
478    #[cfg(feature = "cuda")]
479    const OK_TYPES: [DType; 2] = [DType::F32, DType::F64];
480    #[cfg(not(any(feature = "cuda", feature = "metal")))]
481    const OK_TYPES: [DType; 4] = [DType::F32, DType::F64, DType::F16, DType::BF16];
482
483    #[cfg(feature = "metal")]
484    const FAIL_TYPES: [DType; 6] = [
485        DType::U8,
486        DType::U32,
487        DType::I64,
488        DType::F16,
489        DType::BF16,
490        DType::F64,
491    ];
492    #[cfg(feature = "cuda")]
493    const FAIL_TYPES: [DType; 5] = [DType::U8, DType::U32, DType::I64, DType::F16, DType::BF16];
494    #[cfg(not(any(feature = "cuda", feature = "metal")))]
495    const FAIL_TYPES: [DType; 3] = [DType::U8, DType::U32, DType::I64];
496
497    enum DTypeCase {
498        DType(DType),
499        PyTorchCRF,
500    }
501
502    #[cfg(feature = "metal")]
503    fn epsilon(type_case: DTypeCase) -> f32 {
504        match type_case {
505            DTypeCase::DType(dtype) => match dtype {
506                DType::F64 => 1e-6,
507                DType::F32 => 1e-4,
508                DType::F16 => 1e-2,
509                DType::BF16 => 1e-1,
510                _ => panic!("dtype not supported"),
511            },
512            DTypeCase::PyTorchCRF => 1e-3,
513        }
514    }
515
516    #[cfg(not(feature = "metal"))]
517    fn epsilon(type_case: DTypeCase) -> f64 {
518        match type_case {
519            DTypeCase::DType(dtype) => match dtype {
520                DType::F64 => 1e-6,
521                DType::F32 => 1e-4,
522                DType::F16 => 1e-2,
523                DType::BF16 => 1e-1,
524                _ => panic!("dtype not supported"),
525            },
526            DTypeCase::PyTorchCRF => 1e-3,
527        }
528    }
529
530    #[cfg(feature = "metal")]
531    fn assert_tensor_close(a: &Tensor, b: &Tensor, epsilon: f32) -> Result<()> {
532        assert!(a.dtype() == b.dtype());
533        assert_eq!(a.shape(), b.shape());
534        let epsilon = Tensor::full(epsilon, a.shape(), a.device())?.to_dtype(a.dtype())?;
535        let diff = a.broadcast_sub(b)?.abs()?;
536        let result = all(&diff.broadcast_le(&epsilon)?)?;
537        assert!(result);
538        Ok(())
539    }
540
541    #[cfg(not(feature = "metal"))]
542    fn assert_tensor_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
543        assert!(a.dtype() == b.dtype());
544        assert_eq!(a.shape(), b.shape());
545        let epsilon = Tensor::full(epsilon, a.shape(), a.device())?.to_dtype(a.dtype())?;
546        let diff = a.broadcast_sub(b)?.abs()?;
547        let result = all(&diff.broadcast_le(&epsilon)?)?;
548        assert!(result);
549        Ok(())
550    }
551
552    #[cfg(feature = "metal")]
553    fn cartestian_product(r: Vec<u32>, repeat: usize, dev: &Device) -> Result<Vec<Tensor>> {
554        use itertools::Itertools;
555
556        if repeat <= 1 {
557            return Ok(vec![Tensor::new(r.as_slice(), dev)?]);
558        }
559
560        let mut a: Vec<Vec<u32>> = r
561            .iter()
562            .cartesian_product(r.iter())
563            .map(|(&x, &y)| vec![x, y])
564            .collect();
565        for _ in 2..repeat {
566            a = a
567                .iter()
568                .cartesian_product(r.iter())
569                .map(|(x, &y)| {
570                    let mut z = Vec::from(x.to_owned());
571                    z.push(y);
572                    z
573                })
574                .collect();
575        }
576        Ok(a.iter()
577            .map(|x| Tensor::new(x.as_slice(), dev).unwrap())
578            .collect())
579    }
580
581    #[cfg(not(feature = "metal"))]
582    fn cartestian_product(r: Vec<i64>, repeat: usize, dev: &Device) -> Result<Vec<Tensor>> {
583        use itertools::Itertools;
584
585        if repeat <= 1 {
586            return Ok(vec![Tensor::new(r.as_slice(), dev)?]);
587        }
588
589        let mut a: Vec<Vec<i64>> = r
590            .iter()
591            .cartesian_product(r.iter())
592            .map(|(&x, &y)| vec![x, y])
593            .collect();
594        for _ in 2..repeat {
595            a = a
596                .iter()
597                .cartesian_product(r.iter())
598                .map(|(x, &y)| {
599                    let mut z = Vec::from(x.to_owned());
600                    z.push(y);
601                    z
602                })
603                .collect();
604        }
605        Ok(a.iter()
606            .map(|x| Tensor::new(x.as_slice(), dev).unwrap())
607            .collect())
608    }
609
610    fn cat_scalar_tensor(tensors: Vec<Tensor>) -> candle_core::Result<Tensor> {
611        let tensors: Vec<Tensor> = tensors
612            .into_iter()
613            .map(|t| t.unsqueeze(0).unwrap())
614            .collect();
615        Tensor::cat(&tensors, 0)
616    }
617
618    fn use_gpu(gpu: bool) -> candle_core::Result<Device> {
619        if gpu {
620            if utils::cuda_is_available() {
621                println!("CUDA is available");
622                Device::new_cuda(0)
623            } else if utils::metal_is_available() {
624                println!("Metal is available");
625                Device::new_metal(0)
626            } else {
627                println!("CUDA and Metal are not available, using CPU");
628                Ok(Device::Cpu)
629            }
630        } else {
631            println!("Using CPU");
632            Ok(Device::Cpu)
633        }
634    }
635
636    #[test]
637    fn test_cat_scalar_tensor() -> Result<()> {
638        #[cfg(any(feature = "cuda", feature = "metal"))]
639        let device = use_gpu(true)?;
640        #[cfg(not(any(feature = "cuda", feature = "metal")))]
641        let device = use_gpu(false)?;
642
643        let mut lst = vec![];
644        for i in 0..10 {
645            let x = Tensor::full(i as f32, (), &device)?;
646            lst.push(x);
647        }
648
649        let result = cat_scalar_tensor(lst)?;
650        assert_eq!(
651            result.to_vec1::<f32>().unwrap(),
652            vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]
653        );
654        Ok(())
655    }
656
657    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L37
658    fn make_crf(
659        num_tags: usize,
660        batch_first: bool,
661        start: Option<Tensor>,
662        end: Option<Tensor>,
663        transition: Option<Tensor>,
664        device: &Device,
665    ) -> candle_core::Result<CRF> {
666        let mut crf = CRF::new(num_tags, batch_first, device)?;
667        if let Some(start) = start {
668            crf.start_transitions = start
669        }
670        if let Some(end) = end {
671            crf.end_transitions = end
672        }
673        if let Some(transition) = transition {
674            crf.transitions = transition
675        }
676        Ok(crf)
677    }
678
679    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L23
680    fn assert_all<D: candle_core::WithDType>(x: &Tensor, lo: D, up: D) -> Result<bool> {
681        assert_eq!(x.dims().len(), 1);
682        let dim = x.dims1()?;
683        for i in 0..dim {
684            let a = x.i(i)?.to_scalar::<D>()?;
685            if a < lo || a > up {
686                return Ok(false);
687            }
688        }
689        Ok(true)
690    }
691
692    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L18
693    fn compute_score(crf: &CRF, emission: &Tensor, tag: &Tensor) -> Result<Tensor> {
694        assert_eq!(emission.dims().len(), 2);
695        let (emission_dim1, emission_dim2) = emission.dims2()?;
696        let tag_dim1 = tag.dims1()?;
697        assert_eq!(emission_dim1, tag_dim1);
698        assert_eq!(emission_dim2, crf.num_tags);
699
700        #[cfg(feature = "metal")]
701        assert_all(tag, 0_u32, crf.num_tags as u32 - 1)?;
702        #[cfg(not(feature = "metal"))]
703        assert_all(tag, 0_i64, crf.num_tags as i64 - 1)?;
704
705        #[cfg(feature = "metal")]
706        let tag_vec = tag.to_vec1::<u32>()?;
707        #[cfg(not(feature = "metal"))]
708        let tag_vec = tag.to_vec1::<i64>()?;
709
710        let mut score = crf
711            .start_transitions
712            .i(tag_vec[0] as usize)?
713            .broadcast_add(&crf.end_transitions.i(tag_vec[tag_vec.len() - 1] as usize)?)?;
714
715        for (cur_tag, next_tag) in tag_vec.iter().zip(tag_vec.iter().skip(1)) {
716            let z = crf.transitions.i((*cur_tag as usize, *next_tag as usize))?;
717            score = score.broadcast_add(&z)?;
718        }
719
720        for (i, &t) in tag_vec.iter().enumerate() {
721            let z = emission.i((i, t as usize))?;
722            score = score.broadcast_add(&z)?;
723        }
724
725        Ok(score)
726    }
727
728    #[test]
729    fn test_init_with_dtype() -> Result<()> {
730        #[cfg(any(feature = "cuda", feature = "metal"))]
731        let device = use_gpu(true)?;
732        #[cfg(not(any(feature = "cuda", feature = "metal")))]
733        let device = use_gpu(false)?;
734
735        for dtype in OK_TYPES {
736            assert!(CRF::new_with_dtype(10, false, dtype, &device).is_ok());
737        }
738
739        for dtype in FAIL_TYPES {
740            assert!(CRF::new_with_dtype(10, false, dtype, &device).is_err());
741        }
742        Ok(())
743    }
744
745    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L60
746    #[test]
747    fn test_init_minial() -> Result<()> {
748        #[cfg(any(feature = "cuda", feature = "metal"))]
749        let device = use_gpu(true)?;
750        #[cfg(not(any(feature = "cuda", feature = "metal")))]
751        let device = use_gpu(false)?;
752
753        let num_tags = 10;
754        let crf = CRF::new(num_tags, false, &device)?;
755        assert_eq!(crf.num_tags, num_tags);
756        assert!(!crf.batch_first);
757        assert_eq!(crf.start_transitions.dims1()?, num_tags);
758        assert_eq!(crf.end_transitions.dims1()?, num_tags);
759        assert_eq!(crf.transitions.dims2()?, (num_tags, num_tags));
760        println!("crf:{}", crf);
761        Ok(())
762    }
763
764    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L74
765    #[test]
766    fn test_init_full() -> Result<()> {
767        #[cfg(any(feature = "cuda", feature = "metal"))]
768        let device = use_gpu(true)?;
769        #[cfg(not(any(feature = "cuda", feature = "metal")))]
770        let device = use_gpu(false)?;
771
772        let crf = CRF::new(10, true, &device)?;
773        assert!(crf.batch_first);
774        Ok(())
775    }
776
777    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L78C9-L78C34
778    #[test]
779    fn test_init_nonpositive_num_tags() -> Result<()> {
780        #[cfg(any(feature = "cuda", feature = "metal"))]
781        let device = use_gpu(true)?;
782        #[cfg(not(any(feature = "cuda", feature = "metal")))]
783        let device = use_gpu(false)?;
784
785        let crf = CRF::new(0, false, &device);
786        assert!(crf.is_err());
787
788        Ok(())
789    }
790
791    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L85
792    fn forward_works_with_mask(dtype: DType, device: &Device) -> Result<()> {
793        let crf = make_crf(
794            5,
795            false,
796            Some(
797                Tensor::new(&[-0.0687_f32, 0.0698, -0.0447, 0.0421, 0.0782], device)?
798                    .to_dtype(dtype)?,
799            ),
800            Some(
801                Tensor::new(&[0.0061_f32, -0.0671, -0.0797, 0.0629, -0.0136], device)?
802                    .to_dtype(dtype)?,
803            ),
804            Some(
805                Tensor::new(
806                    &[
807                        [0.0489_f32, -0.0002, 0.0619, 0.0458, 0.0662],
808                        [0.0707, 0.0297, -0.0422, 0.0831, -0.0038],
809                        [0.0439, 0.0178, -0.0754, 0.0260, 0.0681],
810                        [0.0191, 0.0755, 0.0230, 0.0209, -0.0768],
811                        [0.0303, 0.0592, -0.0297, 0.0681, 0.0801],
812                    ],
813                    device,
814                )?
815                .to_dtype(dtype)?,
816            ),
817            device,
818        )?;
819
820        let emissions = Tensor::new(
821            &[
822                [
823                    [1.1699_f32, 1.1900, -0.7254, 0.1490, -1.4910],
824                    [-1.2101, 0.4538, 1.3654, 0.0135, -1.8480],
825                ],
826                [
827                    [0.5861, -0.1651, 0.9721, 0.4464, -0.5512],
828                    [-1.2701, -1.5360, 0.0037, 0.5853, -0.9926],
829                ],
830                [
831                    [-1.7625, 0.5437, 1.6322, -1.1274, -0.1313],
832                    [-0.9301, 0.8906, -2.6483, 0.5849, -1.1069],
833                ],
834            ],
835            device,
836        )?
837        .to_dtype(dtype)?;
838
839        #[cfg(feature = "metal")]
840        let tags = Tensor::new(&[[2_u32, 4], [3, 3], [4, 2]], device)?;
841        #[cfg(not(feature = "metal"))]
842        let tags = Tensor::new(&[[2_i64, 4], [3, 3], [4, 2]], device)?;
843        let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
844        let llh = crf.forward(&emissions, &tags, Some(&mask), Reduction::default())?;
845        println!("llh: {:?}", llh);
846
847        let emissions = emissions.transpose(0, 1)?;
848        let tags = tags.transpose(0, 1)?;
849        let mask = mask.transpose(0, 1)?;
850
851        let (a, _, _) = emissions.dims3()?;
852        let mut manual_llh = Tensor::zeros(llh.shape(), llh.dtype(), llh.device())?;
853
854        for i in 0..a {
855            let emission = emissions.i(i)?;
856            let tag = tags.i(i)?;
857            let mask = mask.i(i)?;
858
859            let seq_len = mask.sum_all().unwrap().to_scalar::<u8>()? as usize;
860            let emission = emission.i(..seq_len)?;
861            let tag = tag.i(..seq_len)?;
862            let numerator = compute_score(&crf, &emission, &tag)?;
863
864            #[cfg(feature = "metal")]
865            let num_tags = crf.num_tags as u32;
866            #[cfg(not(feature = "metal"))]
867            let num_tags = crf.num_tags as i64;
868
869            let product = cartestian_product((0..num_tags).collect_vec(), seq_len, device)?;
870            let all_scores = product
871                .iter()
872                .map(|t| compute_score(&crf, &emission, &t).unwrap());
873
874            let mut denominator =
875                Tensor::zeros(numerator.shape(), numerator.dtype(), numerator.device())?;
876
877            for s in all_scores.into_iter() {
878                denominator = denominator.broadcast_add(&s.exp()?)?;
879            }
880            let denominator = denominator.log()?;
881
882            manual_llh = manual_llh.broadcast_add(&numerator.broadcast_sub(&denominator)?)?;
883        }
884        println!("manual_llh: {:?}", manual_llh);
885        assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(llh.dtype())))?;
886
887        if llh.dtype() == DType::F32 {
888            let manual_llh = Tensor::full(-11.0540_f32, llh.shape(), llh.device())?;
889            println!("Compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
890            assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
891        }
892        llh.backward()?;
893        Ok(())
894    }
895
896    #[test]
897    fn test_forward_works_with_mask() -> Result<()> {
898        #[cfg(any(feature = "cuda", feature = "metal"))]
899        let device = use_gpu(true)?;
900        #[cfg(not(any(feature = "cuda", feature = "metal")))]
901        let device = use_gpu(false)?;
902
903        for dtype in OK_TYPES {
904            forward_works_with_mask(dtype, &device)?;
905        }
906        Ok(())
907    }
908
909    /// https://github.com/kmkurn/pytorch-crf/blob/623e3402d00a2728e99d6e8486010d67c754267b/tests/test_crf.py#L122C9-L122C32
910    fn forward_works_without_mask(dtype: DType, device: &Device) -> Result<()> {
911        let crf = make_crf(
912            5,
913            false,
914            Some(
915                Tensor::new(&[0.0266_f32, -0.0539, 0.0572, -0.0199, -0.0167], device)?
916                    .to_dtype(dtype)?,
917            ),
918            Some(
919                Tensor::new(&[0.0084_f32, 0.0892, 0.0942, -0.0179, 0.0112], device)?
920                    .to_dtype(dtype)?,
921            ),
922            Some(
923                Tensor::new(
924                    &[
925                        [0.0456_f32, 0.0560, 0.0396, 0.0289, 0.0187],
926                        [-0.0951, -0.0286, 0.0582, 0.0384, 0.0863],
927                        [-0.0137, 0.0764, -0.0414, 0.0722, -0.0287],
928                        [0.0365, -0.0033, 0.0726, -0.0620, 0.0516],
929                        [0.0925, -0.0708, 0.0765, 0.0671, -0.0344],
930                    ],
931                    device,
932                )?
933                .to_dtype(dtype)?,
934            ),
935            device,
936        )?;
937
938        let emissions = Tensor::new(
939            &[
940                [
941                    [0.5463_f32, 2.0856, -0.6247, -1.0225, 0.5277],
942                    [-0.4172, -1.4281, -0.5658, -0.5217, -0.6321],
943                ],
944                [
945                    [0.4759, -0.8485, 1.0046, 0.0720, 0.3853],
946                    [-0.7525, 0.1041, 0.2371, 0.5746, -0.5599],
947                ],
948                [
949                    [-0.5022, -0.2030, 0.3655, 0.0714, 1.2449],
950                    [0.1266, 0.6654, -1.1915, -0.1181, 0.0167],
951                ],
952            ],
953            device,
954        )?
955        .to_dtype(dtype)?;
956
957        #[cfg(feature = "metal")]
958        let tags = Tensor::new(&[[3_u32, 2], [3, 1], [4, 3]], device)?;
959        #[cfg(not(feature = "metal"))]
960        let tags = Tensor::new(&[[3_i64, 2], [3, 1], [4, 3]], device)?;
961
962        let llh_no_mask = crf.forward(&emissions, &tags, None, Reduction::default())?;
963
964        let llh_mask = crf.forward(
965            &emissions,
966            &tags,
967            Some(&Tensor::ones_like(&tags)?.to_dtype(DType::U8)?),
968            Reduction::default(),
969        )?;
970
971        println!("llh_no_mask: {:?}", llh_no_mask);
972        println!("llh_mask: {:?}", llh_mask);
973
974        if llh_no_mask.dtype() == DType::F32 {
975            let manual_llh = Tensor::full(-11.0571_f32, llh_no_mask.shape(), llh_no_mask.device())?;
976            println!(
977                "compare with pytorch-crf: {:?}, {:?}",
978                llh_no_mask, manual_llh
979            );
980            assert_tensor_close(&llh_no_mask, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
981        }
982
983        assert_tensor_close(
984            &llh_no_mask,
985            &llh_mask,
986            epsilon(DTypeCase::DType(llh_no_mask.dtype())),
987        )
988    }
989
990    #[test]
991    fn test_forward_works_without_mask() -> Result<()> {
992        #[cfg(any(feature = "cuda", feature = "metal"))]
993        let device = use_gpu(true)?;
994        #[cfg(not(any(feature = "cuda", feature = "metal")))]
995        let device = use_gpu(false)?;
996
997        for dtype in OK_TYPES {
998            forward_works_without_mask(dtype, &device)?;
999        }
1000        Ok(())
1001    }
1002
1003    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L135
1004    fn forward_batched_loss(dtype: DType, device: &Device) -> Result<()> {
1005        let batch_size = 10;
1006        let crf = make_crf(
1007            5,
1008            false,
1009            Some(
1010                Tensor::new(&[-0.0695_f32, -0.0117, -0.0021, -0.0635, -0.0328], device)?
1011                    .to_dtype(dtype)?,
1012            ),
1013            Some(
1014                Tensor::new(&[-0.0524_f32, -0.0827, 0.0868, -0.0140, 0.0131], device)?
1015                    .to_dtype(dtype)?,
1016            ),
1017            Some(
1018                Tensor::new(
1019                    &[
1020                        [-0.0330_f32, 0.0894, -0.0732, 0.0996, 0.0014],
1021                        [-0.0514, -0.0677, -0.0611, -0.0168, 0.0297],
1022                        [0.0580, -0.0224, -0.0465, -0.0527, -0.0133],
1023                        [0.0506, 0.0535, -0.0378, -0.0537, 0.0516],
1024                        [-0.0037, 0.0763, -0.0867, 0.0410, 0.0368],
1025                    ],
1026                    device,
1027                )?
1028                .to_dtype(dtype)?,
1029            ),
1030            device,
1031        )?;
1032
1033        let emissions = Tensor::new(
1034            &[
1035                [
1036                    [
1037                        -1.5867e+00_f32,
1038                        -4.0363e-01,
1039                        1.7869e-02,
1040                        -5.0247e-01,
1041                        8.2934e-01,
1042                    ],
1043                    [-8.5983e-01, -4.0548e-01, 1.3188e-01, 9.8255e-01, 2.2588e-01],
1044                    [
1045                        -6.8282e-01,
1046                        1.8752e+00,
1047                        -3.4774e-01,
1048                        -1.0902e+00,
1049                        1.7499e-01,
1050                    ],
1051                    [3.7244e-01, 1.1534e+00, 7.7696e-01, 3.4387e-01, -9.8422e-01],
1052                    [3.8489e-02, 8.2353e-01, -8.2190e-01, 8.6248e-02, 1.2238e-01],
1053                    [4.4424e-02, 1.3664e+00, -1.3658e+00, -2.4691e-01, 1.1135e+00],
1054                    [1.2708e+00, 2.9114e-01, 1.0744e+00, 5.3505e-02, -1.5511e-01],
1055                    [
1056                        1.3976e+00,
1057                        -1.1226e+00,
1058                        -9.2870e-01,
1059                        1.1908e-01,
1060                        -1.6336e+00,
1061                    ],
1062                    [6.0694e-01, 2.5764e-01, -6.8925e-01, 1.1807e+00, -6.5968e-01],
1063                    [3.5677e-01, -1.4314e+00, 9.4358e-01, 7.9112e-01, -2.1923e-01],
1064                ],
1065                [
1066                    [1.3654e+00, -2.3797e-01, 6.2540e-02, 1.5489e+00, -2.0502e+00],
1067                    [1.3639e+00, -5.9433e-01, 7.0876e-01, -4.9674e-01, 6.9055e-02],
1068                    [
1069                        2.3545e-01,
1070                        -4.0388e-01,
1071                        1.2455e+00,
1072                        -4.1925e-01,
1073                        -3.9647e-01,
1074                    ],
1075                    [
1076                        -6.3912e-01,
1077                        -1.4287e+00,
1078                        -2.2617e+00,
1079                        -5.6802e-01,
1080                        7.4044e-01,
1081                    ],
1082                    [3.8845e-01, -9.7110e-01, -5.7113e-01, 1.3628e+00, 7.4219e-01],
1083                    [
1084                        -7.7064e-01,
1085                        9.3300e-01,
1086                        -1.4319e+00,
1087                        -1.5991e+00,
1088                        2.6631e-01,
1089                    ],
1090                    [
1091                        1.7472e+00,
1092                        -5.9296e-01,
1093                        -1.3249e-03,
1094                        1.4543e-01,
1095                        -6.5364e-01,
1096                    ],
1097                    [
1098                        -3.5911e-01,
1099                        4.4189e-02,
1100                        -1.2928e+00,
1101                        -1.1482e+00,
1102                        1.2672e+00,
1103                    ],
1104                    [1.3452e+00, -2.3875e+00, 1.4895e+00, -7.3329e-01, 2.1750e-01],
1105                    [
1106                        -3.9819e-01,
1107                        4.5757e-01,
1108                        -5.0534e-01,
1109                        -3.0911e+00,
1110                        -1.1324e+00,
1111                    ],
1112                ],
1113                [
1114                    [
1115                        -3.4185e-01,
1116                        -1.0406e+00,
1117                        -4.3079e-01,
1118                        -4.5273e-02,
1119                        1.1170e+00,
1120                    ],
1121                    [
1122                        -8.5589e-01,
1123                        9.4792e-01,
1124                        -8.8419e-01,
1125                        -7.7756e-01,
1126                        -1.7976e-01,
1127                    ],
1128                    [-1.8891e-01, 1.7120e-01, -4.3634e-01, 1.2762e+00, 1.0334e+00],
1129                    [
1130                        2.7852e-01,
1131                        -1.5482e+00,
1132                        5.6432e-01,
1133                        -1.1859e+00,
1134                        -7.0821e-02,
1135                    ],
1136                    [3.4364e-01, 1.2222e+00, 1.0542e+00, -1.7861e-01, 6.4608e-01],
1137                    [-8.4590e-01, 1.4749e+00, 3.7927e-01, 2.2527e+00, -3.5637e-02],
1138                    [
1139                        4.5344e-01,
1140                        -1.4359e+00,
1141                        -2.2955e+00,
1142                        -9.4110e-01,
1143                        -8.5992e-01,
1144                    ],
1145                    [6.8505e-01, -1.5822e-01, -6.9359e-01, 5.9559e-02, 6.8955e-01],
1146                    [
1147                        -3.4006e-01,
1148                        1.7685e+00,
1149                        2.1671e-01,
1150                        -8.6512e-01,
1151                        -2.6517e-01,
1152                    ],
1153                    [1.0503e-01, 1.6486e+00, -2.4486e-01, 5.4843e-01, 1.9252e+00],
1154                ],
1155            ],
1156            device,
1157        )?
1158        .to_dtype(dtype)?;
1159
1160        #[cfg(feature = "metal")]
1161        let tags = Tensor::new(
1162            &[
1163                [1_u32, 0, 0, 1, 2, 1, 0, 4, 3, 0],
1164                [3, 2, 3, 4, 4, 1, 2, 1, 4, 1],
1165                [0, 1, 4, 4, 0, 4, 0, 1, 4, 1],
1166            ],
1167            device,
1168        )?;
1169        #[cfg(not(feature = "metal"))]
1170        let tags = Tensor::new(
1171            &[
1172                [1_i64, 0, 0, 1, 2, 1, 0, 4, 3, 0],
1173                [3, 2, 3, 4, 4, 1, 2, 1, 4, 1],
1174                [0, 1, 4, 4, 0, 4, 0, 1, 4, 1],
1175            ],
1176            device,
1177        )?;
1178
1179        let llh = crf.forward(&emissions, &tags, None, Reduction::default())?;
1180
1181        llh.dims0()?;
1182        let mut total_llh = Tensor::zeros(llh.shape(), llh.dtype(), llh.device())?;
1183        for i in 0..batch_size {
1184            let emissions = emissions
1185                .i((.., i, ..))?
1186                .contiguous()? // force contiguous to fix tensor indexer.
1187                .unsqueeze(1)?;
1188
1189            let tags = tags
1190                .i((.., i))?
1191                .contiguous()? // force contiguous to fix tensor indexer.
1192                .unsqueeze(1)?;
1193
1194            total_llh = total_llh.broadcast_add(&crf.forward(
1195                &emissions,
1196                &tags,
1197                None,
1198                Reduction::default(),
1199            )?)?;
1200        }
1201
1202        println!("llh: {:?}", llh);
1203        println!("total_llh: {:?}", total_llh);
1204
1205        if llh.dtype() == DType::F32 {
1206            let manual_llh = Tensor::full(-49.2024_f32, llh.shape(), llh.device())?;
1207            println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1208            assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1209        }
1210
1211        assert_tensor_close(&llh, &total_llh, epsilon(DTypeCase::DType(llh.dtype())))
1212    }
1213
1214    #[test]
1215    fn test_forward_batched_loss() -> Result<()> {
1216        #[cfg(any(feature = "cuda", feature = "metal"))]
1217        let device = use_gpu(true)?;
1218        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1219        let device = use_gpu(false)?;
1220
1221        for dtype in OK_TYPES {
1222            forward_batched_loss(dtype, &device)?;
1223        }
1224        Ok(())
1225    }
1226
1227    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L159
1228    fn forward_reduction_none(dtype: DType, device: &Device) -> Result<()> {
1229        let crf = make_crf(
1230            5,
1231            false,
1232            Some(
1233                Tensor::new(&[0.0432_f32, 0.0507, -0.0286, 0.0476, -0.0603], device)?
1234                    .to_dtype(dtype)?,
1235            ),
1236            Some(
1237                Tensor::new(&[0.0824_f32, 0.0845, -0.0180, -0.0773, 0.0414], device)?
1238                    .to_dtype(dtype)?,
1239            ),
1240            Some(
1241                Tensor::new(
1242                    &[
1243                        [-0.0894_f32, 0.0512, 0.0066, 0.0534, -0.0182],
1244                        [0.0043, 0.0328, -0.0805, -0.0945, 0.0495],
1245                        [-0.0020, 0.0416, -0.0441, 0.0390, 0.0690],
1246                        [-0.0260, 0.0720, 0.0017, -0.0552, -0.0470],
1247                        [0.0104, 0.0299, -0.0182, -0.0515, 0.0424],
1248                    ],
1249                    device,
1250                )?
1251                .to_dtype(dtype)?,
1252            ),
1253            device,
1254        )?;
1255
1256        let emissions = Tensor::new(
1257            &[
1258                [
1259                    [-0.2560_f32, 1.2459, -2.0063, -0.5449, 1.0978],
1260                    [0.7233, -0.6967, 0.3394, 0.7784, -3.0362],
1261                ],
1262                [
1263                    [-1.3406, -1.1565, 0.0870, 1.8249, 1.3740],
1264                    [-1.3396, -1.0208, 0.6608, -0.5917, 1.3850],
1265                ],
1266                [
1267                    [1.1436, 0.4477, 0.6606, 1.5938, -0.1054],
1268                    [-0.5401, 1.1908, -1.7266, -0.5858, -1.4395],
1269                ],
1270            ],
1271            device,
1272        )?
1273        .to_dtype(dtype)?;
1274
1275        #[cfg(feature = "metal")]
1276        let tags = Tensor::new(&[[1_u32, 3], [1, 3], [1, 2]], device)?;
1277        #[cfg(not(feature = "metal"))]
1278        let tags = Tensor::new(&[[1_i64, 3], [1, 3], [1, 2]], device)?;
1279
1280        let llh = crf.forward(&emissions, &tags, None, Reduction::None)?;
1281        println!("llh: {:?}", llh);
1282        let (seq_length, batch_size) = tags.dims2()?;
1283        assert_eq!(llh.dims1()?, batch_size);
1284
1285        let emissions = emissions.transpose(0, 1)?;
1286        let tags = tags.transpose(0, 1)?;
1287
1288        let (a, _, _) = emissions.dims3()?;
1289        let mut manual_llh = vec![];
1290        for i in 0..a {
1291            let emission = emissions.i(i)?;
1292            let tag = tags.i(i)?;
1293
1294            let numerator = compute_score(&crf, &emission, &tag)?;
1295
1296            #[cfg(feature = "metal")]
1297            let num_tags = crf.num_tags as u32;
1298            #[cfg(not(feature = "metal"))]
1299            let num_tags = crf.num_tags as i64;
1300            let product = cartestian_product((0..num_tags).collect_vec(), seq_length, device)?;
1301
1302            let all_scores = product
1303                .iter()
1304                .map(|t| compute_score(&crf, &emission, &t).unwrap());
1305
1306            let mut denominator = numerator.zeros_like()?;
1307
1308            for s in all_scores.into_iter() {
1309                denominator = denominator.broadcast_add(&s.exp()?)?;
1310            }
1311
1312            let denominator = denominator.log()?;
1313            manual_llh.push((numerator - denominator)?);
1314        }
1315
1316        let manual_llh = cat_scalar_tensor(manual_llh)?;
1317        println!("manual_llh: {:?}", manual_llh);
1318        if dtype == DType::F32 {
1319            let manual_llh = Tensor::new(&[-6.3064_f32, -7.0368], device)?;
1320            println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1321            assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1322        }
1323        assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(llh.dtype())))
1324    }
1325
1326    #[test]
1327    fn test_forward_reduction_none() -> Result<()> {
1328        #[cfg(any(feature = "cuda", feature = "metal"))]
1329        let device = use_gpu(true)?;
1330        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1331        let device = use_gpu(false)?;
1332
1333        for dtype in OK_TYPES {
1334            forward_reduction_none(dtype, &device)?;
1335        }
1336        Ok(())
1337    }
1338
1339    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L192
1340    fn forward_reduction_mean(dtype: DType, device: &Device) -> Result<()> {
1341        let crf = make_crf(
1342            5,
1343            false,
1344            Some(
1345                Tensor::new(&[0.0606_f32, -0.0597, 0.0217, -0.0760, 0.0096], device)?
1346                    .to_dtype(dtype)?,
1347            ),
1348            Some(
1349                Tensor::new(&[-0.0791_f32, -0.0159, 0.0525, 0.0451, -0.0373], device)?
1350                    .to_dtype(dtype)?,
1351            ),
1352            Some(
1353                Tensor::new(
1354                    &[
1355                        [
1356                            5.0599e-02_f32,
1357                            -1.4571e-02,
1358                            2.2383e-02,
1359                            3.3254e-02,
1360                            2.5206e-03,
1361                        ],
1362                        [6.6520e-02, 7.3251e-02, 1.0225e-02, -9.4751e-02, -3.4146e-02],
1363                        [
1364                            -6.7073e-02,
1365                            2.9719e-02,
1366                            -8.5645e-02,
1367                            4.6357e-02,
1368                            -7.2483e-03,
1369                        ],
1370                        [
1371                            4.4980e-02,
1372                            -8.0436e-02,
1373                            6.4611e-05,
1374                            -5.1731e-02,
1375                            -8.2973e-02,
1376                        ],
1377                        [-5.0593e-02, 4.5717e-03, 6.8714e-03, 8.9858e-02, -8.2813e-02],
1378                    ],
1379                    device,
1380                )?
1381                .to_dtype(dtype)?,
1382            ),
1383            device,
1384        )?;
1385
1386        let emissions = Tensor::new(
1387            &[
1388                [
1389                    [0.0535_f32, 0.6821, -0.2587, 1.2250, 0.5327],
1390                    [-2.5028, 0.5942, -0.2508, 0.0597, 1.3800],
1391                ],
1392                [
1393                    [-0.0640, -1.3170, 0.6408, -0.1368, -0.2137],
1394                    [-0.3985, 0.0530, -0.0448, 0.8268, 0.7622],
1395                ],
1396                [
1397                    [1.4061, -0.4045, -0.3174, 0.0737, -1.8753],
1398                    [-1.0892, -0.8641, 0.4778, -0.4032, 0.2838],
1399                ],
1400            ],
1401            device,
1402        )?
1403        .to_dtype(dtype)?;
1404
1405        #[cfg(feature = "metal")]
1406        let tags = Tensor::new(&[[3_u32, 0], [0, 3], [2, 4]], device)?;
1407        #[cfg(not(feature = "metal"))]
1408        let tags = Tensor::new(&[[3_i64, 0], [0, 3], [2, 4]], device)?;
1409
1410        let llh = crf.forward(&emissions, &tags, None, Reduction::Mean)?;
1411        println!("llh: {:?}", llh);
1412
1413        let (seq_length, batch_size) = tags.dims2().unwrap();
1414
1415        let emissions = emissions.transpose(0, 1).unwrap();
1416        let tags = tags.transpose(0, 1).unwrap();
1417
1418        let (a, _, _) = emissions.dims3().unwrap();
1419        let mut manual_llh = llh.zeros_like()?;
1420        for i in 0..a {
1421            let emission = emissions.i(i).unwrap();
1422            let tag = tags.i(i).unwrap();
1423
1424            let numerator = compute_score(&crf, &emission, &tag)?;
1425
1426            #[cfg(feature = "metal")]
1427            let num_tags = crf.num_tags as u32;
1428            #[cfg(not(feature = "metal"))]
1429            let num_tags = crf.num_tags as i64;
1430
1431            let product =
1432                cartestian_product((0..num_tags).collect_vec(), seq_length, &Device::Cpu)?;
1433
1434            let all_scores = product
1435                .iter()
1436                .map(|t| compute_score(&crf, &emission, &t).unwrap());
1437
1438            let mut denominator = numerator.zeros_like()?;
1439
1440            for s in all_scores.into_iter() {
1441                denominator = denominator.broadcast_add(&s.exp()?)?;
1442            }
1443
1444            let denominator = denominator.log()?;
1445
1446            manual_llh = (manual_llh + (numerator - denominator)?)?;
1447        }
1448
1449        #[cfg(feature = "metal")]
1450        let manual_llh = (&manual_llh
1451            / Tensor::full(batch_size as f32, manual_llh.shape(), manual_llh.device())?
1452                .to_dtype(manual_llh.dtype())?)?;
1453
1454        #[cfg(not(feature = "metal"))]
1455        let manual_llh = (&manual_llh
1456            / Tensor::full(batch_size as f64, manual_llh.shape(), manual_llh.device())?
1457                .to_dtype(manual_llh.dtype())?)?;
1458
1459        println!("manual_llh: {:?}", manual_llh);
1460
1461        if dtype == DType::F32 {
1462            let manual_llh = Tensor::new(-5.7756_f32, &llh.device())?;
1463            println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1464            assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1465        }
1466        assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(dtype)))
1467    }
1468
1469    #[test]
1470    fn test_forward_reduction_mean() -> Result<()> {
1471        #[cfg(any(feature = "cuda", feature = "metal"))]
1472        let device = use_gpu(true)?;
1473        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1474        let device = use_gpu(false)?;
1475
1476        for dtype in OK_TYPES {
1477            forward_reduction_mean(dtype, &device)?;
1478        }
1479        Ok(())
1480    }
1481
1482    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L192
1483    fn forward_token_mean(dtype: DType, device: &Device) -> Result<()> {
1484        let crf = make_crf(
1485            5,
1486            false,
1487            Some(
1488                Tensor::new(&[0.0687_f32, 0.0533, 0.0204, 0.0250, -0.0785], device)?
1489                    .to_dtype(dtype)?,
1490            ),
1491            Some(
1492                Tensor::new(
1493                    &[
1494                        4.8827e-02_f32,
1495                        -9.9134e-05,
1496                        9.3184e-02,
1497                        -7.6271e-02,
1498                        3.6482e-02,
1499                    ],
1500                    device,
1501                )?
1502                .to_dtype(dtype)?,
1503            ),
1504            Some(
1505                Tensor::new(
1506                    &[
1507                        [0.0173_f32, -0.0058, -0.0699, -0.0374, 0.0797],
1508                        [-0.0405, 0.0141, -0.0002, 0.0790, 0.0205],
1509                        [-0.0473, 0.0554, -0.0036, 0.0878, 0.0210],
1510                        [0.0761, -0.0406, -0.0905, 0.0590, -0.0030],
1511                        [0.0613, 0.0871, -0.0343, 0.0384, 0.0485],
1512                    ],
1513                    device,
1514                )?
1515                .to_dtype(dtype)?,
1516            ),
1517            device,
1518        )?;
1519
1520        let emissions = Tensor::new(
1521            &[
1522                [
1523                    [0.0110_f32, -0.8502, 0.9678, -0.3219, -0.6029],
1524                    [1.0804, -1.2822, 1.4129, 0.9475, -2.6282],
1525                ],
1526                [
1527                    [0.8993, 0.3029, -0.0686, -0.3108, 0.6216],
1528                    [-2.1503, 1.4301, -0.0301, 0.3572, 0.5460],
1529                ],
1530                [
1531                    [1.3384, 0.8500, 0.0194, -0.6371, 0.1516],
1532                    [-0.7357, 0.3116, 1.5733, -0.8246, -0.4224],
1533                ],
1534            ],
1535            device,
1536        )?
1537        .to_dtype(dtype)?;
1538
1539        #[cfg(feature = "metal")]
1540        let tags = Tensor::new(&[[2_u32, 2], [0, 4], [3, 3]], device)?;
1541        #[cfg(not(feature = "metal"))]
1542        let tags = Tensor::new(&[[2_i64, 2], [0, 4], [3, 3]], device)?;
1543
1544        let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
1545        let llh = crf.forward(&emissions, &tags, Some(&mask), Reduction::TokenMean)?;
1546        println!("llh: {:?}", llh);
1547
1548        let emissions = emissions.transpose(0, 1)?;
1549        let tags = tags.transpose(0, 1)?;
1550        let mask = mask.transpose(0, 1)?;
1551
1552        let (a, _, _) = emissions.dims3()?;
1553        let mut manual_llh = llh.zeros_like()?;
1554        let mut total_tokens = 0;
1555        for i in 0..a {
1556            let emission = emissions.i(i)?;
1557            let tag = tags.i(i)?;
1558            let mask = mask.i(i)?;
1559
1560            let seq_len = mask.sum_all()?.to_scalar::<u8>()? as usize;
1561            let emission = emission.i(..seq_len)?;
1562            let tag = tag.i(..seq_len)?;
1563            let numerator = compute_score(&crf, &emission, &tag)?;
1564
1565            #[cfg(feature = "metal")]
1566            let num_tags = crf.num_tags as u32;
1567            #[cfg(not(feature = "metal"))]
1568            let num_tags = crf.num_tags as i64;
1569
1570            let product = cartestian_product((0..num_tags).collect_vec(), seq_len, device)?;
1571
1572            let all_scores = product
1573                .iter()
1574                .map(|t| compute_score(&crf, &emission, &t).unwrap());
1575
1576            let mut denominator = numerator.zeros_like()?;
1577            for t in all_scores.into_iter() {
1578                denominator = denominator.broadcast_add(&t.exp()?)?;
1579            }
1580
1581            let denominator = denominator.log()?;
1582
1583            manual_llh = (manual_llh + (numerator - denominator)?)?;
1584            total_tokens += seq_len;
1585        }
1586
1587        #[cfg(feature = "metal")]
1588        let total_tokens =
1589            Tensor::full(total_tokens as f32, manual_llh.shape(), manual_llh.device())?
1590                .to_dtype(manual_llh.dtype())?;
1591        #[cfg(not(feature = "metal"))]
1592        let total_tokens =
1593            Tensor::full(total_tokens as f64, manual_llh.shape(), manual_llh.device())?
1594                .to_dtype(manual_llh.dtype())?;
1595
1596        let manual_llh = (manual_llh / total_tokens)?;
1597        println!("manual_llh: {:?}", manual_llh);
1598
1599        if dtype == DType::F32 {
1600            let manual_llh = Tensor::new(-1.4603_f32, &llh.device())?;
1601            println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1602            assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1603        }
1604        assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(dtype)))?;
1605        llh.backward()?;
1606        Ok(())
1607    }
1608
1609    #[test]
1610    fn test_forward_token_mean() -> Result<()> {
1611        #[cfg(any(feature = "cuda", feature = "metal"))]
1612        let device = use_gpu(true)?;
1613        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1614        let device = use_gpu(false)?;
1615
1616        for dtype in OK_TYPES {
1617            forward_token_mean(dtype, &device)?;
1618        }
1619        Ok(())
1620    }
1621
1622    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L263
1623    fn forward_batch_first(dtype: DType, device: &Device) -> Result<()> {
1624        let crf = make_crf(
1625            5,
1626            false,
1627            Some(
1628                Tensor::new(&[0.0384_f32, -0.0811, -0.0291, 0.0444, 0.0943], device)?
1629                    .to_dtype(dtype)?,
1630            ),
1631            Some(
1632                Tensor::new(&[0.0146_f32, 0.0455, 0.0991, 0.0640, -0.0298], device)?
1633                    .to_dtype(dtype)?,
1634            ),
1635            Some(
1636                Tensor::new(
1637                    &[
1638                        [0.0063_f32, 0.0014, 0.0804, -0.0385, -0.0485],
1639                        [0.0485, -0.0963, 0.0799, 0.0198, -0.0549],
1640                        [0.0016, 0.0012, -0.0411, 0.0540, -0.0823],
1641                        [0.0111, 0.0320, 0.0769, 0.0292, -0.0707],
1642                        [-0.0990, -0.0971, 0.0635, 0.0166, 0.0292],
1643                    ],
1644                    device,
1645                )?
1646                .to_dtype(dtype)?,
1647            ),
1648            device,
1649        )?;
1650
1651        let emissions = Tensor::new(
1652            &[
1653                [
1654                    [-1.1338_f32, -0.9228, 0.3260, 0.0327, -1.0345],
1655                    [0.1106, 0.8005, 0.3860, -0.1214, -1.8224],
1656                ],
1657                [
1658                    [-1.3724, -2.2578, -1.8705, -0.1109, 0.3845],
1659                    [-0.4223, 0.8414, -1.4423, -1.2734, 0.5193],
1660                ],
1661                [
1662                    [0.4189, -1.4048, -1.6877, 1.0891, 0.6978],
1663                    [-0.2521, -1.4185, -0.6026, 1.6335, 1.0366],
1664                ],
1665            ],
1666            device,
1667        )?
1668        .to_dtype(dtype)?;
1669
1670        #[cfg(feature = "metal")]
1671        let tags = Tensor::new(&[[1_u32, 4], [1, 4], [1, 4]], device)?;
1672        #[cfg(not(feature = "metal"))]
1673        let tags = Tensor::new(&[[1_i64, 4], [1, 4], [1, 4]], device)?;
1674
1675        let llh = crf.forward(&emissions, &tags, None, Reduction::default())?;
1676        println!("llh: {:?}", llh);
1677
1678        let crf_bf = make_crf(
1679            5,
1680            true,
1681            Some(crf.start_transitions),
1682            Some(crf.end_transitions),
1683            Some(crf.transitions),
1684            device,
1685        )?;
1686
1687        let emissions = emissions.transpose(0, 1).unwrap();
1688        let tags = tags.transpose(0, 1).unwrap();
1689
1690        let llh_bf = crf_bf.forward(&emissions, &tags, None, Reduction::default())?;
1691        println!("llh_bf: {:?}", llh_bf);
1692
1693        if dtype == DType::F32 {
1694            let llh_bf = Tensor::new(-14.8640_f32, &llh.device())?;
1695            println!("compare with pytorch-crf: {:?}, {:?}", llh, llh_bf);
1696            assert_tensor_close(&llh, &llh_bf, epsilon(DTypeCase::PyTorchCRF))?;
1697        }
1698        assert_tensor_close(&llh, &llh_bf, epsilon(DTypeCase::DType(dtype)))
1699    }
1700
1701    #[test]
1702    fn test_forward_batch_first() -> Result<()> {
1703        #[cfg(any(feature = "cuda", feature = "metal"))]
1704        let device = use_gpu(true)?;
1705        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1706        let device = use_gpu(false)?;
1707
1708        for dtype in OK_TYPES {
1709            forward_batch_first(dtype, &device)?;
1710        }
1711        Ok(())
1712    }
1713
1714    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L286
1715    #[test]
1716    fn test_forward_emissions_has_bad_number_of_dimension() -> Result<()> {
1717        #[cfg(any(feature = "cuda", feature = "metal"))]
1718        let device = use_gpu(true)?;
1719        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1720        let device = use_gpu(false)?;
1721
1722        let emissions = Tensor::randn(0.0_f32, 1., (1, 2), &device)?;
1723
1724        #[cfg(feature = "metal")]
1725        let tags = Tensor::zeros((2, 2), DType::U32, &device)?;
1726        #[cfg(not(feature = "metal"))]
1727        let tags = Tensor::zeros((2, 2), DType::I64, &device)?;
1728
1729        let crf = make_crf(5, false, None, None, None, &device)?;
1730        let result = crf.forward(&emissions, &tags, None, Reduction::default());
1731        println!("{:?}", result);
1732        assert!(result.is_err());
1733        assert_eq!(
1734            result.err().unwrap().to_string(),
1735            "emissions must have 3 dimensions, got 2"
1736        );
1737        Ok(())
1738    }
1739
1740    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L295C9-L295C46
1741    #[test]
1742    fn test_forward_emissions_and_tags_size_mismatch() -> Result<()> {
1743        #[cfg(any(feature = "cuda", feature = "metal"))]
1744        let device = use_gpu(true)?;
1745        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1746        let device = use_gpu(false)?;
1747
1748        let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
1749        #[cfg(feature = "metal")]
1750        let tags = Tensor::zeros((2, 2), DType::U32, &device)?;
1751        #[cfg(not(feature = "metal"))]
1752        let tags = Tensor::zeros((2, 2), DType::I64, &device)?;
1753
1754        let crf = make_crf(3, false, None, None, None, &device)?;
1755        let result = crf.forward(&emissions, &tags, None, Reduction::default());
1756        println!("{:?}", result);
1757        assert!(result.is_err());
1758        assert_eq!(
1759            result.err().unwrap().to_string(),
1760            "the first two dimensions of emissions and tags must match, got (1, 2) and (1, 2)"
1761        );
1762        Ok(())
1763    }
1764
1765    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L306C9-L306C66
1766    #[test]
1767    fn test_forward_emissions_last_dimension_not_equal_to_number_of_tags() -> Result<()> {
1768        #[cfg(any(feature = "cuda", feature = "metal"))]
1769        let device = use_gpu(true)?;
1770        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1771        let device = use_gpu(false)?;
1772
1773        let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
1774        #[cfg(feature = "metal")]
1775        let tags = Tensor::zeros((1, 2), DType::U32, &device)?;
1776        #[cfg(not(feature = "metal"))]
1777        let tags = Tensor::zeros((1, 2), DType::I64, &device)?;
1778
1779        let crf = make_crf(10, false, None, None, None, &device)?;
1780        let result = crf.forward(&emissions, &tags, None, Reduction::default());
1781        println!("{:?}", result);
1782        assert!(result.is_err());
1783        assert_eq!(
1784            result.err().unwrap().to_string(),
1785            "expected last dimension of emissions is 10, got 3"
1786        );
1787        Ok(())
1788    }
1789
1790    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L315C9-L315C47
1791    #[test]
1792    fn test_forward_first_timestep_mask_is_not_all_on() -> Result<()> {
1793        #[cfg(any(feature = "cuda", feature = "metal"))]
1794        let device = use_gpu(true)?;
1795        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1796        let device = use_gpu(false)?;
1797
1798        let emissions = Tensor::randn(0.0_f32, 1., (3, 2, 4), &device)?;
1799        #[cfg(feature = "metal")]
1800        let tags = Tensor::zeros((3, 2), DType::U32, &device)?;
1801        #[cfg(not(feature = "metal"))]
1802        let tags = Tensor::zeros((3, 2), DType::I64, &device)?;
1803
1804        let mask = Tensor::new(&[[1_u8, 1, 1], [0, 0, 0]], &device)
1805            .unwrap()
1806            .transpose(0, 1)?;
1807        let crf = make_crf(4, false, None, None, None, &device)?;
1808
1809        let result = crf.forward(&emissions, &tags, Some(&mask), Reduction::default());
1810        println!("{:?}", result);
1811        assert!(result.is_err());
1812        assert_eq!(
1813            result.err().unwrap().to_string(),
1814            "mask of the first timestep must all be on"
1815        );
1816
1817        let emissions = emissions.transpose(0, 1)?;
1818        let tags = tags.transpose(0, 1)?;
1819        let mask = mask.transpose(0, 1)?;
1820        let crf = make_crf(4, true, None, None, None, &device)?;
1821
1822        let result = crf.forward(&emissions, &tags, Some(&mask), Reduction::default());
1823        println!("{:?}", result);
1824        assert!(result.is_err());
1825        assert_eq!(
1826            result.err().unwrap().to_string(),
1827            "mask of the first timestep must all be on"
1828        );
1829        Ok(())
1830    }
1831
1832    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L345
1833    fn decode_works_with_mask(dtype: DType, device: &Device) -> Result<()> {
1834        let crf = make_crf(
1835            5,
1836            false,
1837            Some(
1838                Tensor::new(&[0.0548_f32, -0.0239, -0.0291, -0.0208, 0.0665], device)?
1839                    .to_dtype(dtype)?,
1840            ),
1841            Some(
1842                Tensor::new(&[-0.0612_f32, -0.0615, -0.0557, 0.0672, 0.0470], device)?
1843                    .to_dtype(dtype)?,
1844            ),
1845            Some(
1846                Tensor::new(
1847                    &[
1848                        [-0.0751_f32, -0.0941, 0.0248, 0.0900, -0.0776],
1849                        [0.0381, -0.0550, -0.0333, -0.0124, -0.0356],
1850                        [-0.0383, -0.0910, 0.0914, -0.0330, -0.0119],
1851                        [0.0358, 0.0513, 0.0013, -0.0380, 0.0626],
1852                        [-0.0168, 0.0871, 0.0489, 0.0019, -0.0548],
1853                    ],
1854                    device,
1855                )?
1856                .to_dtype(dtype)?,
1857            ),
1858            device,
1859        )?;
1860
1861        let emissions = Tensor::new(
1862            &[
1863                [
1864                    [1.8238_f32, 1.3041, -0.0845, 1.3981, 0.1027],
1865                    [1.1092, -0.1616, 1.9770, -1.6850, -1.4289],
1866                ],
1867                [
1868                    [0.2831, 0.0936, -1.1957, 0.2637, -0.8048],
1869                    [0.4553, -0.0393, 2.3307, -0.3505, -2.3531],
1870                ],
1871                [
1872                    [1.6232, 0.2230, 0.3585, -0.7957, -0.2464],
1873                    [-0.3805, 0.3646, -1.0142, -1.2563, -0.6568],
1874                ],
1875            ],
1876            device,
1877        )?
1878        .to_dtype(dtype)?;
1879
1880        let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
1881        let best_tags = crf.decode(&emissions, Some(&mask))?;
1882        println!("best_tags: {:?}", best_tags);
1883        assert_eq!(best_tags, vec![vec![0, 3, 0], vec![2, 2]]);
1884        Ok(())
1885    }
1886
1887    #[test]
1888    fn test_decode_works_with_mask() -> Result<()> {
1889        #[cfg(any(feature = "cuda", feature = "metal"))]
1890        let device = use_gpu(true)?;
1891        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1892        let device = use_gpu(false)?;
1893
1894        for dtype in OK_TYPES {
1895            decode_works_with_mask(dtype, &device)?;
1896        }
1897        Ok(())
1898    }
1899
1900    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L372
1901    fn decode_works_without_mask(dtype: DType, device: &Device) -> Result<()> {
1902        let crf = make_crf(
1903            5,
1904            false,
1905            Some(
1906                Tensor::new(&[0.0762_f32, 0.0743, 0.0234, -0.0387, -0.0269], device)?
1907                    .to_dtype(dtype)?,
1908            ),
1909            Some(
1910                Tensor::new(&[0.0102_f32, -0.0137, -0.0149, 0.0700, -0.0701], device)?
1911                    .to_dtype(dtype)?,
1912            ),
1913            Some(
1914                Tensor::new(
1915                    &[
1916                        [-0.0620_f32, -0.0527, 0.0034, 0.0694, -0.0853],
1917                        [0.0922, -0.0613, -0.0592, 0.0482, 0.0632],
1918                        [-0.0433, 0.0069, -0.0161, -0.0330, -0.0602],
1919                        [-0.0649, 0.0047, 0.0593, 0.0733, 0.0203],
1920                        [0.0997, 0.0007, 0.0938, 0.0427, 0.0823],
1921                    ],
1922                    device,
1923                )?
1924                .to_dtype(dtype)?,
1925            ),
1926            device,
1927        )?;
1928
1929        let emissions = Tensor::new(
1930            &[
1931                [
1932                    [0.8913_f32, -0.0355, -1.4378, 0.8390, -0.7296],
1933                    [1.5530, -1.3165, -0.5769, -0.8085, -0.2610],
1934                ],
1935                [
1936                    [-0.9622, -0.3234, -0.5353, -0.4424, -0.1456],
1937                    [-0.3844, 0.2524, 1.9393, 0.1217, -1.2519],
1938                ],
1939                [
1940                    [-0.1619, -0.2520, 1.9566, 0.4863, 1.5627],
1941                    [-0.3999, 1.4914, 1.0620, -0.6408, -0.3032],
1942                ],
1943            ],
1944            device,
1945        )?
1946        .to_dtype(dtype)?;
1947
1948        let best_tags_no_mask = crf.decode(&emissions, None)?;
1949        println!("best_tags: {:?}", best_tags_no_mask);
1950        assert_eq!(best_tags_no_mask, vec![vec![0, 4, 2], vec![0, 2, 1]]);
1951        Ok(())
1952    }
1953
1954    #[test]
1955    fn test_decode_works_without_mask() -> Result<()> {
1956        #[cfg(any(feature = "cuda", feature = "metal"))]
1957        let device = use_gpu(true)?;
1958        #[cfg(not(any(feature = "cuda", feature = "metal")))]
1959        let device = use_gpu(false)?;
1960
1961        for dtype in OK_TYPES {
1962            decode_works_without_mask(dtype, &device)?;
1963        }
1964        Ok(())
1965    }
1966
1967    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L384C9-L384C28
1968    fn decode_batched_decode(dtype: DType, device: &Device) -> Result<()> {
1969        let crf = make_crf(
1970            5,
1971            false,
1972            Some(
1973                Tensor::new(&[-0.0489_f32, 0.0460, -0.0924, -0.0722, 0.0736], device)?
1974                    .to_dtype(dtype)?,
1975            ),
1976            Some(
1977                Tensor::new(&[0.0843_f32, 0.0344, -0.0996, 0.0944, 0.0622], device)?
1978                    .to_dtype(dtype)?,
1979            ),
1980            Some(
1981                Tensor::new(
1982                    &[
1983                        [0.0780_f32, -0.0794, 0.0208, 0.0039, 0.0080],
1984                        [-0.0923, -0.0359, 0.0103, 0.0550, -0.0029],
1985                        [0.0628, -0.0787, -0.0256, 0.0554, -0.0969],
1986                        [0.0655, -0.0055, 0.0718, -0.0275, -0.0994],
1987                        [-0.0492, -0.0953, 0.0862, 0.0580, 0.0422],
1988                    ],
1989                    device,
1990                )?
1991                .to_dtype(dtype)?,
1992            ),
1993            device,
1994        )?;
1995
1996        let emissions = Tensor::new(
1997            &[
1998                [
1999                    [0.7720_f32, 0.9488, 0.6672, 1.8839, -0.6844],
2000                    [1.6192, 0.2733, 0.8063, -0.0377, -2.3208],
2001                ],
2002                [
2003                    [-0.4374, -1.4631, -0.1330, -0.2155, 1.6044],
2004                    [0.7017, -1.1525, -1.0692, 0.3463, 0.9816],
2005                ],
2006                [
2007                    [-1.3011, 0.5237, -1.1700, -0.9017, -0.5747],
2008                    [-1.0040, 0.7791, -0.3735, 0.8300, 1.5138],
2009                ],
2010            ],
2011            device,
2012        )?
2013        .to_dtype(dtype)?;
2014
2015        let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
2016
2017        let batched = crf.decode(&emissions, Some(&mask))?;
2018        println!("batched: {:?}", batched);
2019
2020        let batch_size = 2;
2021        let mut non_batched = vec![];
2022        for i in 0..batch_size {
2023            let emissions = emissions.i((.., i, ..))?.unsqueeze(1)?.contiguous()?;
2024            let mask = mask.i((.., i))?.unsqueeze(1)?.contiguous()?;
2025
2026            let result = crf.decode(&emissions, Some(&mask))?;
2027            non_batched.push(result[0].clone());
2028        }
2029        println!("non_batched: {:?}", non_batched);
2030
2031        assert_eq!(batched, non_batched);
2032        assert_eq!(batched, vec![vec![3, 4, 1], vec![0, 4]]);
2033        Ok(())
2034    }
2035
2036    #[test]
2037    fn test_decode_batched_decode() -> Result<()> {
2038        #[cfg(any(feature = "cuda", feature = "metal"))]
2039        let device = use_gpu(true)?;
2040        #[cfg(not(any(feature = "cuda", feature = "metal")))]
2041        let device = use_gpu(false)?;
2042
2043        for dtype in OK_TYPES {
2044            decode_batched_decode(dtype, &device)?;
2045        }
2046        Ok(())
2047    }
2048
2049    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L408
2050    fn decode_batch_first(dtype: DType, device: &Device) -> Result<()> {
2051        let crf = make_crf(
2052            5,
2053            false,
2054            Some(
2055                Tensor::new(&[-0.0464_f32, 0.0818, 0.0829, -0.0121, -0.0788], device)?
2056                    .to_dtype(dtype)?,
2057            ),
2058            Some(
2059                Tensor::new(&[-0.0088_f32, 0.0586, 0.0057, 0.0316, -0.0388], device)?
2060                    .to_dtype(dtype)?,
2061            ),
2062            Some(
2063                Tensor::new(
2064                    &[
2065                        [-0.0536_f32, -0.0093, 0.0276, 0.0351, 0.0604],
2066                        [0.0734, 0.0764, -0.0773, 0.0821, 0.0294],
2067                        [-0.0540, -0.0158, 0.0437, 0.0992, 0.0473],
2068                        [0.0875, 0.0324, -0.0941, 0.0585, 0.0761],
2069                        [-0.0930, -0.0832, 0.0290, 0.0974, 0.0914],
2070                    ],
2071                    device,
2072                )?
2073                .to_dtype(dtype)?,
2074            ),
2075            device,
2076        )?;
2077
2078        let emissions = Tensor::new(
2079            &[
2080                [
2081                    [-0.6633_f32, 1.4045, -1.3710, 1.5054, 0.8431],
2082                    [-0.1157, -0.0201, -0.2685, -0.6683, 0.0213],
2083                ],
2084                [
2085                    [-0.7870, -0.2497, -0.3901, 0.0181, 0.0976],
2086                    [0.4487, 0.2629, 2.2021, -0.7489, 0.1199],
2087                ],
2088                [
2089                    [0.7837, -0.0174, -0.3873, -0.4722, -0.2462],
2090                    [-0.6268, -0.9438, 0.6666, -0.6545, 1.0409],
2091                ],
2092            ],
2093            device,
2094        )?
2095        .to_dtype(dtype)?;
2096
2097        let best_tags = crf.decode(&emissions, None)?;
2098        println!("best_tags: {:?}", best_tags);
2099
2100        let crf_bf = make_crf(
2101            5,
2102            true,
2103            Some(crf.start_transitions),
2104            Some(crf.end_transitions),
2105            Some(crf.transitions),
2106            device,
2107        )?;
2108
2109        let emissions = emissions.transpose(0, 1)?;
2110        let best_tags_bf = crf_bf.decode(&emissions, None)?;
2111        println!("best_tags_bf: {:?}", best_tags_bf);
2112
2113        assert_eq!(best_tags, best_tags_bf);
2114        assert_eq!(best_tags, vec![vec![1, 3, 0], vec![1, 2, 4]]);
2115        Ok(())
2116    }
2117
2118    #[test]
2119    fn test_decode_batch_first() -> Result<()> {
2120        #[cfg(any(feature = "cuda", feature = "metal"))]
2121        let device = use_gpu(true)?;
2122        #[cfg(not(any(feature = "cuda", feature = "metal")))]
2123        let device = use_gpu(false)?;
2124
2125        for dtype in OK_TYPES {
2126            decode_batch_first(dtype, &device)?;
2127        }
2128        Ok(())
2129    }
2130
2131    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L427
2132    #[test]
2133    fn test_decode_emissions_has_bad_number_of_dimension() -> Result<()> {
2134        #[cfg(any(feature = "cuda", feature = "metal"))]
2135        let device = use_gpu(true)?;
2136        #[cfg(not(any(feature = "cuda", feature = "metal")))]
2137        let device = use_gpu(false)?;
2138
2139        let emissions = Tensor::randn(0.0_f32, 1., (1, 2), &device)?;
2140        let crf = make_crf(5, false, None, None, None, &device)?;
2141        let result = crf.decode(&emissions, None);
2142        println!("{:?}", result);
2143        assert!(result.is_err());
2144        assert_eq!(
2145            result.err().unwrap().to_string(),
2146            "emissions must have 3 dimensions, got 2"
2147        );
2148        Ok(())
2149    }
2150
2151    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L435
2152    #[test]
2153    fn test_decode_emissions_last_dimension_not_equal_to_number_of_tags() -> Result<()> {
2154        #[cfg(any(feature = "cuda", feature = "metal"))]
2155        let device = use_gpu(true)?;
2156        #[cfg(not(any(feature = "cuda", feature = "metal")))]
2157        let device = use_gpu(false)?;
2158
2159        let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
2160        let crf = make_crf(10, false, None, None, None, &device)?;
2161        let result = crf.decode(&emissions, None);
2162        println!("{:?}", result);
2163        assert!(result.is_err());
2164        assert_eq!(
2165            result.err().unwrap().to_string(),
2166            "expected last dimension of emissions is 10, got 3"
2167        );
2168        Ok(())
2169    }
2170
2171    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L443
2172    #[test]
2173    fn test_decode_emissions_and_mask_size_mismatch() -> Result<()> {
2174        #[cfg(any(feature = "cuda", feature = "metal"))]
2175        let device = use_gpu(true)?;
2176        #[cfg(not(any(feature = "cuda", feature = "metal")))]
2177        let device = use_gpu(false)?;
2178
2179        let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
2180        let mask = Tensor::new(&[[1_u8, 1], [1, 0]], &device)?;
2181        let crf = make_crf(3, false, None, None, None, &device)?;
2182        let result = crf.decode(&emissions, Some(&mask));
2183        println!("{:?}", result);
2184        assert!(result.is_err());
2185        assert_eq!(
2186            result.err().unwrap().to_string(),
2187            "the first two dimensions of emissions and mask must match, got (1, 2) and (2, 2)"
2188        );
2189        Ok(())
2190    }
2191
2192    /// https://github.com/kmkurn/pytorch-crf/blob/8f3203a1f1d7984c87718bfe31853242670258db/tests/test_crf.py#L454
2193    #[test]
2194    fn test_decode_first_timestep_mask_is_not_all_on() -> Result<()> {
2195        #[cfg(any(feature = "cuda", feature = "metal"))]
2196        let device = use_gpu(true)?;
2197        #[cfg(not(any(feature = "cuda", feature = "metal")))]
2198        let device = use_gpu(false)?;
2199
2200        let emissions = Tensor::randn(0.0_f32, 1., (3, 2, 4), &device)?;
2201        let mask = Tensor::new(&[[1_u8, 1, 1], [0, 0, 0]], &device)?.transpose(0, 1)?;
2202
2203        let crf = make_crf(4, false, None, None, None, &device)?;
2204        let result = crf.decode(&emissions, Some(&mask));
2205        println!("{:?}", result);
2206        assert!(result.is_err());
2207        assert_eq!(
2208            result.err().unwrap().to_string(),
2209            "mask of the first timestep must all be on"
2210        );
2211
2212        let emissions = emissions.transpose(0, 1)?;
2213        let mask = mask.transpose(0, 1)?;
2214        let crf = make_crf(4, true, None, None, None, &device)?;
2215        let result = crf.decode(&emissions, Some(&mask));
2216        println!("{:?}", result);
2217        assert!(result.is_err());
2218        assert_eq!(
2219            result.err().unwrap().to_string(),
2220            "mask of the first timestep must all be on"
2221        );
2222        Ok(())
2223    }
2224}