1use super::*;
2use crate::mat::matmut::SyncCell;
3use crate::utils::bound::{Array, Dim, Partition};
4use crate::{ContiguousFwd, Idx, IdxInc};
5use core::marker::PhantomData;
6use core::ptr::NonNull;
7use equator::assert;
8use generativity::Guard;
9
10pub struct Mut<'a, T, Rows = usize, RStride = isize> {
12 pub(super) imp: ColView<T, Rows, RStride>,
13 pub(super) __marker: PhantomData<&'a mut T>,
14}
15
16impl<'short, T, Rows: Copy, RStride: Copy> Reborrow<'short> for Mut<'_, T, Rows, RStride> {
17 type Target = Ref<'short, T, Rows, RStride>;
18
19 #[inline]
20 fn rb(&'short self) -> Self::Target {
21 Ref {
22 imp: self.imp,
23 __marker: PhantomData,
24 }
25 }
26}
27impl<'short, T, Rows: Copy, RStride: Copy> ReborrowMut<'short> for Mut<'_, T, Rows, RStride> {
28 type Target = Mut<'short, T, Rows, RStride>;
29
30 #[inline]
31 fn rb_mut(&'short mut self) -> Self::Target {
32 Mut {
33 imp: self.imp,
34 __marker: PhantomData,
35 }
36 }
37}
38impl<'a, T, Rows: Copy, RStride: Copy> IntoConst for Mut<'a, T, Rows, RStride> {
39 type Target = Ref<'a, T, Rows, RStride>;
40
41 #[inline]
42 fn into_const(self) -> Self::Target {
43 Ref {
44 imp: self.imp,
45 __marker: PhantomData,
46 }
47 }
48}
49
50unsafe impl<T: Sync, Rows: Sync, RStride: Sync> Sync for Mut<'_, T, Rows, RStride> {}
51unsafe impl<T: Send, Rows: Send, RStride: Send> Send for Mut<'_, T, Rows, RStride> {}
52
53impl<'a, T> ColMut<'a, T> {
54 #[inline]
56 pub fn from_mut(value: &'a mut T) -> Self {
57 unsafe { ColMut::from_raw_parts_mut(value as *mut T, 1, 1) }
58 }
59
60 #[inline]
63 pub fn from_slice_mut(slice: &'a mut [T]) -> Self {
64 let len = slice.len();
65 unsafe { Self::from_raw_parts_mut(slice.as_mut_ptr(), len, 1) }
66 }
67}
68
69impl<'a, T, Rows: Shape, RStride: Stride> ColMut<'a, T, Rows, RStride> {
70 #[inline(always)]
76 #[track_caller]
77 pub const unsafe fn from_raw_parts_mut(ptr: *mut T, nrows: Rows, row_stride: RStride) -> Self {
78 Self {
79 0: Mut {
80 imp: ColView {
81 ptr: NonNull::new_unchecked(ptr),
82 nrows,
83 row_stride,
84 },
85 __marker: PhantomData,
86 },
87 }
88 }
89
90 #[inline(always)]
92 pub fn as_ptr(&self) -> *const T {
93 self.rb().as_ptr()
94 }
95
96 #[inline(always)]
98 pub fn nrows(&self) -> Rows {
99 self.imp.nrows
100 }
101
102 #[inline(always)]
104 pub fn ncols(&self) -> usize {
105 1
106 }
107
108 #[inline(always)]
110 pub fn shape(&self) -> (Rows, usize) {
111 (self.nrows(), self.ncols())
112 }
113
114 #[inline(always)]
116 pub fn row_stride(&self) -> RStride {
117 self.imp.row_stride
118 }
119
120 #[inline(always)]
122 pub fn ptr_at(&self, row: IdxInc<Rows>) -> *const T {
123 self.rb().ptr_at(row)
124 }
125
126 #[inline(always)]
133 #[track_caller]
134 pub unsafe fn ptr_inbounds_at(&self, row: Idx<Rows>) -> *const T {
135 self.rb().ptr_inbounds_at(row)
136 }
137
138 #[inline]
139 #[track_caller]
140 pub fn split_at_row(self, row: IdxInc<Rows>) -> (ColRef<'a, T, usize, RStride>, ColRef<'a, T, usize, RStride>) {
142 self.into_const().split_at_row(row)
143 }
144
145 #[inline(always)]
146 pub fn transpose(self) -> RowRef<'a, T, Rows, RStride> {
148 self.into_const().transpose()
149 }
150
151 #[inline(always)]
152 pub fn conjugate(self) -> ColRef<'a, T::Conj, Rows, RStride>
154 where
155 T: Conjugate,
156 {
157 self.into_const().conjugate()
158 }
159
160 #[inline(always)]
161 pub fn canonical(self) -> ColRef<'a, T::Canonical, Rows, RStride>
163 where
164 T: Conjugate,
165 {
166 self.into_const().canonical()
167 }
168
169 #[inline(always)]
170 pub fn adjoint(self) -> RowRef<'a, T::Conj, Rows, RStride>
172 where
173 T: Conjugate,
174 {
175 self.into_const().adjoint()
176 }
177
178 #[track_caller]
179 #[inline(always)]
180 pub fn get<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
182 where
183 ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
184 {
185 <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get(self.into_const(), row)
186 }
187
188 #[track_caller]
189 #[inline(always)]
190 pub unsafe fn get_unchecked<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
192 where
193 ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
194 {
195 unsafe { <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get_unchecked(self.into_const(), row) }
196 }
197
198 #[inline]
199 pub fn reverse_rows(self) -> ColRef<'a, T, Rows, RStride::Rev> {
201 self.into_const().reverse_rows()
202 }
203
204 #[inline]
205 pub fn subrows<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> ColRef<'a, T, V, RStride> {
207 self.into_const().subrows(row_start, nrows)
208 }
209
210 #[inline]
211 #[track_caller]
212 pub fn as_row_shape<V: Shape>(self, nrows: V) -> ColRef<'a, T, V, RStride> {
214 self.into_const().as_row_shape(nrows)
215 }
216
217 #[inline]
218 pub fn as_dyn_rows(self) -> ColRef<'a, T, usize, RStride> {
220 self.into_const().as_dyn_rows()
221 }
222
223 #[inline]
224 pub fn as_dyn_stride(self) -> ColRef<'a, T, Rows, isize> {
226 self.into_const().as_dyn_stride()
227 }
228
229 #[inline]
230 pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
232 where
233 Rows: 'a,
234 {
235 self.into_const().iter()
236 }
237
238 #[inline]
239 #[cfg(feature = "rayon")]
240 pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
242 where
243 T: Sync,
244 Rows: 'a,
245 {
246 self.into_const().par_iter()
247 }
248
249 #[inline]
250 #[track_caller]
251 #[cfg(feature = "rayon")]
252 pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, T, usize, RStride>>
254 where
255 T: Sync,
256 Rows: 'a,
257 {
258 self.into_const().par_partition(count)
259 }
260
261 #[inline]
262 pub fn try_as_col_major(self) -> Option<ColRef<'a, T, Rows, ContiguousFwd>> {
264 self.into_const().try_as_col_major()
265 }
266
267 #[inline]
268 pub fn try_as_col_major_mut(self) -> Option<ColMut<'a, T, Rows, ContiguousFwd>> {
270 self.into_const().try_as_col_major().map(|x| unsafe { x.const_cast() })
271 }
272
273 #[inline(always)]
274 #[doc(hidden)]
275 pub unsafe fn const_cast(self) -> ColMut<'a, T, Rows, RStride> {
276 self
277 }
278
279 #[inline]
280 #[doc(hidden)]
281 pub fn bind_r<'N>(self, row: Guard<'N>) -> ColMut<'a, T, Dim<'N>, RStride> {
282 unsafe { ColMut::from_raw_parts_mut(self.as_ptr_mut(), self.nrows().bind(row), self.row_stride()) }
283 }
284
285 #[inline]
286 pub fn as_mat(self) -> MatRef<'a, T, Rows, usize, RStride, isize> {
288 self.into_const().as_mat()
289 }
290
291 #[inline]
292 pub fn as_mat_mut(self) -> MatMut<'a, T, Rows, usize, RStride, isize> {
294 unsafe { self.into_const().as_mat().const_cast() }
295 }
296
297 #[inline]
298 pub fn as_diagonal(self) -> DiagRef<'a, T, Rows, RStride> {
300 DiagRef {
301 0: crate::diag::Ref { inner: self.into_const() },
302 }
303 }
304}
305
306impl<T, Rows: Shape, RStride: Stride, Inner: for<'short> ReborrowMut<'short, Target = Mut<'short, T, Rows, RStride>>> generic::Col<Inner> {
307 #[inline]
309 pub fn copy_from<RhsT: Conjugate<Canonical = T>>(&mut self, other: impl AsColRef<T = RhsT, Rows = Rows>)
310 where
311 T: ComplexField,
312 {
313 let other = other.as_col_ref();
314 let this = self.rb_mut();
315
316 assert!(all(this.nrows() == other.nrows(), this.ncols() == other.ncols(),));
317 let m = this.nrows();
318
319 with_dim!(M, m.unbound());
320 imp(
321 self.rb_mut().as_row_shape_mut(M).as_dyn_stride_mut(),
322 other.as_row_shape(M).canonical(),
323 Conj::get::<RhsT>(),
324 );
325
326 pub fn imp<'M, 'N, T: ComplexField>(this: ColMut<'_, T, Dim<'M>>, other: ColRef<'_, T, Dim<'M>>, conj_: Conj) {
327 match conj_ {
328 Conj::No => {
329 zip!(this, other).for_each(|unzip!(dst, src)| *dst = copy(&src));
330 },
331 Conj::Yes => {
332 zip!(this, other).for_each(|unzip!(dst, src)| *dst = conj(&src));
333 },
334 }
335 }
336 }
337
338 #[inline]
340 pub fn fill(&mut self, value: T)
341 where
342 T: Clone,
343 {
344 fn cloner<T: Clone>(value: T) -> impl for<'a> FnMut(crate::linalg::zip::Last<&'a mut T>) {
345 #[inline(always)]
346 move |x| *x.0 = value.clone()
347 }
348 z!(self.rb_mut().as_dyn_rows_mut()).for_each(cloner::<T>(value));
349 }
350
351 #[inline]
352 pub fn as_mut(&mut self) -> ColMut<'_, T, Rows, RStride> {
354 self.rb_mut()
355 }
356}
357impl<'a, T, Rows: Shape, RStride: Stride> ColMut<'a, T, Rows, RStride> {
358 #[inline(always)]
359 pub fn as_ptr_mut(&self) -> *mut T {
361 self.rb().as_ptr() as *mut T
362 }
363
364 #[inline(always)]
365 pub fn ptr_at_mut(&self, row: IdxInc<Rows>) -> *mut T {
367 self.rb().ptr_at(row) as *mut T
368 }
369
370 #[inline(always)]
371 #[track_caller]
372 pub unsafe fn ptr_inbounds_at_mut(&self, row: Idx<Rows>) -> *mut T {
374 self.rb().ptr_inbounds_at(row) as *mut T
375 }
376
377 #[inline]
378 #[track_caller]
379 pub fn split_at_row_mut(self, row: IdxInc<Rows>) -> (ColMut<'a, T, usize, RStride>, ColMut<'a, T, usize, RStride>) {
381 let (a, b) = self.into_const().split_at_row(row);
382 unsafe { (a.const_cast(), b.const_cast()) }
383 }
384
385 #[inline(always)]
386 pub fn transpose_mut(self) -> RowMut<'a, T, Rows, RStride> {
388 unsafe { self.into_const().transpose().const_cast() }
389 }
390
391 #[inline(always)]
392 pub fn conjugate_mut(self) -> ColMut<'a, T::Conj, Rows, RStride>
394 where
395 T: Conjugate,
396 {
397 unsafe { self.into_const().conjugate().const_cast() }
398 }
399
400 #[inline(always)]
401 pub fn canonical_mut(self) -> ColMut<'a, T::Canonical, Rows, RStride>
403 where
404 T: Conjugate,
405 {
406 unsafe { self.into_const().canonical().const_cast() }
407 }
408
409 #[inline(always)]
410 pub fn adjoint_mut(self) -> RowMut<'a, T::Conj, Rows, RStride>
412 where
413 T: Conjugate,
414 {
415 unsafe { self.into_const().adjoint().const_cast() }
416 }
417
418 #[inline(always)]
419 #[track_caller]
420 pub(crate) fn at_mut(self, row: Idx<Rows>) -> &'a mut T {
421 assert!(all(row < self.nrows()));
422 unsafe { self.at_mut_unchecked(row) }
423 }
424
425 #[inline(always)]
426 #[track_caller]
427 pub(crate) unsafe fn at_mut_unchecked(self, row: Idx<Rows>) -> &'a mut T {
428 &mut *self.ptr_inbounds_at_mut(row)
429 }
430
431 #[track_caller]
432 #[inline(always)]
433 pub fn get_mut<RowRange>(self, row: RowRange) -> <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
435 where
436 ColMut<'a, T, Rows, RStride>: ColIndex<RowRange>,
437 {
438 <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::get(self, row)
439 }
440
441 #[track_caller]
442 #[inline(always)]
443 pub unsafe fn get_mut_unchecked<RowRange>(self, row: RowRange) -> <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
445 where
446 ColMut<'a, T, Rows, RStride>: ColIndex<RowRange>,
447 {
448 unsafe { <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::get_unchecked(self, row) }
449 }
450
451 #[inline]
452 pub fn reverse_rows_mut(self) -> ColMut<'a, T, Rows, RStride::Rev> {
454 unsafe { self.into_const().reverse_rows().const_cast() }
455 }
456
457 #[inline]
458 #[track_caller]
459 pub fn subrows_mut<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> ColMut<'a, T, V, RStride> {
461 unsafe { self.into_const().subrows(row_start, nrows).const_cast() }
462 }
463
464 #[inline]
465 #[track_caller]
466 pub fn as_row_shape_mut<V: Shape>(self, nrows: V) -> ColMut<'a, T, V, RStride> {
468 unsafe { self.into_const().as_row_shape(nrows).const_cast() }
469 }
470
471 #[inline]
472 pub fn as_dyn_rows_mut(self) -> ColMut<'a, T, usize, RStride> {
474 unsafe { self.into_const().as_dyn_rows().const_cast() }
475 }
476
477 #[inline]
478 pub fn as_dyn_stride_mut(self) -> ColMut<'a, T, Rows, isize> {
480 unsafe { self.into_const().as_dyn_stride().const_cast() }
481 }
482
483 #[inline]
484 pub fn iter_mut(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a mut T>
486 where
487 Rows: 'a,
488 {
489 let this = self.into_const();
490 Rows::indices(Rows::start(), this.nrows().end()).map(move |j| unsafe { this.const_cast().at_mut_unchecked(j) })
491 }
492
493 #[inline]
494 #[cfg(feature = "rayon")]
495 pub fn par_iter_mut(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a mut T>
497 where
498 T: Send,
499 Rows: 'a,
500 {
501 unsafe {
502 let this = self.as_type::<SyncCell<T>>().into_const();
503
504 use rayon::prelude::*;
505 (0..this.nrows().unbound()).into_par_iter().map(move |j| {
506 let ptr = this.const_cast().at_mut_unchecked(Idx::<Rows>::new_unbound(j));
507 &mut *(ptr as *mut SyncCell<T> as *mut T)
508 })
509 }
510 }
511
512 #[inline]
513 #[track_caller]
514 #[cfg(feature = "rayon")]
515 pub fn par_partition_mut(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColMut<'a, T, usize, RStride>>
517 where
518 T: Send,
519 Rows: 'a,
520 {
521 use rayon::prelude::*;
522 unsafe {
523 self.as_type::<SyncCell<T>>()
524 .into_const()
525 .par_partition(count)
526 .map(|col| col.const_cast().as_type::<T>())
527 }
528 }
529
530 pub(crate) unsafe fn as_type<U>(self) -> ColMut<'a, U, Rows, RStride> {
531 ColMut::from_raw_parts_mut(self.as_ptr_mut() as *mut U, self.nrows(), self.row_stride())
532 }
533
534 #[inline]
535 pub fn as_diagonal_mut(self) -> DiagMut<'a, T, Rows, RStride> {
537 DiagMut {
538 0: crate::diag::Mut { inner: self },
539 }
540 }
541
542 #[inline]
543 #[track_caller]
544 pub(crate) fn __at_mut(self, i: Idx<Rows>) -> &'a mut T {
545 self.at_mut(i)
546 }
547}
548
549impl<'a, T, Rows: Shape> ColMut<'a, T, Rows, ContiguousFwd> {
550 #[inline]
552 pub fn as_slice_mut(self) -> &'a mut [T] {
553 unsafe { core::slice::from_raw_parts_mut(self.as_ptr_mut(), self.nrows().unbound()) }
554 }
555}
556
557impl<'a, 'ROWS, T> ColMut<'a, T, Dim<'ROWS>, ContiguousFwd> {
558 #[inline]
560 pub fn as_array_mut(self) -> &'a mut Array<'ROWS, T> {
561 unsafe { &mut *(self.as_slice_mut() as *mut [_] as *mut Array<'ROWS, T>) }
562 }
563}
564
565impl<'ROWS, 'a, T, RStride: Stride> ColMut<'a, T, Dim<'ROWS>, RStride> {
566 #[doc(hidden)]
567 #[inline]
568 pub fn split_rows_with<'TOP, 'BOT>(
569 self,
570 row: Partition<'TOP, 'BOT, 'ROWS>,
571 ) -> (ColRef<'a, T, Dim<'TOP>, RStride>, ColRef<'a, T, Dim<'BOT>, RStride>) {
572 let (a, b) = self.split_at_row(row.midpoint());
573 (a.as_row_shape(row.head), b.as_row_shape(row.tail))
574 }
575}
576
577impl<'ROWS, 'a, T, RStride: Stride> ColMut<'a, T, Dim<'ROWS>, RStride> {
578 #[doc(hidden)]
579 #[inline]
580 pub fn split_rows_with_mut<'TOP, 'BOT>(
581 self,
582 row: Partition<'TOP, 'BOT, 'ROWS>,
583 ) -> (ColMut<'a, T, Dim<'TOP>, RStride>, ColMut<'a, T, Dim<'BOT>, RStride>) {
584 let (a, b) = self.split_at_row_mut(row.midpoint());
585 (a.as_row_shape_mut(row.head), b.as_row_shape_mut(row.tail))
586 }
587}
588
589impl<T: core::fmt::Debug, Rows: Shape, RStride: Stride> core::fmt::Debug for Mut<'_, T, Rows, RStride> {
590 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
591 self.rb().fmt(f)
592 }
593}
594
595impl<'a, T, Rows: Shape> ColMut<'a, T, Rows>
596where
597 T: RealField,
598{
599 pub fn max(&self) -> Option<T> {
601 self.rb().as_dyn_rows().as_dyn_stride().internal_max()
602 }
603
604 pub fn min(&self) -> Option<T> {
606 self.rb().as_dyn_rows().as_dyn_stride().internal_min()
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use crate::Col;
613
614 #[test]
615 fn test_col_min() {
616 let mut col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
617 let colmut = col.as_mut();
618 assert_eq!(colmut.min(), Some(1.0));
619
620 let mut empty: Col<f64> = Col::from_fn(0, |_| 0.0);
621 let emptymut = empty.as_mut();
622 assert_eq!(emptymut.min(), None);
623 }
624
625 #[test]
626 fn test_col_max() {
627 let mut col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
628 let colmut = col.as_mut();
629 assert_eq!(colmut.max(), Some(5.0));
630
631 let mut empty: Col<f64> = Col::from_fn(0, |_| 0.0);
632 let emptymut = empty.as_mut();
633 assert_eq!(emptymut.max(), None);
634 }
635}