tenrso_exec/
lib.rs

1//! # tenrso-exec
2//!
3//! Unified execution API for TenRSo.
4//!
5//! **Version:** 0.1.0-alpha.2
6//! **Tests:** 244 passing (100%)
7//! **Status:** M4 Complete - Unified execution with optimization
8//!
9//! This crate provides:
10//! - `einsum_ex` - the main public API for tensor contractions
11//! - `TenrsoExecutor` trait for different backends
12//! - CPU executor with automatic representation selection
13//! - Integration with planner and all backend operations
14
15#![deny(warnings)]
16
17pub mod executor;
18pub mod hints;
19pub mod ops;
20
21// Re-exports
22pub use executor::*;
23pub use hints::*;
24
25use scirs2_core::numeric::{Float, FromPrimitive, Num};
26
27/// Execute an einsum contraction with hints
28///
29/// # Example
30/// ```ignore
31/// use tenrso_exec::{einsum_ex, ExecHints};
32///
33/// let y = einsum_ex::<f32>("bij,bjk->bik")
34///     .inputs(&[A, B])
35///     .hints(&ExecHints::default())
36///     .run()?;
37/// ```
38pub fn einsum_ex<T>(spec: &str) -> EinsumBuilder<'_, T>
39where
40    T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive + 'static,
41{
42    EinsumBuilder::new(spec)
43}
44
45/// Builder for einsum operations
46pub struct EinsumBuilder<'a, T>
47where
48    T: Clone + Num + Float + FromPrimitive + 'static,
49{
50    spec: String,
51    inputs: Option<&'a [tenrso_core::TensorHandle<T>]>,
52    hints: ExecHints,
53}
54
55impl<'a, T> EinsumBuilder<'a, T>
56where
57    T: Clone + Num + std::ops::AddAssign + std::default::Default + Float + FromPrimitive + 'static,
58{
59    /// Create a new einsum builder
60    pub fn new(spec: impl Into<String>) -> Self {
61        Self {
62            spec: spec.into(),
63            inputs: None,
64            hints: ExecHints::default(),
65        }
66    }
67
68    /// Set input tensors
69    ///
70    /// # Arguments
71    ///
72    /// * `inputs` - Slice of tensor handles to use as inputs
73    ///
74    /// # Example
75    ///
76    /// ```ignore
77    /// let result = einsum_ex::<f32>("ij,jk->ik")
78    ///     .inputs(&[tensor_a, tensor_b])
79    ///     .run()?;
80    /// ```
81    pub fn inputs(mut self, inputs: &'a [tenrso_core::TensorHandle<T>]) -> Self {
82        self.inputs = Some(inputs);
83        self
84    }
85
86    /// Set execution hints
87    ///
88    /// # Arguments
89    ///
90    /// * `hints` - Execution hints for optimization
91    ///
92    /// # Example
93    ///
94    /// ```ignore
95    /// let hints = ExecHints {
96    ///     prefer_sparse: true,
97    ///     ..Default::default()
98    /// };
99    /// let result = einsum_ex::<f32>("ij,jk->ik")
100    ///     .inputs(&[tensor_a, tensor_b])
101    ///     .hints(&hints)
102    ///     .run()?;
103    /// ```
104    pub fn hints(mut self, hints: &ExecHints) -> Self {
105        self.hints = hints.clone();
106        self
107    }
108
109    /// Execute the einsum operation
110    ///
111    /// # Returns
112    ///
113    /// Result containing the output tensor handle
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if:
118    /// - No inputs were provided
119    /// - Input count doesn't match the einsum spec
120    /// - Execution fails
121    pub fn run(self) -> anyhow::Result<tenrso_core::TensorHandle<T>> {
122        let inputs = self
123            .inputs
124            .ok_or_else(|| anyhow::anyhow!("No inputs provided to einsum_ex"))?;
125
126        // Use CpuExecutor for execution
127        let mut executor = CpuExecutor::new();
128        executor.einsum(&self.spec, inputs, &self.hints)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use tenrso_core::{DenseND, TensorHandle};
136
137    #[test]
138    fn test_einsum_ex_builder_matmul() {
139        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
140        let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
141
142        let handle_a = TensorHandle::from_dense_auto(a);
143        let handle_b = TensorHandle::from_dense_auto(b);
144
145        let result = einsum_ex::<f64>("ij,jk->ik")
146            .inputs(&[handle_a, handle_b])
147            .run()
148            .unwrap();
149
150        let result_dense = result.as_dense().unwrap();
151        assert_eq!(result_dense.shape(), &[2, 2]);
152
153        // Verify computation: [1*5+2*7, 1*6+2*8] = [19, 22]
154        //                      [3*5+4*7, 3*6+4*8] = [43, 50]
155        let result_view = result_dense.view();
156        let diff1: f64 = result_view[[0, 0]] - 19.0;
157        let diff2: f64 = result_view[[0, 1]] - 22.0;
158        assert!(diff1.abs() < 1e-10);
159        assert!(diff2.abs() < 1e-10);
160    }
161
162    #[test]
163    fn test_einsum_ex_builder_with_hints() {
164        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
165        let b = DenseND::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
166
167        let handle_a = TensorHandle::from_dense_auto(a);
168        let handle_b = TensorHandle::from_dense_auto(b);
169
170        let hints = ExecHints::default();
171
172        let result = einsum_ex::<f64>("ij,jk->ik")
173            .inputs(&[handle_a, handle_b])
174            .hints(&hints)
175            .run()
176            .unwrap();
177
178        let result_dense = result.as_dense().unwrap();
179        assert_eq!(result_dense.shape(), &[2, 2]);
180    }
181
182    #[test]
183    fn test_einsum_ex_builder_three_tensors() {
184        let a = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
185        let b = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
186        let c = DenseND::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
187
188        let handle_a = TensorHandle::from_dense_auto(a);
189        let handle_b = TensorHandle::from_dense_auto(b);
190        let handle_c = TensorHandle::from_dense_auto(c);
191
192        let result = einsum_ex::<f64>("ij,jk,kl->il")
193            .inputs(&[handle_a, handle_b, handle_c])
194            .run()
195            .unwrap();
196
197        let result_dense = result.as_dense().unwrap();
198        assert_eq!(result_dense.shape(), &[2, 2]);
199
200        // Verify non-zero result (detailed computation would be complex)
201        let result_view = result_dense.view();
202        let val: f64 = result_view[[0, 0]];
203        assert!(val.abs() > 0.0);
204    }
205
206    #[test]
207    fn test_einsum_ex_builder_no_inputs() {
208        let result = einsum_ex::<f64>("ij,jk->ik").run();
209
210        assert!(result.is_err());
211        assert!(result
212            .unwrap_err()
213            .to_string()
214            .contains("No inputs provided"));
215    }
216}