1mod impl_network;
7
8use crate::model::{SurfaceModel, SurfaceModelConfig};
9use crate::points::Point;
10use cnc::nn::ModelFeatures;
11use cnc::params::Params;
12use ndarray::{Array1, Array2};
13use rstmt::nrt::Triad;
14
15#[derive(Clone, Debug)]
24pub struct SurfaceNetwork<T = f32> {
25 pub(crate) headspace: Triad,
26 pub(crate) model: SurfaceModel<T>,
27 pub(crate) points: Vec<Point<T>>,
28}
29
30impl<T> SurfaceNetwork<T> {
31 pub fn new(headspace: Triad) -> Self
33 where
34 T: num_traits::Float + num_traits::FromPrimitive,
35 {
36 let config = SurfaceModelConfig::default();
37 let features = ModelFeatures::new(3, 64, 3, 1);
38 let model = SurfaceModel::zeros(config, features).with_attention();
39 Self {
40 headspace,
41 model,
42 points: Vec::new(),
43 }
44 }
45 pub const fn headspace(&self) -> Triad {
47 self.headspace
48 }
49 pub fn headspace_mut(&mut self) -> &mut Triad {
51 &mut self.headspace
52 }
53 pub const fn model(&self) -> &SurfaceModel<T> {
55 &self.model
56 }
57 pub fn model_mut(&mut self) -> &mut SurfaceModel<T> {
59 &mut self.model
60 }
61 pub const fn points(&self) -> &Vec<Point<T>> {
63 &self.points
64 }
65 pub fn points_mut(&mut self) -> &mut Vec<Point<T>> {
67 &mut self.points
68 }
69 pub const fn input(&self) -> &Params<T> {
71 self.model().input()
72 }
73 pub fn input_mut(&mut self) -> &mut Params<T> {
75 self.model_mut().input_mut()
76 }
77 pub const fn input_bias(&self) -> &Array1<T> {
79 self.model().input().bias()
80 }
81 pub fn input_bias_mut(&mut self) -> &mut Array1<T> {
83 self.model_mut().input_mut().bias_mut()
84 }
85 pub const fn input_weights(&self) -> &Array2<T> {
87 self.model().input().weights()
88 }
89 pub fn input_weights_mut(&mut self) -> &mut Array2<T> {
91 self.model_mut().input_mut().weights_mut()
92 }
93 pub const fn hidden(&self) -> &Vec<Params<T>> {
95 self.model().hidden()
96 }
97 #[inline]
99 pub fn hidden_mut(&mut self) -> &mut Vec<Params<T>> {
100 self.model_mut().hidden_mut()
101 }
102 pub const fn output(&self) -> &Params<T> {
104 self.model().output()
105 }
106 #[inline]
108 pub fn output_mut(&mut self) -> &mut Params<T> {
109 self.model_mut().output_mut()
110 }
111 pub const fn output_bias(&self) -> &Array1<T> {
113 self.model().output().bias()
114 }
115 #[inline]
117 pub fn output_bias_mut(&mut self) -> &mut Array1<T> {
118 self.model_mut().output_mut().bias_mut()
119 }
120 pub const fn output_weights(&self) -> &Array2<T> {
122 self.model().output().weights()
123 }
124 #[inline]
126 pub fn output_weights_mut(&mut self) -> &mut Array2<T> {
127 self.model_mut().output_mut().weights_mut()
128 }
129
130 pub fn extend_surface<I>(&mut self, iter: I)
131 where
132 I: IntoIterator<Item = Point<T>>,
133 {
134 self.points.extend(iter);
135 }
136 pub fn set_points<I>(&mut self, iter: I)
138 where
139 I: IntoIterator<Item = Point<T>>,
140 {
141 self.points = Vec::from_iter(iter);
142 }
143 pub fn with_attention(self) -> Self
145 where
146 T: num_traits::FromPrimitive,
147 {
148 Self {
149 model: self.model.with_attention(),
150 ..self
151 }
152 }
153 pub fn with_points<I>(self, iter: I) -> Self
155 where
156 I: IntoIterator<Item = Point<T>>,
157 {
158 Self {
159 points: Vec::from_iter(iter),
160 ..self
161 }
162 }
163}
164
165impl<T> PartialEq for SurfaceNetwork<T>
166where
167 T: PartialEq,
168 SurfaceModel<T>: PartialEq,
169{
170 fn eq(&self, other: &Self) -> bool {
171 self.headspace == other.headspace
172 && self.model == other.model
173 && self.points == other.points
174 }
175}