use crate::{Module, Parameter};
use parking_lot::Mutex;
use torsh_core::{
device::DeviceType,
error::{Result, TorshError},
shape::Shape,
};
use torsh_tensor::Tensor;
use scirs2_core::random::Random;
use scirs2_core::RngExt;
#[cfg(feature = "std")]
use std::{boxed::Box, collections::HashMap, string::String, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub trait LazyModule: Module {
fn is_initialized(&self) -> bool;
fn initialize(&mut self, input_shape: &Shape) -> Result<()>;
fn input_shape(&self) -> Option<Shape>;
fn output_shape(&self, input_shape: &Shape) -> Option<Shape>;
}
#[derive(Debug, Clone)]
pub enum LazyState {
Uninitialized,
Initializing,
Initialized { input_shape: Shape },
}
pub struct LazyWrapper<M> {
module: Option<M>,
state: Arc<Mutex<LazyState>>,
#[allow(dead_code)]
factory: Option<Box<dyn Fn(&Shape) -> Result<M> + Send + Sync>>,
training: bool,
device: DeviceType,
}
impl<M: Module + Send + Sync> LazyWrapper<M> {
pub fn new<F>(factory: F) -> Self
where
F: Fn(&Shape) -> Result<M> + Send + Sync + 'static,
{
Self {
module: None,
state: Arc::new(Mutex::new(LazyState::Uninitialized)),
factory: Some(Box::new(factory)),
training: true,
device: DeviceType::Cpu,
}
}
pub fn from_lazy_module(module: M) -> Self
where
M: LazyModule,
{
let state = if module.is_initialized() {
if let Some(input_shape) = module.input_shape() {
LazyState::Initialized { input_shape }
} else {
LazyState::Uninitialized
}
} else {
LazyState::Uninitialized
};
Self {
module: Some(module),
state: Arc::new(Mutex::new(state)),
factory: None,
training: true,
device: DeviceType::Cpu,
}
}
#[allow(dead_code)]
fn ensure_initialized(&mut self, input_shape: &Shape) -> Result<()> {
let mut state = self.state.lock();
match &*state {
LazyState::Initialized { .. } => {
return Ok(());
}
LazyState::Initializing => {
return Err(TorshError::RuntimeError(
"Circular dependency in lazy module initialization".to_string(),
));
}
LazyState::Uninitialized => {
}
}
*state = LazyState::Initializing;
drop(state);
let result = if let Some(ref factory) = self.factory {
match factory(input_shape) {
Ok(mut module) => {
if self.training {
module.train();
} else {
module.eval();
}
module.to_device(self.device)?;
self.module = Some(module);
Ok(())
}
Err(e) => Err(e),
}
} else if let Some(ref mut _module) = self.module {
Ok(())
} else {
Err(TorshError::RuntimeError(
"No module or factory available for initialization".to_string(),
))
};
let mut state = self.state.lock();
match result {
Ok(()) => {
*state = LazyState::Initialized {
input_shape: input_shape.clone(),
};
}
Err(_) => {
*state = LazyState::Uninitialized;
}
}
result
}
pub fn module(&self) -> Option<&M> {
self.module.as_ref()
}
pub fn module_mut(&mut self) -> Option<&mut M> {
self.module.as_mut()
}
pub fn is_initialized(&self) -> bool {
matches!(&*self.state.lock(), LazyState::Initialized { .. })
}
pub fn input_shape(&self) -> Option<Shape> {
if let LazyState::Initialized { input_shape } = &*self.state.lock() {
Some(input_shape.clone())
} else {
None
}
}
}
impl<M: Module + Send + Sync> Module for LazyWrapper<M> {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
let _input_shape = input.shape();
if !self.is_initialized() {
return Err(TorshError::RuntimeError(
"LazyWrapper requires mutable access for initialization during forward pass. Consider initializing explicitly.".to_string()
));
}
if let Some(ref module) = self.module {
module.forward(input)
} else {
Err(TorshError::RuntimeError(
"Module not initialized".to_string(),
))
}
}
fn parameters(&self) -> HashMap<String, Parameter> {
if let Some(ref module) = self.module {
module.parameters()
} else {
HashMap::new()
}
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
if let Some(ref module) = self.module {
module.named_parameters()
} else {
HashMap::new()
}
}
fn training(&self) -> bool {
self.training
}
fn train(&mut self) {
self.training = true;
if let Some(ref mut module) = self.module {
module.train();
}
}
fn eval(&mut self) {
self.training = false;
if let Some(ref mut module) = self.module {
module.eval();
}
}
fn set_training(&mut self, training: bool) {
self.training = training;
if let Some(ref mut module) = self.module {
module.set_training(training);
}
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.device = device;
if let Some(ref mut module) = self.module {
module.to_device(device)
} else {
Ok(())
}
}
fn name(&self) -> Option<&str> {
if let Some(ref module) = self.module {
module.name()
} else {
Some("LazyWrapper")
}
}
fn children(&self) -> Vec<&dyn Module> {
if let Some(ref module) = self.module {
vec![module as &dyn Module]
} else {
Vec::new()
}
}
fn zero_grad(&mut self) {
if let Some(ref mut module) = self.module {
module.zero_grad();
}
}
fn extra_repr(&self) -> String {
let state_str = match &*self.state.lock() {
LazyState::Uninitialized => "uninitialized".to_string(),
LazyState::Initializing => "initializing...".to_string(),
LazyState::Initialized { input_shape } => {
format!("initialized(input_shape={:?})", input_shape.dims())
}
};
if let Some(ref module) = self.module {
format!("LazyWrapper({}): {}", state_str, module.extra_repr())
} else {
format!("LazyWrapper({})", state_str)
}
}
}
#[derive(Debug)]
pub struct LazyLinear {
out_features: usize,
bias: bool,
weight: Option<Parameter>,
bias_param: Option<Parameter>,
state: Arc<Mutex<LazyState>>,
training: bool,
device: DeviceType,
}
impl LazyLinear {
pub fn new(out_features: usize, bias: bool) -> Self {
Self {
out_features,
bias,
weight: None,
bias_param: None,
state: Arc::new(Mutex::new(LazyState::Uninitialized)),
training: true,
device: DeviceType::Cpu,
}
}
pub fn with_features(out_features: usize) -> Self {
Self::new(out_features, true)
}
}
impl LazyModule for LazyLinear {
fn is_initialized(&self) -> bool {
matches!(&*self.state.lock(), LazyState::Initialized { .. })
}
fn initialize(&mut self, input_shape: &Shape) -> Result<()> {
let mut state = self.state.lock();
if matches!(&*state, LazyState::Initialized { .. }) {
return Ok(());
}
if matches!(&*state, LazyState::Initializing) {
return Err(TorshError::RuntimeError(
"Circular dependency in LazyLinear initialization".to_string(),
));
}
*state = LazyState::Initializing;
drop(state);
let dims = input_shape.dims();
if dims.is_empty() {
return Err(TorshError::InvalidShape(
"Input tensor must have at least 1 dimension for LazyLinear".to_string(),
));
}
let in_features = dims[dims.len() - 1];
let bound = (6.0 / (in_features + self.out_features) as f32).sqrt();
let mut rng = Random::seed(0);
let weight_data: Vec<f32> = (0..self.out_features * in_features)
.map(|_| rng.random::<f32>() * 2.0 * bound - bound)
.collect();
let weight_tensor = Tensor::from_data(
weight_data,
vec![self.out_features, in_features],
self.device,
)
.expect("weight tensor creation should succeed");
self.weight = Some(Parameter::new(weight_tensor));
if self.bias {
let bias_data = vec![0.0f32; self.out_features];
let bias_tensor = Tensor::from_data(bias_data, vec![self.out_features], self.device)
.expect("tensor creation should succeed");
self.bias_param = Some(Parameter::new(bias_tensor));
}
let mut state = self.state.lock();
*state = LazyState::Initialized {
input_shape: input_shape.clone(),
};
Ok(())
}
fn input_shape(&self) -> Option<Shape> {
if let LazyState::Initialized { input_shape } = &*self.state.lock() {
Some(input_shape.clone())
} else {
None
}
}
fn output_shape(&self, input_shape: &Shape) -> Option<Shape> {
let dims = input_shape.dims();
if dims.is_empty() {
return None;
}
let mut output_dims = dims.to_vec();
let last_idx = output_dims.len() - 1;
output_dims[last_idx] = self.out_features;
Some(Shape::new(output_dims))
}
}
impl Module for LazyLinear {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.is_initialized() {
return Err(TorshError::RuntimeError(
"LazyLinear not initialized. Call initialize() or use LazyWrapper.".to_string(),
));
}
let weight = self
.weight
.as_ref()
.expect("weight should be initialized before forward pass");
let weight_tensor = weight.tensor().read().clone();
let mut output = input.matmul(&weight_tensor.transpose(0, 1)?)?;
if let Some(ref bias_param) = self.bias_param {
let bias_tensor = bias_param.tensor().read().clone();
output = output.add_op(&bias_tensor)?;
}
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(ref weight) = self.weight {
params.insert("weight".to_string(), weight.clone());
}
if let Some(ref bias) = self.bias_param {
params.insert("bias".to_string(), bias.clone());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters()
}
fn training(&self) -> bool {
self.training
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn set_training(&mut self, training: bool) {
self.training = training;
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.device = device;
if let Some(ref mut weight) = self.weight {
let weight_tensor = weight.tensor().read().clone().to(device)?;
*weight.tensor().write() = weight_tensor;
}
if let Some(ref mut bias) = self.bias_param {
let bias_tensor = bias.tensor().read().clone().to(device)?;
*bias.tensor().write() = bias_tensor;
}
Ok(())
}
fn name(&self) -> Option<&str> {
Some("LazyLinear")
}
fn extra_repr(&self) -> String {
let initialized = if self.is_initialized() {
if let Some(input_shape) = self.input_shape() {
let dims = input_shape.dims();
let in_features = dims[dims.len() - 1];
format!("in_features={}, ", in_features)
} else {
"".to_string()
}
} else {
"uninitialized, ".to_string()
};
format!(
"{}out_features={}, bias={}",
initialized, self.out_features, self.bias
)
}
}
pub fn lazy_linear(out_features: usize) -> LazyLinear {
LazyLinear::new(out_features, true)
}
pub fn lazy_linear_no_bias(out_features: usize) -> LazyLinear {
LazyLinear::new(out_features, false)
}
#[macro_export]
macro_rules! lazy_module {
($factory:expr) => {
$crate::lazy::LazyWrapper::new($factory)
};
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::Linear;
use torsh_tensor::Tensor;
#[test]
fn test_lazy_linear_initialization() {
let mut lazy_linear = LazyLinear::new(10, true);
assert!(!lazy_linear.is_initialized());
assert!(lazy_linear.input_shape().is_none());
let input_shape = Shape::new(vec![32, 20]); lazy_linear.initialize(&input_shape).unwrap();
assert!(lazy_linear.is_initialized());
assert_eq!(lazy_linear.input_shape().unwrap().dims(), &[32, 20]);
let params = lazy_linear.parameters();
assert!(params.contains_key("weight"));
assert!(params.contains_key("bias"));
let weight = params.get("weight").unwrap();
let weight_binding = weight.tensor();
let weight_tensor = weight_binding.read();
assert_eq!(weight_tensor.shape().dims(), &[10, 20]);
let bias = params.get("bias").unwrap();
let bias_binding = bias.tensor();
let bias_tensor = bias_binding.read();
assert_eq!(bias_tensor.shape().dims(), &[10]);
}
#[test]
fn test_lazy_linear_forward() {
let mut lazy_linear = LazyLinear::new(5, true);
let input = Tensor::ones(&[2, 10], DeviceType::Cpu).unwrap();
lazy_linear.initialize(&input.shape()).unwrap();
let output = lazy_linear.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[2, 5]);
}
#[test]
fn test_lazy_wrapper_with_factory() {
let lazy_wrapper = LazyWrapper::new(|input_shape: &Shape| {
let dims = input_shape.dims();
let in_features = dims[dims.len() - 1];
Ok(Linear::new(in_features, 8, true))
});
assert!(!lazy_wrapper.is_initialized());
}
#[test]
fn test_output_shape_prediction() {
let lazy_linear = LazyLinear::new(7, false);
let input_shape = Shape::new(vec![16, 32, 15]);
let output_shape = lazy_linear.output_shape(&input_shape).unwrap();
assert_eq!(output_shape.dims(), &[16, 32, 7]);
}
#[test]
fn test_extra_repr() {
let mut lazy_linear = LazyLinear::new(12, true);
let repr_before = lazy_linear.extra_repr();
assert!(repr_before.contains("uninitialized"));
assert!(repr_before.contains("out_features=12"));
assert!(repr_before.contains("bias=true"));
let input_shape = Shape::new(vec![8, 24]);
lazy_linear.initialize(&input_shape).unwrap();
let repr_after = lazy_linear.extra_repr();
assert!(repr_after.contains("in_features=24"));
assert!(repr_after.contains("out_features=12"));
assert!(!repr_after.contains("uninitialized"));
}
}