1use crate::graph::Context;
79use crate::{uuid::Uuid, Float, FxHashMap, Graph, NdArray, NdArrayView, NdArrayViewMut, Tensor};
80use serde::{Deserialize, Serialize};
81use serde_json;
82use smallvec::alloc::fmt::{Display, Formatter};
83use std::cell::RefCell;
84use std::collections::HashMap;
85use std::sync::{Arc, RwLock};
86
87use std::error::Error;
88use std::fs::File;
89use std::ops::Deref;
90use std::path::Path;
91
92#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug, Serialize, Deserialize)]
93pub struct VariableID(pub(crate) usize);
97
98impl From<usize> for VariableID {
99 fn from(a: usize) -> VariableID {
100 VariableID(a)
101 }
102}
103
104impl From<VariableID> for usize {
105 fn from(a: VariableID) -> usize {
106 a.0
107 }
108}
109
110impl std::fmt::Display for VariableID {
111 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
112 write!(f, "{}", self.0)
113 }
114}
115
116const DEFAULT_NAMESPACE_ID: &str = "";
117
118pub type Variable<F> = RefCell<NdArray<F>>;
119
120pub trait GetVariableTensor<'g, F: Float, Arg> {
122 fn variable(&'g self, id: Arg) -> Tensor<'g, F>;
123}
124
125impl<'g, 'e: 'g, F: Float> GetVariableTensor<'g, F, &'static str> for Context<'e, F> {
126 fn variable(&'g self, name: &str) -> Tensor<'g, F> {
128 self.graph
129 .variable_by_name(name, &self.var_env_ref.default_namespace())
130 }
131}
132
133impl<'g, 'e: 'g, F: Float> GetVariableTensor<'g, F, VariableID> for Context<'e, F> {
134 fn variable(&'g self, id: VariableID) -> Tensor<'g, F> {
136 self.graph.variable_by_id(id)
137 }
138}
139
140impl<'g, 'e: 'g, F: Float> GetVariableTensor<'g, F, (&'static str, &'static str)>
141 for Context<'e, F>
142{
143 fn variable(&'g self, id: (&'static str, &'static str)) -> Tensor<'g, F> {
145 self.graph
146 .variable_by_name(id.1, &self.var_env_ref.namespace(id.0))
147 }
148}
149
150#[derive(Clone)]
154pub struct VariableEnvironment<F> {
155 pub(crate) array_list: Vec<Variable<F>>,
156 pub(crate) name_to_id: FxHashMap<FullName, VariableID>,
157}
158
159#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
161pub(crate) struct FullName {
162 pub(crate) namespace_id: String,
163 pub(crate) variable_name: String,
164}
165
166pub struct VariableSlot<'ns, 'env, F: Float> {
172 namespace: &'ns mut VariableNamespaceMut<'env, F>,
173}
174
175pub struct NamedVariableSlot<'ns, 'env, F: Float, S: Into<String>> {
182 namespace: &'ns mut VariableNamespaceMut<'env, F>,
183 name: S,
184}
185
186pub struct DefaultVariableSlot<'env, F: Float> {
190 env: &'env mut VariableEnvironment<F>,
191}
192
193pub struct NamedDefaultVariableSlot<'env, F: Float, S: Into<String>> {
198 env: &'env mut VariableEnvironment<F>,
199 name: S,
200}
201
202pub struct VariableNamespace<'env, F: Float> {
207 pub(crate) env: &'env VariableEnvironment<F>,
208 pub(crate) namespace_id: &'static str,
209}
210
211pub struct VariableNamespaceMut<'env, F: Float> {
215 pub(crate) env: &'env mut VariableEnvironment<F>,
216 pub(crate) namespace_id: &'static str,
217}
218
219impl FullName {
220 fn new(_namespace_id: &'static str, variablename: String) -> Self {
221 FullName {
222 namespace_id: _namespace_id.to_string(),
223 variable_name: variablename,
224 }
225 }
226}
227
228impl Display for FullName {
229 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230 let ns = self.namespace_id.deref();
231 let name = self.variable_name.deref();
232 write!(f, "{ns}\u{00001}{name}")
233 }
234}
235
236pub trait NamespaceTrait<F: Float> {
237 fn name(&self) -> &'static str;
239
240 fn env(&self) -> &VariableEnvironment<F>;
242
243 #[inline]
245 fn get_array_by_id(&self, vid: VariableID) -> &RefCell<NdArray<F>> {
246 &self.env().array_list[vid.0]
247 }
248
249 #[inline]
253 fn get_array_by_name<S: AsRef<str>>(&self, name: S) -> Option<&RefCell<NdArray<F>>> {
254 let name = &FullName::new(self.name(), name.as_ref().to_string());
255 self.env()
256 .name_to_id
257 .get(name)
258 .map(|vid| &self.env().array_list[vid.0])
259 }
260
261 fn current_var_ids(&self) -> Vec<VariableID> {
263 self.env()
264 .name_to_id
265 .iter()
266 .filter_map(|(v_name, &vid)| {
267 if v_name.namespace_id == self.name() {
268 Some(vid)
269 } else {
270 None
271 }
272 })
273 .collect()
274 }
275
276 fn current_var_names(&self) -> Vec<&str> {
278 self.env()
279 .name_to_id
280 .iter()
281 .filter_map(|(v_name, _v_id)| {
282 if v_name.namespace_id == self.name() {
283 Some(v_name.variable_name.deref())
284 } else {
285 None
286 }
287 })
288 .collect()
289 }
290}
291
292#[allow(clippy::needless_lifetimes)]
293impl<'ns, 'env, F: Float, S: Into<String>> NamedVariableSlot<'ns, 'env, F, S> {
294 pub fn set<D: scirs2_core::ndarray::Dimension>(
296 self,
297 v: scirs2_core::ndarray::Array<F, D>,
298 ) -> VariableID {
299 register_variable(
300 v,
301 self.namespace.namespace_id,
302 self.name.into(),
303 self.namespace.env,
304 )
305 }
306}
307
308impl<'env, F: Float> DefaultVariableSlot<'env, F> {
309 pub fn set<D: scirs2_core::ndarray::Dimension>(
311 self,
312 v: scirs2_core::ndarray::Array<F, D>,
313 ) -> VariableID {
314 register_variable(
315 v,
316 DEFAULT_NAMESPACE_ID,
317 Uuid::new_v4().to_string(),
318 self.env,
319 )
320 }
321
322 pub fn name<S: Into<String>>(self, name: S) -> NamedDefaultVariableSlot<'env, F, S> {
324 NamedDefaultVariableSlot {
325 env: self.env,
326 name,
327 }
328 }
329}
330
331#[allow(clippy::needless_lifetimes)]
332impl<'env, F: Float, S: Into<String>> NamedDefaultVariableSlot<'env, F, S> {
333 pub fn set<D: scirs2_core::ndarray::Dimension>(
335 self,
336 v: scirs2_core::ndarray::Array<F, D>,
337 ) -> VariableID {
338 register_variable(v, DEFAULT_NAMESPACE_ID, self.name.into(), self.env)
339 }
340}
341
342impl<'ns, 'env, F: Float> VariableSlot<'ns, 'env, F> {
343 pub fn set<D: scirs2_core::ndarray::Dimension>(
345 self,
346 v: scirs2_core::ndarray::Array<F, D>,
347 ) -> VariableID {
348 register_variable(
349 v,
350 self.namespace.namespace_id,
351 Uuid::new_v4().to_string(),
352 self.namespace.env,
353 )
354 }
355
356 pub fn name<S: Into<String>>(self, name: S) -> NamedVariableSlot<'ns, 'env, F, S> {
358 NamedVariableSlot {
359 namespace: self.namespace,
360 name,
361 }
362 }
363}
364
365#[allow(dead_code)]
366fn register_variable<F: Float, D: scirs2_core::ndarray::Dimension, S: Into<String>>(
367 v: scirs2_core::ndarray::Array<F, D>,
368 namespace_id: &'static str,
369 variable_name: S,
370 env: &mut VariableEnvironment<F>,
371) -> VariableID {
372 let vid = FullName::new(namespace_id, variable_name.into());
373 let next_id = env.array_list.len().into();
374 env.name_to_id.insert(vid, next_id);
375 env.array_list.push(RefCell::new(v.into_dyn()));
376 next_id
377}
378
379#[allow(clippy::needless_lifetimes)]
380impl<'env, F: Float> NamespaceTrait<F> for VariableNamespace<'env, F> {
381 #[inline]
382 fn name(&self) -> &'static str {
383 self.namespace_id
384 }
385 #[inline]
386 fn env(&self) -> &VariableEnvironment<F> {
387 self.env
388 }
389}
390
391impl<F: Float> NamespaceTrait<F> for VariableNamespaceMut<'_, F> {
392 #[inline]
393 fn name(&self) -> &'static str {
394 self.namespace_id
395 }
396 #[inline]
397 fn env(&self) -> &VariableEnvironment<F> {
398 self.env
399 }
400}
401
402impl<F: Float> VariableNamespace<'_, F> {
403 #[allow(unused)]
405 pub fn iter(&self) -> impl Iterator<Item = (&str, &RefCell<NdArray<F>>)> {
406 iter(self)
407 }
408}
409
410impl<F: Float> VariableNamespaceMut<'_, F> {
411 #[allow(unused)]
413 pub fn iter(&self) -> impl Iterator<Item = (&str, &RefCell<NdArray<F>>)> {
414 iter(self)
415 }
416}
417
418#[allow(dead_code)]
419fn iter<F: Float>(
420 ns: &impl NamespaceTrait<F>,
421) -> impl Iterator<Item = (&str, &RefCell<NdArray<F>>)> {
422 ns.env().name_to_id.iter().filter_map(move |ent| {
423 if ent.0.namespace_id == ns.name() {
425 Some((
426 ent.0.variable_name.deref(),
427 ns.get_array_by_name(ent.0.variable_name.deref())
428 .expect("Operation failed"),
429 ))
430 } else {
431 None
432 }
433 })
434}
435impl<'ns, 'env, F: Float> VariableNamespaceMut<'env, F> {
436 pub fn slot(&'ns mut self) -> VariableSlot<'ns, 'env, F> {
438 VariableSlot { namespace: self }
439 }
440}
441
442#[test]
443#[allow(dead_code)]
444fn test_env_iter() {
445 use crate::ndarray_ext;
446
447 let mut env = VariableEnvironment::<f32>::new();
448 let v1 = env.slot().set(ndarray_ext::zeros(&[3, 2]));
449 let v2 = env.slot().set(ndarray_ext::zeros(&[2, 3]));
450 for (i, (vid, arr)) in env.iter().enumerate() {
451 if i == 0 {
452 assert_eq!(vid, v1);
453 assert_eq!(arr.borrow().shape(), &[3, 2]);
454 }
455 if i == 1 {
456 assert_eq!(vid, v2);
457 assert_eq!(arr.borrow().shape(), &[2, 3]);
458 }
459 }
460}
461
462#[test]
463#[allow(dead_code)]
464fn test_namespace_iter() {
465 use crate::ndarray_ext;
466
467 let mut env = VariableEnvironment::<f32>::new();
468 env.slot().name("v1").set(ndarray_ext::zeros(&[3, 2]));
469 env.slot().name("v2").set(ndarray_ext::zeros(&[2, 3]));
470
471 let mut found_v1 = false;
472 let mut found_v2 = false;
473 for (name, arr) in env.default_namespace().iter() {
474 match name {
475 "v1" => {
476 assert_eq!(arr.borrow().shape(), &[3, 2]);
477 found_v1 = true;
478 }
479 "v2" => {
480 assert_eq!(arr.borrow().shape(), &[2, 3]);
481 found_v2 = true;
482 }
483 _ => panic!("Unexpected variable name: {}", name),
484 }
485 }
486 assert!(found_v1, "Variable v1 not found");
487 assert!(found_v2, "Variable v2 not found");
488
489 let mut found_v1_mut = false;
490 let mut found_v2_mut = false;
491 for (name, arr) in env.default_namespace_mut().iter() {
492 match name {
493 "v1" => {
494 assert_eq!(arr.borrow().shape(), &[3, 2]);
495 found_v1_mut = true;
496 }
497 "v2" => {
498 assert_eq!(arr.borrow().shape(), &[2, 3]);
499 found_v2_mut = true;
500 }
501 _ => panic!("Unexpected variable name: {}", name),
502 }
503 }
504 assert!(found_v1_mut, "Variable v1 not found in mutable iterator");
505 assert!(found_v2_mut, "Variable v2 not found in mutable iterator");
506}
507
508#[derive(Serialize)]
509struct SerializableVariableEnvironment<'a, F> {
510 array_list: &'a Vec<Variable<F>>,
511 name_to_id: FxHashMap<String, VariableID>,
512}
513
514#[derive(Deserialize)]
515struct DeserializedVariableEnvironment<F> {
516 array_list: Vec<Variable<F>>,
517 name_to_id: FxHashMap<String, VariableID>,
518}
519
520impl VariableEnvironment<f32> {
522 pub fn load<P: AsRef<Path>>(path: P) -> Result<VariableEnvironment<f32>, Box<dyn Error>> {
526 let raw: DeserializedVariableEnvironment<f32> = Self::deserialize(path)?;
527 Self::load_internal(raw)
528 }
529
530 pub fn initialize<P: AsRef<Path>>(&mut self, path: P) -> Result<(), Box<dyn Error>> {
532 let raw: DeserializedVariableEnvironment<f32> = Self::deserialize(path)?;
533 let VariableEnvironment {
534 array_list,
535 name_to_id,
536 } = Self::load_internal(raw)?;
537 self.array_list = array_list;
538 self.name_to_id = name_to_id;
539 Ok(())
540 }
541}
542
543impl VariableEnvironment<f64> {
545 pub fn load<P: AsRef<Path>>(path: P) -> Result<VariableEnvironment<f64>, Box<dyn Error>> {
549 let raw: DeserializedVariableEnvironment<f64> = Self::deserialize(path)?;
550 Self::load_internal(raw)
551 }
552
553 pub fn initialize<P: AsRef<Path>>(&mut self, path: P) -> Result<(), Box<dyn Error>> {
555 let raw: DeserializedVariableEnvironment<f64> = Self::deserialize(path)?;
556 let VariableEnvironment {
557 array_list,
558 name_to_id,
559 } = Self::load_internal(raw)?;
560 self.array_list = array_list;
561 self.name_to_id = name_to_id;
562 Ok(())
563 }
564}
565
566impl<F: Float> VariableEnvironment<F> {
567 pub fn new() -> VariableEnvironment<F> {
569 Self {
570 name_to_id: FxHashMap::default(),
571 array_list: Vec::new(),
572 }
573 }
574}
575
576impl<F: Float> Default for VariableEnvironment<F> {
577 fn default() -> Self {
578 Self::new()
579 }
580}
581
582impl<'env, F: Float> VariableEnvironment<F> {
583 #[allow(unused)]
585 pub fn iter(&self) -> impl Iterator<Item = (VariableID, &RefCell<NdArray<F>>)> {
586 self.array_list
587 .iter()
588 .enumerate()
589 .map(|(i, v)| (VariableID::from(i), v))
590 }
591
592 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Box<dyn Error>> {
596 let f = File::create(path.as_ref())?;
597 serde_json::to_writer(f, &self.prepare_for_serde())?;
598 Ok(())
599 }
600
601 fn deserialize<T, P: AsRef<Path>>(path: P) -> Result<T, Box<dyn Error>>
602 where
603 T: for<'de> Deserialize<'de>,
604 {
605 let f = File::open(path.as_ref())?;
606 let ret = serde_json::from_reader(f)?;
607 Ok(ret)
608 }
609
610 fn load_internal<T>(
611 env: DeserializedVariableEnvironment<T>,
612 ) -> Result<VariableEnvironment<T>, Box<dyn Error>> {
613 let name_to_id: FxHashMap<FullName, VariableID> = env
614 .name_to_id
615 .iter()
616 .map(|(fullname, &vid)| {
617 let mut split = fullname.split("\u{0001}");
618 let namespace_id = split.next().expect("Operation failed").to_owned();
619 let var_name = split.next().expect("Operation failed").to_owned();
620 let fullname = FullName {
621 namespace_id,
622 variable_name: var_name,
623 };
624 (fullname, vid)
625 })
626 .collect();
627
628 Ok(VariableEnvironment {
629 array_list: env.array_list,
630 name_to_id,
631 })
632 }
633
634 fn prepare_for_serde(&self) -> SerializableVariableEnvironment<F> {
635 let name_to_id: FxHashMap<String, VariableID> = self
636 .name_to_id
637 .iter()
638 .map(|(fullname, vid)| (fullname.to_string(), *vid))
639 .collect();
640 SerializableVariableEnvironment {
641 array_list: &self.array_list,
642 name_to_id,
643 }
644 }
645
646 pub fn slot(&'env mut self) -> DefaultVariableSlot<'env, F> {
648 DefaultVariableSlot { env: self }
649 }
650
651 pub fn set<D: scirs2_core::ndarray::Dimension>(
653 &'env mut self,
654 v: scirs2_core::ndarray::Array<F, D>,
655 ) -> VariableID {
656 register_variable(v, DEFAULT_NAMESPACE_ID, Uuid::new_v4().to_string(), self)
657 }
658
659 pub fn name<S: Into<String>>(&'env mut self, name: S) -> NamedDefaultVariableSlot<'env, F, S> {
661 NamedDefaultVariableSlot { env: self, name }
662 }
663
664 #[inline]
669 pub fn namespace(&'env self, namespaceid: &'static str) -> VariableNamespace<'env, F> {
670 VariableNamespace {
671 namespace_id: namespaceid,
672 env: self,
673 }
674 }
675
676 #[inline]
681 pub fn namespace_mut(
682 &'env mut self,
683 namespace_id: &'static str,
684 ) -> VariableNamespaceMut<'env, F> {
685 VariableNamespaceMut {
686 namespace_id,
687 env: self,
688 }
689 }
690
691 #[inline]
696 pub fn default_namespace(&'env self) -> VariableNamespace<'env, F> {
697 self.namespace(DEFAULT_NAMESPACE_ID)
698 }
699
700 #[inline]
704 pub fn default_namespace_mut(&'env mut self) -> VariableNamespaceMut<'env, F> {
705 self.namespace_mut(DEFAULT_NAMESPACE_ID)
706 }
707
708 #[inline]
712 pub fn get_array_by_id(&self, vid: VariableID) -> Option<&RefCell<NdArray<F>>> {
713 self.array_list.get(vid.0)
714 }
715
716 pub fn run<FN, R>(&'env self, f: FN) -> R
720 where
721 FN: FnOnce(&mut Context<'env, F>) -> R,
722 {
723 let g = Graph {
724 node_set: RefCell::new(Vec::with_capacity(256)),
725 variable2node: RefCell::new(HashMap::new()),
726 };
727 let mut c = Context {
728 var_env_ref: self,
729 graph: g,
730 };
731 f(&mut c)
732 }
733
734 #[allow(dead_code)]
735 pub(crate) fn as_view(&self, vid: VariableID) -> NdArrayView<F> {
736 unsafe {
737 self.array_list[vid.0]
738 .borrow()
739 .raw_view()
740 .clone()
741 .deref_into_view()
742 }
743 }
744
745 #[allow(dead_code)]
746 pub(crate) fn as_view_mut(&self, vid: VariableID) -> NdArrayViewMut<F> {
747 unsafe {
748 self.array_list[vid.0]
749 .borrow_mut()
750 .raw_view_mut()
751 .clone()
752 .deref_into_view_mut()
753 }
754 }
755}
756
757impl<'g, F: Float> Graph<F> {
758 pub fn variable_by_name<S: AsRef<str>>(
760 &self,
761 name: S,
762 namespace: &impl NamespaceTrait<F>,
763 ) -> Tensor<F> {
764 let full_name = &FullName::new(namespace.name(), name.as_ref().to_string());
765 if let Some(&vid) = namespace.env().name_to_id.get(full_name) {
766 self.variable_by_id(vid)
768 } else {
769 let ns = namespace.name();
770 if ns.is_empty() {
771 panic!(
772 "variable array not found in default namespace: {}",
773 name.as_ref()
774 )
775 } else {
776 panic!(
777 "variable array `{}` not found in namespace {}",
778 name.as_ref(),
779 ns
780 )
781 }
782 }
783 }
784
785 pub fn var_tensors_by_id<'e: 'g>(
789 &'g self,
790 env: &'e VariableEnvironment<F>,
791 ) -> impl Iterator<Item = (VariableID, Tensor<'g, F>)> {
792 (0..env.array_list.len()).map(move |vid| (vid.into(), self.variable_by_id(vid.into())))
793 }
794
795 pub fn var_tensors_by_name<'ns, 'e: 'g>(
799 &'g self,
800 ns: &'ns VariableNamespace<'e, F>,
801 ) -> impl Iterator<Item = (&'ns str, Tensor<'g, F>)> {
802 ns.env().name_to_id.iter().filter_map(move |ent| {
803 if ent.0.namespace_id == ns.name() {
805 Some((ent.0.variable_name.deref(), self.variable_by_id(*ent.1)))
806 } else {
807 None
808 }
809 })
810 }
811}
812
813#[allow(unused)]
814#[allow(dead_code)]
815fn compile_common_usages() {
816 use crate::prelude::*;
817 use crate::tensor_ops as T;
818
819 let mut env = VariableEnvironment::<f32>::new();
820 env.run(|g| {
823 let ns = g.env().default_namespace();
824
825 let _v3_ = g.variable_by_name("a", &ns);
826 let v = g.variable("a");
827 let v2 = g.variable(VariableID(0));
828 let v3 = g.variable(("my_ns", "a"));
829 let ones = T::zeros(&[1], g) + v + v2 + v3;
830 let _ = ones.eval(g);
831 });
832
833 env.run(|g| {
834 let ns = g.env().default_namespace();
835 let v = g.variable("a");
836 let _ = v.eval(g);
837 })
838}
839
840#[test]
841#[allow(dead_code)]
842fn save_and_load() {
843 use crate::ndarray_ext;
844 use std::collections::HashMap;
845 use std::fs;
846
847 let dir = "/tmp/rust-autograd/test/save_and_load";
848 fs::create_dir_all(dir).expect("Operation failed");
849 let path = format!("{}/model.json", dir);
850 let mut rng = ndarray_ext::ArrayRng::<f64>::default();
851
852 let mut env = VariableEnvironment::new();
853 env.slot().name("a").set(rng.standard_normal(&[2, 3]));
854 env.slot().name("b").set(rng.standard_normal(&[2, 3]));
855
856 env.save(&path).expect("Operation failed");
858
859 {
861 let loaded_env = VariableEnvironment::<f64>::load(&path).expect("Operation failed");
862
863 assert_eq!(env.name_to_id, loaded_env.name_to_id);
865
866 for (vid, array) in env.iter() {
868 let loaded_env_map: HashMap<_, _> = loaded_env.iter().collect();
869 let loaded_array = loaded_env_map.get(&vid).expect("Operation failed");
870
871 let arr1 = array.borrow();
873 let arr2 = loaded_array.borrow();
874
875 assert_eq!(arr1.shape(), arr2.shape());
877
878 let epsilon = 1e-6;
880 for (a, b) in arr1.iter().zip(arr2.iter()) {
881 assert!(
882 (a - b).abs() < epsilon,
883 "Arrays differ: {} vs {} exceeds epsilon {}",
884 a,
885 b,
886 epsilon
887 );
888 }
889 }
890 }
891}
892
893#[test]
894#[allow(dead_code)]
895fn save_and_init() {
896 use crate::ndarray_ext;
898 use std::fs;
899
900 let dir = "/tmp/rust-autograd/test/save_and_init";
901 fs::create_dir_all(dir).expect("Operation failed");
902 let path = format!("{}/model.json", dir);
903 let mut rng = ndarray_ext::ArrayRng::<f64>::default();
904
905 let mut env = VariableEnvironment::new();
906 let a = env.name("a").set(rng.standard_normal(&[2, 3]));
907 let b = env.name("b").set(rng.standard_normal(&[2, 3]));
908
909 for _ in 0..10 {
910 env.run(|g| {
911 let _a_ = g.variable(a);
912 let _b_ = g.variable(b);
913 g.env().save(&path).expect("Operation failed");
914 });
915 }
916
917 env.initialize(&path).expect("Operation failed");
918}
919
920#[derive(Clone)]
947pub struct SafeVariableEnvironment<F: Float + Send + Sync> {
948 inner: Arc<RwLock<VariableEnvironment<F>>>,
950 #[cfg(feature = "simd")]
952 platform_caps: Arc<scirs2_core::simd_ops::PlatformCapabilities>,
953}
954
955impl<F: Float + Send + Sync> SafeVariableEnvironment<F> {
956 pub fn new() -> Self {
958 Self {
959 inner: Arc::new(RwLock::new(VariableEnvironment::new())),
960 #[cfg(feature = "simd")]
961 platform_caps: Arc::new(scirs2_core::simd_ops::PlatformCapabilities::detect()),
962 }
963 }
964
965 pub fn set_variable(
967 &self,
968 array: NdArray<F>,
969 ) -> Result<VariableID, Box<dyn Error + Send + Sync>> {
970 let mut env = self
971 .inner
972 .write()
973 .map_err(|e| format!("Failed to acquire write lock: {}", e))?;
974
975 let var_id = env.set(array);
977 Ok(var_id)
978 }
979
980 pub fn name_variable<S: AsRef<str>>(
982 &self,
983 name: S,
984 array: NdArray<F>,
985 ) -> Result<VariableID, Box<dyn Error + Send + Sync>> {
986 let mut env = self
987 .inner
988 .write()
989 .map_err(|e| format!("Failed to acquire write lock: {}", e))?;
990
991 let var_id = env.name(name.as_ref()).set(array);
992 Ok(var_id)
993 }
994
995 pub fn get_variable(
997 &self,
998 var_id: VariableID,
999 ) -> Result<NdArray<F>, Box<dyn Error + Send + Sync>> {
1000 let env = self
1001 .inner
1002 .read()
1003 .map_err(|e| format!("Failed to acquire read lock: {}", e))?;
1004
1005 if let Some(var) = env.array_list.get(var_id.0) {
1006 Ok(var.borrow().clone())
1007 } else {
1008 Err(format!("Variable ID {:?} not found", var_id).into())
1009 }
1010 }
1011
1012 pub fn backward(&self, output_var: VariableID) -> Result<(), Box<dyn Error + Send + Sync>> {
1017 #[cfg(feature = "simd")]
1021 {
1022 self.simd_backward_pass(output_var)
1024 }
1025 #[cfg(not(feature = "simd"))]
1026 {
1027 self.scalar_backward_pass(output_var)
1028 }
1029 }
1030
1031 #[cfg(feature = "simd")]
1033 fn simd_backward_pass(
1034 &self,
1035 _output_var: VariableID,
1036 ) -> Result<(), Box<dyn Error + Send + Sync>> {
1037 Ok(())
1047 }
1048
1049 fn scalar_backward_pass(
1051 &self,
1052 _output_var: VariableID,
1053 ) -> Result<(), Box<dyn Error + Send + Sync>> {
1054 Ok(())
1056 }
1057
1058 pub fn parallel_backward_pass(
1063 &self,
1064 outputs: &[VariableID],
1065 _inputs: &[VariableID],
1066 ) -> Result<Vec<Option<NdArray<F>>>, Box<dyn Error + Send + Sync>> {
1067 #[cfg(feature = "simd")]
1068 {
1069 if self.platform_caps.num_cores() >= 4 && outputs.len() >= 4 {
1070 return self.parallel_simd_backward_pass(outputs);
1071 }
1072 }
1073
1074 let mut gradients = Vec::with_capacity(outputs.len());
1076 for &output_var in outputs {
1077 self.backward(output_var)?;
1078 gradients.push(None);
1080 }
1081 Ok(gradients)
1082 }
1083
1084 #[cfg(feature = "simd")]
1086 fn parallel_simd_backward_pass(
1087 &self,
1088 _outputs: &[VariableID],
1089 ) -> Result<Vec<Option<NdArray<F>>>, Box<dyn Error + Send + Sync>> {
1090 use scirs2_core::parallel_ops::*;
1091
1092 Ok(Vec::new()) }
1103
1104 pub fn run<R>(
1106 &self,
1107 func: impl FnOnce(&VariableEnvironment<F>) -> R,
1108 ) -> Result<R, Box<dyn Error + Send + Sync>> {
1109 let env = self
1110 .inner
1111 .read()
1112 .map_err(|e| format!("Failed to acquire read lock: {}", e))?;
1113 Ok(func(&*env))
1114 }
1115
1116 pub fn len(&self) -> Result<usize, Box<dyn Error + Send + Sync>> {
1118 let env = self
1119 .inner
1120 .read()
1121 .map_err(|e| format!("Failed to acquire read lock: {}", e))?;
1122 Ok(env.array_list.len())
1123 }
1124
1125 pub fn is_empty(&self) -> Result<bool, Box<dyn Error + Send + Sync>> {
1127 Ok(self.len()? == 0)
1128 }
1129}
1130
1131unsafe impl<F: Float + Send + Sync> Send for SafeVariableEnvironment<F> {}
1133unsafe impl<F: Float + Send + Sync> Sync for SafeVariableEnvironment<F> {}
1134
1135impl<F: Float + Send + Sync> Default for SafeVariableEnvironment<F> {
1136 fn default() -> Self {
1137 Self::new()
1138 }
1139}
1140
1141#[derive(Clone)]
1146pub struct SafeVariable<F: Float + Send + Sync> {
1147 pub id: VariableID,
1149 pub env: Arc<SafeVariableEnvironment<F>>,
1151 pub requires_grad: bool,
1153}
1154
1155impl<F: Float + Send + Sync> SafeVariable<F> {
1156 pub fn new(
1158 data: NdArray<F>,
1159 env: Arc<SafeVariableEnvironment<F>>,
1160 requires_grad: bool,
1161 ) -> Result<Self, Box<dyn Error + Send + Sync>> {
1162 let id = env.set_variable(data)?;
1163 Ok(Self {
1164 id,
1165 env,
1166 requires_grad,
1167 })
1168 }
1169
1170 pub fn backward(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
1172 if !self.requires_grad {
1173 return Ok(()); }
1175 self.env.backward(self.id)
1176 }
1177
1178 pub fn data(&self) -> Result<NdArray<F>, Box<dyn Error + Send + Sync>> {
1180 self.env.get_variable(self.id)
1181 }
1182
1183 pub fn requires_grad(&self) -> bool {
1185 self.requires_grad
1186 }
1187
1188 pub fn set_requires_grad(&mut self, requires_grad: bool) {
1190 self.requires_grad = requires_grad;
1191 }
1192}
1193
1194unsafe impl<F: Float + Send + Sync> Send for SafeVariable<F> {}
1196unsafe impl<F: Float + Send + Sync> Sync for SafeVariable<F> {}
1197
1198pub trait AutogradTensor<F: Float> {
1200 fn backward(&self) -> Result<(), Box<dyn Error + Send + Sync>>;
1201 fn grad(&self) -> Option<&NdArray<F>>;
1202 fn requires_grad(&self) -> bool;
1203 fn set_requires_grad(&mut self, requires_grad: bool);
1204}
1205
1206impl<F: Float + Send + Sync> AutogradTensor<F> for SafeVariable<F> {
1207 fn backward(&self) -> Result<(), Box<dyn Error + Send + Sync>> {
1208 SafeVariable::backward(self)
1209 }
1210
1211 fn grad(&self) -> Option<&NdArray<F>> {
1212 None
1215 }
1216
1217 fn requires_grad(&self) -> bool {
1218 SafeVariable::requires_grad(self)
1219 }
1220
1221 fn set_requires_grad(&mut self, requires_grad: bool) {
1222 SafeVariable::set_requires_grad(self, requires_grad)
1223 }
1224}