ghostflow_nn/
inference.rs1use ghostflow_core::{Result, Tensor, GhostError};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct InferenceConfig {
15 pub enable_fusion: bool,
17 pub enable_constant_folding: bool,
19 pub batch_size: usize,
21 pub use_mixed_precision: bool,
23 pub num_threads: usize,
25}
26
27impl Default for InferenceConfig {
28 fn default() -> Self {
29 Self {
30 enable_fusion: true,
31 enable_constant_folding: true,
32 batch_size: 1,
33 use_mixed_precision: false,
34 num_threads: num_cpus::get(),
35 }
36 }
37}
38
39pub struct InferenceOptimizer {
41 config: InferenceConfig,
42 fused_ops: Vec<FusedOp>,
43}
44
45#[derive(Debug, Clone)]
47pub struct FusedOp {
48 pub name: String,
49 pub ops: Vec<String>,
50}
51
52impl InferenceOptimizer {
53 pub fn new(config: InferenceConfig) -> Self {
55 Self {
56 config,
57 fused_ops: Vec::new(),
58 }
59 }
60
61 pub fn optimize(&mut self) -> Result<()> {
63 if self.config.enable_fusion {
64 self.fuse_operators()?;
65 }
66
67 if self.config.enable_constant_folding {
68 self.fold_constants()?;
69 }
70
71 Ok(())
72 }
73
74 fn fuse_operators(&mut self) -> Result<()> {
76 self.fused_ops.push(FusedOp {
78 name: "ConvBNReLU".to_string(),
79 ops: vec!["Conv2d".to_string(), "BatchNorm".to_string(), "ReLU".to_string()],
80 });
81
82 self.fused_ops.push(FusedOp {
84 name: "LinearReLU".to_string(),
85 ops: vec!["Linear".to_string(), "ReLU".to_string()],
86 });
87
88 self.fused_ops.push(FusedOp {
90 name: "GEMM".to_string(),
91 ops: vec!["MatMul".to_string(), "Add".to_string()],
92 });
93
94 Ok(())
95 }
96
97 fn fold_constants(&mut self) -> Result<()> {
99 Ok(())
102 }
103
104 pub fn get_fused_ops(&self) -> &[FusedOp] {
106 &self.fused_ops
107 }
108}
109
110pub struct BatchInference {
112 batch_size: usize,
113 buffer: Vec<Tensor>,
114}
115
116impl BatchInference {
117 pub fn new(batch_size: usize) -> Self {
119 Self {
120 batch_size,
121 buffer: Vec::new(),
122 }
123 }
124
125 pub fn add(&mut self, sample: Tensor) {
127 self.buffer.push(sample);
128 }
129
130 pub fn is_ready(&self) -> bool {
132 self.buffer.len() >= self.batch_size
133 }
134
135 pub fn get_batch(&mut self) -> Result<Option<Tensor>> {
137 if !self.is_ready() {
138 return Ok(None);
139 }
140
141 let batch = self.stack_tensors()?;
143 self.buffer.clear();
144 Ok(Some(batch))
145 }
146
147 pub fn flush(&mut self) -> Result<Option<Tensor>> {
149 if self.buffer.is_empty() {
150 return Ok(None);
151 }
152
153 let batch = self.stack_tensors()?;
154 self.buffer.clear();
155 Ok(Some(batch))
156 }
157
158 fn stack_tensors(&self) -> Result<Tensor> {
159 if self.buffer.is_empty() {
160 return Err(GhostError::InvalidShape("Empty buffer".to_string()));
161 }
162
163 let first_shape = self.buffer[0].dims();
164 let batch_size = self.buffer.len();
165
166 let mut new_shape = vec![batch_size];
168 new_shape.extend_from_slice(first_shape);
169
170 let mut all_data = Vec::new();
172 for tensor in &self.buffer {
173 all_data.extend(tensor.data_f32());
174 }
175
176 Tensor::from_slice(&all_data, &new_shape)
177 }
178}
179
180pub struct InferenceSession {
182 config: InferenceConfig,
183 optimizer: InferenceOptimizer,
184 cache: HashMap<String, Tensor>,
185}
186
187impl InferenceSession {
188 pub fn new(config: InferenceConfig) -> Self {
190 let optimizer = InferenceOptimizer::new(config.clone());
191 Self {
192 config,
193 optimizer,
194 cache: HashMap::new(),
195 }
196 }
197
198 pub fn initialize(&mut self) -> Result<()> {
200 self.optimizer.optimize()?;
201 Ok(())
202 }
203
204 pub fn run(&mut self, _input: &Tensor) -> Result<Tensor> {
206 Err(GhostError::NotImplemented("Inference execution not yet implemented".to_string()))
209 }
210
211 pub fn run_batch(&mut self, _inputs: &[Tensor]) -> Result<Vec<Tensor>> {
213 Err(GhostError::NotImplemented("Batch inference not yet implemented".to_string()))
215 }
216
217 pub fn cache_tensor(&mut self, name: String, tensor: Tensor) {
219 self.cache.insert(name, tensor);
220 }
221
222 pub fn get_cached(&self, name: &str) -> Option<&Tensor> {
224 self.cache.get(name)
225 }
226
227 pub fn clear_cache(&mut self) {
229 self.cache.clear();
230 }
231
232 pub fn config(&self) -> &InferenceConfig {
234 &self.config
235 }
236}
237
238pub fn warmup_model<F>(mut inference_fn: F, input_shape: &[usize], num_iterations: usize) -> Result<f64>
240where
241 F: FnMut(&Tensor) -> Result<Tensor>,
242{
243 use std::time::Instant;
244
245 let numel: usize = input_shape.iter().product();
247 let dummy_data = vec![0.0f32; numel];
248 let dummy_input = Tensor::from_slice(&dummy_data, input_shape)?;
249
250 for _ in 0..3 {
252 let _ = inference_fn(&dummy_input)?;
253 }
254
255 let start = Instant::now();
257 for _ in 0..num_iterations {
258 let _ = inference_fn(&dummy_input)?;
259 }
260 let elapsed = start.elapsed();
261
262 Ok(elapsed.as_secs_f64() * 1000.0 / num_iterations as f64)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_inference_config() {
272 let config = InferenceConfig::default();
273 assert!(config.enable_fusion);
274 assert!(config.enable_constant_folding);
275 assert_eq!(config.batch_size, 1);
276 }
277
278 #[test]
279 fn test_inference_optimizer() {
280 let config = InferenceConfig::default();
281 let mut optimizer = InferenceOptimizer::new(config);
282
283 optimizer.optimize().unwrap();
284
285 let fused_ops = optimizer.get_fused_ops();
286 assert!(!fused_ops.is_empty());
287 }
288
289 #[test]
290 fn test_batch_inference() {
291 let mut batch = BatchInference::new(2);
292
293 let t1 = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
294 let t2 = Tensor::from_slice(&[3.0f32, 4.0], &[2]).unwrap();
295
296 batch.add(t1);
297 assert!(!batch.is_ready());
298
299 batch.add(t2);
300 assert!(batch.is_ready());
301
302 let batched = batch.get_batch().unwrap().unwrap();
303 assert_eq!(batched.dims(), &[2, 2]);
304 }
305
306 #[test]
307 fn test_batch_flush() {
308 let mut batch = BatchInference::new(3);
309
310 let t1 = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
311 batch.add(t1);
312
313 let flushed = batch.flush().unwrap().unwrap();
314 assert_eq!(flushed.dims(), &[1, 2]);
315 }
316
317 #[test]
318 fn test_inference_session() {
319 let config = InferenceConfig::default();
320 let mut session = InferenceSession::new(config);
321
322 session.initialize().unwrap();
323
324 let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
326 session.cache_tensor("test".to_string(), tensor);
327
328 assert!(session.get_cached("test").is_some());
329
330 session.clear_cache();
331 assert!(session.get_cached("test").is_none());
332 }
333}
334