Skip to main content

kunquant_rs/
batch.rs

1use crate::buffer::BufferNameMap;
2use crate::error::{KunQuantError, Result};
3use crate::executor::Executor;
4use crate::ffi;
5use crate::library::Module;
6
7/// Parameters for batch computation of factor values over time series data.
8///
9/// `BatchParams` defines the dimensions and time window for batch factor computation.
10/// It specifies how many stocks to process, the total time series length, and which
11/// subset of time points to compute.
12///
13/// # Data Layout
14///
15/// KunQuant expects data in time-series format where:
16/// - Rows represent time points (e.g., trading days)
17/// - Columns represent stocks
18/// - Data is stored in row-major order: `[t0_s0, t0_s1, ..., t0_sN, t1_s0, ...]`
19///
20/// # SIMD Requirements
21///
22/// For optimal performance, `num_stocks` must be a multiple of 8 to enable
23/// SIMD (Single Instruction, Multiple Data) vectorization.
24#[derive(Debug, Clone)]
25pub struct BatchParams {
26    /// Number of stocks to process (must be multiple of 8 for SIMD optimization)
27    pub num_stocks: usize,
28    /// Total number of time points in the input data arrays
29    pub total_time: usize,
30    /// Starting time index for computation (0-based)
31    pub cur_time: usize,
32    /// Number of consecutive time points to compute
33    pub length: usize,
34}
35
36impl BatchParams {
37    /// Creates new batch computation parameters with validation.
38    ///
39    /// This constructor validates that the stock count meets SIMD requirements
40    /// and that the time window parameters are consistent.
41    ///
42    /// # Arguments
43    ///
44    /// * `num_stocks` - Number of stocks to process (must be multiple of 8)
45    /// * `total_time` - Total number of time points in input data
46    /// * `cur_time` - Starting time index for computation (0-based)
47    /// * `length` - Number of consecutive time points to compute
48    ///
49    /// # Returns
50    ///
51    /// Returns `Ok(BatchParams)` on success, or an error if:
52    /// - `num_stocks` is not a multiple of 8
53    /// - `cur_time + length > total_time` (time window exceeds data bounds)
54    /// - Any parameter is invalid
55    ///
56    /// # Examples
57    ///
58    /// ```rust,no_run
59    /// use kunquant_rs::BatchParams;
60    ///
61    /// # fn main() -> kunquant_rs::Result<()> {
62    /// // Process 16 stocks over 100 time points, computing all points
63    /// let params = BatchParams::new(16, 100, 0, 100)?;
64    ///
65    /// // Process 8 stocks, compute only the last 20 time points
66    /// let params = BatchParams::new(8, 252, 232, 20)?;
67    ///
68    /// // This would fail - stock count not multiple of 8
69    /// // let invalid = BatchParams::new(10, 100, 0, 100)?; // Error!
70    /// # Ok(())
71    /// # }
72    /// ```
73    ///
74    /// # Performance Notes
75    ///
76    /// - Larger `num_stocks` values (multiples of 8) enable better SIMD utilization
77    /// - Smaller `length` values reduce memory usage and computation time
78    /// - `cur_time` and `length` allow processing data in chunks for memory efficiency
79    pub fn new(
80        num_stocks: usize,
81        total_time: usize,
82        cur_time: usize,
83        length: usize,
84    ) -> Result<Self> {
85        if num_stocks % 8 != 0 {
86            return Err(KunQuantError::InvalidStockCount { num_stocks });
87        }
88
89        Ok(BatchParams {
90            num_stocks,
91            total_time,
92            cur_time,
93            length,
94        })
95    }
96
97    /// Creates parameters for computing the entire time range.
98    ///
99    /// This is a convenience method that creates batch parameters to process
100    /// all time points in the dataset, equivalent to calling
101    /// `BatchParams::new(num_stocks, total_time, 0, total_time)`.
102    ///
103    /// # Arguments
104    ///
105    /// * `num_stocks` - Number of stocks to process (must be multiple of 8)
106    /// * `total_time` - Total number of time points in the data
107    ///
108    /// # Returns
109    ///
110    /// Returns `Ok(BatchParams)` configured to process the entire time range,
111    /// or an error if `num_stocks` is not a multiple of 8.
112    ///
113    /// # Examples
114    ///
115    /// ```rust,no_run
116    /// use kunquant_rs::BatchParams;
117    ///
118    /// # fn main() -> kunquant_rs::Result<()> {
119    /// // Process all 252 trading days for 16 stocks
120    /// let params = BatchParams::full_range(16, 252)?;
121    ///
122    /// // Equivalent to:
123    /// // let params = BatchParams::new(16, 252, 0, 252)?;
124    /// # Ok(())
125    /// # }
126    /// ```
127    ///
128    /// # Use Cases
129    ///
130    /// - Historical backtesting over complete datasets
131    /// - Factor computation for entire time series
132    /// - Initial factor validation and testing
133    /// - Scenarios where memory constraints are not a concern
134    pub fn full_range(num_stocks: usize, total_time: usize) -> Result<Self> {
135        Self::new(num_stocks, total_time, 0, total_time)
136    }
137}
138
139/// Executes batch factor computation on historical time series data.
140///
141/// This function runs a complete factor computation over a specified time window
142/// using the provided executor, module, and data buffers. It's the primary interface
143/// for batch processing of historical market data.
144///
145/// # Arguments
146///
147/// * `executor` - The KunQuant executor to use for computation
148/// * `module` - The compiled factor module containing the computation graph
149/// * `buffers` - Buffer map containing input data and output storage
150/// * `params` - Batch parameters defining the computation window and dimensions
151///
152/// # Returns
153///
154/// Returns `Ok(())` on successful computation, or an error if:
155/// - Input buffers don't contain required data
156/// - Buffer dimensions don't match the parameters
157/// - The computation encounters runtime errors
158/// - Memory allocation fails during execution
159///
160/// # Examples
161///
162/// ```rust,no_run
163/// use kunquant_rs::{Executor, Library, BufferNameMap, BatchParams, run_graph};
164///
165/// # fn main() -> kunquant_rs::Result<()> {
166/// // Set up computation components
167/// let executor = Executor::single_thread()?;
168/// let library = Library::load("factors.so")?;
169/// let module = library.get_module("alpha001")?;
170///
171/// // Prepare data buffers
172/// let mut buffers = BufferNameMap::new()?;
173/// let mut input_data = vec![1.0f32; 16 * 100]; // 16 stocks, 100 time points
174/// let mut output_data = vec![0.0f32; 16 * 100];
175///
176/// buffers.set_buffer_slice("close", &mut input_data)?;
177/// buffers.set_buffer_slice("alpha001", &mut output_data)?;
178///
179/// // Execute computation
180/// let params = BatchParams::full_range(16, 100)?;
181/// run_graph(&executor, &module, &buffers, &params)?;
182///
183/// // Results are now available in output_data
184/// # Ok(())
185/// # }
186/// ```
187///
188/// # Data Requirements
189///
190/// - All input buffers must be populated with data before calling
191/// - Buffer sizes must match `num_stocks * total_time`
192/// - Data should be in row-major order (time-first layout)
193/// - Output buffers must be pre-allocated with sufficient space
194///
195/// # Performance Notes
196///
197/// - Computation is CPU-intensive and benefits from multi-threading
198/// - Memory usage scales with `num_stocks * total_time * sizeof(f32)`
199/// - SIMD optimizations require `num_stocks` to be a multiple of 8
200/// - Consider processing data in chunks for very large datasets
201pub fn run_graph(
202    executor: &Executor,
203    module: &Module,
204    buffers: &BufferNameMap,
205    params: &BatchParams,
206) -> Result<()> {
207    unsafe {
208        ffi::kunRunGraph(
209            executor.handle(),
210            module.handle(),
211            buffers.handle(),
212            params.num_stocks,
213            params.total_time,
214            params.cur_time,
215            params.length,
216        );
217    }
218    Ok(())
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_batch_params_validation() {
227        // Valid parameters
228        assert!(BatchParams::new(8, 100, 0, 100).is_ok());
229        assert!(BatchParams::new(16, 100, 10, 50).is_ok());
230
231        // Invalid stock count (not multiple of 8)
232        assert!(BatchParams::new(7, 100, 0, 100).is_err());
233        assert!(BatchParams::new(15, 100, 0, 100).is_err());
234    }
235
236    #[test]
237    fn test_full_range_params() {
238        let params = BatchParams::full_range(24, 500).unwrap();
239        assert_eq!(params.num_stocks, 24);
240        assert_eq!(params.total_time, 500);
241        assert_eq!(params.cur_time, 0);
242        assert_eq!(params.length, 500);
243    }
244}