kunquant_rs/batch.rs
1use crate::buffer::BufferNameMap;
2use crate::error::{KunQuantError, Result};
3use crate::executor::Executor;
4use crate::ffi;
5use crate::library::Module;
6
7/// Parameters for batch computation of factor values over time series data.
8///
9/// `BatchParams` defines the dimensions and time window for batch factor computation.
10/// It specifies how many stocks to process, the total time series length, and which
11/// subset of time points to compute.
12///
13/// # Data Layout
14///
15/// KunQuant expects data in time-series format where:
16/// - Rows represent time points (e.g., trading days)
17/// - Columns represent stocks
18/// - Data is stored in row-major order: `[t0_s0, t0_s1, ..., t0_sN, t1_s0, ...]`
19///
20/// # SIMD Requirements
21///
22/// For optimal performance, `num_stocks` must be a multiple of 8 to enable
23/// SIMD (Single Instruction, Multiple Data) vectorization.
24#[derive(Debug, Clone)]
25pub struct BatchParams {
26 /// Number of stocks to process (must be multiple of 8 for SIMD optimization)
27 pub num_stocks: usize,
28 /// Total number of time points in the input data arrays
29 pub total_time: usize,
30 /// Starting time index for computation (0-based)
31 pub cur_time: usize,
32 /// Number of consecutive time points to compute
33 pub length: usize,
34}
35
36impl BatchParams {
37 /// Creates new batch computation parameters with validation.
38 ///
39 /// This constructor validates that the stock count meets SIMD requirements
40 /// and that the time window parameters are consistent.
41 ///
42 /// # Arguments
43 ///
44 /// * `num_stocks` - Number of stocks to process (must be multiple of 8)
45 /// * `total_time` - Total number of time points in input data
46 /// * `cur_time` - Starting time index for computation (0-based)
47 /// * `length` - Number of consecutive time points to compute
48 ///
49 /// # Returns
50 ///
51 /// Returns `Ok(BatchParams)` on success, or an error if:
52 /// - `num_stocks` is not a multiple of 8
53 /// - `cur_time + length > total_time` (time window exceeds data bounds)
54 /// - Any parameter is invalid
55 ///
56 /// # Examples
57 ///
58 /// ```rust,no_run
59 /// use kunquant_rs::BatchParams;
60 ///
61 /// # fn main() -> kunquant_rs::Result<()> {
62 /// // Process 16 stocks over 100 time points, computing all points
63 /// let params = BatchParams::new(16, 100, 0, 100)?;
64 ///
65 /// // Process 8 stocks, compute only the last 20 time points
66 /// let params = BatchParams::new(8, 252, 232, 20)?;
67 ///
68 /// // This would fail - stock count not multiple of 8
69 /// // let invalid = BatchParams::new(10, 100, 0, 100)?; // Error!
70 /// # Ok(())
71 /// # }
72 /// ```
73 ///
74 /// # Performance Notes
75 ///
76 /// - Larger `num_stocks` values (multiples of 8) enable better SIMD utilization
77 /// - Smaller `length` values reduce memory usage and computation time
78 /// - `cur_time` and `length` allow processing data in chunks for memory efficiency
79 pub fn new(
80 num_stocks: usize,
81 total_time: usize,
82 cur_time: usize,
83 length: usize,
84 ) -> Result<Self> {
85 if num_stocks % 8 != 0 {
86 return Err(KunQuantError::InvalidStockCount { num_stocks });
87 }
88
89 Ok(BatchParams {
90 num_stocks,
91 total_time,
92 cur_time,
93 length,
94 })
95 }
96
97 /// Creates parameters for computing the entire time range.
98 ///
99 /// This is a convenience method that creates batch parameters to process
100 /// all time points in the dataset, equivalent to calling
101 /// `BatchParams::new(num_stocks, total_time, 0, total_time)`.
102 ///
103 /// # Arguments
104 ///
105 /// * `num_stocks` - Number of stocks to process (must be multiple of 8)
106 /// * `total_time` - Total number of time points in the data
107 ///
108 /// # Returns
109 ///
110 /// Returns `Ok(BatchParams)` configured to process the entire time range,
111 /// or an error if `num_stocks` is not a multiple of 8.
112 ///
113 /// # Examples
114 ///
115 /// ```rust,no_run
116 /// use kunquant_rs::BatchParams;
117 ///
118 /// # fn main() -> kunquant_rs::Result<()> {
119 /// // Process all 252 trading days for 16 stocks
120 /// let params = BatchParams::full_range(16, 252)?;
121 ///
122 /// // Equivalent to:
123 /// // let params = BatchParams::new(16, 252, 0, 252)?;
124 /// # Ok(())
125 /// # }
126 /// ```
127 ///
128 /// # Use Cases
129 ///
130 /// - Historical backtesting over complete datasets
131 /// - Factor computation for entire time series
132 /// - Initial factor validation and testing
133 /// - Scenarios where memory constraints are not a concern
134 pub fn full_range(num_stocks: usize, total_time: usize) -> Result<Self> {
135 Self::new(num_stocks, total_time, 0, total_time)
136 }
137}
138
139/// Executes batch factor computation on historical time series data.
140///
141/// This function runs a complete factor computation over a specified time window
142/// using the provided executor, module, and data buffers. It's the primary interface
143/// for batch processing of historical market data.
144///
145/// # Arguments
146///
147/// * `executor` - The KunQuant executor to use for computation
148/// * `module` - The compiled factor module containing the computation graph
149/// * `buffers` - Buffer map containing input data and output storage
150/// * `params` - Batch parameters defining the computation window and dimensions
151///
152/// # Returns
153///
154/// Returns `Ok(())` on successful computation, or an error if:
155/// - Input buffers don't contain required data
156/// - Buffer dimensions don't match the parameters
157/// - The computation encounters runtime errors
158/// - Memory allocation fails during execution
159///
160/// # Examples
161///
162/// ```rust,no_run
163/// use kunquant_rs::{Executor, Library, BufferNameMap, BatchParams, run_graph};
164///
165/// # fn main() -> kunquant_rs::Result<()> {
166/// // Set up computation components
167/// let executor = Executor::single_thread()?;
168/// let library = Library::load("factors.so")?;
169/// let module = library.get_module("alpha001")?;
170///
171/// // Prepare data buffers
172/// let mut buffers = BufferNameMap::new()?;
173/// let mut input_data = vec![1.0f32; 16 * 100]; // 16 stocks, 100 time points
174/// let mut output_data = vec![0.0f32; 16 * 100];
175///
176/// buffers.set_buffer_slice("close", &mut input_data)?;
177/// buffers.set_buffer_slice("alpha001", &mut output_data)?;
178///
179/// // Execute computation
180/// let params = BatchParams::full_range(16, 100)?;
181/// run_graph(&executor, &module, &buffers, ¶ms)?;
182///
183/// // Results are now available in output_data
184/// # Ok(())
185/// # }
186/// ```
187///
188/// # Data Requirements
189///
190/// - All input buffers must be populated with data before calling
191/// - Buffer sizes must match `num_stocks * total_time`
192/// - Data should be in row-major order (time-first layout)
193/// - Output buffers must be pre-allocated with sufficient space
194///
195/// # Performance Notes
196///
197/// - Computation is CPU-intensive and benefits from multi-threading
198/// - Memory usage scales with `num_stocks * total_time * sizeof(f32)`
199/// - SIMD optimizations require `num_stocks` to be a multiple of 8
200/// - Consider processing data in chunks for very large datasets
201pub fn run_graph(
202 executor: &Executor,
203 module: &Module,
204 buffers: &BufferNameMap,
205 params: &BatchParams,
206) -> Result<()> {
207 unsafe {
208 ffi::kunRunGraph(
209 executor.handle(),
210 module.handle(),
211 buffers.handle(),
212 params.num_stocks,
213 params.total_time,
214 params.cur_time,
215 params.length,
216 );
217 }
218 Ok(())
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_batch_params_validation() {
227 // Valid parameters
228 assert!(BatchParams::new(8, 100, 0, 100).is_ok());
229 assert!(BatchParams::new(16, 100, 10, 50).is_ok());
230
231 // Invalid stock count (not multiple of 8)
232 assert!(BatchParams::new(7, 100, 0, 100).is_err());
233 assert!(BatchParams::new(15, 100, 0, 100).is_err());
234 }
235
236 #[test]
237 fn test_full_range_params() {
238 let params = BatchParams::full_range(24, 500).unwrap();
239 assert_eq!(params.num_stocks, 24);
240 assert_eq!(params.total_time, 500);
241 assert_eq!(params.cur_time, 0);
242 assert_eq!(params.length, 500);
243 }
244}