1use crate::prelude_dev::*;
2use num::complex::ComplexFloat;
3use num::Num;
4
5#[inline]
8pub fn inner_pack_tril_c_contig<T>(a: &mut [MaybeUninit<T>], offset_a: usize, b: &[T], offset_b: usize, n: usize)
9where
10 T: Clone,
11{
12 let mut a = &mut a[offset_a..];
13 let mut b = &b[offset_b..];
14 for i in 0..n {
15 let (a_prev, a_next) = a.split_at_mut(i + 1);
16 let (b_prev, b_next) = b.split_at(n);
17 a = a_next;
18 b = b_next;
19 a_prev.iter_mut().zip(b_prev.iter()).for_each(|(ai, bi)| {
20 ai.write(bi.clone());
21 });
22 }
23}
24
25#[inline]
26pub fn inner_pack_tril_general<T>(a: &mut [MaybeUninit<T>], la: &Layout<Ix1>, b: &[T], lb: &Layout<Ix2>, n: usize)
27where
28 T: Clone,
29{
30 let mut idx_a = 0;
31 for i in 0..n {
32 for j in 0..=i {
33 let loc_b = unsafe { lb.index_uncheck(&[i, j]) } as usize;
34 let loc_a = unsafe { la.index_uncheck(&[idx_a]) } as usize;
35 a[loc_a].write(b[loc_b].clone());
36 idx_a += 1;
37 }
38 }
39}
40
41#[inline]
42pub fn inner_pack_triu_c_contig<T>(a: &mut [MaybeUninit<T>], offset_a: usize, b: &[T], offset_b: usize, n: usize)
43where
44 T: Clone,
45{
46 let a = &mut a[offset_a..];
47 let b = &b[offset_b..];
48 let mut idx_a = 0;
49 for i in 0..n {
50 for j in i..n {
51 a[idx_a].write(b[i * n + j].clone());
52 idx_a += 1;
53 }
54 }
55}
56
57#[inline]
58pub fn inner_pack_triu_general<T>(a: &mut [MaybeUninit<T>], la: &Layout<Ix1>, b: &[T], lb: &Layout<Ix2>, n: usize)
59where
60 T: Clone,
61{
62 let mut idx_a = 0;
63 for i in 0..n {
64 for j in i..n {
65 let loc_b = unsafe { lb.index_uncheck(&[i, j]) } as usize;
66 let loc_a = unsafe { la.index_uncheck(&[idx_a]) } as usize;
67 a[loc_a].write(b[loc_b].clone());
68 idx_a += 1;
69 }
70 }
71}
72
73pub fn pack_tri_cpu_serial<T>(
74 a: &mut [MaybeUninit<T>],
75 la: &Layout<IxD>,
76 b: &[T],
77 lb: &Layout<IxD>,
78 uplo: FlagUpLo,
79) -> Result<()>
80where
81 T: Clone,
82{
83 let la_split = la.dim_split_at(-1)?;
90 let lb_split = lb.dim_split_at(-2)?;
91 let (la_rest, la_inner) = la_split;
92 let (lb_rest, lb_inner) = lb_split;
93
94 let broad_rest = translate_to_col_major(&[&la_rest, &lb_rest], TensorIterOrder::K)?;
96 let la_rest = &broad_rest[0];
97 let lb_rest = &broad_rest[1];
98 let la_rest_iter = IterLayoutColMajor::new(la_rest)?;
99 let lb_rest_iter = IterLayoutColMajor::new(lb_rest)?;
100
101 let n = lb_inner.shape()[0];
103
104 let c_contig = la_inner.c_contig() && lb_inner.c_contig();
106
107 match uplo {
108 FlagUpLo::U => match c_contig {
109 true => {
110 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
111 inner_pack_triu_c_contig(a, offset_a, b, offset_b, n);
112 }
113 },
114 false => {
115 let mut la_inner = la_inner.to_dim::<Ix1>()?;
116 let mut lb_inner = lb_inner.to_dim::<Ix2>()?;
117 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
118 unsafe {
119 la_inner.set_offset(offset_a);
120 lb_inner.set_offset(offset_b);
121 }
122 inner_pack_triu_general(a, &la_inner, b, &lb_inner, n);
123 }
124 },
125 },
126 FlagUpLo::L => match c_contig {
127 true => {
128 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
129 inner_pack_tril_c_contig(a, offset_a, b, offset_b, n);
130 }
131 },
132 false => {
133 let mut la_inner = la_inner.to_dim::<Ix1>()?;
134 let mut lb_inner = lb_inner.to_dim::<Ix2>()?;
135 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
136 unsafe {
137 la_inner.set_offset(offset_a);
138 lb_inner.set_offset(offset_b);
139 }
140 inner_pack_tril_general(a, &la_inner, b, &lb_inner, n);
141 }
142 },
143 },
144 }
145 Ok(())
146}
147
148#[inline]
153pub fn inner_unpack_tril_c_contig<T>(
154 a: &mut [MaybeUninit<T>],
155 offset_a: usize,
156 b: &[T],
157 offset_b: usize,
158 n: usize,
159 symm: FlagSymm,
160) where
161 T: ComplexFloat,
162{
163 let a = &mut a[offset_a..];
164 let b = &b[offset_b..];
165 let mut idx_b = 0;
166 match symm {
167 FlagSymm::Sy => {
168 for i in 0..n {
169 for j in 0..=i {
170 a[i * n + j].write(b[idx_b]);
171 a[j * n + i].write(b[idx_b]);
172 idx_b += 1;
173 }
174 }
175 },
176 FlagSymm::He => {
177 for i in 0..n {
178 for j in 0..=i {
179 a[i * n + j].write(b[idx_b]);
180 a[j * n + i].write(b[idx_b].conj());
181 idx_b += 1;
182 }
183 }
184 },
185 FlagSymm::Ay => {
186 for i in 0..n {
187 for j in 0..i {
188 a[i * n + j].write(b[idx_b]);
189 a[j * n + i].write(-b[idx_b]);
190 idx_b += 1;
191 }
192 a[i * n + i].write(T::zero());
193 idx_b += 1;
194 }
195 },
196 FlagSymm::Ah => {
197 for i in 0..n {
198 for j in 0..i {
199 a[i * n + j].write(b[idx_b]);
200 a[j * n + i].write(-b[idx_b].conj());
201 idx_b += 1;
202 }
203 a[i * n + i].write(T::zero());
204 idx_b += 1;
205 }
206 },
207 FlagSymm::N => {
208 for i in 0..n {
209 for j in 0..=i {
210 a[i * n + j].write(b[idx_b]);
211 idx_b += 1;
212 }
213 }
214 },
215 }
216}
217
218#[inline]
219pub fn inner_unpack_tril_general<T>(
220 a: &mut [MaybeUninit<T>],
221 la: &Layout<Ix2>,
222 b: &[T],
223 lb: &Layout<Ix1>,
224 n: usize,
225 symm: FlagSymm,
226) where
227 T: ComplexFloat,
228{
229 let mut idx_b = 0;
230 match symm {
231 FlagSymm::Sy => {
232 for i in 0..n {
233 for j in 0..=i {
234 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
235 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
236 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
237 a[loc_a_ij].write(b[loc_b]);
238 a[loc_a_ji].write(b[loc_b]);
239 idx_b += 1;
240 }
241 }
242 },
243 FlagSymm::He => {
244 for i in 0..n {
245 for j in 0..=i {
246 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
247 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
248 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
249 a[loc_a_ij].write(b[loc_b]);
250 a[loc_a_ji].write(b[loc_b].conj());
251 idx_b += 1;
252 }
253 }
254 },
255 FlagSymm::Ay => {
256 for i in 0..n {
257 for j in 0..i {
258 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
259 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
260 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
261 a[loc_a_ij].write(b[loc_b]);
262 a[loc_a_ji].write(-b[loc_b]);
263 idx_b += 1;
264 }
265 let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
266 a[loc_a_ii].write(T::zero());
267 idx_b += 1;
268 }
269 },
270 FlagSymm::Ah => {
271 for i in 0..n {
272 for j in 0..i {
273 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
274 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
275 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
276 a[loc_a_ij].write(b[loc_b]);
277 a[loc_a_ji].write(-b[loc_b].conj());
278 idx_b += 1;
279 }
280 let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
281 a[loc_a_ii].write(T::zero());
282 idx_b += 1;
283 }
284 },
285 FlagSymm::N => {
286 for i in 0..n {
287 for j in 0..=i {
288 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
289 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
290 a[loc_a_ij].write(b[loc_b]);
291 idx_b += 1;
292 }
293 }
294 },
295 }
296}
297
298#[inline]
299pub fn inner_unpack_triu_c_contig<T>(
300 a: &mut [MaybeUninit<T>],
301 offset_a: usize,
302 b: &[T],
303 offset_b: usize,
304 n: usize,
305 symm: FlagSymm,
306) where
307 T: ComplexFloat,
308{
309 let a = &mut a[offset_a..];
310 let b = &b[offset_b..];
311 let mut idx_b = 0;
312 match symm {
313 FlagSymm::Sy => {
314 for i in 0..n {
315 for j in i..n {
316 a[i * n + j].write(b[idx_b]);
317 a[j * n + i].write(b[idx_b]);
318 idx_b += 1;
319 }
320 }
321 },
322 FlagSymm::He => {
323 for i in 0..n {
324 for j in i..n {
325 a[i * n + j].write(b[idx_b]);
326 a[j * n + i].write(b[idx_b].conj());
327 idx_b += 1;
328 }
329 }
330 },
331 FlagSymm::Ay => {
332 for i in 0..n {
333 a[i * n + i].write(T::zero());
334 idx_b += 1;
335 for j in (i + 1)..n {
336 a[i * n + j].write(b[idx_b]);
337 a[j * n + i].write(-b[idx_b]);
338 idx_b += 1;
339 }
340 }
341 },
342 FlagSymm::Ah => {
343 for i in 0..n {
344 a[i * n + i].write(T::zero());
345 idx_b += 1;
346 for j in (i + 1)..n {
347 a[i * n + j].write(b[idx_b]);
348 a[j * n + i].write(-b[idx_b].conj());
349 idx_b += 1;
350 }
351 }
352 },
353 FlagSymm::N => {
354 for i in 0..n {
355 for j in i..n {
356 a[i * n + j].write(b[idx_b]);
357 idx_b += 1;
358 }
359 }
360 },
361 }
362}
363
364#[inline]
365pub fn inner_unpack_triu_general<T>(
366 a: &mut [MaybeUninit<T>],
367 la: &Layout<Ix2>,
368 b: &[T],
369 lb: &Layout<Ix1>,
370 n: usize,
371 symm: FlagSymm,
372) where
373 T: ComplexFloat,
374{
375 let mut idx_b = 0;
376 match symm {
377 FlagSymm::Sy => {
378 for i in 0..n {
379 for j in i..n {
380 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
381 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
382 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
383 a[loc_a_ij].write(b[loc_b]);
384 a[loc_a_ji].write(b[loc_b]);
385 idx_b += 1;
386 }
387 }
388 },
389 FlagSymm::He => {
390 for i in 0..n {
391 for j in i..n {
392 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
393 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
394 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
395 a[loc_a_ij].write(b[loc_b]);
396 a[loc_a_ji].write(b[loc_b].conj());
397 idx_b += 1;
398 }
399 }
400 },
401 FlagSymm::Ay => {
402 for i in 0..n {
403 let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
404 a[loc_a_ii].write(T::zero());
405 idx_b += 1;
406 for j in (i + 1)..n {
407 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
408 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
409 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
410 a[loc_a_ij].write(b[loc_b]);
411 a[loc_a_ji].write(-b[loc_b]);
412 idx_b += 1;
413 }
414 }
415 },
416 FlagSymm::Ah => {
417 for i in 0..n {
418 let loc_a_ii = unsafe { la.index_uncheck(&[i, i]) } as usize;
419 a[loc_a_ii].write(T::zero());
420 idx_b += 1;
421 for j in (i + 1)..n {
422 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
423 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
424 let loc_a_ji = unsafe { la.index_uncheck(&[j, i]) } as usize;
425 a[loc_a_ij].write(b[loc_b]);
426 a[loc_a_ji].write(-b[loc_b].conj());
427 idx_b += 1;
428 }
429 }
430 },
431 FlagSymm::N => {
432 for i in 0..n {
433 for j in i..n {
434 let loc_b = unsafe { lb.index_uncheck(&[idx_b]) } as usize;
435 let loc_a_ij = unsafe { la.index_uncheck(&[i, j]) } as usize;
436 a[loc_a_ij].write(b[loc_b]);
437 idx_b += 1;
438 }
439 }
440 },
441 }
442}
443
444pub fn unpack_tri_cpu_serial<T>(
445 a: &mut [MaybeUninit<T>],
446 la: &Layout<IxD>,
447 b: &[T],
448 lb: &Layout<IxD>,
449 uplo: FlagUpLo,
450 symm: FlagSymm,
451) -> Result<()>
452where
453 T: ComplexFloat,
454{
455 let la_split = la.dim_split_at(-2)?;
462 let lb_split = lb.dim_split_at(-1)?;
463 let (la_rest, la_inner) = la_split;
464 let (lb_rest, lb_inner) = lb_split;
465
466 let broad_rest = translate_to_col_major(&[&la_rest, &lb_rest], TensorIterOrder::K)?;
468 let la_rest = &broad_rest[0];
469 let lb_rest = &broad_rest[1];
470 let la_rest_iter = IterLayoutColMajor::new(la_rest)?;
471 let lb_rest_iter = IterLayoutColMajor::new(lb_rest)?;
472
473 let n = la_inner.shape()[0];
475
476 let c_contig = la_inner.c_contig() && lb_inner.c_contig();
478
479 match uplo {
480 FlagUpLo::U => match c_contig {
481 true => {
482 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
483 inner_unpack_triu_c_contig(a, offset_a, b, offset_b, n, symm);
484 }
485 },
486 false => {
487 let mut la_inner = la_inner.to_dim::<Ix2>()?;
488 let mut lb_inner = lb_inner.to_dim::<Ix1>()?;
489 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
490 unsafe {
491 la_inner.set_offset(offset_a);
492 lb_inner.set_offset(offset_b);
493 }
494 inner_unpack_triu_general(a, &la_inner, b, &lb_inner, n, symm);
495 }
496 },
497 },
498 FlagUpLo::L => match c_contig {
499 true => {
500 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
501 inner_unpack_tril_c_contig(a, offset_a, b, offset_b, n, symm);
502 }
503 },
504 false => {
505 let mut la_inner = la_inner.to_dim::<Ix2>()?;
506 let mut lb_inner = lb_inner.to_dim::<Ix1>()?;
507 for (offset_a, offset_b) in izip!(la_rest_iter, lb_rest_iter) {
508 unsafe {
509 la_inner.set_offset(offset_a);
510 lb_inner.set_offset(offset_b);
511 }
512 inner_unpack_tril_general(a, &la_inner, b, &lb_inner, n, symm);
513 }
514 },
515 },
516 }
517 Ok(())
518}
519
520pub fn tril_cpu_serial<T, D>(raw: &mut [T], layout: &Layout<D>, k: isize) -> Result<()>
525where
526 T: Num + Clone,
527 D: DimAPI,
528{
529 let (la_rest, la_ix2) = layout.dim_split_at(-2)?;
530 let mut la_ix2 = la_ix2.into_dim::<Ix2>()?;
531 for offset in IterLayoutColMajor::new(&la_rest)? {
532 unsafe { la_ix2.set_offset(offset) };
533 tril_ix2_cpu_serial(raw, &la_ix2, k)?;
534 }
535 Ok(())
536}
537
538pub fn tril_ix2_cpu_serial<T>(raw: &mut [T], layout: &Layout<Ix2>, k: isize) -> Result<()>
539where
540 T: Num + Clone,
541{
542 let [nrow, ncol] = *layout.shape();
543 for i in 0..nrow {
544 let j_start = (i as isize + k + 1).max(0) as usize;
545 for j in j_start..ncol {
546 unsafe {
547 raw[layout.index_uncheck(&[i, j]) as usize] = T::zero();
548 }
549 }
550 }
551 Ok(())
552}
553
554pub fn triu_cpu_serial<T, D>(raw: &mut [T], layout: &Layout<D>, k: isize) -> Result<()>
559where
560 T: Num + Clone,
561 D: DimAPI,
562{
563 let (la_rest, la_ix2) = layout.dim_split_at(-2)?;
564 let mut la_ix2 = la_ix2.into_dim::<Ix2>()?;
565 for offset in IterLayoutColMajor::new(&la_rest)? {
566 unsafe { la_ix2.set_offset(offset) };
567 triu_ix2_cpu_serial(raw, &la_ix2, k)?;
568 }
569 Ok(())
570}
571
572pub fn triu_ix2_cpu_serial<T>(raw: &mut [T], layout: &Layout<Ix2>, k: isize) -> Result<()>
573where
574 T: Num + Clone,
575{
576 let [nrow, _] = *layout.shape();
577 for i in 0..nrow {
578 let j_end = (i as isize + k).max(0) as usize;
579 for j in 0..j_end {
580 unsafe {
581 raw[layout.index_uncheck(&[i, j]) as usize] = T::zero();
582 }
583 }
584 }
585 Ok(())
586}
587
588