rstsr_native_impl/cpu_serial/
op_tri.rs

1use crate::prelude_dev::*;
2use num::complex::ComplexFloat;
3use num::Num;
4
5/* #region pack tri */
6
7#[inline]
8pub fn inner_pack_tril_c_contig<T>(a: &mut [MaybeUninit<T>], offset_a: usize, b: &[T], offset_b: usize, n: usize)
9where
10    T: Clone,
11{
12    let mut a = &mut a[offset_a..];
13    let mut b = &b[offset_b..];
14    for i in 0..n {
15        let (a_prev, a_next) = a.split_at_mut(i + 1);
16        let (b_prev, b_next) = b.split_at(n);
17        a = a_next;
18        b = b_next;
19        a_prev.iter_mut().zip(b_prev.iter()).for_each(|(ai, bi)| {
20            ai.write(bi.clone());
21        });
22    }
23}
24
25#[inline]
26pub fn inner_pack_tril_general<T>(a: &mut [MaybeUninit<T>], la: &Layout<Ix1>, b: &[T], lb: &Layout<Ix2>, n: usize)
27where
28    T: Clone,
29{
30    let mut idx_a = 0;
31    for i in 0..n {
32        for j in 0..=i {
33            let loc_b = unsafe { lb.index_uncheck(&[i, j]) } as usize;
34            let loc_a = unsafe { la.index_uncheck(&[idx_a]) } as usize;
35            a[loc_a].write(b[loc_b].clone());
36            idx_a += 1;
37        }
38    }
39}
40
41#[inline]
42pub fn inner_pack_triu_c_contig<T>(a: &mut [MaybeUninit<T>], offset_a: usize, b: &[T], offset_b: usize, n: usize)
43where
44    T: Clone,
45{
46    let a = &mut a[offset_a..];
47    let b = &b[offset_b..];
48    let mut idx_a = 0;
49    for i in 0..n {
50        for j in i..n {
51            a[idx_a].write(b[i * n + j].clone());
52            idx_a += 1;
53        }
54    }
55}
56
57#[inline]
58pub fn inner_pack_triu_general<T>(a: &mut [MaybeUninit<T>], la: &Layout<Ix1>, b: &[T], lb: &Layout<Ix2>, n: usize)
59where
60    T: Clone,
61{
62    let mut idx_a = 0;
63    for i in 0..n {
64        for j in i..n {
65            let loc_b = unsafe { lb.index_uncheck(&[i, j]) } as usize;
66            let loc_a = unsafe { la.index_uncheck(&[idx_a]) } as usize;
67            a[loc_a].write(b[loc_b].clone());
68            idx_a += 1;
69        }
70    }
71}
72
73pub fn pack_tri_cpu_serial<T>(
74    a: &mut [MaybeUninit<T>],
75    la: &Layout<IxD>,
76    b: &[T],
77    lb: &Layout<IxD>,
78    uplo: FlagUpLo,
79) -> Result<()>
80where
81    T: Clone,
82{
83    // we assume dimension checks have been performed, and do not check them here
84    // - ndim_a = ndim_b + 1
85    // - shape_a (..., n, n) and shape_b (..., n * (n + 1) / 2)
86    // - rest shape are the same
87
88    // split dimensions
89    let la_split = la.dim_split_at(-1)?;
90    let lb_split = lb.dim_split_at(-2)?;
91    let (la_rest, la_inner) = la_split;
92    let (lb_rest, lb_inner) = lb_split;
93
94    // rest dimensions handling
95    let broad_rest = translate_to_col_major(&[&la_rest, &lb_rest], TensorIterOrder::K)?;
96    let la_rest = &broad_rest[0];
97    let lb_rest = &broad_rest[1];
98    let la_rest_iter = IterLayoutColMajor::new(la_rest)?;
99    let lb_rest_iter = IterLayoutColMajor::new(lb_rest)?;
100
101    // inner dimensions handling
102    let n = lb_inner.shape()[0];
103
104    // contiguous flags
105    let c_contig = la_inner.c_contig() && lb_inner.c_contig();
106
107    match uplo {
108        FlagUpLo::U => match c_contig {
109            true => {
110                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
111                    inner_pack_triu_c_contig(a, offset_a, b, offset_b, n);
112                }
113            },
114            false => {
115                let mut la_inner = la_inner.to_dim::<Ix1>()?;
116                let mut lb_inner = lb_inner.to_dim::<Ix2>()?;
117                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
118                    unsafe {
119                        la_inner.set_offset(offset_a);
120                        lb_inner.set_offset(offset_b);
121                    }
122                    inner_pack_triu_general(a, &la_inner, b, &lb_inner, n);
123                }
124            },
125        },
126        FlagUpLo::L => match c_contig {
127            true => {
128                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
129                    inner_pack_tril_c_contig(a, offset_a, b, offset_b, n);
130                }
131            },
132            false => {
133                let mut la_inner = la_inner.to_dim::<Ix1>()?;
134                let mut lb_inner = lb_inner.to_dim::<Ix2>()?;
135                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
136                    unsafe {
137                        la_inner.set_offset(offset_a);
138                        lb_inner.set_offset(offset_b);
139                    }
140                    inner_pack_tril_general(a, &la_inner, b, &lb_inner, n);
141                }
142            },
143        },
144    }
145    Ok(())
146}
147
148/* #endregion */
149
150/* #region unpack tri */
151
152#[inline]
153pub fn inner_unpack_tril_c_contig<T>(
154    a: &mut [MaybeUninit<T>],
155    offset_a: usize,
156    b: &[T],
157    offset_b: usize,
158    n: usize,
159    symm: FlagSymm,
160) where
161    T: ComplexFloat,
162{
163    let a = &mut a[offset_a..];
164    let b = &b[offset_b..];
165    let mut idx_b = 0;
166    match symm {
167        FlagSymm::Sy => {
168            for i in 0..n {
169                for j in 0..=i {
170                    a[i * n + j].write(b[idx_b]);
171                    a[j * n + i].write(b[idx_b]);
172                    idx_b += 1;
173                }
174            }
175        },
176        FlagSymm::He => {
177            for i in 0..n {
178                for j in 0..=i {
179                    a[i * n + j].write(b[idx_b]);
180                    a[j * n + i].write(b[idx_b].conj());
181                    idx_b += 1;
182                }
183            }
184        },
185        FlagSymm::Ay => {
186            for i in 0..n {
187                for j in 0..i {
188                    a[i * n + j].write(b[idx_b]);
189                    a[j * n + i].write(-b[idx_b]);
190                    idx_b += 1;
191                }
192                a[i * n + i].write(T::zero());
193                idx_b += 1;
194            }
195        },
196        FlagSymm::Ah => {
197            for i in 0..n {
198                for j in 0..i {
199                    a[i * n + j].write(b[idx_b]);
200                    a[j * n + i].write(-b[idx_b].conj());
201                    idx_b += 1;
202                }
203                a[i * n + i].write(T::zero());
204                idx_b += 1;
205            }
206        },
207        FlagSymm::N => {
208            for i in 0..n {
209                for j in 0..=i {
210                    a[i * n + j].write(b[idx_b]);
211                    idx_b += 1;
212                }
213            }
214        },
215    }
216}
217
218#[inline]
219pub fn inner_unpack_tril_general<T>(
220    a: &mut [MaybeUninit<T>],
221    la: &Layout<Ix2>,
222    b: &[T],
223    lb: &Layout<Ix1>,
224    n: usize,
225    symm: FlagSymm,
226) where
227    T: ComplexFloat,
228{
229    let mut idx_b = 0;
230    match symm {
231        FlagSymm::Sy => {
232            for i in 0..n {
233                for j in 0..=i {
234                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
235                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
236                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
237                    a[loc_a_ij].write(b[loc_b]);
238                    a[loc_a_ji].write(b[loc_b]);
239                    idx_b += 1;
240                }
241            }
242        },
243        FlagSymm::He => {
244            for i in 0..n {
245                for j in 0..=i {
246                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
247                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
248                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
249                    a[loc_a_ij].write(b[loc_b]);
250                    a[loc_a_ji].write(b[loc_b].conj());
251                    idx_b += 1;
252                }
253            }
254        },
255        FlagSymm::Ay => {
256            for i in 0..n {
257                for j in 0..i {
258                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
259                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
260                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
261                    a[loc_a_ij].write(b[loc_b]);
262                    a[loc_a_ji].write(-b[loc_b]);
263                    idx_b += 1;
264                }
265                let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
266                a[loc_a_ii].write(T::zero());
267                idx_b += 1;
268            }
269        },
270        FlagSymm::Ah => {
271            for i in 0..n {
272                for j in 0..i {
273                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
274                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
275                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
276                    a[loc_a_ij].write(b[loc_b]);
277                    a[loc_a_ji].write(-b[loc_b].conj());
278                    idx_b += 1;
279                }
280                let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
281                a[loc_a_ii].write(T::zero());
282                idx_b += 1;
283            }
284        },
285        FlagSymm::N => {
286            for i in 0..n {
287                for j in 0..=i {
288                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
289                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
290                    a[loc_a_ij].write(b[loc_b]);
291                    idx_b += 1;
292                }
293            }
294        },
295    }
296}
297
298#[inline]
299pub fn inner_unpack_triu_c_contig<T>(
300    a: &mut [MaybeUninit<T>],
301    offset_a: usize,
302    b: &[T],
303    offset_b: usize,
304    n: usize,
305    symm: FlagSymm,
306) where
307    T: ComplexFloat,
308{
309    let a = &mut a[offset_a..];
310    let b = &b[offset_b..];
311    let mut idx_b = 0;
312    match symm {
313        FlagSymm::Sy => {
314            for i in 0..n {
315                for j in i..n {
316                    a[i * n + j].write(b[idx_b]);
317                    a[j * n + i].write(b[idx_b]);
318                    idx_b += 1;
319                }
320            }
321        },
322        FlagSymm::He => {
323            for i in 0..n {
324                for j in i..n {
325                    a[i * n + j].write(b[idx_b]);
326                    a[j * n + i].write(b[idx_b].conj());
327                    idx_b += 1;
328                }
329            }
330        },
331        FlagSymm::Ay => {
332            for i in 0..n {
333                a[i * n + i].write(T::zero());
334                idx_b += 1;
335                for j in (i + 1)..n {
336                    a[i * n + j].write(b[idx_b]);
337                    a[j * n + i].write(-b[idx_b]);
338                    idx_b += 1;
339                }
340            }
341        },
342        FlagSymm::Ah => {
343            for i in 0..n {
344                a[i * n + i].write(T::zero());
345                idx_b += 1;
346                for j in (i + 1)..n {
347                    a[i * n + j].write(b[idx_b]);
348                    a[j * n + i].write(-b[idx_b].conj());
349                    idx_b += 1;
350                }
351            }
352        },
353        FlagSymm::N => {
354            for i in 0..n {
355                for j in i..n {
356                    a[i * n + j].write(b[idx_b]);
357                    idx_b += 1;
358                }
359            }
360        },
361    }
362}
363
364#[inline]
365pub fn inner_unpack_triu_general<T>(
366    a: &mut [MaybeUninit<T>],
367    la: &Layout<Ix2>,
368    b: &[T],
369    lb: &Layout<Ix1>,
370    n: usize,
371    symm: FlagSymm,
372) where
373    T: ComplexFloat,
374{
375    let mut idx_b = 0;
376    match symm {
377        FlagSymm::Sy => {
378            for i in 0..n {
379                for j in i..n {
380                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
381                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
382                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
383                    a[loc_a_ij].write(b[loc_b]);
384                    a[loc_a_ji].write(b[loc_b]);
385                    idx_b += 1;
386                }
387            }
388        },
389        FlagSymm::He => {
390            for i in 0..n {
391                for j in i..n {
392                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
393                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
394                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
395                    a[loc_a_ij].write(b[loc_b]);
396                    a[loc_a_ji].write(b[loc_b].conj());
397                    idx_b += 1;
398                }
399            }
400        },
401        FlagSymm::Ay => {
402            for i in 0..n {
403                let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
404                a[loc_a_ii].write(T::zero());
405                idx_b += 1;
406                for j in (i + 1)..n {
407                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
408                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
409                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
410                    a[loc_a_ij].write(b[loc_b]);
411                    a[loc_a_ji].write(-b[loc_b]);
412                    idx_b += 1;
413                }
414            }
415        },
416        FlagSymm::Ah => {
417            for i in 0..n {
418                let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
419                a[loc_a_ii].write(T::zero());
420                idx_b += 1;
421                for j in (i + 1)..n {
422                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
423                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
424                    let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
425                    a[loc_a_ij].write(b[loc_b]);
426                    a[loc_a_ji].write(-b[loc_b].conj());
427                    idx_b += 1;
428                }
429            }
430        },
431        FlagSymm::N => {
432            for i in 0..n {
433                for j in i..n {
434                    let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
435                    let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
436                    a[loc_a_ij].write(b[loc_b]);
437                    idx_b += 1;
438                }
439            }
440        },
441    }
442}
443
444pub fn unpack_tri_cpu_serial<T>(
445    a: &mut [MaybeUninit<T>],
446    la: &Layout<IxD>,
447    b: &[T],
448    lb: &Layout<IxD>,
449    uplo: FlagUpLo,
450    symm: FlagSymm,
451) -> Result<()>
452where
453    T: ComplexFloat,
454{
455    // we assume dimension checks have been performed, and do not check them here
456    // - ndim_a + 1 = ndim_b
457    // - shape_a (..., n * (n + 1) / 2) and shape_b (..., n, n)
458    // - rest shape are the same
459
460    // split dimensions
461    let la_split = la.dim_split_at(-2)?;
462    let lb_split = lb.dim_split_at(-1)?;
463    let (la_rest, la_inner) = la_split;
464    let (lb_rest, lb_inner) = lb_split;
465
466    // rest dimensions handling
467    let broad_rest = translate_to_col_major(&[&la_rest, &lb_rest], TensorIterOrder::K)?;
468    let la_rest = &broad_rest[0];
469    let lb_rest = &broad_rest[1];
470    let la_rest_iter = IterLayoutColMajor::new(la_rest)?;
471    let lb_rest_iter = IterLayoutColMajor::new(lb_rest)?;
472
473    // inner dimensions handling
474    let n = la_inner.shape()[0];
475
476    // contiguous flags
477    let c_contig = la_inner.c_contig() && lb_inner.c_contig();
478
479    match uplo {
480        FlagUpLo::U => match c_contig {
481            true => {
482                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
483                    inner_unpack_triu_c_contig(a, offset_a, b, offset_b, n, symm);
484                }
485            },
486            false => {
487                let mut la_inner = la_inner.to_dim::<Ix2>()?;
488                let mut lb_inner = lb_inner.to_dim::<Ix1>()?;
489                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
490                    unsafe {
491                        la_inner.set_offset(offset_a);
492                        lb_inner.set_offset(offset_b);
493                    }
494                    inner_unpack_triu_general(a, &la_inner, b, &lb_inner, n, symm);
495                }
496            },
497        },
498        FlagUpLo::L => match c_contig {
499            true => {
500                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
501                    inner_unpack_tril_c_contig(a, offset_a, b, offset_b, n, symm);
502                }
503            },
504            false => {
505                let mut la_inner = la_inner.to_dim::<Ix2>()?;
506                let mut lb_inner = lb_inner.to_dim::<Ix1>()?;
507                for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
508                    unsafe {
509                        la_inner.set_offset(offset_a);
510                        lb_inner.set_offset(offset_b);
511                    }
512                    inner_unpack_tril_general(a, &la_inner, b, &lb_inner, n, symm);
513                }
514            },
515        },
516    }
517    Ok(())
518}
519
520/* #endregion */
521
522/* #region tril */
523
524pub fn tril_cpu_serial<T, D>(raw: &mut [T], layout: &Layout<D>, k: isize) -> Result<()>
525where
526    T: Num + Clone,
527    D: DimAPI,
528{
529    let (la_rest, la_ix2) = layout.dim_split_at(-2)?;
530    let mut la_ix2 = la_ix2.into_dim::<Ix2>()?;
531    for offset in IterLayoutColMajor::new(&la_rest)? {
532        unsafe { la_ix2.set_offset(offset) };
533        tril_ix2_cpu_serial(raw, &la_ix2, k)?;
534    }
535    Ok(())
536}
537
538pub fn tril_ix2_cpu_serial<T>(raw: &mut [T], layout: &Layout<Ix2>, k: isize) -> Result<()>
539where
540    T: Num + Clone,
541{
542    let [nrow, ncol] = *layout.shape();
543    for i in 0..nrow {
544        let j_start = (i as isize + k + 1).max(0) as usize;
545        for j in j_start..ncol {
546            unsafe {
547                raw[layout.index_uncheck(&[i, j]) as usize] = T::zero();
548            }
549        }
550    }
551    Ok(())
552}
553
554/* #endregion */
555
556/* #region triu */
557
558pub fn triu_cpu_serial<T, D>(raw: &mut [T], layout: &Layout<D>, k: isize) -> Result<()>
559where
560    T: Num + Clone,
561    D: DimAPI,
562{
563    let (la_rest, la_ix2) = layout.dim_split_at(-2)?;
564    let mut la_ix2 = la_ix2.into_dim::<Ix2>()?;
565    for offset in IterLayoutColMajor::new(&la_rest)? {
566        unsafe { la_ix2.set_offset(offset) };
567        triu_ix2_cpu_serial(raw, &la_ix2, k)?;
568    }
569    Ok(())
570}
571
572pub fn triu_ix2_cpu_serial<T>(raw: &mut [T], layout: &Layout<Ix2>, k: isize) -> Result<()>
573where
574    T: Num + Clone,
575{
576    let [nrow, _] = *layout.shape();
577    for i in 0..nrow {
578        let j_end = (i as isize + k).max(0) as usize;
579        for j in 0..j_end {
580            unsafe {
581                raw[layout.index_uncheck(&[i, j]) as usize] = T::zero();
582            }
583        }
584    }
585    Ok(())
586}
587
588/* #endregion */