kizzasi_inference/
pipeline.rs

1//! Inference pipeline construction
2//!
3//! Provides a builder pattern for constructing complete inference pipelines
4//! with tokenization, model, and optional constraints.
5//!
6//! # Pipeline Hooks
7//!
8//! The pipeline supports preprocessing and postprocessing hooks that allow
9//! custom transformations at different stages:
10//!
11//! ```text
12//! Input → [Preprocess] → Tokenize → Model → Constrain → [Postprocess] → Output
13//! ```
14
15use crate::engine::{EngineConfig, InferenceEngine};
16use crate::error::{InferenceError, InferenceResult};
17use kizzasi_logic::{ConstrainedInference, GuardrailSet};
18use kizzasi_model::AutoregressiveModel;
19use kizzasi_tokenizer::SignalTokenizer;
20use scirs2_core::ndarray::Array1;
21use std::sync::Arc;
22
23/// Preprocessing hook that transforms input before model forward pass
24pub type PreprocessHook = Arc<dyn Fn(&Array1<f32>) -> InferenceResult<Array1<f32>> + Send + Sync>;
25
26/// Postprocessing hook that transforms output after model forward pass
27pub type PostprocessHook = Arc<dyn Fn(&Array1<f32>) -> InferenceResult<Array1<f32>> + Send + Sync>;
28
29/// A complete inference pipeline
30pub struct Pipeline {
31    /// The inference engine
32    engine: InferenceEngine,
33    /// Optional tokenizer for preprocessing
34    tokenizer: Option<Box<dyn SignalTokenizer>>,
35    /// Whether to apply tokenization
36    use_tokenizer: bool,
37    /// Whether constraints are enabled
38    constraints_enabled: bool,
39    /// Guardrail set for constraint enforcement
40    guardrails: Option<GuardrailSet>,
41    /// Preprocessing hooks (applied before tokenization)
42    preprocess_hooks: Vec<PreprocessHook>,
43    /// Postprocessing hooks (applied after detokenization)
44    postprocess_hooks: Vec<PostprocessHook>,
45}
46
47impl Pipeline {
48    /// Create a single prediction step through the pipeline
49    ///
50    /// The complete flow is:
51    /// 1. Apply preprocessing hooks
52    /// 2. Optional tokenization of input signal
53    /// 3. Model forward pass
54    /// 4. Optional constraint enforcement
55    /// 5. Optional detokenization
56    /// 6. Apply postprocessing hooks
57    pub fn forward(&mut self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
58        // Step 1: Preprocessing hooks
59        let mut preprocessed = input.clone();
60        for hook in &self.preprocess_hooks {
61            preprocessed = hook(&preprocessed)?;
62        }
63
64        // Step 2: Tokenization (if enabled)
65        let tokenized = if self.use_tokenizer {
66            if let Some(tokenizer) = &self.tokenizer {
67                tokenizer
68                    .encode(&preprocessed)
69                    .map_err(|e| InferenceError::TokenizationError(e.to_string()))?
70            } else {
71                return Err(InferenceError::TokenizationError(
72                    "Tokenizer enabled but not provided".to_string(),
73                ));
74            }
75        } else {
76            preprocessed
77        };
78
79        // Step 3: Model forward pass
80        let output = self.engine.step(&tokenized)?;
81
82        // Step 4: Constraint enforcement (if enabled)
83        let constrained = if self.constraints_enabled {
84            self.apply_constraints(&output)?
85        } else {
86            output
87        };
88
89        // Step 5: Detokenization (if enabled)
90        let decoded = if self.use_tokenizer {
91            if let Some(tokenizer) = &self.tokenizer {
92                tokenizer
93                    .decode(&constrained)
94                    .map_err(|e| InferenceError::TokenizationError(e.to_string()))?
95            } else {
96                return Err(InferenceError::TokenizationError(
97                    "Tokenizer enabled but not provided".to_string(),
98                ));
99            }
100        } else {
101            constrained
102        };
103
104        // Step 6: Postprocessing hooks
105        let mut postprocessed = decoded;
106        for hook in &self.postprocess_hooks {
107            postprocessed = hook(&postprocessed)?;
108        }
109
110        Ok(postprocessed)
111    }
112
113    /// Apply constraints to the output
114    ///
115    /// If guardrails are configured, enforces constraints by projecting
116    /// the output onto the constraint-satisfying manifold
117    fn apply_constraints(&self, output: &Array1<f32>) -> InferenceResult<Array1<f32>> {
118        if let Some(ref guardrails) = self.guardrails {
119            // Use guardrails to constrain the output
120            guardrails
121                .constrain(output)
122                .map_err(|e| InferenceError::ConstraintError(e.to_string()))
123        } else {
124            // No guardrails configured, return unchanged
125            Ok(output.clone())
126        }
127    }
128
129    /// Set the guardrails for constraint enforcement
130    pub fn set_guardrails(&mut self, guardrails: GuardrailSet) {
131        self.guardrails = Some(guardrails);
132        self.constraints_enabled = true;
133    }
134
135    /// Remove guardrails
136    pub fn clear_guardrails(&mut self) {
137        self.guardrails = None;
138        self.constraints_enabled = false;
139    }
140
141    /// Get a reference to the guardrails
142    pub fn guardrails(&self) -> Option<&GuardrailSet> {
143        self.guardrails.as_ref()
144    }
145
146    /// Perform multi-step prediction through the pipeline
147    pub fn rollout(
148        &mut self,
149        initial: &Array1<f32>,
150        steps: usize,
151    ) -> InferenceResult<Vec<Array1<f32>>> {
152        let mut outputs = Vec::with_capacity(steps);
153        let mut current = initial.clone();
154
155        for _ in 0..steps {
156            let output = self.forward(&current)?;
157            outputs.push(output.clone());
158            current = output;
159        }
160
161        Ok(outputs)
162    }
163
164    /// Reset the pipeline state
165    pub fn reset(&mut self) {
166        self.engine.reset();
167    }
168
169    /// Get the underlying engine
170    pub fn engine(&self) -> &InferenceEngine {
171        &self.engine
172    }
173
174    /// Get mutable access to the engine
175    pub fn engine_mut(&mut self) -> &mut InferenceEngine {
176        &mut self.engine
177    }
178
179    /// Check if constraints are enabled
180    pub fn has_constraints(&self) -> bool {
181        self.constraints_enabled
182    }
183
184    /// Check if tokenizer is enabled
185    pub fn has_tokenizer(&self) -> bool {
186        self.use_tokenizer && self.tokenizer.is_some()
187    }
188
189    /// Add a preprocessing hook
190    pub fn add_preprocess_hook(&mut self, hook: PreprocessHook) {
191        self.preprocess_hooks.push(hook);
192    }
193
194    /// Add a postprocessing hook
195    pub fn add_postprocess_hook(&mut self, hook: PostprocessHook) {
196        self.postprocess_hooks.push(hook);
197    }
198
199    /// Get number of preprocessing hooks
200    pub fn num_preprocess_hooks(&self) -> usize {
201        self.preprocess_hooks.len()
202    }
203
204    /// Get number of postprocessing hooks
205    pub fn num_postprocess_hooks(&self) -> usize {
206        self.postprocess_hooks.len()
207    }
208
209    /// Clear all preprocessing hooks
210    pub fn clear_preprocess_hooks(&mut self) {
211        self.preprocess_hooks.clear();
212    }
213
214    /// Clear all postprocessing hooks
215    pub fn clear_postprocess_hooks(&mut self) {
216        self.postprocess_hooks.clear();
217    }
218}
219
220/// Builder for constructing inference pipelines
221pub struct PipelineBuilder {
222    engine_config: Option<EngineConfig>,
223    model: Option<Box<dyn AutoregressiveModel>>,
224    tokenizer: Option<Box<dyn SignalTokenizer>>,
225    use_tokenizer: bool,
226    constraints_enabled: bool,
227    guardrails: Option<GuardrailSet>,
228    preprocess_hooks: Vec<PreprocessHook>,
229    postprocess_hooks: Vec<PostprocessHook>,
230}
231
232impl PipelineBuilder {
233    /// Create a new pipeline builder
234    pub fn new() -> Self {
235        Self {
236            engine_config: None,
237            model: None,
238            tokenizer: None,
239            use_tokenizer: false,
240            constraints_enabled: false,
241            guardrails: None,
242            preprocess_hooks: Vec::new(),
243            postprocess_hooks: Vec::new(),
244        }
245    }
246
247    /// Set the engine configuration
248    pub fn engine_config(mut self, config: EngineConfig) -> Self {
249        self.engine_config = Some(config);
250        self
251    }
252
253    /// Set the model
254    pub fn model(mut self, model: Box<dyn AutoregressiveModel>) -> Self {
255        self.model = Some(model);
256        self
257    }
258
259    /// Set the tokenizer
260    pub fn tokenizer(mut self, tokenizer: Box<dyn SignalTokenizer>) -> Self {
261        self.tokenizer = Some(tokenizer);
262        self.use_tokenizer = true;
263        self
264    }
265
266    /// Enable/disable tokenizer usage
267    pub fn use_tokenizer(mut self, use_tok: bool) -> Self {
268        self.use_tokenizer = use_tok;
269        self
270    }
271
272    /// Enable constraint enforcement
273    pub fn with_constraints(mut self) -> Self {
274        self.constraints_enabled = true;
275        self
276    }
277
278    /// Set guardrails for constraint enforcement
279    pub fn guardrails(mut self, guardrails: GuardrailSet) -> Self {
280        self.guardrails = Some(guardrails);
281        self.constraints_enabled = true;
282        self
283    }
284
285    /// Add a preprocessing hook
286    pub fn add_preprocess_hook(mut self, hook: PreprocessHook) -> Self {
287        self.preprocess_hooks.push(hook);
288        self
289    }
290
291    /// Add a postprocessing hook
292    pub fn add_postprocess_hook(mut self, hook: PostprocessHook) -> Self {
293        self.postprocess_hooks.push(hook);
294        self
295    }
296
297    /// Build the pipeline
298    pub fn build(self) -> InferenceResult<Pipeline> {
299        let engine_config = self
300            .engine_config
301            .ok_or_else(|| InferenceError::PipelineConfig("engine_config not set".into()))?;
302
303        let engine = if let Some(model) = self.model {
304            InferenceEngine::with_model(engine_config, model)
305        } else {
306            InferenceEngine::new(engine_config)
307        };
308
309        Ok(Pipeline {
310            engine,
311            tokenizer: self.tokenizer,
312            use_tokenizer: self.use_tokenizer,
313            constraints_enabled: self.constraints_enabled,
314            guardrails: self.guardrails,
315            preprocess_hooks: self.preprocess_hooks,
316            postprocess_hooks: self.postprocess_hooks,
317        })
318    }
319}
320
321impl Default for PipelineBuilder {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::sampling::SamplingConfig;
331
332    #[test]
333    fn test_pipeline_builder_basic() {
334        let engine_config = EngineConfig::new(3, 3);
335        let pipeline = PipelineBuilder::new()
336            .engine_config(engine_config)
337            .with_constraints()
338            .build();
339
340        assert!(pipeline.is_ok());
341        let p = pipeline.unwrap();
342        assert!(p.has_constraints());
343        assert!(!p.has_tokenizer());
344    }
345
346    #[test]
347    fn test_pipeline_missing_config() {
348        let result = PipelineBuilder::new().build();
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_pipeline_with_model() {
354        use kizzasi_model::s4::{S4Config, S4D};
355
356        let model_config = S4Config::new()
357            .input_dim(1)
358            .hidden_dim(64)
359            .state_dim(16)
360            .num_layers(2)
361            .diagonal(true);
362        let model = S4D::new(model_config).unwrap();
363
364        let engine_config = EngineConfig::new(1, 10);
365        let mut pipeline = PipelineBuilder::new()
366            .engine_config(engine_config)
367            .model(Box::new(model))
368            .build()
369            .unwrap();
370
371        let input = Array1::from_vec(vec![0.5]);
372        let output = pipeline.forward(&input);
373
374        assert!(output.is_ok());
375    }
376
377    #[test]
378    fn test_pipeline_rollout() {
379        use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
380
381        let model_config = RwkvConfig::new()
382            .input_dim(1)
383            .hidden_dim(64)
384            .intermediate_dim(256)
385            .num_layers(2);
386        let model = Rwkv::new(model_config).unwrap();
387
388        let engine_config = EngineConfig::new(1, 10);
389        let mut pipeline = PipelineBuilder::new()
390            .engine_config(engine_config)
391            .model(Box::new(model))
392            .build()
393            .unwrap();
394
395        let initial = Array1::from_vec(vec![0.5]);
396        let outputs = pipeline.rollout(&initial, 5);
397
398        assert!(outputs.is_ok());
399        assert_eq!(outputs.unwrap().len(), 5);
400    }
401
402    #[test]
403    fn test_pipeline_reset() {
404        let engine_config = EngineConfig::new(1, 1);
405        let mut pipeline = PipelineBuilder::new()
406            .engine_config(engine_config)
407            .build()
408            .unwrap();
409
410        pipeline.reset();
411        assert_eq!(pipeline.engine().step_count(), 0);
412    }
413
414    #[test]
415    fn test_pipeline_with_sampling() {
416        use crate::sampling::SamplingStrategy;
417        use kizzasi_model::s4::{S4Config, S4D};
418
419        let model_config = S4Config::new()
420            .input_dim(1)
421            .hidden_dim(64)
422            .state_dim(16)
423            .num_layers(2)
424            .diagonal(true);
425        let model = S4D::new(model_config).unwrap();
426
427        let sampling = SamplingConfig::new()
428            .strategy(SamplingStrategy::TopK)
429            .top_k(5);
430
431        let engine_config = EngineConfig::new(1, 10)
432            .sampling(sampling)
433            .use_embeddings(true);
434
435        let mut pipeline = PipelineBuilder::new()
436            .engine_config(engine_config)
437            .model(Box::new(model))
438            .build()
439            .unwrap();
440
441        let input = Array1::from_vec(vec![0.5]);
442        let output = pipeline.forward(&input);
443
444        assert!(output.is_ok());
445    }
446}