faer/col/
colmut.rs

1use super::*;
2use crate::mat::matmut::SyncCell;
3use crate::utils::bound::{Array, Dim, Partition};
4use crate::{ContiguousFwd, Idx, IdxInc};
5use core::marker::PhantomData;
6use core::ptr::NonNull;
7use equator::assert;
8use generativity::Guard;
9
10/// see [`super::ColMut`]
11pub struct Mut<'a, T, Rows = usize, RStride = isize> {
12	pub(super) imp: ColView<T, Rows, RStride>,
13	pub(super) __marker: PhantomData<&'a mut T>,
14}
15
16impl<'short, T, Rows: Copy, RStride: Copy> Reborrow<'short> for Mut<'_, T, Rows, RStride> {
17	type Target = Ref<'short, T, Rows, RStride>;
18
19	#[inline]
20	fn rb(&'short self) -> Self::Target {
21		Ref {
22			imp: self.imp,
23			__marker: PhantomData,
24		}
25	}
26}
27impl<'short, T, Rows: Copy, RStride: Copy> ReborrowMut<'short> for Mut<'_, T, Rows, RStride> {
28	type Target = Mut<'short, T, Rows, RStride>;
29
30	#[inline]
31	fn rb_mut(&'short mut self) -> Self::Target {
32		Mut {
33			imp: self.imp,
34			__marker: PhantomData,
35		}
36	}
37}
38impl<'a, T, Rows: Copy, RStride: Copy> IntoConst for Mut<'a, T, Rows, RStride> {
39	type Target = Ref<'a, T, Rows, RStride>;
40
41	#[inline]
42	fn into_const(self) -> Self::Target {
43		Ref {
44			imp: self.imp,
45			__marker: PhantomData,
46		}
47	}
48}
49
50unsafe impl<T: Sync, Rows: Sync, RStride: Sync> Sync for Mut<'_, T, Rows, RStride> {}
51unsafe impl<T: Send, Rows: Send, RStride: Send> Send for Mut<'_, T, Rows, RStride> {}
52
53impl<'a, T> ColMut<'a, T> {
54	/// creates a column view over the given element
55	#[inline]
56	pub fn from_mut(value: &'a mut T) -> Self {
57		unsafe { ColMut::from_raw_parts_mut(value as *mut T, 1, 1) }
58	}
59
60	/// creates a `ColMut` from slice views over the column vector data, the result has the same
61	/// number of rows as the length of the input slice
62	#[inline]
63	pub fn from_slice_mut(slice: &'a mut [T]) -> Self {
64		let len = slice.len();
65		unsafe { Self::from_raw_parts_mut(slice.as_mut_ptr(), len, 1) }
66	}
67}
68
69impl<'a, T, Rows: Shape, RStride: Stride> ColMut<'a, T, Rows, RStride> {
70	/// creates a `ColMut` from pointers to the column vector data, number of rows, and row stride
71	///
72	/// # safety
73	/// this function has the same safety requirements as
74	/// [`MatMut::from_raw_parts(ptr, nrows, 1, row_stride, 0)`]
75	#[inline(always)]
76	#[track_caller]
77	pub const unsafe fn from_raw_parts_mut(ptr: *mut T, nrows: Rows, row_stride: RStride) -> Self {
78		Self {
79			0: Mut {
80				imp: ColView {
81					ptr: NonNull::new_unchecked(ptr),
82					nrows,
83					row_stride,
84				},
85				__marker: PhantomData,
86			},
87		}
88	}
89
90	/// returns a pointer to the column data
91	#[inline(always)]
92	pub fn as_ptr(&self) -> *const T {
93		self.rb().as_ptr()
94	}
95
96	/// returns the number of rows of the column
97	#[inline(always)]
98	pub fn nrows(&self) -> Rows {
99		self.imp.nrows
100	}
101
102	/// returns the number of columns of the column (always `1`)
103	#[inline(always)]
104	pub fn ncols(&self) -> usize {
105		1
106	}
107
108	/// returns the number of rows and columns of the column
109	#[inline(always)]
110	pub fn shape(&self) -> (Rows, usize) {
111		(self.nrows(), self.ncols())
112	}
113
114	/// returns the row stride of the column, specified in number of elements, not in bytes
115	#[inline(always)]
116	pub fn row_stride(&self) -> RStride {
117		self.imp.row_stride
118	}
119
120	/// returns a raw pointer to the element at the given index
121	#[inline(always)]
122	pub fn ptr_at(&self, row: IdxInc<Rows>) -> *const T {
123		self.rb().ptr_at(row)
124	}
125
126	/// returns a raw pointer to the element at the given index, assuming the provided index
127	/// is within the column bounds
128	///
129	/// # safety
130	/// the behavior is undefined if any of the following conditions are violated:
131	/// * `row < self.nrows()`
132	#[inline(always)]
133	#[track_caller]
134	pub unsafe fn ptr_inbounds_at(&self, row: Idx<Rows>) -> *const T {
135		self.rb().ptr_inbounds_at(row)
136	}
137
138	#[inline]
139	#[track_caller]
140	/// see [`ColRef::split_at_row`]
141	pub fn split_at_row(self, row: IdxInc<Rows>) -> (ColRef<'a, T, usize, RStride>, ColRef<'a, T, usize, RStride>) {
142		self.into_const().split_at_row(row)
143	}
144
145	#[inline(always)]
146	/// see [`ColRef::transpose`]
147	pub fn transpose(self) -> RowRef<'a, T, Rows, RStride> {
148		self.into_const().transpose()
149	}
150
151	#[inline(always)]
152	/// see [`ColRef::conjugate`]
153	pub fn conjugate(self) -> ColRef<'a, T::Conj, Rows, RStride>
154	where
155		T: Conjugate,
156	{
157		self.into_const().conjugate()
158	}
159
160	#[inline(always)]
161	/// see [`ColRef::canonical`]
162	pub fn canonical(self) -> ColRef<'a, T::Canonical, Rows, RStride>
163	where
164		T: Conjugate,
165	{
166		self.into_const().canonical()
167	}
168
169	#[inline(always)]
170	/// see [`ColRef::adjoint`]
171	pub fn adjoint(self) -> RowRef<'a, T::Conj, Rows, RStride>
172	where
173		T: Conjugate,
174	{
175		self.into_const().adjoint()
176	}
177
178	#[track_caller]
179	#[inline(always)]
180	/// see [`ColRef::get`]
181	pub fn get<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
182	where
183		ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
184	{
185		<ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get(self.into_const(), row)
186	}
187
188	#[track_caller]
189	#[inline(always)]
190	/// see [`ColRef::get_unchecked`]
191	pub unsafe fn get_unchecked<RowRange>(self, row: RowRange) -> <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
192	where
193		ColRef<'a, T, Rows, RStride>: ColIndex<RowRange>,
194	{
195		unsafe { <ColRef<'a, T, Rows, RStride> as ColIndex<RowRange>>::get_unchecked(self.into_const(), row) }
196	}
197
198	#[inline]
199	/// see [`ColRef::reverse_rows`]
200	pub fn reverse_rows(self) -> ColRef<'a, T, Rows, RStride::Rev> {
201		self.into_const().reverse_rows()
202	}
203
204	#[inline]
205	/// see [`ColRef::subrows`]
206	pub fn subrows<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> ColRef<'a, T, V, RStride> {
207		self.into_const().subrows(row_start, nrows)
208	}
209
210	#[inline]
211	#[track_caller]
212	/// see [`ColRef::as_row_shape`]
213	pub fn as_row_shape<V: Shape>(self, nrows: V) -> ColRef<'a, T, V, RStride> {
214		self.into_const().as_row_shape(nrows)
215	}
216
217	#[inline]
218	/// see [`ColRef::as_dyn_rows`]
219	pub fn as_dyn_rows(self) -> ColRef<'a, T, usize, RStride> {
220		self.into_const().as_dyn_rows()
221	}
222
223	#[inline]
224	/// see [`ColRef::as_dyn_stride`]
225	pub fn as_dyn_stride(self) -> ColRef<'a, T, Rows, isize> {
226		self.into_const().as_dyn_stride()
227	}
228
229	#[inline]
230	/// see [`ColRef::iter`]
231	pub fn iter(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a T>
232	where
233		Rows: 'a,
234	{
235		self.into_const().iter()
236	}
237
238	#[inline]
239	#[cfg(feature = "rayon")]
240	/// see [`ColRef::par_iter`]
241	pub fn par_iter(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a T>
242	where
243		T: Sync,
244		Rows: 'a,
245	{
246		self.into_const().par_iter()
247	}
248
249	#[inline]
250	#[track_caller]
251	#[cfg(feature = "rayon")]
252	/// see [`ColRef::par_partition`]
253	pub fn par_partition(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, T, usize, RStride>>
254	where
255		T: Sync,
256		Rows: 'a,
257	{
258		self.into_const().par_partition(count)
259	}
260
261	#[inline]
262	/// see [`ColRef::try_as_col_major`]
263	pub fn try_as_col_major(self) -> Option<ColRef<'a, T, Rows, ContiguousFwd>> {
264		self.into_const().try_as_col_major()
265	}
266
267	#[inline]
268	/// see [`ColRef::try_as_col_major`]
269	pub fn try_as_col_major_mut(self) -> Option<ColMut<'a, T, Rows, ContiguousFwd>> {
270		self.into_const().try_as_col_major().map(|x| unsafe { x.const_cast() })
271	}
272
273	#[inline(always)]
274	#[doc(hidden)]
275	pub unsafe fn const_cast(self) -> ColMut<'a, T, Rows, RStride> {
276		self
277	}
278
279	#[inline]
280	#[doc(hidden)]
281	pub fn bind_r<'N>(self, row: Guard<'N>) -> ColMut<'a, T, Dim<'N>, RStride> {
282		unsafe { ColMut::from_raw_parts_mut(self.as_ptr_mut(), self.nrows().bind(row), self.row_stride()) }
283	}
284
285	#[inline]
286	/// see [`ColRef::as_mat`]
287	pub fn as_mat(self) -> MatRef<'a, T, Rows, usize, RStride, isize> {
288		self.into_const().as_mat()
289	}
290
291	#[inline]
292	/// see [`ColRef::as_mat`]
293	pub fn as_mat_mut(self) -> MatMut<'a, T, Rows, usize, RStride, isize> {
294		unsafe { self.into_const().as_mat().const_cast() }
295	}
296
297	#[inline]
298	/// see [`ColRef::as_diagonal`]
299	pub fn as_diagonal(self) -> DiagRef<'a, T, Rows, RStride> {
300		DiagRef {
301			0: crate::diag::Ref { inner: self.into_const() },
302		}
303	}
304}
305
306impl<T, Rows: Shape, RStride: Stride, Inner: for<'short> ReborrowMut<'short, Target = Mut<'short, T, Rows, RStride>>> generic::Col<Inner> {
307	/// copies `other` into `self`
308	#[inline]
309	pub fn copy_from<RhsT: Conjugate<Canonical = T>>(&mut self, other: impl AsColRef<T = RhsT, Rows = Rows>)
310	where
311		T: ComplexField,
312	{
313		let other = other.as_col_ref();
314		let this = self.rb_mut();
315
316		assert!(all(this.nrows() == other.nrows(), this.ncols() == other.ncols(),));
317		let m = this.nrows();
318
319		with_dim!(M, m.unbound());
320		imp(
321			self.rb_mut().as_row_shape_mut(M).as_dyn_stride_mut(),
322			other.as_row_shape(M).canonical(),
323			Conj::get::<RhsT>(),
324		);
325
326		pub fn imp<'M, 'N, T: ComplexField>(this: ColMut<'_, T, Dim<'M>>, other: ColRef<'_, T, Dim<'M>>, conj_: Conj) {
327			match conj_ {
328				Conj::No => {
329					zip!(this, other).for_each(|unzip!(dst, src)| *dst = copy(&src));
330				},
331				Conj::Yes => {
332					zip!(this, other).for_each(|unzip!(dst, src)| *dst = conj(&src));
333				},
334			}
335		}
336	}
337
338	/// fills all the elements of `self` with `value`
339	#[inline]
340	pub fn fill(&mut self, value: T)
341	where
342		T: Clone,
343	{
344		fn cloner<T: Clone>(value: T) -> impl for<'a> FnMut(crate::linalg::zip::Last<&'a mut T>) {
345			#[inline(always)]
346			move |x| *x.0 = value.clone()
347		}
348		z!(self.rb_mut().as_dyn_rows_mut()).for_each(cloner::<T>(value));
349	}
350
351	#[inline]
352	/// returns a view over `self`
353	pub fn as_mut(&mut self) -> ColMut<'_, T, Rows, RStride> {
354		self.rb_mut()
355	}
356}
357impl<'a, T, Rows: Shape, RStride: Stride> ColMut<'a, T, Rows, RStride> {
358	#[inline(always)]
359	/// see [`ColRef::as_ptr`]
360	pub fn as_ptr_mut(&self) -> *mut T {
361		self.rb().as_ptr() as *mut T
362	}
363
364	#[inline(always)]
365	/// see [`ColRef::ptr_at`]
366	pub fn ptr_at_mut(&self, row: IdxInc<Rows>) -> *mut T {
367		self.rb().ptr_at(row) as *mut T
368	}
369
370	#[inline(always)]
371	#[track_caller]
372	/// see [`ColRef::ptr_inbounds_at`]
373	pub unsafe fn ptr_inbounds_at_mut(&self, row: Idx<Rows>) -> *mut T {
374		self.rb().ptr_inbounds_at(row) as *mut T
375	}
376
377	#[inline]
378	#[track_caller]
379	/// see [`ColRef::split_at_row`]
380	pub fn split_at_row_mut(self, row: IdxInc<Rows>) -> (ColMut<'a, T, usize, RStride>, ColMut<'a, T, usize, RStride>) {
381		let (a, b) = self.into_const().split_at_row(row);
382		unsafe { (a.const_cast(), b.const_cast()) }
383	}
384
385	#[inline(always)]
386	/// see [`ColRef::transpose`]
387	pub fn transpose_mut(self) -> RowMut<'a, T, Rows, RStride> {
388		unsafe { self.into_const().transpose().const_cast() }
389	}
390
391	#[inline(always)]
392	/// see [`ColRef::conjugate`]
393	pub fn conjugate_mut(self) -> ColMut<'a, T::Conj, Rows, RStride>
394	where
395		T: Conjugate,
396	{
397		unsafe { self.into_const().conjugate().const_cast() }
398	}
399
400	#[inline(always)]
401	/// see [`ColRef::canonical`]
402	pub fn canonical_mut(self) -> ColMut<'a, T::Canonical, Rows, RStride>
403	where
404		T: Conjugate,
405	{
406		unsafe { self.into_const().canonical().const_cast() }
407	}
408
409	#[inline(always)]
410	/// see [`ColRef::adjoint`]
411	pub fn adjoint_mut(self) -> RowMut<'a, T::Conj, Rows, RStride>
412	where
413		T: Conjugate,
414	{
415		unsafe { self.into_const().adjoint().const_cast() }
416	}
417
418	#[inline(always)]
419	#[track_caller]
420	pub(crate) fn at_mut(self, row: Idx<Rows>) -> &'a mut T {
421		assert!(all(row < self.nrows()));
422		unsafe { self.at_mut_unchecked(row) }
423	}
424
425	#[inline(always)]
426	#[track_caller]
427	pub(crate) unsafe fn at_mut_unchecked(self, row: Idx<Rows>) -> &'a mut T {
428		&mut *self.ptr_inbounds_at_mut(row)
429	}
430
431	#[track_caller]
432	#[inline(always)]
433	/// see [`ColRef::get`]
434	pub fn get_mut<RowRange>(self, row: RowRange) -> <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
435	where
436		ColMut<'a, T, Rows, RStride>: ColIndex<RowRange>,
437	{
438		<ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::get(self, row)
439	}
440
441	#[track_caller]
442	#[inline(always)]
443	/// see [`ColRef::get_unchecked`]
444	pub unsafe fn get_mut_unchecked<RowRange>(self, row: RowRange) -> <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::Target
445	where
446		ColMut<'a, T, Rows, RStride>: ColIndex<RowRange>,
447	{
448		unsafe { <ColMut<'a, T, Rows, RStride> as ColIndex<RowRange>>::get_unchecked(self, row) }
449	}
450
451	#[inline]
452	/// see [`ColRef::reverse_rows`]
453	pub fn reverse_rows_mut(self) -> ColMut<'a, T, Rows, RStride::Rev> {
454		unsafe { self.into_const().reverse_rows().const_cast() }
455	}
456
457	#[inline]
458	#[track_caller]
459	/// see [`ColRef::subrows`]
460	pub fn subrows_mut<V: Shape>(self, row_start: IdxInc<Rows>, nrows: V) -> ColMut<'a, T, V, RStride> {
461		unsafe { self.into_const().subrows(row_start, nrows).const_cast() }
462	}
463
464	#[inline]
465	#[track_caller]
466	/// see [`ColRef::as_row_shape`]
467	pub fn as_row_shape_mut<V: Shape>(self, nrows: V) -> ColMut<'a, T, V, RStride> {
468		unsafe { self.into_const().as_row_shape(nrows).const_cast() }
469	}
470
471	#[inline]
472	/// see [`ColRef::as_dyn_rows`]
473	pub fn as_dyn_rows_mut(self) -> ColMut<'a, T, usize, RStride> {
474		unsafe { self.into_const().as_dyn_rows().const_cast() }
475	}
476
477	#[inline]
478	/// see [`ColRef::as_dyn_stride`]
479	pub fn as_dyn_stride_mut(self) -> ColMut<'a, T, Rows, isize> {
480		unsafe { self.into_const().as_dyn_stride().const_cast() }
481	}
482
483	#[inline]
484	/// see [`ColRef::iter`]
485	pub fn iter_mut(self) -> impl 'a + ExactSizeIterator + DoubleEndedIterator<Item = &'a mut T>
486	where
487		Rows: 'a,
488	{
489		let this = self.into_const();
490		Rows::indices(Rows::start(), this.nrows().end()).map(move |j| unsafe { this.const_cast().at_mut_unchecked(j) })
491	}
492
493	#[inline]
494	#[cfg(feature = "rayon")]
495	/// see [`ColRef::par_iter`]
496	pub fn par_iter_mut(self) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = &'a mut T>
497	where
498		T: Send,
499		Rows: 'a,
500	{
501		unsafe {
502			let this = self.as_type::<SyncCell<T>>().into_const();
503
504			use rayon::prelude::*;
505			(0..this.nrows().unbound()).into_par_iter().map(move |j| {
506				let ptr = this.const_cast().at_mut_unchecked(Idx::<Rows>::new_unbound(j));
507				&mut *(ptr as *mut SyncCell<T> as *mut T)
508			})
509		}
510	}
511
512	#[inline]
513	#[track_caller]
514	#[cfg(feature = "rayon")]
515	/// see [`ColRef::par_partition`]
516	pub fn par_partition_mut(self, count: usize) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColMut<'a, T, usize, RStride>>
517	where
518		T: Send,
519		Rows: 'a,
520	{
521		use rayon::prelude::*;
522		unsafe {
523			self.as_type::<SyncCell<T>>()
524				.into_const()
525				.par_partition(count)
526				.map(|col| col.const_cast().as_type::<T>())
527		}
528	}
529
530	pub(crate) unsafe fn as_type<U>(self) -> ColMut<'a, U, Rows, RStride> {
531		ColMut::from_raw_parts_mut(self.as_ptr_mut() as *mut U, self.nrows(), self.row_stride())
532	}
533
534	#[inline]
535	/// see [`ColRef::as_diagonal`]
536	pub fn as_diagonal_mut(self) -> DiagMut<'a, T, Rows, RStride> {
537		DiagMut {
538			0: crate::diag::Mut { inner: self },
539		}
540	}
541
542	#[inline]
543	#[track_caller]
544	pub(crate) fn __at_mut(self, i: Idx<Rows>) -> &'a mut T {
545		self.at_mut(i)
546	}
547}
548
549impl<'a, T, Rows: Shape> ColMut<'a, T, Rows, ContiguousFwd> {
550	/// returns a reference over the elements as a slice
551	#[inline]
552	pub fn as_slice_mut(self) -> &'a mut [T] {
553		unsafe { core::slice::from_raw_parts_mut(self.as_ptr_mut(), self.nrows().unbound()) }
554	}
555}
556
557impl<'a, 'ROWS, T> ColMut<'a, T, Dim<'ROWS>, ContiguousFwd> {
558	/// returns a reference over the elements as a lifetime-bound slice
559	#[inline]
560	pub fn as_array_mut(self) -> &'a mut Array<'ROWS, T> {
561		unsafe { &mut *(self.as_slice_mut() as *mut [_] as *mut Array<'ROWS, T>) }
562	}
563}
564
565impl<'ROWS, 'a, T, RStride: Stride> ColMut<'a, T, Dim<'ROWS>, RStride> {
566	#[doc(hidden)]
567	#[inline]
568	pub fn split_rows_with<'TOP, 'BOT>(
569		self,
570		row: Partition<'TOP, 'BOT, 'ROWS>,
571	) -> (ColRef<'a, T, Dim<'TOP>, RStride>, ColRef<'a, T, Dim<'BOT>, RStride>) {
572		let (a, b) = self.split_at_row(row.midpoint());
573		(a.as_row_shape(row.head), b.as_row_shape(row.tail))
574	}
575}
576
577impl<'ROWS, 'a, T, RStride: Stride> ColMut<'a, T, Dim<'ROWS>, RStride> {
578	#[doc(hidden)]
579	#[inline]
580	pub fn split_rows_with_mut<'TOP, 'BOT>(
581		self,
582		row: Partition<'TOP, 'BOT, 'ROWS>,
583	) -> (ColMut<'a, T, Dim<'TOP>, RStride>, ColMut<'a, T, Dim<'BOT>, RStride>) {
584		let (a, b) = self.split_at_row_mut(row.midpoint());
585		(a.as_row_shape_mut(row.head), b.as_row_shape_mut(row.tail))
586	}
587}
588
589impl<T: core::fmt::Debug, Rows: Shape, RStride: Stride> core::fmt::Debug for Mut<'_, T, Rows, RStride> {
590	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
591		self.rb().fmt(f)
592	}
593}
594
595impl<'a, T, Rows: Shape> ColMut<'a, T, Rows>
596where
597	T: RealField,
598{
599	/// Returns the maximum element in the column, or `None` if the column is empty
600	pub fn max(&self) -> Option<T> {
601		self.rb().as_dyn_rows().as_dyn_stride().internal_max()
602	}
603
604	/// Returns the minimum element in the column, or `None` if the column is empty
605	pub fn min(&self) -> Option<T> {
606		self.rb().as_dyn_rows().as_dyn_stride().internal_min()
607	}
608}
609
610#[cfg(test)]
611mod tests {
612	use crate::Col;
613
614	#[test]
615	fn test_col_min() {
616		let mut col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
617		let colmut = col.as_mut();
618		assert_eq!(colmut.min(), Some(1.0));
619
620		let mut empty: Col<f64> = Col::from_fn(0, |_| 0.0);
621		let emptymut = empty.as_mut();
622		assert_eq!(emptymut.min(), None);
623	}
624
625	#[test]
626	fn test_col_max() {
627		let mut col: Col<f64> = Col::from_fn(5, |x| (x + 1) as f64);
628		let colmut = col.as_mut();
629		assert_eq!(colmut.max(), Some(5.0));
630
631		let mut empty: Col<f64> = Col::from_fn(0, |_| 0.0);
632		let emptymut = empty.as_mut();
633		assert_eq!(emptymut.max(), None);
634	}
635}