1use candle_core::{shape::Dim, DType, Device, Error, IndexOp, Result, Tensor, D};
2use candle_nn::{Init, VarBuilder};
3use std::fmt::Display;
4
5#[derive(Debug)]
7pub enum Reduction {
8 None,
9 Sum,
10 Mean,
11 TokenMean,
12}
13
14impl Default for Reduction {
15 fn default() -> Self {
16 Reduction::Sum
17 }
18}
19
20pub fn crf(num_tags: usize, batch_first: bool, vb: VarBuilder) -> Result<CRF> {
21 let start_transitions = vb.get_with_hints(
22 num_tags,
23 "start_transitions",
24 Init::Uniform {
25 lo: -0.1_f64,
26 up: 1.0_f64,
27 },
28 )?;
29
30 let end_transitions = vb.get_with_hints(
31 num_tags,
32 "end_transitions",
33 Init::Uniform {
34 lo: -0.1_f64,
35 up: 1.0_f64,
36 },
37 )?;
38
39 let transitions = vb.get_with_hints(
40 (num_tags, num_tags),
41 "transitions",
42 Init::Uniform {
43 lo: -0.1_f64,
44 up: 1.0_f64,
45 },
46 )?;
47
48 Ok(CRF {
49 num_tags,
50 batch_first,
51 start_transitions,
52 end_transitions,
53 transitions,
54 })
55}
56
57pub struct CRF {
60 pub(crate) num_tags: usize,
61 pub(crate) batch_first: bool,
62
63 pub(crate) start_transitions: Tensor,
64 pub(crate) end_transitions: Tensor,
65 pub(crate) transitions: Tensor,
66}
67
68impl Display for CRF {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(
73 f,
74 "CRF(num_tags: {}, batch_first: {})",
75 self.num_tags, self.batch_first
76 )
77 }
78}
79
80impl CRF {
81 pub fn new(num_tags: usize, batch_first: bool, device: &Device) -> Result<Self> {
84 Self::new_with_dtype(num_tags, batch_first, DType::F32, device)
85 }
86
87 pub fn new_with_dtype(
88 num_tags: usize,
89 batch_first: bool,
90 dtype: DType,
91 device: &Device,
92 ) -> Result<Self> {
93 {
94 use DType::*;
95 match dtype {
96 #[cfg(feature = "metal")]
97 F32 => {}
98 #[cfg(feature = "cuda")]
99 F32 | F64 => {}
100 #[cfg(not(any(feature = "cuda", feature = "metal")))]
101 BF16 | F16 | F32 | F64 => {}
102 _ => return Err(Error::UnsupportedDTypeForOp(dtype, "unsupported dtype")),
103 }
104 }
105
106 if num_tags == 0 {
107 return Err(Error::Msg("num_tags must be greater than 0".to_string()));
108 }
109
110 let start_transitions = Tensor::rand(-0.1_f32, 1.0, num_tags, device)?.to_dtype(dtype)?;
111 let end_transitions = Tensor::rand(-0.1_f32, 1.0, num_tags, device)?.to_dtype(dtype)?;
112 let transitions =
113 Tensor::rand(-0.1_f32, 1.0, (num_tags, num_tags), device)?.to_dtype(dtype)?;
114
115 Ok(Self {
116 num_tags,
117 batch_first,
118 start_transitions,
119 end_transitions,
120 transitions,
121 })
122 }
123
124 pub fn load(num_tags: usize, batch_first: bool, vb: VarBuilder) -> Result<Self> {
125 crf(num_tags, batch_first, vb)
126 }
127
128 fn validate(
131 &self,
132 emissions: &Tensor,
133 tags: Option<&Tensor>,
134 mask: Option<&Tensor>,
135 ) -> Result<()> {
136 {
137 let dtype_transitions = self.transitions.dtype();
138 let dtype_emissions = emissions.dtype();
139 if dtype_transitions != dtype_emissions {
140 return Err(Error::Msg(format!(
141 "emissions and CRF must have the same dtype, expected {:?}, got {:?}",
142 dtype_transitions, dtype_emissions
143 )));
144 }
145 }
146
147 {
148 let dims = emissions.dims().len();
150 if dims != 3 {
151 return Err(Error::Msg(format!(
152 "emissions must have 3 dimensions, got {}",
153 dims
154 )));
155 }
156 }
157
158 let (d1, d2, d3) = emissions.dims3()?;
159
160 if d3 != self.num_tags {
161 return Err(Error::Msg(format!(
163 "expected last dimension of emissions is {}, got {}",
164 self.num_tags, d3
165 )));
166 }
167
168 if let Some(tags) = tags {
169 #[cfg(feature = "metal")]
170 if tags.dtype() != DType::U32 {
171 return Err(Error::Msg("tags must be of type u32".to_string()));
172 }
173
174 #[cfg(not(feature = "metal"))]
175 if tags.dtype() != DType::I64 {
176 return Err(Error::Msg("tags must be of type i64".to_string()));
177 }
178
179 if tags.dims().len() != 2 {
180 return Err(Error::Msg(format!(
182 "tags must have 2 dimensions, got {}",
183 tags.dims().len()
184 )));
185 }
186
187 let (tag_d1, tag_d2) = tags.dims2()?;
188 if (d1, d2) != (tag_d1, tag_d2) {
189 return Err(Error::Msg(format!(
190 "the first two dimensions of emissions and tags must match, got ({}, {}) and ({}, {})",
191 d1, d2, d1, d2
192 )));
193 }
194 }
195
196 if let Some(mask) = mask {
197 if mask.dtype() != DType::U8 {
198 return Err(Error::Msg("mask must be of type u8".to_string()));
199 }
200
201 if mask.dims().len() != 2 {
202 return Err(Error::Msg(format!(
204 "mask must have 2 dimensions, got {}",
205 mask.dims().len()
206 )));
207 }
208
209 let (mask_d1, mask_d2) = mask.dims2()?;
210 if (d1, d2) != (mask_d1, mask_d2) {
211 return Err(Error::Msg(format!(
212 "the first two dimensions of emissions and mask must match, got ({}, {}) and ({}, {})",
213 d1, d2, mask_d1, mask_d2
214 )));
215 }
216
217 let no_empty_seq = !self.batch_first && all(&mask.i(0)?)?;
218 let no_empty_seq_bf = self.batch_first && all(&mask.i((.., 0))?)?;
219
220 if !no_empty_seq && !no_empty_seq_bf {
221 return Err(Error::Msg(
222 "mask of the first timestep must all be on".to_string(),
223 ));
224 }
225 }
226
227 Ok(())
228 }
229
230 fn compute_score(&self, emissions: &Tensor, tags: &Tensor, mask: &Tensor) -> Result<Tensor> {
233 let (d1, d2, d3) = emissions.dims3()?;
234 let (seq_length, batch_size) = tags.dims2()?;
235 assert_eq!(d1, seq_length);
236 assert_eq!(d2, batch_size);
237 assert_eq!(d3, self.num_tags);
238 assert_eq!(mask.shape(), tags.shape());
239 assert!(all(&mask.i(0)?)?);
240
241 let mask = mask.to_dtype(emissions.dtype())?;
242
243 let mut score = self.start_transitions.i(&tags.i(0)?)?;
244
245 let z = gather(&emissions.i((0, 0..batch_size))?, &tags.i(0)?)?;
246
247 score = score.broadcast_add(&z)?;
248
249 for i in 1..seq_length {
250 let z = gather(&self.transitions.i(&tags.i(i - 1)?)?, &tags.i(i)?)?;
251 score = score.broadcast_add(&z.broadcast_mul(&mask.i(i)?)?)?;
252
253 let z = gather(&emissions.i((i, 0..batch_size))?, &tags.i(i)?)?;
254 score = score.broadcast_add(&z.broadcast_mul(&mask.i(i)?)?)?;
255 }
256
257 let seq_ends = mask
258 .to_dtype(DType::I64)?
259 .sum(0)?
260 .broadcast_sub(&Tensor::ones(1, DType::I64, mask.device())?)?;
261
262 #[cfg(feature = "metal")]
263 let last_tags = {
264 let tags2 = tags.to_dtype(DType::F32)?;
265 gather(
266 &tags2.i(&seq_ends)?,
267 &Tensor::arange(0, batch_size as u32, mask.device())?,
268 )?
269 };
270
271 #[cfg(not(feature = "metal"))]
272 let last_tags = gather(
273 &tags.i(&seq_ends)?,
274 &Tensor::arange(0, batch_size as i64, mask.device())?,
275 )?;
276
277 #[cfg(feature = "metal")]
278 let last_tags = last_tags.to_dtype(DType::U32)?;
279
280 score.broadcast_add(&self.end_transitions.i(&last_tags)?)
281 }
282
283 fn compute_normalizer(&self, emissions: &Tensor, mask: &Tensor) -> Result<Tensor> {
286 let (d1, d2, d3) = emissions.dims3()?;
287 let (seq_length, batch_size) = mask.dims2()?;
288 assert_eq!(d1, seq_length);
289 assert_eq!(d2, batch_size);
290 assert_eq!(d3, self.num_tags);
291 assert!(all(&mask.i(0)?)?);
292
293 let mut score = self.start_transitions.broadcast_add(&emissions.i(0)?)?;
294
295 for i in 1..seq_length {
296 let broadcast_score = score.unsqueeze(2)?;
297
298 let broadcast_emissions = emissions.i(i)?.unsqueeze(1)?;
299 let next_score = broadcast_score
300 .broadcast_add(&self.transitions)?
301 .broadcast_add(&broadcast_emissions)?;
302
303 let next_score = next_score.log_sum_exp(1)?;
304 let z = mask.i(i)?.unsqueeze(1)?.broadcast_as(next_score.shape())?;
305 score = z.where_cond(&next_score, &score)?;
306 }
307
308 score = score.broadcast_add(&self.end_transitions)?;
309 score.log_sum_exp(1)
310 }
311
312 fn viterbi_decode(&self, emissions: &Tensor, mask: &Tensor) -> Result<Vec<Vec<u32>>> {
315 let (d1, d2, d3) = emissions.dims3()?;
316 let (seq_length, batch_size) = mask.dims2()?;
317 assert_eq!(d1, seq_length);
318 assert_eq!(d2, batch_size);
319 assert_eq!(d3, self.num_tags);
320 assert!(all(&mask.i(0)?)?);
321
322 let mut score = self.start_transitions.broadcast_add(&emissions.i(0)?)?;
323
324 let mut history = Vec::with_capacity(seq_length);
325 for i in 1..seq_length {
326 let broadcast_sore = score.unsqueeze(2)?;
327
328 let broadcast_emission = emissions.i(i)?.unsqueeze(1)?;
329
330 let next_score = broadcast_sore
331 .broadcast_add(&self.transitions)?
332 .broadcast_add(&broadcast_emission)?;
333
334 let (next_score, indices) = max_indices(&next_score, 1)?;
335
336 let z = mask.i(i)?.unsqueeze(1)?.broadcast_as(next_score.shape())?;
337 score = z.where_cond(&next_score, &score)?;
338 history.push(indices);
339 }
340
341 score = score.broadcast_add(&self.end_transitions)?;
342
343 let seq_ends = mask
344 .to_dtype(DType::I64)?
345 .sum(0)?
346 .broadcast_sub(&Tensor::ones(1, DType::I64, mask.device())?)?;
347
348 let mut best_tags_list = vec![];
349
350 for idx in 0..batch_size {
351 let best_last_tag = score.i(idx)?.argmax(0)?;
352
353 let mut best_tags = vec![best_last_tag.to_scalar::<u32>()?];
354
355 let z = seq_ends.i(idx)?.to_scalar::<i64>()? as usize;
356 let mut a = history[..z].to_vec();
357 a.reverse();
358 for hist in a.iter() {
359 let last_idx = *best_tags.last().unwrap() as usize;
360 let best_last_tag = hist.i(idx)?.i(last_idx)?;
361 best_tags.push(best_last_tag.to_scalar::<u32>()?);
362 }
363
364 best_tags.reverse();
365 best_tags_list.push(best_tags);
366 }
367
368 Ok(best_tags_list)
369 }
370
371 pub fn decode(&self, emissions: &Tensor, mask: Option<&Tensor>) -> Result<Vec<Vec<u32>>> {
374 self.validate(emissions, None, mask)?;
375 let mask = if let Some(mask) = mask {
376 mask.clone()
377 } else {
378 let (d1, d2, _) = emissions.dims3()?;
379 Tensor::ones((d1, d2), DType::U8, emissions.device())?
380 };
381
382 let (emissions, mask) = if self.batch_first {
383 (emissions.transpose(0, 1)?, mask.transpose(0, 1)?)
384 } else {
385 (emissions.clone(), mask.clone())
386 };
387 self.viterbi_decode(&emissions, &mask)
388 }
389
390 pub fn forward(
393 &self,
394 emissions: &Tensor,
395 tags: &Tensor,
396 mask: Option<&Tensor>,
397 reduction: Reduction,
398 ) -> Result<Tensor> {
399 self.validate(emissions, Some(tags), mask)?;
400 let mask = if let Some(mask) = mask {
401 mask.clone()
402 } else {
403 Tensor::ones_like(tags)?.to_dtype(DType::U8)?
404 };
405
406 let (emissions, tags, mask) = if self.batch_first {
407 (
408 emissions.transpose(0, 1)?,
409 tags.transpose(0, 1)?,
410 mask.transpose(0, 1)?,
411 )
412 } else {
413 (emissions.clone(), tags.clone(), mask.clone())
414 };
415
416 let numerator = self.compute_score(&emissions, &tags, &mask)?;
417 let denominator = self.compute_normalizer(&emissions, &mask)?;
418
419 let llh = numerator.broadcast_sub(&denominator)?;
420
421 match reduction {
422 Reduction::Sum => llh.sum_all(),
423 Reduction::Mean => llh.mean_all(),
424 Reduction::TokenMean => {
425 let mask = mask.to_dtype(llh.dtype())?;
426 let z = mask.sum_all()?;
427 llh.sum_all()?.broadcast_div(&z)
428 }
429 Reduction::None => Ok(llh),
430 }
431 }
432}
433
434pub(crate) fn all(x: &Tensor) -> Result<bool> {
437 let zero = x.zeros_like()?;
438 Ok(x.broadcast_ne(&zero)?.min_all()?.to_scalar::<u8>()? != 0)
439}
440
441pub(crate) fn gather(src: &Tensor, idx: &Tensor) -> Result<Tensor> {
444 let index = idx.reshape((idx.dim(0)?, 1))?;
445 src.gather(&index, D::Minus1)?.squeeze(D::Minus1)
446}
447
448pub(crate) fn max_indices<D: Dim + Copy>(x: &Tensor, dim: D) -> Result<(Tensor, Tensor)> {
451 let max = x.max(dim)?;
452 let idx = x.argmax(dim)?;
453 Ok((max, idx))
454}
455
456#[cfg(test)]
459mod tests {
460
461 use super::*;
472 use anyhow::Result;
473 use candle_core::{utils, DType, Device, IndexOp, Tensor};
474 use itertools::Itertools;
475
476 #[cfg(feature = "metal")]
477 const OK_TYPES: [DType; 1] = [DType::F32];
478 #[cfg(feature = "cuda")]
479 const OK_TYPES: [DType; 2] = [DType::F32, DType::F64];
480 #[cfg(not(any(feature = "cuda", feature = "metal")))]
481 const OK_TYPES: [DType; 4] = [DType::F32, DType::F64, DType::F16, DType::BF16];
482
483 #[cfg(feature = "metal")]
484 const FAIL_TYPES: [DType; 6] = [
485 DType::U8,
486 DType::U32,
487 DType::I64,
488 DType::F16,
489 DType::BF16,
490 DType::F64,
491 ];
492 #[cfg(feature = "cuda")]
493 const FAIL_TYPES: [DType; 5] = [DType::U8, DType::U32, DType::I64, DType::F16, DType::BF16];
494 #[cfg(not(any(feature = "cuda", feature = "metal")))]
495 const FAIL_TYPES: [DType; 3] = [DType::U8, DType::U32, DType::I64];
496
497 enum DTypeCase {
498 DType(DType),
499 PyTorchCRF,
500 }
501
502 #[cfg(feature = "metal")]
503 fn epsilon(type_case: DTypeCase) -> f32 {
504 match type_case {
505 DTypeCase::DType(dtype) => match dtype {
506 DType::F64 => 1e-6,
507 DType::F32 => 1e-4,
508 DType::F16 => 1e-2,
509 DType::BF16 => 1e-1,
510 _ => panic!("dtype not supported"),
511 },
512 DTypeCase::PyTorchCRF => 1e-3,
513 }
514 }
515
516 #[cfg(not(feature = "metal"))]
517 fn epsilon(type_case: DTypeCase) -> f64 {
518 match type_case {
519 DTypeCase::DType(dtype) => match dtype {
520 DType::F64 => 1e-6,
521 DType::F32 => 1e-4,
522 DType::F16 => 1e-2,
523 DType::BF16 => 1e-1,
524 _ => panic!("dtype not supported"),
525 },
526 DTypeCase::PyTorchCRF => 1e-3,
527 }
528 }
529
530 #[cfg(feature = "metal")]
531 fn assert_tensor_close(a: &Tensor, b: &Tensor, epsilon: f32) -> Result<()> {
532 assert!(a.dtype() == b.dtype());
533 assert_eq!(a.shape(), b.shape());
534 let epsilon = Tensor::full(epsilon, a.shape(), a.device())?.to_dtype(a.dtype())?;
535 let diff = a.broadcast_sub(b)?.abs()?;
536 let result = all(&diff.broadcast_le(&epsilon)?)?;
537 assert!(result);
538 Ok(())
539 }
540
541 #[cfg(not(feature = "metal"))]
542 fn assert_tensor_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
543 assert!(a.dtype() == b.dtype());
544 assert_eq!(a.shape(), b.shape());
545 let epsilon = Tensor::full(epsilon, a.shape(), a.device())?.to_dtype(a.dtype())?;
546 let diff = a.broadcast_sub(b)?.abs()?;
547 let result = all(&diff.broadcast_le(&epsilon)?)?;
548 assert!(result);
549 Ok(())
550 }
551
552 #[cfg(feature = "metal")]
553 fn cartestian_product(r: Vec<u32>, repeat: usize, dev: &Device) -> Result<Vec<Tensor>> {
554 use itertools::Itertools;
555
556 if repeat <= 1 {
557 return Ok(vec![Tensor::new(r.as_slice(), dev)?]);
558 }
559
560 let mut a: Vec<Vec<u32>> = r
561 .iter()
562 .cartesian_product(r.iter())
563 .map(|(&x, &y)| vec![x, y])
564 .collect();
565 for _ in 2..repeat {
566 a = a
567 .iter()
568 .cartesian_product(r.iter())
569 .map(|(x, &y)| {
570 let mut z = Vec::from(x.to_owned());
571 z.push(y);
572 z
573 })
574 .collect();
575 }
576 Ok(a.iter()
577 .map(|x| Tensor::new(x.as_slice(), dev).unwrap())
578 .collect())
579 }
580
581 #[cfg(not(feature = "metal"))]
582 fn cartestian_product(r: Vec<i64>, repeat: usize, dev: &Device) -> Result<Vec<Tensor>> {
583 use itertools::Itertools;
584
585 if repeat <= 1 {
586 return Ok(vec![Tensor::new(r.as_slice(), dev)?]);
587 }
588
589 let mut a: Vec<Vec<i64>> = r
590 .iter()
591 .cartesian_product(r.iter())
592 .map(|(&x, &y)| vec![x, y])
593 .collect();
594 for _ in 2..repeat {
595 a = a
596 .iter()
597 .cartesian_product(r.iter())
598 .map(|(x, &y)| {
599 let mut z = Vec::from(x.to_owned());
600 z.push(y);
601 z
602 })
603 .collect();
604 }
605 Ok(a.iter()
606 .map(|x| Tensor::new(x.as_slice(), dev).unwrap())
607 .collect())
608 }
609
610 fn cat_scalar_tensor(tensors: Vec<Tensor>) -> candle_core::Result<Tensor> {
611 let tensors: Vec<Tensor> = tensors
612 .into_iter()
613 .map(|t| t.unsqueeze(0).unwrap())
614 .collect();
615 Tensor::cat(&tensors, 0)
616 }
617
618 fn use_gpu(gpu: bool) -> candle_core::Result<Device> {
619 if gpu {
620 if utils::cuda_is_available() {
621 println!("CUDA is available");
622 Device::new_cuda(0)
623 } else if utils::metal_is_available() {
624 println!("Metal is available");
625 Device::new_metal(0)
626 } else {
627 println!("CUDA and Metal are not available, using CPU");
628 Ok(Device::Cpu)
629 }
630 } else {
631 println!("Using CPU");
632 Ok(Device::Cpu)
633 }
634 }
635
636 #[test]
637 fn test_cat_scalar_tensor() -> Result<()> {
638 #[cfg(any(feature = "cuda", feature = "metal"))]
639 let device = use_gpu(true)?;
640 #[cfg(not(any(feature = "cuda", feature = "metal")))]
641 let device = use_gpu(false)?;
642
643 let mut lst = vec![];
644 for i in 0..10 {
645 let x = Tensor::full(i as f32, (), &device)?;
646 lst.push(x);
647 }
648
649 let result = cat_scalar_tensor(lst)?;
650 assert_eq!(
651 result.to_vec1::<f32>().unwrap(),
652 vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]
653 );
654 Ok(())
655 }
656
657 fn make_crf(
659 num_tags: usize,
660 batch_first: bool,
661 start: Option<Tensor>,
662 end: Option<Tensor>,
663 transition: Option<Tensor>,
664 device: &Device,
665 ) -> candle_core::Result<CRF> {
666 let mut crf = CRF::new(num_tags, batch_first, device)?;
667 if let Some(start) = start {
668 crf.start_transitions = start
669 }
670 if let Some(end) = end {
671 crf.end_transitions = end
672 }
673 if let Some(transition) = transition {
674 crf.transitions = transition
675 }
676 Ok(crf)
677 }
678
679 fn assert_all<D: candle_core::WithDType>(x: &Tensor, lo: D, up: D) -> Result<bool> {
681 assert_eq!(x.dims().len(), 1);
682 let dim = x.dims1()?;
683 for i in 0..dim {
684 let a = x.i(i)?.to_scalar::<D>()?;
685 if a < lo || a > up {
686 return Ok(false);
687 }
688 }
689 Ok(true)
690 }
691
692 fn compute_score(crf: &CRF, emission: &Tensor, tag: &Tensor) -> Result<Tensor> {
694 assert_eq!(emission.dims().len(), 2);
695 let (emission_dim1, emission_dim2) = emission.dims2()?;
696 let tag_dim1 = tag.dims1()?;
697 assert_eq!(emission_dim1, tag_dim1);
698 assert_eq!(emission_dim2, crf.num_tags);
699
700 #[cfg(feature = "metal")]
701 assert_all(tag, 0_u32, crf.num_tags as u32 - 1)?;
702 #[cfg(not(feature = "metal"))]
703 assert_all(tag, 0_i64, crf.num_tags as i64 - 1)?;
704
705 #[cfg(feature = "metal")]
706 let tag_vec = tag.to_vec1::<u32>()?;
707 #[cfg(not(feature = "metal"))]
708 let tag_vec = tag.to_vec1::<i64>()?;
709
710 let mut score = crf
711 .start_transitions
712 .i(tag_vec[0] as usize)?
713 .broadcast_add(&crf.end_transitions.i(tag_vec[tag_vec.len() - 1] as usize)?)?;
714
715 for (cur_tag, next_tag) in tag_vec.iter().zip(tag_vec.iter().skip(1)) {
716 let z = crf.transitions.i((*cur_tag as usize, *next_tag as usize))?;
717 score = score.broadcast_add(&z)?;
718 }
719
720 for (i, &t) in tag_vec.iter().enumerate() {
721 let z = emission.i((i, t as usize))?;
722 score = score.broadcast_add(&z)?;
723 }
724
725 Ok(score)
726 }
727
728 #[test]
729 fn test_init_with_dtype() -> Result<()> {
730 #[cfg(any(feature = "cuda", feature = "metal"))]
731 let device = use_gpu(true)?;
732 #[cfg(not(any(feature = "cuda", feature = "metal")))]
733 let device = use_gpu(false)?;
734
735 for dtype in OK_TYPES {
736 assert!(CRF::new_with_dtype(10, false, dtype, &device).is_ok());
737 }
738
739 for dtype in FAIL_TYPES {
740 assert!(CRF::new_with_dtype(10, false, dtype, &device).is_err());
741 }
742 Ok(())
743 }
744
745 #[test]
747 fn test_init_minial() -> Result<()> {
748 #[cfg(any(feature = "cuda", feature = "metal"))]
749 let device = use_gpu(true)?;
750 #[cfg(not(any(feature = "cuda", feature = "metal")))]
751 let device = use_gpu(false)?;
752
753 let num_tags = 10;
754 let crf = CRF::new(num_tags, false, &device)?;
755 assert_eq!(crf.num_tags, num_tags);
756 assert!(!crf.batch_first);
757 assert_eq!(crf.start_transitions.dims1()?, num_tags);
758 assert_eq!(crf.end_transitions.dims1()?, num_tags);
759 assert_eq!(crf.transitions.dims2()?, (num_tags, num_tags));
760 println!("crf:{}", crf);
761 Ok(())
762 }
763
764 #[test]
766 fn test_init_full() -> Result<()> {
767 #[cfg(any(feature = "cuda", feature = "metal"))]
768 let device = use_gpu(true)?;
769 #[cfg(not(any(feature = "cuda", feature = "metal")))]
770 let device = use_gpu(false)?;
771
772 let crf = CRF::new(10, true, &device)?;
773 assert!(crf.batch_first);
774 Ok(())
775 }
776
777 #[test]
779 fn test_init_nonpositive_num_tags() -> Result<()> {
780 #[cfg(any(feature = "cuda", feature = "metal"))]
781 let device = use_gpu(true)?;
782 #[cfg(not(any(feature = "cuda", feature = "metal")))]
783 let device = use_gpu(false)?;
784
785 let crf = CRF::new(0, false, &device);
786 assert!(crf.is_err());
787
788 Ok(())
789 }
790
791 fn forward_works_with_mask(dtype: DType, device: &Device) -> Result<()> {
793 let crf = make_crf(
794 5,
795 false,
796 Some(
797 Tensor::new(&[-0.0687_f32, 0.0698, -0.0447, 0.0421, 0.0782], device)?
798 .to_dtype(dtype)?,
799 ),
800 Some(
801 Tensor::new(&[0.0061_f32, -0.0671, -0.0797, 0.0629, -0.0136], device)?
802 .to_dtype(dtype)?,
803 ),
804 Some(
805 Tensor::new(
806 &[
807 [0.0489_f32, -0.0002, 0.0619, 0.0458, 0.0662],
808 [0.0707, 0.0297, -0.0422, 0.0831, -0.0038],
809 [0.0439, 0.0178, -0.0754, 0.0260, 0.0681],
810 [0.0191, 0.0755, 0.0230, 0.0209, -0.0768],
811 [0.0303, 0.0592, -0.0297, 0.0681, 0.0801],
812 ],
813 device,
814 )?
815 .to_dtype(dtype)?,
816 ),
817 device,
818 )?;
819
820 let emissions = Tensor::new(
821 &[
822 [
823 [1.1699_f32, 1.1900, -0.7254, 0.1490, -1.4910],
824 [-1.2101, 0.4538, 1.3654, 0.0135, -1.8480],
825 ],
826 [
827 [0.5861, -0.1651, 0.9721, 0.4464, -0.5512],
828 [-1.2701, -1.5360, 0.0037, 0.5853, -0.9926],
829 ],
830 [
831 [-1.7625, 0.5437, 1.6322, -1.1274, -0.1313],
832 [-0.9301, 0.8906, -2.6483, 0.5849, -1.1069],
833 ],
834 ],
835 device,
836 )?
837 .to_dtype(dtype)?;
838
839 #[cfg(feature = "metal")]
840 let tags = Tensor::new(&[[2_u32, 4], [3, 3], [4, 2]], device)?;
841 #[cfg(not(feature = "metal"))]
842 let tags = Tensor::new(&[[2_i64, 4], [3, 3], [4, 2]], device)?;
843 let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
844 let llh = crf.forward(&emissions, &tags, Some(&mask), Reduction::default())?;
845 println!("llh: {:?}", llh);
846
847 let emissions = emissions.transpose(0, 1)?;
848 let tags = tags.transpose(0, 1)?;
849 let mask = mask.transpose(0, 1)?;
850
851 let (a, _, _) = emissions.dims3()?;
852 let mut manual_llh = Tensor::zeros(llh.shape(), llh.dtype(), llh.device())?;
853
854 for i in 0..a {
855 let emission = emissions.i(i)?;
856 let tag = tags.i(i)?;
857 let mask = mask.i(i)?;
858
859 let seq_len = mask.sum_all().unwrap().to_scalar::<u8>()? as usize;
860 let emission = emission.i(..seq_len)?;
861 let tag = tag.i(..seq_len)?;
862 let numerator = compute_score(&crf, &emission, &tag)?;
863
864 #[cfg(feature = "metal")]
865 let num_tags = crf.num_tags as u32;
866 #[cfg(not(feature = "metal"))]
867 let num_tags = crf.num_tags as i64;
868
869 let product = cartestian_product((0..num_tags).collect_vec(), seq_len, device)?;
870 let all_scores = product
871 .iter()
872 .map(|t| compute_score(&crf, &emission, &t).unwrap());
873
874 let mut denominator =
875 Tensor::zeros(numerator.shape(), numerator.dtype(), numerator.device())?;
876
877 for s in all_scores.into_iter() {
878 denominator = denominator.broadcast_add(&s.exp()?)?;
879 }
880 let denominator = denominator.log()?;
881
882 manual_llh = manual_llh.broadcast_add(&numerator.broadcast_sub(&denominator)?)?;
883 }
884 println!("manual_llh: {:?}", manual_llh);
885 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(llh.dtype())))?;
886
887 if llh.dtype() == DType::F32 {
888 let manual_llh = Tensor::full(-11.0540_f32, llh.shape(), llh.device())?;
889 println!("Compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
890 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
891 }
892 llh.backward()?;
893 Ok(())
894 }
895
896 #[test]
897 fn test_forward_works_with_mask() -> Result<()> {
898 #[cfg(any(feature = "cuda", feature = "metal"))]
899 let device = use_gpu(true)?;
900 #[cfg(not(any(feature = "cuda", feature = "metal")))]
901 let device = use_gpu(false)?;
902
903 for dtype in OK_TYPES {
904 forward_works_with_mask(dtype, &device)?;
905 }
906 Ok(())
907 }
908
909 fn forward_works_without_mask(dtype: DType, device: &Device) -> Result<()> {
911 let crf = make_crf(
912 5,
913 false,
914 Some(
915 Tensor::new(&[0.0266_f32, -0.0539, 0.0572, -0.0199, -0.0167], device)?
916 .to_dtype(dtype)?,
917 ),
918 Some(
919 Tensor::new(&[0.0084_f32, 0.0892, 0.0942, -0.0179, 0.0112], device)?
920 .to_dtype(dtype)?,
921 ),
922 Some(
923 Tensor::new(
924 &[
925 [0.0456_f32, 0.0560, 0.0396, 0.0289, 0.0187],
926 [-0.0951, -0.0286, 0.0582, 0.0384, 0.0863],
927 [-0.0137, 0.0764, -0.0414, 0.0722, -0.0287],
928 [0.0365, -0.0033, 0.0726, -0.0620, 0.0516],
929 [0.0925, -0.0708, 0.0765, 0.0671, -0.0344],
930 ],
931 device,
932 )?
933 .to_dtype(dtype)?,
934 ),
935 device,
936 )?;
937
938 let emissions = Tensor::new(
939 &[
940 [
941 [0.5463_f32, 2.0856, -0.6247, -1.0225, 0.5277],
942 [-0.4172, -1.4281, -0.5658, -0.5217, -0.6321],
943 ],
944 [
945 [0.4759, -0.8485, 1.0046, 0.0720, 0.3853],
946 [-0.7525, 0.1041, 0.2371, 0.5746, -0.5599],
947 ],
948 [
949 [-0.5022, -0.2030, 0.3655, 0.0714, 1.2449],
950 [0.1266, 0.6654, -1.1915, -0.1181, 0.0167],
951 ],
952 ],
953 device,
954 )?
955 .to_dtype(dtype)?;
956
957 #[cfg(feature = "metal")]
958 let tags = Tensor::new(&[[3_u32, 2], [3, 1], [4, 3]], device)?;
959 #[cfg(not(feature = "metal"))]
960 let tags = Tensor::new(&[[3_i64, 2], [3, 1], [4, 3]], device)?;
961
962 let llh_no_mask = crf.forward(&emissions, &tags, None, Reduction::default())?;
963
964 let llh_mask = crf.forward(
965 &emissions,
966 &tags,
967 Some(&Tensor::ones_like(&tags)?.to_dtype(DType::U8)?),
968 Reduction::default(),
969 )?;
970
971 println!("llh_no_mask: {:?}", llh_no_mask);
972 println!("llh_mask: {:?}", llh_mask);
973
974 if llh_no_mask.dtype() == DType::F32 {
975 let manual_llh = Tensor::full(-11.0571_f32, llh_no_mask.shape(), llh_no_mask.device())?;
976 println!(
977 "compare with pytorch-crf: {:?}, {:?}",
978 llh_no_mask, manual_llh
979 );
980 assert_tensor_close(&llh_no_mask, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
981 }
982
983 assert_tensor_close(
984 &llh_no_mask,
985 &llh_mask,
986 epsilon(DTypeCase::DType(llh_no_mask.dtype())),
987 )
988 }
989
990 #[test]
991 fn test_forward_works_without_mask() -> Result<()> {
992 #[cfg(any(feature = "cuda", feature = "metal"))]
993 let device = use_gpu(true)?;
994 #[cfg(not(any(feature = "cuda", feature = "metal")))]
995 let device = use_gpu(false)?;
996
997 for dtype in OK_TYPES {
998 forward_works_without_mask(dtype, &device)?;
999 }
1000 Ok(())
1001 }
1002
1003 fn forward_batched_loss(dtype: DType, device: &Device) -> Result<()> {
1005 let batch_size = 10;
1006 let crf = make_crf(
1007 5,
1008 false,
1009 Some(
1010 Tensor::new(&[-0.0695_f32, -0.0117, -0.0021, -0.0635, -0.0328], device)?
1011 .to_dtype(dtype)?,
1012 ),
1013 Some(
1014 Tensor::new(&[-0.0524_f32, -0.0827, 0.0868, -0.0140, 0.0131], device)?
1015 .to_dtype(dtype)?,
1016 ),
1017 Some(
1018 Tensor::new(
1019 &[
1020 [-0.0330_f32, 0.0894, -0.0732, 0.0996, 0.0014],
1021 [-0.0514, -0.0677, -0.0611, -0.0168, 0.0297],
1022 [0.0580, -0.0224, -0.0465, -0.0527, -0.0133],
1023 [0.0506, 0.0535, -0.0378, -0.0537, 0.0516],
1024 [-0.0037, 0.0763, -0.0867, 0.0410, 0.0368],
1025 ],
1026 device,
1027 )?
1028 .to_dtype(dtype)?,
1029 ),
1030 device,
1031 )?;
1032
1033 let emissions = Tensor::new(
1034 &[
1035 [
1036 [
1037 -1.5867e+00_f32,
1038 -4.0363e-01,
1039 1.7869e-02,
1040 -5.0247e-01,
1041 8.2934e-01,
1042 ],
1043 [-8.5983e-01, -4.0548e-01, 1.3188e-01, 9.8255e-01, 2.2588e-01],
1044 [
1045 -6.8282e-01,
1046 1.8752e+00,
1047 -3.4774e-01,
1048 -1.0902e+00,
1049 1.7499e-01,
1050 ],
1051 [3.7244e-01, 1.1534e+00, 7.7696e-01, 3.4387e-01, -9.8422e-01],
1052 [3.8489e-02, 8.2353e-01, -8.2190e-01, 8.6248e-02, 1.2238e-01],
1053 [4.4424e-02, 1.3664e+00, -1.3658e+00, -2.4691e-01, 1.1135e+00],
1054 [1.2708e+00, 2.9114e-01, 1.0744e+00, 5.3505e-02, -1.5511e-01],
1055 [
1056 1.3976e+00,
1057 -1.1226e+00,
1058 -9.2870e-01,
1059 1.1908e-01,
1060 -1.6336e+00,
1061 ],
1062 [6.0694e-01, 2.5764e-01, -6.8925e-01, 1.1807e+00, -6.5968e-01],
1063 [3.5677e-01, -1.4314e+00, 9.4358e-01, 7.9112e-01, -2.1923e-01],
1064 ],
1065 [
1066 [1.3654e+00, -2.3797e-01, 6.2540e-02, 1.5489e+00, -2.0502e+00],
1067 [1.3639e+00, -5.9433e-01, 7.0876e-01, -4.9674e-01, 6.9055e-02],
1068 [
1069 2.3545e-01,
1070 -4.0388e-01,
1071 1.2455e+00,
1072 -4.1925e-01,
1073 -3.9647e-01,
1074 ],
1075 [
1076 -6.3912e-01,
1077 -1.4287e+00,
1078 -2.2617e+00,
1079 -5.6802e-01,
1080 7.4044e-01,
1081 ],
1082 [3.8845e-01, -9.7110e-01, -5.7113e-01, 1.3628e+00, 7.4219e-01],
1083 [
1084 -7.7064e-01,
1085 9.3300e-01,
1086 -1.4319e+00,
1087 -1.5991e+00,
1088 2.6631e-01,
1089 ],
1090 [
1091 1.7472e+00,
1092 -5.9296e-01,
1093 -1.3249e-03,
1094 1.4543e-01,
1095 -6.5364e-01,
1096 ],
1097 [
1098 -3.5911e-01,
1099 4.4189e-02,
1100 -1.2928e+00,
1101 -1.1482e+00,
1102 1.2672e+00,
1103 ],
1104 [1.3452e+00, -2.3875e+00, 1.4895e+00, -7.3329e-01, 2.1750e-01],
1105 [
1106 -3.9819e-01,
1107 4.5757e-01,
1108 -5.0534e-01,
1109 -3.0911e+00,
1110 -1.1324e+00,
1111 ],
1112 ],
1113 [
1114 [
1115 -3.4185e-01,
1116 -1.0406e+00,
1117 -4.3079e-01,
1118 -4.5273e-02,
1119 1.1170e+00,
1120 ],
1121 [
1122 -8.5589e-01,
1123 9.4792e-01,
1124 -8.8419e-01,
1125 -7.7756e-01,
1126 -1.7976e-01,
1127 ],
1128 [-1.8891e-01, 1.7120e-01, -4.3634e-01, 1.2762e+00, 1.0334e+00],
1129 [
1130 2.7852e-01,
1131 -1.5482e+00,
1132 5.6432e-01,
1133 -1.1859e+00,
1134 -7.0821e-02,
1135 ],
1136 [3.4364e-01, 1.2222e+00, 1.0542e+00, -1.7861e-01, 6.4608e-01],
1137 [-8.4590e-01, 1.4749e+00, 3.7927e-01, 2.2527e+00, -3.5637e-02],
1138 [
1139 4.5344e-01,
1140 -1.4359e+00,
1141 -2.2955e+00,
1142 -9.4110e-01,
1143 -8.5992e-01,
1144 ],
1145 [6.8505e-01, -1.5822e-01, -6.9359e-01, 5.9559e-02, 6.8955e-01],
1146 [
1147 -3.4006e-01,
1148 1.7685e+00,
1149 2.1671e-01,
1150 -8.6512e-01,
1151 -2.6517e-01,
1152 ],
1153 [1.0503e-01, 1.6486e+00, -2.4486e-01, 5.4843e-01, 1.9252e+00],
1154 ],
1155 ],
1156 device,
1157 )?
1158 .to_dtype(dtype)?;
1159
1160 #[cfg(feature = "metal")]
1161 let tags = Tensor::new(
1162 &[
1163 [1_u32, 0, 0, 1, 2, 1, 0, 4, 3, 0],
1164 [3, 2, 3, 4, 4, 1, 2, 1, 4, 1],
1165 [0, 1, 4, 4, 0, 4, 0, 1, 4, 1],
1166 ],
1167 device,
1168 )?;
1169 #[cfg(not(feature = "metal"))]
1170 let tags = Tensor::new(
1171 &[
1172 [1_i64, 0, 0, 1, 2, 1, 0, 4, 3, 0],
1173 [3, 2, 3, 4, 4, 1, 2, 1, 4, 1],
1174 [0, 1, 4, 4, 0, 4, 0, 1, 4, 1],
1175 ],
1176 device,
1177 )?;
1178
1179 let llh = crf.forward(&emissions, &tags, None, Reduction::default())?;
1180
1181 llh.dims0()?;
1182 let mut total_llh = Tensor::zeros(llh.shape(), llh.dtype(), llh.device())?;
1183 for i in 0..batch_size {
1184 let emissions = emissions
1185 .i((.., i, ..))?
1186 .contiguous()? .unsqueeze(1)?;
1188
1189 let tags = tags
1190 .i((.., i))?
1191 .contiguous()? .unsqueeze(1)?;
1193
1194 total_llh = total_llh.broadcast_add(&crf.forward(
1195 &emissions,
1196 &tags,
1197 None,
1198 Reduction::default(),
1199 )?)?;
1200 }
1201
1202 println!("llh: {:?}", llh);
1203 println!("total_llh: {:?}", total_llh);
1204
1205 if llh.dtype() == DType::F32 {
1206 let manual_llh = Tensor::full(-49.2024_f32, llh.shape(), llh.device())?;
1207 println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1208 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1209 }
1210
1211 assert_tensor_close(&llh, &total_llh, epsilon(DTypeCase::DType(llh.dtype())))
1212 }
1213
1214 #[test]
1215 fn test_forward_batched_loss() -> Result<()> {
1216 #[cfg(any(feature = "cuda", feature = "metal"))]
1217 let device = use_gpu(true)?;
1218 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1219 let device = use_gpu(false)?;
1220
1221 for dtype in OK_TYPES {
1222 forward_batched_loss(dtype, &device)?;
1223 }
1224 Ok(())
1225 }
1226
1227 fn forward_reduction_none(dtype: DType, device: &Device) -> Result<()> {
1229 let crf = make_crf(
1230 5,
1231 false,
1232 Some(
1233 Tensor::new(&[0.0432_f32, 0.0507, -0.0286, 0.0476, -0.0603], device)?
1234 .to_dtype(dtype)?,
1235 ),
1236 Some(
1237 Tensor::new(&[0.0824_f32, 0.0845, -0.0180, -0.0773, 0.0414], device)?
1238 .to_dtype(dtype)?,
1239 ),
1240 Some(
1241 Tensor::new(
1242 &[
1243 [-0.0894_f32, 0.0512, 0.0066, 0.0534, -0.0182],
1244 [0.0043, 0.0328, -0.0805, -0.0945, 0.0495],
1245 [-0.0020, 0.0416, -0.0441, 0.0390, 0.0690],
1246 [-0.0260, 0.0720, 0.0017, -0.0552, -0.0470],
1247 [0.0104, 0.0299, -0.0182, -0.0515, 0.0424],
1248 ],
1249 device,
1250 )?
1251 .to_dtype(dtype)?,
1252 ),
1253 device,
1254 )?;
1255
1256 let emissions = Tensor::new(
1257 &[
1258 [
1259 [-0.2560_f32, 1.2459, -2.0063, -0.5449, 1.0978],
1260 [0.7233, -0.6967, 0.3394, 0.7784, -3.0362],
1261 ],
1262 [
1263 [-1.3406, -1.1565, 0.0870, 1.8249, 1.3740],
1264 [-1.3396, -1.0208, 0.6608, -0.5917, 1.3850],
1265 ],
1266 [
1267 [1.1436, 0.4477, 0.6606, 1.5938, -0.1054],
1268 [-0.5401, 1.1908, -1.7266, -0.5858, -1.4395],
1269 ],
1270 ],
1271 device,
1272 )?
1273 .to_dtype(dtype)?;
1274
1275 #[cfg(feature = "metal")]
1276 let tags = Tensor::new(&[[1_u32, 3], [1, 3], [1, 2]], device)?;
1277 #[cfg(not(feature = "metal"))]
1278 let tags = Tensor::new(&[[1_i64, 3], [1, 3], [1, 2]], device)?;
1279
1280 let llh = crf.forward(&emissions, &tags, None, Reduction::None)?;
1281 println!("llh: {:?}", llh);
1282 let (seq_length, batch_size) = tags.dims2()?;
1283 assert_eq!(llh.dims1()?, batch_size);
1284
1285 let emissions = emissions.transpose(0, 1)?;
1286 let tags = tags.transpose(0, 1)?;
1287
1288 let (a, _, _) = emissions.dims3()?;
1289 let mut manual_llh = vec![];
1290 for i in 0..a {
1291 let emission = emissions.i(i)?;
1292 let tag = tags.i(i)?;
1293
1294 let numerator = compute_score(&crf, &emission, &tag)?;
1295
1296 #[cfg(feature = "metal")]
1297 let num_tags = crf.num_tags as u32;
1298 #[cfg(not(feature = "metal"))]
1299 let num_tags = crf.num_tags as i64;
1300 let product = cartestian_product((0..num_tags).collect_vec(), seq_length, device)?;
1301
1302 let all_scores = product
1303 .iter()
1304 .map(|t| compute_score(&crf, &emission, &t).unwrap());
1305
1306 let mut denominator = numerator.zeros_like()?;
1307
1308 for s in all_scores.into_iter() {
1309 denominator = denominator.broadcast_add(&s.exp()?)?;
1310 }
1311
1312 let denominator = denominator.log()?;
1313 manual_llh.push((numerator - denominator)?);
1314 }
1315
1316 let manual_llh = cat_scalar_tensor(manual_llh)?;
1317 println!("manual_llh: {:?}", manual_llh);
1318 if dtype == DType::F32 {
1319 let manual_llh = Tensor::new(&[-6.3064_f32, -7.0368], device)?;
1320 println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1321 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1322 }
1323 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(llh.dtype())))
1324 }
1325
1326 #[test]
1327 fn test_forward_reduction_none() -> Result<()> {
1328 #[cfg(any(feature = "cuda", feature = "metal"))]
1329 let device = use_gpu(true)?;
1330 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1331 let device = use_gpu(false)?;
1332
1333 for dtype in OK_TYPES {
1334 forward_reduction_none(dtype, &device)?;
1335 }
1336 Ok(())
1337 }
1338
1339 fn forward_reduction_mean(dtype: DType, device: &Device) -> Result<()> {
1341 let crf = make_crf(
1342 5,
1343 false,
1344 Some(
1345 Tensor::new(&[0.0606_f32, -0.0597, 0.0217, -0.0760, 0.0096], device)?
1346 .to_dtype(dtype)?,
1347 ),
1348 Some(
1349 Tensor::new(&[-0.0791_f32, -0.0159, 0.0525, 0.0451, -0.0373], device)?
1350 .to_dtype(dtype)?,
1351 ),
1352 Some(
1353 Tensor::new(
1354 &[
1355 [
1356 5.0599e-02_f32,
1357 -1.4571e-02,
1358 2.2383e-02,
1359 3.3254e-02,
1360 2.5206e-03,
1361 ],
1362 [6.6520e-02, 7.3251e-02, 1.0225e-02, -9.4751e-02, -3.4146e-02],
1363 [
1364 -6.7073e-02,
1365 2.9719e-02,
1366 -8.5645e-02,
1367 4.6357e-02,
1368 -7.2483e-03,
1369 ],
1370 [
1371 4.4980e-02,
1372 -8.0436e-02,
1373 6.4611e-05,
1374 -5.1731e-02,
1375 -8.2973e-02,
1376 ],
1377 [-5.0593e-02, 4.5717e-03, 6.8714e-03, 8.9858e-02, -8.2813e-02],
1378 ],
1379 device,
1380 )?
1381 .to_dtype(dtype)?,
1382 ),
1383 device,
1384 )?;
1385
1386 let emissions = Tensor::new(
1387 &[
1388 [
1389 [0.0535_f32, 0.6821, -0.2587, 1.2250, 0.5327],
1390 [-2.5028, 0.5942, -0.2508, 0.0597, 1.3800],
1391 ],
1392 [
1393 [-0.0640, -1.3170, 0.6408, -0.1368, -0.2137],
1394 [-0.3985, 0.0530, -0.0448, 0.8268, 0.7622],
1395 ],
1396 [
1397 [1.4061, -0.4045, -0.3174, 0.0737, -1.8753],
1398 [-1.0892, -0.8641, 0.4778, -0.4032, 0.2838],
1399 ],
1400 ],
1401 device,
1402 )?
1403 .to_dtype(dtype)?;
1404
1405 #[cfg(feature = "metal")]
1406 let tags = Tensor::new(&[[3_u32, 0], [0, 3], [2, 4]], device)?;
1407 #[cfg(not(feature = "metal"))]
1408 let tags = Tensor::new(&[[3_i64, 0], [0, 3], [2, 4]], device)?;
1409
1410 let llh = crf.forward(&emissions, &tags, None, Reduction::Mean)?;
1411 println!("llh: {:?}", llh);
1412
1413 let (seq_length, batch_size) = tags.dims2().unwrap();
1414
1415 let emissions = emissions.transpose(0, 1).unwrap();
1416 let tags = tags.transpose(0, 1).unwrap();
1417
1418 let (a, _, _) = emissions.dims3().unwrap();
1419 let mut manual_llh = llh.zeros_like()?;
1420 for i in 0..a {
1421 let emission = emissions.i(i).unwrap();
1422 let tag = tags.i(i).unwrap();
1423
1424 let numerator = compute_score(&crf, &emission, &tag)?;
1425
1426 #[cfg(feature = "metal")]
1427 let num_tags = crf.num_tags as u32;
1428 #[cfg(not(feature = "metal"))]
1429 let num_tags = crf.num_tags as i64;
1430
1431 let product =
1432 cartestian_product((0..num_tags).collect_vec(), seq_length, &Device::Cpu)?;
1433
1434 let all_scores = product
1435 .iter()
1436 .map(|t| compute_score(&crf, &emission, &t).unwrap());
1437
1438 let mut denominator = numerator.zeros_like()?;
1439
1440 for s in all_scores.into_iter() {
1441 denominator = denominator.broadcast_add(&s.exp()?)?;
1442 }
1443
1444 let denominator = denominator.log()?;
1445
1446 manual_llh = (manual_llh + (numerator - denominator)?)?;
1447 }
1448
1449 #[cfg(feature = "metal")]
1450 let manual_llh = (&manual_llh
1451 / Tensor::full(batch_size as f32, manual_llh.shape(), manual_llh.device())?
1452 .to_dtype(manual_llh.dtype())?)?;
1453
1454 #[cfg(not(feature = "metal"))]
1455 let manual_llh = (&manual_llh
1456 / Tensor::full(batch_size as f64, manual_llh.shape(), manual_llh.device())?
1457 .to_dtype(manual_llh.dtype())?)?;
1458
1459 println!("manual_llh: {:?}", manual_llh);
1460
1461 if dtype == DType::F32 {
1462 let manual_llh = Tensor::new(-5.7756_f32, &llh.device())?;
1463 println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1464 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1465 }
1466 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(dtype)))
1467 }
1468
1469 #[test]
1470 fn test_forward_reduction_mean() -> Result<()> {
1471 #[cfg(any(feature = "cuda", feature = "metal"))]
1472 let device = use_gpu(true)?;
1473 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1474 let device = use_gpu(false)?;
1475
1476 for dtype in OK_TYPES {
1477 forward_reduction_mean(dtype, &device)?;
1478 }
1479 Ok(())
1480 }
1481
1482 fn forward_token_mean(dtype: DType, device: &Device) -> Result<()> {
1484 let crf = make_crf(
1485 5,
1486 false,
1487 Some(
1488 Tensor::new(&[0.0687_f32, 0.0533, 0.0204, 0.0250, -0.0785], device)?
1489 .to_dtype(dtype)?,
1490 ),
1491 Some(
1492 Tensor::new(
1493 &[
1494 4.8827e-02_f32,
1495 -9.9134e-05,
1496 9.3184e-02,
1497 -7.6271e-02,
1498 3.6482e-02,
1499 ],
1500 device,
1501 )?
1502 .to_dtype(dtype)?,
1503 ),
1504 Some(
1505 Tensor::new(
1506 &[
1507 [0.0173_f32, -0.0058, -0.0699, -0.0374, 0.0797],
1508 [-0.0405, 0.0141, -0.0002, 0.0790, 0.0205],
1509 [-0.0473, 0.0554, -0.0036, 0.0878, 0.0210],
1510 [0.0761, -0.0406, -0.0905, 0.0590, -0.0030],
1511 [0.0613, 0.0871, -0.0343, 0.0384, 0.0485],
1512 ],
1513 device,
1514 )?
1515 .to_dtype(dtype)?,
1516 ),
1517 device,
1518 )?;
1519
1520 let emissions = Tensor::new(
1521 &[
1522 [
1523 [0.0110_f32, -0.8502, 0.9678, -0.3219, -0.6029],
1524 [1.0804, -1.2822, 1.4129, 0.9475, -2.6282],
1525 ],
1526 [
1527 [0.8993, 0.3029, -0.0686, -0.3108, 0.6216],
1528 [-2.1503, 1.4301, -0.0301, 0.3572, 0.5460],
1529 ],
1530 [
1531 [1.3384, 0.8500, 0.0194, -0.6371, 0.1516],
1532 [-0.7357, 0.3116, 1.5733, -0.8246, -0.4224],
1533 ],
1534 ],
1535 device,
1536 )?
1537 .to_dtype(dtype)?;
1538
1539 #[cfg(feature = "metal")]
1540 let tags = Tensor::new(&[[2_u32, 2], [0, 4], [3, 3]], device)?;
1541 #[cfg(not(feature = "metal"))]
1542 let tags = Tensor::new(&[[2_i64, 2], [0, 4], [3, 3]], device)?;
1543
1544 let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
1545 let llh = crf.forward(&emissions, &tags, Some(&mask), Reduction::TokenMean)?;
1546 println!("llh: {:?}", llh);
1547
1548 let emissions = emissions.transpose(0, 1)?;
1549 let tags = tags.transpose(0, 1)?;
1550 let mask = mask.transpose(0, 1)?;
1551
1552 let (a, _, _) = emissions.dims3()?;
1553 let mut manual_llh = llh.zeros_like()?;
1554 let mut total_tokens = 0;
1555 for i in 0..a {
1556 let emission = emissions.i(i)?;
1557 let tag = tags.i(i)?;
1558 let mask = mask.i(i)?;
1559
1560 let seq_len = mask.sum_all()?.to_scalar::<u8>()? as usize;
1561 let emission = emission.i(..seq_len)?;
1562 let tag = tag.i(..seq_len)?;
1563 let numerator = compute_score(&crf, &emission, &tag)?;
1564
1565 #[cfg(feature = "metal")]
1566 let num_tags = crf.num_tags as u32;
1567 #[cfg(not(feature = "metal"))]
1568 let num_tags = crf.num_tags as i64;
1569
1570 let product = cartestian_product((0..num_tags).collect_vec(), seq_len, device)?;
1571
1572 let all_scores = product
1573 .iter()
1574 .map(|t| compute_score(&crf, &emission, &t).unwrap());
1575
1576 let mut denominator = numerator.zeros_like()?;
1577 for t in all_scores.into_iter() {
1578 denominator = denominator.broadcast_add(&t.exp()?)?;
1579 }
1580
1581 let denominator = denominator.log()?;
1582
1583 manual_llh = (manual_llh + (numerator - denominator)?)?;
1584 total_tokens += seq_len;
1585 }
1586
1587 #[cfg(feature = "metal")]
1588 let total_tokens =
1589 Tensor::full(total_tokens as f32, manual_llh.shape(), manual_llh.device())?
1590 .to_dtype(manual_llh.dtype())?;
1591 #[cfg(not(feature = "metal"))]
1592 let total_tokens =
1593 Tensor::full(total_tokens as f64, manual_llh.shape(), manual_llh.device())?
1594 .to_dtype(manual_llh.dtype())?;
1595
1596 let manual_llh = (manual_llh / total_tokens)?;
1597 println!("manual_llh: {:?}", manual_llh);
1598
1599 if dtype == DType::F32 {
1600 let manual_llh = Tensor::new(-1.4603_f32, &llh.device())?;
1601 println!("compare with pytorch-crf: {:?}, {:?}", llh, manual_llh);
1602 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::PyTorchCRF))?;
1603 }
1604 assert_tensor_close(&llh, &manual_llh, epsilon(DTypeCase::DType(dtype)))?;
1605 llh.backward()?;
1606 Ok(())
1607 }
1608
1609 #[test]
1610 fn test_forward_token_mean() -> Result<()> {
1611 #[cfg(any(feature = "cuda", feature = "metal"))]
1612 let device = use_gpu(true)?;
1613 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1614 let device = use_gpu(false)?;
1615
1616 for dtype in OK_TYPES {
1617 forward_token_mean(dtype, &device)?;
1618 }
1619 Ok(())
1620 }
1621
1622 fn forward_batch_first(dtype: DType, device: &Device) -> Result<()> {
1624 let crf = make_crf(
1625 5,
1626 false,
1627 Some(
1628 Tensor::new(&[0.0384_f32, -0.0811, -0.0291, 0.0444, 0.0943], device)?
1629 .to_dtype(dtype)?,
1630 ),
1631 Some(
1632 Tensor::new(&[0.0146_f32, 0.0455, 0.0991, 0.0640, -0.0298], device)?
1633 .to_dtype(dtype)?,
1634 ),
1635 Some(
1636 Tensor::new(
1637 &[
1638 [0.0063_f32, 0.0014, 0.0804, -0.0385, -0.0485],
1639 [0.0485, -0.0963, 0.0799, 0.0198, -0.0549],
1640 [0.0016, 0.0012, -0.0411, 0.0540, -0.0823],
1641 [0.0111, 0.0320, 0.0769, 0.0292, -0.0707],
1642 [-0.0990, -0.0971, 0.0635, 0.0166, 0.0292],
1643 ],
1644 device,
1645 )?
1646 .to_dtype(dtype)?,
1647 ),
1648 device,
1649 )?;
1650
1651 let emissions = Tensor::new(
1652 &[
1653 [
1654 [-1.1338_f32, -0.9228, 0.3260, 0.0327, -1.0345],
1655 [0.1106, 0.8005, 0.3860, -0.1214, -1.8224],
1656 ],
1657 [
1658 [-1.3724, -2.2578, -1.8705, -0.1109, 0.3845],
1659 [-0.4223, 0.8414, -1.4423, -1.2734, 0.5193],
1660 ],
1661 [
1662 [0.4189, -1.4048, -1.6877, 1.0891, 0.6978],
1663 [-0.2521, -1.4185, -0.6026, 1.6335, 1.0366],
1664 ],
1665 ],
1666 device,
1667 )?
1668 .to_dtype(dtype)?;
1669
1670 #[cfg(feature = "metal")]
1671 let tags = Tensor::new(&[[1_u32, 4], [1, 4], [1, 4]], device)?;
1672 #[cfg(not(feature = "metal"))]
1673 let tags = Tensor::new(&[[1_i64, 4], [1, 4], [1, 4]], device)?;
1674
1675 let llh = crf.forward(&emissions, &tags, None, Reduction::default())?;
1676 println!("llh: {:?}", llh);
1677
1678 let crf_bf = make_crf(
1679 5,
1680 true,
1681 Some(crf.start_transitions),
1682 Some(crf.end_transitions),
1683 Some(crf.transitions),
1684 device,
1685 )?;
1686
1687 let emissions = emissions.transpose(0, 1).unwrap();
1688 let tags = tags.transpose(0, 1).unwrap();
1689
1690 let llh_bf = crf_bf.forward(&emissions, &tags, None, Reduction::default())?;
1691 println!("llh_bf: {:?}", llh_bf);
1692
1693 if dtype == DType::F32 {
1694 let llh_bf = Tensor::new(-14.8640_f32, &llh.device())?;
1695 println!("compare with pytorch-crf: {:?}, {:?}", llh, llh_bf);
1696 assert_tensor_close(&llh, &llh_bf, epsilon(DTypeCase::PyTorchCRF))?;
1697 }
1698 assert_tensor_close(&llh, &llh_bf, epsilon(DTypeCase::DType(dtype)))
1699 }
1700
1701 #[test]
1702 fn test_forward_batch_first() -> Result<()> {
1703 #[cfg(any(feature = "cuda", feature = "metal"))]
1704 let device = use_gpu(true)?;
1705 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1706 let device = use_gpu(false)?;
1707
1708 for dtype in OK_TYPES {
1709 forward_batch_first(dtype, &device)?;
1710 }
1711 Ok(())
1712 }
1713
1714 #[test]
1716 fn test_forward_emissions_has_bad_number_of_dimension() -> Result<()> {
1717 #[cfg(any(feature = "cuda", feature = "metal"))]
1718 let device = use_gpu(true)?;
1719 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1720 let device = use_gpu(false)?;
1721
1722 let emissions = Tensor::randn(0.0_f32, 1., (1, 2), &device)?;
1723
1724 #[cfg(feature = "metal")]
1725 let tags = Tensor::zeros((2, 2), DType::U32, &device)?;
1726 #[cfg(not(feature = "metal"))]
1727 let tags = Tensor::zeros((2, 2), DType::I64, &device)?;
1728
1729 let crf = make_crf(5, false, None, None, None, &device)?;
1730 let result = crf.forward(&emissions, &tags, None, Reduction::default());
1731 println!("{:?}", result);
1732 assert!(result.is_err());
1733 assert_eq!(
1734 result.err().unwrap().to_string(),
1735 "emissions must have 3 dimensions, got 2"
1736 );
1737 Ok(())
1738 }
1739
1740 #[test]
1742 fn test_forward_emissions_and_tags_size_mismatch() -> Result<()> {
1743 #[cfg(any(feature = "cuda", feature = "metal"))]
1744 let device = use_gpu(true)?;
1745 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1746 let device = use_gpu(false)?;
1747
1748 let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
1749 #[cfg(feature = "metal")]
1750 let tags = Tensor::zeros((2, 2), DType::U32, &device)?;
1751 #[cfg(not(feature = "metal"))]
1752 let tags = Tensor::zeros((2, 2), DType::I64, &device)?;
1753
1754 let crf = make_crf(3, false, None, None, None, &device)?;
1755 let result = crf.forward(&emissions, &tags, None, Reduction::default());
1756 println!("{:?}", result);
1757 assert!(result.is_err());
1758 assert_eq!(
1759 result.err().unwrap().to_string(),
1760 "the first two dimensions of emissions and tags must match, got (1, 2) and (1, 2)"
1761 );
1762 Ok(())
1763 }
1764
1765 #[test]
1767 fn test_forward_emissions_last_dimension_not_equal_to_number_of_tags() -> Result<()> {
1768 #[cfg(any(feature = "cuda", feature = "metal"))]
1769 let device = use_gpu(true)?;
1770 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1771 let device = use_gpu(false)?;
1772
1773 let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
1774 #[cfg(feature = "metal")]
1775 let tags = Tensor::zeros((1, 2), DType::U32, &device)?;
1776 #[cfg(not(feature = "metal"))]
1777 let tags = Tensor::zeros((1, 2), DType::I64, &device)?;
1778
1779 let crf = make_crf(10, false, None, None, None, &device)?;
1780 let result = crf.forward(&emissions, &tags, None, Reduction::default());
1781 println!("{:?}", result);
1782 assert!(result.is_err());
1783 assert_eq!(
1784 result.err().unwrap().to_string(),
1785 "expected last dimension of emissions is 10, got 3"
1786 );
1787 Ok(())
1788 }
1789
1790 #[test]
1792 fn test_forward_first_timestep_mask_is_not_all_on() -> Result<()> {
1793 #[cfg(any(feature = "cuda", feature = "metal"))]
1794 let device = use_gpu(true)?;
1795 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1796 let device = use_gpu(false)?;
1797
1798 let emissions = Tensor::randn(0.0_f32, 1., (3, 2, 4), &device)?;
1799 #[cfg(feature = "metal")]
1800 let tags = Tensor::zeros((3, 2), DType::U32, &device)?;
1801 #[cfg(not(feature = "metal"))]
1802 let tags = Tensor::zeros((3, 2), DType::I64, &device)?;
1803
1804 let mask = Tensor::new(&[[1_u8, 1, 1], [0, 0, 0]], &device)
1805 .unwrap()
1806 .transpose(0, 1)?;
1807 let crf = make_crf(4, false, None, None, None, &device)?;
1808
1809 let result = crf.forward(&emissions, &tags, Some(&mask), Reduction::default());
1810 println!("{:?}", result);
1811 assert!(result.is_err());
1812 assert_eq!(
1813 result.err().unwrap().to_string(),
1814 "mask of the first timestep must all be on"
1815 );
1816
1817 let emissions = emissions.transpose(0, 1)?;
1818 let tags = tags.transpose(0, 1)?;
1819 let mask = mask.transpose(0, 1)?;
1820 let crf = make_crf(4, true, None, None, None, &device)?;
1821
1822 let result = crf.forward(&emissions, &tags, Some(&mask), Reduction::default());
1823 println!("{:?}", result);
1824 assert!(result.is_err());
1825 assert_eq!(
1826 result.err().unwrap().to_string(),
1827 "mask of the first timestep must all be on"
1828 );
1829 Ok(())
1830 }
1831
1832 fn decode_works_with_mask(dtype: DType, device: &Device) -> Result<()> {
1834 let crf = make_crf(
1835 5,
1836 false,
1837 Some(
1838 Tensor::new(&[0.0548_f32, -0.0239, -0.0291, -0.0208, 0.0665], device)?
1839 .to_dtype(dtype)?,
1840 ),
1841 Some(
1842 Tensor::new(&[-0.0612_f32, -0.0615, -0.0557, 0.0672, 0.0470], device)?
1843 .to_dtype(dtype)?,
1844 ),
1845 Some(
1846 Tensor::new(
1847 &[
1848 [-0.0751_f32, -0.0941, 0.0248, 0.0900, -0.0776],
1849 [0.0381, -0.0550, -0.0333, -0.0124, -0.0356],
1850 [-0.0383, -0.0910, 0.0914, -0.0330, -0.0119],
1851 [0.0358, 0.0513, 0.0013, -0.0380, 0.0626],
1852 [-0.0168, 0.0871, 0.0489, 0.0019, -0.0548],
1853 ],
1854 device,
1855 )?
1856 .to_dtype(dtype)?,
1857 ),
1858 device,
1859 )?;
1860
1861 let emissions = Tensor::new(
1862 &[
1863 [
1864 [1.8238_f32, 1.3041, -0.0845, 1.3981, 0.1027],
1865 [1.1092, -0.1616, 1.9770, -1.6850, -1.4289],
1866 ],
1867 [
1868 [0.2831, 0.0936, -1.1957, 0.2637, -0.8048],
1869 [0.4553, -0.0393, 2.3307, -0.3505, -2.3531],
1870 ],
1871 [
1872 [1.6232, 0.2230, 0.3585, -0.7957, -0.2464],
1873 [-0.3805, 0.3646, -1.0142, -1.2563, -0.6568],
1874 ],
1875 ],
1876 device,
1877 )?
1878 .to_dtype(dtype)?;
1879
1880 let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
1881 let best_tags = crf.decode(&emissions, Some(&mask))?;
1882 println!("best_tags: {:?}", best_tags);
1883 assert_eq!(best_tags, vec![vec![0, 3, 0], vec![2, 2]]);
1884 Ok(())
1885 }
1886
1887 #[test]
1888 fn test_decode_works_with_mask() -> Result<()> {
1889 #[cfg(any(feature = "cuda", feature = "metal"))]
1890 let device = use_gpu(true)?;
1891 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1892 let device = use_gpu(false)?;
1893
1894 for dtype in OK_TYPES {
1895 decode_works_with_mask(dtype, &device)?;
1896 }
1897 Ok(())
1898 }
1899
1900 fn decode_works_without_mask(dtype: DType, device: &Device) -> Result<()> {
1902 let crf = make_crf(
1903 5,
1904 false,
1905 Some(
1906 Tensor::new(&[0.0762_f32, 0.0743, 0.0234, -0.0387, -0.0269], device)?
1907 .to_dtype(dtype)?,
1908 ),
1909 Some(
1910 Tensor::new(&[0.0102_f32, -0.0137, -0.0149, 0.0700, -0.0701], device)?
1911 .to_dtype(dtype)?,
1912 ),
1913 Some(
1914 Tensor::new(
1915 &[
1916 [-0.0620_f32, -0.0527, 0.0034, 0.0694, -0.0853],
1917 [0.0922, -0.0613, -0.0592, 0.0482, 0.0632],
1918 [-0.0433, 0.0069, -0.0161, -0.0330, -0.0602],
1919 [-0.0649, 0.0047, 0.0593, 0.0733, 0.0203],
1920 [0.0997, 0.0007, 0.0938, 0.0427, 0.0823],
1921 ],
1922 device,
1923 )?
1924 .to_dtype(dtype)?,
1925 ),
1926 device,
1927 )?;
1928
1929 let emissions = Tensor::new(
1930 &[
1931 [
1932 [0.8913_f32, -0.0355, -1.4378, 0.8390, -0.7296],
1933 [1.5530, -1.3165, -0.5769, -0.8085, -0.2610],
1934 ],
1935 [
1936 [-0.9622, -0.3234, -0.5353, -0.4424, -0.1456],
1937 [-0.3844, 0.2524, 1.9393, 0.1217, -1.2519],
1938 ],
1939 [
1940 [-0.1619, -0.2520, 1.9566, 0.4863, 1.5627],
1941 [-0.3999, 1.4914, 1.0620, -0.6408, -0.3032],
1942 ],
1943 ],
1944 device,
1945 )?
1946 .to_dtype(dtype)?;
1947
1948 let best_tags_no_mask = crf.decode(&emissions, None)?;
1949 println!("best_tags: {:?}", best_tags_no_mask);
1950 assert_eq!(best_tags_no_mask, vec![vec![0, 4, 2], vec![0, 2, 1]]);
1951 Ok(())
1952 }
1953
1954 #[test]
1955 fn test_decode_works_without_mask() -> Result<()> {
1956 #[cfg(any(feature = "cuda", feature = "metal"))]
1957 let device = use_gpu(true)?;
1958 #[cfg(not(any(feature = "cuda", feature = "metal")))]
1959 let device = use_gpu(false)?;
1960
1961 for dtype in OK_TYPES {
1962 decode_works_without_mask(dtype, &device)?;
1963 }
1964 Ok(())
1965 }
1966
1967 fn decode_batched_decode(dtype: DType, device: &Device) -> Result<()> {
1969 let crf = make_crf(
1970 5,
1971 false,
1972 Some(
1973 Tensor::new(&[-0.0489_f32, 0.0460, -0.0924, -0.0722, 0.0736], device)?
1974 .to_dtype(dtype)?,
1975 ),
1976 Some(
1977 Tensor::new(&[0.0843_f32, 0.0344, -0.0996, 0.0944, 0.0622], device)?
1978 .to_dtype(dtype)?,
1979 ),
1980 Some(
1981 Tensor::new(
1982 &[
1983 [0.0780_f32, -0.0794, 0.0208, 0.0039, 0.0080],
1984 [-0.0923, -0.0359, 0.0103, 0.0550, -0.0029],
1985 [0.0628, -0.0787, -0.0256, 0.0554, -0.0969],
1986 [0.0655, -0.0055, 0.0718, -0.0275, -0.0994],
1987 [-0.0492, -0.0953, 0.0862, 0.0580, 0.0422],
1988 ],
1989 device,
1990 )?
1991 .to_dtype(dtype)?,
1992 ),
1993 device,
1994 )?;
1995
1996 let emissions = Tensor::new(
1997 &[
1998 [
1999 [0.7720_f32, 0.9488, 0.6672, 1.8839, -0.6844],
2000 [1.6192, 0.2733, 0.8063, -0.0377, -2.3208],
2001 ],
2002 [
2003 [-0.4374, -1.4631, -0.1330, -0.2155, 1.6044],
2004 [0.7017, -1.1525, -1.0692, 0.3463, 0.9816],
2005 ],
2006 [
2007 [-1.3011, 0.5237, -1.1700, -0.9017, -0.5747],
2008 [-1.0040, 0.7791, -0.3735, 0.8300, 1.5138],
2009 ],
2010 ],
2011 device,
2012 )?
2013 .to_dtype(dtype)?;
2014
2015 let mask = Tensor::new(&[[1_u8, 1, 1], [1, 1, 0]], device)?.transpose(0, 1)?;
2016
2017 let batched = crf.decode(&emissions, Some(&mask))?;
2018 println!("batched: {:?}", batched);
2019
2020 let batch_size = 2;
2021 let mut non_batched = vec![];
2022 for i in 0..batch_size {
2023 let emissions = emissions.i((.., i, ..))?.unsqueeze(1)?.contiguous()?;
2024 let mask = mask.i((.., i))?.unsqueeze(1)?.contiguous()?;
2025
2026 let result = crf.decode(&emissions, Some(&mask))?;
2027 non_batched.push(result[0].clone());
2028 }
2029 println!("non_batched: {:?}", non_batched);
2030
2031 assert_eq!(batched, non_batched);
2032 assert_eq!(batched, vec![vec![3, 4, 1], vec![0, 4]]);
2033 Ok(())
2034 }
2035
2036 #[test]
2037 fn test_decode_batched_decode() -> Result<()> {
2038 #[cfg(any(feature = "cuda", feature = "metal"))]
2039 let device = use_gpu(true)?;
2040 #[cfg(not(any(feature = "cuda", feature = "metal")))]
2041 let device = use_gpu(false)?;
2042
2043 for dtype in OK_TYPES {
2044 decode_batched_decode(dtype, &device)?;
2045 }
2046 Ok(())
2047 }
2048
2049 fn decode_batch_first(dtype: DType, device: &Device) -> Result<()> {
2051 let crf = make_crf(
2052 5,
2053 false,
2054 Some(
2055 Tensor::new(&[-0.0464_f32, 0.0818, 0.0829, -0.0121, -0.0788], device)?
2056 .to_dtype(dtype)?,
2057 ),
2058 Some(
2059 Tensor::new(&[-0.0088_f32, 0.0586, 0.0057, 0.0316, -0.0388], device)?
2060 .to_dtype(dtype)?,
2061 ),
2062 Some(
2063 Tensor::new(
2064 &[
2065 [-0.0536_f32, -0.0093, 0.0276, 0.0351, 0.0604],
2066 [0.0734, 0.0764, -0.0773, 0.0821, 0.0294],
2067 [-0.0540, -0.0158, 0.0437, 0.0992, 0.0473],
2068 [0.0875, 0.0324, -0.0941, 0.0585, 0.0761],
2069 [-0.0930, -0.0832, 0.0290, 0.0974, 0.0914],
2070 ],
2071 device,
2072 )?
2073 .to_dtype(dtype)?,
2074 ),
2075 device,
2076 )?;
2077
2078 let emissions = Tensor::new(
2079 &[
2080 [
2081 [-0.6633_f32, 1.4045, -1.3710, 1.5054, 0.8431],
2082 [-0.1157, -0.0201, -0.2685, -0.6683, 0.0213],
2083 ],
2084 [
2085 [-0.7870, -0.2497, -0.3901, 0.0181, 0.0976],
2086 [0.4487, 0.2629, 2.2021, -0.7489, 0.1199],
2087 ],
2088 [
2089 [0.7837, -0.0174, -0.3873, -0.4722, -0.2462],
2090 [-0.6268, -0.9438, 0.6666, -0.6545, 1.0409],
2091 ],
2092 ],
2093 device,
2094 )?
2095 .to_dtype(dtype)?;
2096
2097 let best_tags = crf.decode(&emissions, None)?;
2098 println!("best_tags: {:?}", best_tags);
2099
2100 let crf_bf = make_crf(
2101 5,
2102 true,
2103 Some(crf.start_transitions),
2104 Some(crf.end_transitions),
2105 Some(crf.transitions),
2106 device,
2107 )?;
2108
2109 let emissions = emissions.transpose(0, 1)?;
2110 let best_tags_bf = crf_bf.decode(&emissions, None)?;
2111 println!("best_tags_bf: {:?}", best_tags_bf);
2112
2113 assert_eq!(best_tags, best_tags_bf);
2114 assert_eq!(best_tags, vec![vec![1, 3, 0], vec![1, 2, 4]]);
2115 Ok(())
2116 }
2117
2118 #[test]
2119 fn test_decode_batch_first() -> Result<()> {
2120 #[cfg(any(feature = "cuda", feature = "metal"))]
2121 let device = use_gpu(true)?;
2122 #[cfg(not(any(feature = "cuda", feature = "metal")))]
2123 let device = use_gpu(false)?;
2124
2125 for dtype in OK_TYPES {
2126 decode_batch_first(dtype, &device)?;
2127 }
2128 Ok(())
2129 }
2130
2131 #[test]
2133 fn test_decode_emissions_has_bad_number_of_dimension() -> Result<()> {
2134 #[cfg(any(feature = "cuda", feature = "metal"))]
2135 let device = use_gpu(true)?;
2136 #[cfg(not(any(feature = "cuda", feature = "metal")))]
2137 let device = use_gpu(false)?;
2138
2139 let emissions = Tensor::randn(0.0_f32, 1., (1, 2), &device)?;
2140 let crf = make_crf(5, false, None, None, None, &device)?;
2141 let result = crf.decode(&emissions, None);
2142 println!("{:?}", result);
2143 assert!(result.is_err());
2144 assert_eq!(
2145 result.err().unwrap().to_string(),
2146 "emissions must have 3 dimensions, got 2"
2147 );
2148 Ok(())
2149 }
2150
2151 #[test]
2153 fn test_decode_emissions_last_dimension_not_equal_to_number_of_tags() -> Result<()> {
2154 #[cfg(any(feature = "cuda", feature = "metal"))]
2155 let device = use_gpu(true)?;
2156 #[cfg(not(any(feature = "cuda", feature = "metal")))]
2157 let device = use_gpu(false)?;
2158
2159 let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
2160 let crf = make_crf(10, false, None, None, None, &device)?;
2161 let result = crf.decode(&emissions, None);
2162 println!("{:?}", result);
2163 assert!(result.is_err());
2164 assert_eq!(
2165 result.err().unwrap().to_string(),
2166 "expected last dimension of emissions is 10, got 3"
2167 );
2168 Ok(())
2169 }
2170
2171 #[test]
2173 fn test_decode_emissions_and_mask_size_mismatch() -> Result<()> {
2174 #[cfg(any(feature = "cuda", feature = "metal"))]
2175 let device = use_gpu(true)?;
2176 #[cfg(not(any(feature = "cuda", feature = "metal")))]
2177 let device = use_gpu(false)?;
2178
2179 let emissions = Tensor::randn(0.0_f32, 1., (1, 2, 3), &device)?;
2180 let mask = Tensor::new(&[[1_u8, 1], [1, 0]], &device)?;
2181 let crf = make_crf(3, false, None, None, None, &device)?;
2182 let result = crf.decode(&emissions, Some(&mask));
2183 println!("{:?}", result);
2184 assert!(result.is_err());
2185 assert_eq!(
2186 result.err().unwrap().to_string(),
2187 "the first two dimensions of emissions and mask must match, got (1, 2) and (2, 2)"
2188 );
2189 Ok(())
2190 }
2191
2192 #[test]
2194 fn test_decode_first_timestep_mask_is_not_all_on() -> Result<()> {
2195 #[cfg(any(feature = "cuda", feature = "metal"))]
2196 let device = use_gpu(true)?;
2197 #[cfg(not(any(feature = "cuda", feature = "metal")))]
2198 let device = use_gpu(false)?;
2199
2200 let emissions = Tensor::randn(0.0_f32, 1., (3, 2, 4), &device)?;
2201 let mask = Tensor::new(&[[1_u8, 1, 1], [0, 0, 0]], &device)?.transpose(0, 1)?;
2202
2203 let crf = make_crf(4, false, None, None, None, &device)?;
2204 let result = crf.decode(&emissions, Some(&mask));
2205 println!("{:?}", result);
2206 assert!(result.is_err());
2207 assert_eq!(
2208 result.err().unwrap().to_string(),
2209 "mask of the first timestep must all be on"
2210 );
2211
2212 let emissions = emissions.transpose(0, 1)?;
2213 let mask = mask.transpose(0, 1)?;
2214 let crf = make_crf(4, true, None, None, None, &device)?;
2215 let result = crf.decode(&emissions, Some(&mask));
2216 println!("{:?}", result);
2217 assert!(result.is_err());
2218 assert_eq!(
2219 result.err().unwrap().to_string(),
2220 "mask of the first timestep must all be on"
2221 );
2222 Ok(())
2223 }
2224}