1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
//! Model interface for training with Tensorlogic.
use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn};
use std::collections::HashMap;
/// Trait for trainable models.
///
/// This trait defines the interface for models that can be trained with the
/// Tensorlogic training infrastructure. Models must implement forward and
/// backward passes, parameter management, and optional save/load functionality.
pub trait Model {
/// Perform a forward pass through the model.
///
/// # Arguments
/// * `input` - Input tensor
///
/// # Returns
/// Output tensor from the model
fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>>;
/// Perform a backward pass to compute gradients.
///
/// # Arguments
/// * `input` - Input tensor used in forward pass
/// * `grad_output` - Gradient of loss with respect to model output
///
/// # Returns
/// Gradients for each model parameter
fn backward(
&self,
input: &ArrayView<f64, Ix2>,
grad_output: &ArrayView<f64, Ix2>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
/// Get a reference to the model's parameters.
fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>>;
/// Get a mutable reference to the model's parameters.
fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>>;
/// Set the model's parameters.
fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>);
/// Get the number of parameters in the model.
fn num_parameters(&self) -> usize {
self.parameters().values().map(|p| p.len()).sum()
}
/// Save model state to a dictionary.
fn state_dict(&self) -> HashMap<String, Vec<f64>> {
self.parameters()
.iter()
.map(|(name, param)| (name.clone(), param.iter().copied().collect()))
.collect()
}
/// Load model state from a dictionary.
fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) -> TrainResult<()> {
let parameters = self.parameters_mut();
for (name, values) in state {
if let Some(param) = parameters.get_mut(&name) {
if param.len() != values.len() {
return Err(TrainError::InvalidParameter(format!(
"Parameter '{}' size mismatch: expected {}, got {}",
name,
param.len(),
values.len()
)));
}
for (p, v) in param.iter_mut().zip(values.iter()) {
*p = *v;
}
} else {
return Err(TrainError::InvalidParameter(format!(
"Parameter '{}' not found in model",
name
)));
}
}
Ok(())
}
/// Reset model parameters (optional, for retraining).
fn reset_parameters(&mut self) {
// Default implementation does nothing
// Models can override this to implement custom initialization
}
}
/// Trait for models that support automatic differentiation via scirs2-autograd.
///
/// This trait extends the base Model trait with support for training using
/// SciRS2's automatic differentiation system.
///
/// Note: This trait is currently a placeholder for future scirs2-autograd integration.
/// The actual Variable type will be specified once scirs2-autograd is fully integrated.
pub trait AutodiffModel: Model {
/// Forward pass with autodiff tracking (placeholder).
///
/// # Arguments
/// * `input` - Input data array
///
/// # Returns
/// Success indicator (actual implementation will return autodiff Variable)
fn forward_autodiff(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<()> {
// Placeholder implementation
let _ = input;
Ok(())
}
/// Compute gradients automatically using backward pass (placeholder).
///
/// # Returns
/// Gradients for all parameters
fn compute_gradients(&self) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
// Placeholder implementation
Ok(HashMap::new())
}
}
/// Trait for models with dynamic computation graphs.
///
/// This extends the model interface to support variable-sized inputs
/// and dynamic graph construction (e.g., for RNNs, variable-length sequences).
pub trait DynamicModel {
/// Forward pass with dynamic input dimensions.
fn forward_dynamic(&self, input: &ArrayView<f64, IxDyn>) -> TrainResult<Array<f64, IxDyn>>;
/// Backward pass with dynamic input dimensions.
fn backward_dynamic(
&self,
input: &ArrayView<f64, IxDyn>,
grad_output: &ArrayView<f64, IxDyn>,
) -> TrainResult<HashMap<String, Array<f64, IxDyn>>>;
}
/// A simple linear model for testing and demonstration.
#[derive(Debug, Clone)]
pub struct LinearModel {
/// Model parameters (weights and biases).
parameters: HashMap<String, Array<f64, Ix2>>,
/// Input dimension.
input_dim: usize,
/// Output dimension.
output_dim: usize,
}
impl LinearModel {
/// Create a new linear model.
///
/// # Arguments
/// * `input_dim` - Input dimension
/// * `output_dim` - Output dimension
pub fn new(input_dim: usize, output_dim: usize) -> Self {
let mut parameters = HashMap::new();
// Initialize weights with small random values (simplified)
let weights = Array::zeros((input_dim, output_dim));
let biases = Array::zeros((1, output_dim));
parameters.insert("weight".to_string(), weights);
parameters.insert("bias".to_string(), biases);
Self {
parameters,
input_dim,
output_dim,
}
}
/// Initialize parameters with Xavier/Glorot uniform initialization.
pub fn xavier_init(&mut self) {
let limit = (6.0 / (self.input_dim + self.output_dim) as f64).sqrt();
if let Some(weights) = self.parameters.get_mut("weight") {
// Simplified initialization (in practice, use proper random)
weights.mapv_inplace(|_| (limit * 2.0 * 0.5) - limit);
}
}
/// Get input dimension.
pub fn input_dim(&self) -> usize {
self.input_dim
}
/// Get output dimension.
pub fn output_dim(&self) -> usize {
self.output_dim
}
}
impl Model for LinearModel {
fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>> {
let weights = self
.parameters
.get("weight")
.ok_or_else(|| TrainError::InvalidParameter("weight not found".to_string()))?;
let biases = self
.parameters
.get("bias")
.ok_or_else(|| TrainError::InvalidParameter("bias not found".to_string()))?;
// Linear transformation: Y = X @ W + b
let output = input.dot(weights) + biases;
Ok(output)
}
fn backward(
&self,
input: &ArrayView<f64, Ix2>,
grad_output: &ArrayView<f64, Ix2>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
// Gradient w.r.t. weights: dL/dW = X^T @ dL/dY
let grad_weights = input.t().dot(grad_output);
gradients.insert("weight".to_string(), grad_weights);
// Gradient w.r.t. biases: dL/db = sum(dL/dY, axis=0)
let grad_biases = grad_output
.sum_axis(scirs2_core::ndarray::Axis(0))
.insert_axis(scirs2_core::ndarray::Axis(0));
gradients.insert("bias".to_string(), grad_biases);
Ok(gradients)
}
fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>> {
&self.parameters
}
fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>> {
&mut self.parameters
}
fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>) {
self.parameters = parameters;
}
fn reset_parameters(&mut self) {
self.xavier_init();
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr2;
#[test]
fn test_linear_model_creation() {
let model = LinearModel::new(10, 5);
assert_eq!(model.input_dim(), 10);
assert_eq!(model.output_dim(), 5);
assert_eq!(model.parameters().len(), 2);
}
#[test]
fn test_linear_model_forward() {
let model = LinearModel::new(3, 2);
let input = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let output = model.forward(&input.view()).expect("unwrap");
assert_eq!(output.shape(), &[2, 2]);
}
#[test]
fn test_linear_model_backward() {
let model = LinearModel::new(3, 2);
let input = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let grad_output = arr2(&[[1.0, 1.0], [1.0, 1.0]]);
let gradients = model
.backward(&input.view(), &grad_output.view())
.expect("unwrap");
assert!(gradients.contains_key("weight"));
assert!(gradients.contains_key("bias"));
assert_eq!(gradients["weight"].shape(), &[3, 2]);
assert_eq!(gradients["bias"].shape(), &[1, 2]);
}
#[test]
fn test_model_state_dict() {
let model = LinearModel::new(2, 2);
let state = model.state_dict();
assert_eq!(state.len(), 2);
assert!(state.contains_key("weight"));
assert!(state.contains_key("bias"));
}
#[test]
fn test_model_load_state() {
let mut model = LinearModel::new(2, 2);
let state = model.state_dict();
// Modify parameters
model.parameters_mut().get_mut("weight").expect("unwrap")[[0, 0]] = 99.0;
// Load original state
model.load_state_dict(state.clone()).expect("unwrap");
// Verify state was restored
assert_eq!(
model.parameters().get("weight").expect("unwrap")[[0, 0]],
0.0
);
}
#[test]
fn test_num_parameters() {
let model = LinearModel::new(10, 5);
// 10*5 weights + 5 biases = 55
assert_eq!(model.num_parameters(), 55);
}
}