kizzasi_inference/
pipeline.rs1use crate::engine::{EngineConfig, InferenceEngine};
16use crate::error::{InferenceError, InferenceResult};
17use kizzasi_logic::{ConstrainedInference, GuardrailSet};
18use kizzasi_model::AutoregressiveModel;
19use kizzasi_tokenizer::SignalTokenizer;
20use scirs2_core::ndarray::Array1;
21use std::sync::Arc;
22
23pub type PreprocessHook = Arc<dyn Fn(&Array1<f32>) -> InferenceResult<Array1<f32>> + Send + Sync>;
25
26pub type PostprocessHook = Arc<dyn Fn(&Array1<f32>) -> InferenceResult<Array1<f32>> + Send + Sync>;
28
29pub struct Pipeline {
31 engine: InferenceEngine,
33 tokenizer: Option<Box<dyn SignalTokenizer>>,
35 use_tokenizer: bool,
37 constraints_enabled: bool,
39 guardrails: Option<GuardrailSet>,
41 preprocess_hooks: Vec<PreprocessHook>,
43 postprocess_hooks: Vec<PostprocessHook>,
45}
46
47impl Pipeline {
48 pub fn forward(&mut self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
58 let mut preprocessed = input.clone();
60 for hook in &self.preprocess_hooks {
61 preprocessed = hook(&preprocessed)?;
62 }
63
64 let tokenized = if self.use_tokenizer {
66 if let Some(tokenizer) = &self.tokenizer {
67 tokenizer
68 .encode(&preprocessed)
69 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?
70 } else {
71 return Err(InferenceError::TokenizationError(
72 "Tokenizer enabled but not provided".to_string(),
73 ));
74 }
75 } else {
76 preprocessed
77 };
78
79 let output = self.engine.step(&tokenized)?;
81
82 let constrained = if self.constraints_enabled {
84 self.apply_constraints(&output)?
85 } else {
86 output
87 };
88
89 let decoded = if self.use_tokenizer {
91 if let Some(tokenizer) = &self.tokenizer {
92 tokenizer
93 .decode(&constrained)
94 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?
95 } else {
96 return Err(InferenceError::TokenizationError(
97 "Tokenizer enabled but not provided".to_string(),
98 ));
99 }
100 } else {
101 constrained
102 };
103
104 let mut postprocessed = decoded;
106 for hook in &self.postprocess_hooks {
107 postprocessed = hook(&postprocessed)?;
108 }
109
110 Ok(postprocessed)
111 }
112
113 fn apply_constraints(&self, output: &Array1<f32>) -> InferenceResult<Array1<f32>> {
118 if let Some(ref guardrails) = self.guardrails {
119 guardrails
121 .constrain(output)
122 .map_err(|e| InferenceError::ConstraintError(e.to_string()))
123 } else {
124 Ok(output.clone())
126 }
127 }
128
129 pub fn set_guardrails(&mut self, guardrails: GuardrailSet) {
131 self.guardrails = Some(guardrails);
132 self.constraints_enabled = true;
133 }
134
135 pub fn clear_guardrails(&mut self) {
137 self.guardrails = None;
138 self.constraints_enabled = false;
139 }
140
141 pub fn guardrails(&self) -> Option<&GuardrailSet> {
143 self.guardrails.as_ref()
144 }
145
146 pub fn rollout(
148 &mut self,
149 initial: &Array1<f32>,
150 steps: usize,
151 ) -> InferenceResult<Vec<Array1<f32>>> {
152 let mut outputs = Vec::with_capacity(steps);
153 let mut current = initial.clone();
154
155 for _ in 0..steps {
156 let output = self.forward(¤t)?;
157 outputs.push(output.clone());
158 current = output;
159 }
160
161 Ok(outputs)
162 }
163
164 pub fn reset(&mut self) {
166 self.engine.reset();
167 }
168
169 pub fn engine(&self) -> &InferenceEngine {
171 &self.engine
172 }
173
174 pub fn engine_mut(&mut self) -> &mut InferenceEngine {
176 &mut self.engine
177 }
178
179 pub fn has_constraints(&self) -> bool {
181 self.constraints_enabled
182 }
183
184 pub fn has_tokenizer(&self) -> bool {
186 self.use_tokenizer && self.tokenizer.is_some()
187 }
188
189 pub fn add_preprocess_hook(&mut self, hook: PreprocessHook) {
191 self.preprocess_hooks.push(hook);
192 }
193
194 pub fn add_postprocess_hook(&mut self, hook: PostprocessHook) {
196 self.postprocess_hooks.push(hook);
197 }
198
199 pub fn num_preprocess_hooks(&self) -> usize {
201 self.preprocess_hooks.len()
202 }
203
204 pub fn num_postprocess_hooks(&self) -> usize {
206 self.postprocess_hooks.len()
207 }
208
209 pub fn clear_preprocess_hooks(&mut self) {
211 self.preprocess_hooks.clear();
212 }
213
214 pub fn clear_postprocess_hooks(&mut self) {
216 self.postprocess_hooks.clear();
217 }
218}
219
220pub struct PipelineBuilder {
222 engine_config: Option<EngineConfig>,
223 model: Option<Box<dyn AutoregressiveModel>>,
224 tokenizer: Option<Box<dyn SignalTokenizer>>,
225 use_tokenizer: bool,
226 constraints_enabled: bool,
227 guardrails: Option<GuardrailSet>,
228 preprocess_hooks: Vec<PreprocessHook>,
229 postprocess_hooks: Vec<PostprocessHook>,
230}
231
232impl PipelineBuilder {
233 pub fn new() -> Self {
235 Self {
236 engine_config: None,
237 model: None,
238 tokenizer: None,
239 use_tokenizer: false,
240 constraints_enabled: false,
241 guardrails: None,
242 preprocess_hooks: Vec::new(),
243 postprocess_hooks: Vec::new(),
244 }
245 }
246
247 pub fn engine_config(mut self, config: EngineConfig) -> Self {
249 self.engine_config = Some(config);
250 self
251 }
252
253 pub fn model(mut self, model: Box<dyn AutoregressiveModel>) -> Self {
255 self.model = Some(model);
256 self
257 }
258
259 pub fn tokenizer(mut self, tokenizer: Box<dyn SignalTokenizer>) -> Self {
261 self.tokenizer = Some(tokenizer);
262 self.use_tokenizer = true;
263 self
264 }
265
266 pub fn use_tokenizer(mut self, use_tok: bool) -> Self {
268 self.use_tokenizer = use_tok;
269 self
270 }
271
272 pub fn with_constraints(mut self) -> Self {
274 self.constraints_enabled = true;
275 self
276 }
277
278 pub fn guardrails(mut self, guardrails: GuardrailSet) -> Self {
280 self.guardrails = Some(guardrails);
281 self.constraints_enabled = true;
282 self
283 }
284
285 pub fn add_preprocess_hook(mut self, hook: PreprocessHook) -> Self {
287 self.preprocess_hooks.push(hook);
288 self
289 }
290
291 pub fn add_postprocess_hook(mut self, hook: PostprocessHook) -> Self {
293 self.postprocess_hooks.push(hook);
294 self
295 }
296
297 pub fn build(self) -> InferenceResult<Pipeline> {
299 let engine_config = self
300 .engine_config
301 .ok_or_else(|| InferenceError::PipelineConfig("engine_config not set".into()))?;
302
303 let engine = if let Some(model) = self.model {
304 InferenceEngine::with_model(engine_config, model)
305 } else {
306 InferenceEngine::new(engine_config)
307 };
308
309 Ok(Pipeline {
310 engine,
311 tokenizer: self.tokenizer,
312 use_tokenizer: self.use_tokenizer,
313 constraints_enabled: self.constraints_enabled,
314 guardrails: self.guardrails,
315 preprocess_hooks: self.preprocess_hooks,
316 postprocess_hooks: self.postprocess_hooks,
317 })
318 }
319}
320
321impl Default for PipelineBuilder {
322 fn default() -> Self {
323 Self::new()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::sampling::SamplingConfig;
331
332 #[test]
333 fn test_pipeline_builder_basic() {
334 let engine_config = EngineConfig::new(3, 3);
335 let pipeline = PipelineBuilder::new()
336 .engine_config(engine_config)
337 .with_constraints()
338 .build();
339
340 assert!(pipeline.is_ok());
341 let p = pipeline.unwrap();
342 assert!(p.has_constraints());
343 assert!(!p.has_tokenizer());
344 }
345
346 #[test]
347 fn test_pipeline_missing_config() {
348 let result = PipelineBuilder::new().build();
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_pipeline_with_model() {
354 use kizzasi_model::s4::{S4Config, S4D};
355
356 let model_config = S4Config::new()
357 .input_dim(1)
358 .hidden_dim(64)
359 .state_dim(16)
360 .num_layers(2)
361 .diagonal(true);
362 let model = S4D::new(model_config).unwrap();
363
364 let engine_config = EngineConfig::new(1, 10);
365 let mut pipeline = PipelineBuilder::new()
366 .engine_config(engine_config)
367 .model(Box::new(model))
368 .build()
369 .unwrap();
370
371 let input = Array1::from_vec(vec![0.5]);
372 let output = pipeline.forward(&input);
373
374 assert!(output.is_ok());
375 }
376
377 #[test]
378 fn test_pipeline_rollout() {
379 use kizzasi_model::rwkv::{Rwkv, RwkvConfig};
380
381 let model_config = RwkvConfig::new()
382 .input_dim(1)
383 .hidden_dim(64)
384 .intermediate_dim(256)
385 .num_layers(2);
386 let model = Rwkv::new(model_config).unwrap();
387
388 let engine_config = EngineConfig::new(1, 10);
389 let mut pipeline = PipelineBuilder::new()
390 .engine_config(engine_config)
391 .model(Box::new(model))
392 .build()
393 .unwrap();
394
395 let initial = Array1::from_vec(vec![0.5]);
396 let outputs = pipeline.rollout(&initial, 5);
397
398 assert!(outputs.is_ok());
399 assert_eq!(outputs.unwrap().len(), 5);
400 }
401
402 #[test]
403 fn test_pipeline_reset() {
404 let engine_config = EngineConfig::new(1, 1);
405 let mut pipeline = PipelineBuilder::new()
406 .engine_config(engine_config)
407 .build()
408 .unwrap();
409
410 pipeline.reset();
411 assert_eq!(pipeline.engine().step_count(), 0);
412 }
413
414 #[test]
415 fn test_pipeline_with_sampling() {
416 use crate::sampling::SamplingStrategy;
417 use kizzasi_model::s4::{S4Config, S4D};
418
419 let model_config = S4Config::new()
420 .input_dim(1)
421 .hidden_dim(64)
422 .state_dim(16)
423 .num_layers(2)
424 .diagonal(true);
425 let model = S4D::new(model_config).unwrap();
426
427 let sampling = SamplingConfig::new()
428 .strategy(SamplingStrategy::TopK)
429 .top_k(5);
430
431 let engine_config = EngineConfig::new(1, 10)
432 .sampling(sampling)
433 .use_embeddings(true);
434
435 let mut pipeline = PipelineBuilder::new()
436 .engine_config(engine_config)
437 .model(Box::new(model))
438 .build()
439 .unwrap();
440
441 let input = Array1::from_vec(vec![0.5]);
442 let output = pipeline.forward(&input);
443
444 assert!(output.is_ok());
445 }
446}