1use crate::tensor::Tensor;
7use crate::types::{ModelGraph, OptimizationLevel, ProviderId, SessionId};
8use anyhow::{Result, anyhow};
9use dashmap::DashMap;
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15#[derive(Debug, Clone)]
17pub struct SessionConfig {
18 pub thread_count: Option<usize>,
20 pub memory_limit: Option<usize>,
22 pub optimization_level: OptimizationLevel,
24 pub preferred_providers: Vec<ProviderId>,
26 pub timeout_seconds: Option<u64>,
28 pub max_concurrent_inferences: Option<usize>,
30 pub enable_metrics: bool,
32 pub custom_options: HashMap<String, String>,
34}
35
36impl Default for SessionConfig {
37 fn default() -> Self {
38 Self {
39 thread_count: None,
40 memory_limit: None,
41 optimization_level: OptimizationLevel::Basic,
42 preferred_providers: vec![ProviderId::CPU],
43 timeout_seconds: Some(30),
44 max_concurrent_inferences: Some(10),
45 enable_metrics: true,
46 custom_options: HashMap::new(),
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct SessionStatistics {
54 pub total_inferences: u64,
56 pub total_inference_time_ms: u64,
58 pub average_inference_time_ms: f64,
60 pub min_inference_time_ms: Option<u64>,
62 pub max_inference_time_ms: Option<u64>,
64 pub peak_memory_bytes: usize,
66 pub current_memory_bytes: usize,
68 pub error_count: u64,
70 pub created_at: Instant,
72 pub last_inference_at: Option<Instant>,
74}
75
76impl Default for SessionStatistics {
77 fn default() -> Self {
78 Self {
79 total_inferences: 0,
80 total_inference_time_ms: 0,
81 average_inference_time_ms: 0.0,
82 min_inference_time_ms: None,
83 max_inference_time_ms: None,
84 peak_memory_bytes: 0,
85 current_memory_bytes: 0,
86 error_count: 0,
87 created_at: Instant::now(),
88 last_inference_at: None,
89 }
90 }
91}
92
93#[derive(Debug, Clone)]
95struct ResourceUsage {
96 current_memory: usize,
98 #[allow(dead_code)]
100 peak_memory: usize,
101 active_inferences: usize,
103}
104
105impl Default for ResourceUsage {
106 fn default() -> Self {
107 Self {
108 current_memory: 0,
109 peak_memory: 0,
110 active_inferences: 0,
111 }
112 }
113}
114
115#[derive(Debug)]
117pub struct InferenceSession {
118 pub id: SessionId,
120 pub model: Arc<ModelGraph>,
122 pub config: SessionConfig,
124 pub statistics: Arc<RwLock<SessionStatistics>>,
126 resource_usage: Arc<RwLock<ResourceUsage>>,
128 created_at: Instant,
130 marked_for_deletion: bool,
132}
133
134impl InferenceSession {
135 pub fn new(model: ModelGraph, config: SessionConfig) -> Self {
137 let id = SessionId::new_v4();
138 let created_at = Instant::now();
139
140 let mut statistics = SessionStatistics::default();
141 statistics.created_at = created_at;
142
143 Self {
144 id,
145 model: Arc::new(model),
146 config,
147 statistics: Arc::new(RwLock::new(statistics)),
148 resource_usage: Arc::new(RwLock::new(ResourceUsage::default())),
149 created_at,
150 marked_for_deletion: false,
151 }
152 }
153
154 pub async fn run_inference(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
156 let start_time = Instant::now();
157
158 self.check_resource_limits().await?;
160
161 {
163 let mut usage = self.resource_usage.write().await;
164 usage.active_inferences += 1;
165
166 if let Some(max_concurrent) = self.config.max_concurrent_inferences {
167 if usage.active_inferences > max_concurrent {
168 usage.active_inferences -= 1;
169 return Err(anyhow!("Max concurrent inferences exceeded"));
170 }
171 }
172 }
173
174 let result = self.execute_inference(inputs).await;
176
177 let inference_time = start_time.elapsed();
179 self.update_statistics(inference_time, result.is_ok()).await;
180
181 {
183 let mut usage = self.resource_usage.write().await;
184 usage.active_inferences = usage.active_inferences.saturating_sub(1);
185 }
186
187 result
188 }
189
190 async fn check_resource_limits(&self) -> Result<()> {
192 let usage = self.resource_usage.read().await;
193
194 if let Some(memory_limit) = self.config.memory_limit {
195 if usage.current_memory > memory_limit {
196 return Err(anyhow!(
197 "Memory limit exceeded: {} > {}",
198 usage.current_memory,
199 memory_limit
200 ));
201 }
202 }
203
204 if let Some(timeout_seconds) = self.config.timeout_seconds {
206 let timeout = Duration::from_secs(timeout_seconds);
207 if self.created_at.elapsed() > timeout {
208 return Err(anyhow!("Session timeout exceeded"));
209 }
210 }
211
212 Ok(())
213 }
214
215 async fn execute_inference(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
217 if inputs.len() != self.model.inputs.len() {
226 return Err(anyhow!(
227 "Input tensor count mismatch: expected {}, got {}",
228 self.model.inputs.len(),
229 inputs.len()
230 ));
231 }
232
233 tokio::time::sleep(Duration::from_millis(1)).await;
235
236 let outputs: Result<Vec<Tensor>> = self
238 .model
239 .outputs
240 .iter()
241 .enumerate()
242 .map(|(i, _output_name)| {
243 Tensor::ones(
245 vec![1, 10],
246 crate::types::DataType::F32,
247 crate::types::TensorLayout::RowMajor,
248 )
249 .map_err(|e| anyhow!("Failed to create output tensor {}: {}", i, e))
250 })
251 .collect();
252
253 outputs
254 }
255
256 async fn update_statistics(&self, inference_time: Duration, success: bool) {
258 let mut stats = self.statistics.write().await;
259
260 let inference_time_ms = inference_time.as_millis() as u64;
261
262 if success {
263 stats.total_inferences += 1;
264 stats.total_inference_time_ms += inference_time_ms;
265 stats.average_inference_time_ms =
266 stats.total_inference_time_ms as f64 / stats.total_inferences as f64;
267
268 stats.min_inference_time_ms = Some(
269 stats
270 .min_inference_time_ms
271 .map_or(inference_time_ms, |min| min.min(inference_time_ms)),
272 );
273
274 stats.max_inference_time_ms = Some(
275 stats
276 .max_inference_time_ms
277 .map_or(inference_time_ms, |max| max.max(inference_time_ms)),
278 );
279 } else {
280 stats.error_count += 1;
281 }
282
283 stats.last_inference_at = Some(Instant::now());
284 }
285
286 pub async fn get_statistics(&self) -> SessionStatistics {
288 self.statistics.read().await.clone()
289 }
290
291 pub async fn get_resource_usage(&self) -> ResourceUsage {
293 self.resource_usage.read().await.clone()
294 }
295
296 pub fn mark_for_deletion(&mut self) {
298 self.marked_for_deletion = true;
299 }
300
301 pub fn is_marked_for_deletion(&self) -> bool {
303 self.marked_for_deletion
304 }
305
306 pub fn age(&self) -> Duration {
308 self.created_at.elapsed()
309 }
310}
311
312#[derive(Debug)]
314pub struct SessionManager {
315 sessions: DashMap<SessionId, Arc<InferenceSession>>,
317 #[allow(dead_code)]
319 global_memory_limit: Option<usize>,
320 max_sessions: Option<usize>,
322 default_config: SessionConfig,
324}
325
326impl SessionManager {
327 pub fn new() -> Self {
329 Self {
330 sessions: DashMap::new(),
331 global_memory_limit: None,
332 max_sessions: Some(100),
333 default_config: SessionConfig::default(),
334 }
335 }
336
337 pub fn with_config(
339 global_memory_limit: Option<usize>,
340 max_sessions: Option<usize>,
341 default_config: SessionConfig,
342 ) -> Self {
343 Self {
344 sessions: DashMap::new(),
345 global_memory_limit,
346 max_sessions,
347 default_config,
348 }
349 }
350
351 pub async fn create_session(&self, model: ModelGraph) -> Result<SessionId> {
353 self.create_session_with_config(model, None).await
354 }
355
356 pub async fn create_session_with_config(
358 &self,
359 model: ModelGraph,
360 config: Option<SessionConfig>,
361 ) -> Result<SessionId> {
362 if let Some(max_sessions) = self.max_sessions {
364 if self.sessions.len() >= max_sessions {
365 self.cleanup_expired_sessions().await;
367
368 if self.sessions.len() >= max_sessions {
369 return Err(anyhow!(
370 "Maximum number of sessions reached: {}",
371 max_sessions
372 ));
373 }
374 }
375 }
376
377 model
379 .validate()
380 .map_err(|e| anyhow!("Invalid model graph: {}", e))?;
381
382 let session_config = config.unwrap_or_else(|| self.default_config.clone());
383 let session = Arc::new(InferenceSession::new(model, session_config));
384 let session_id = session.id;
385
386 self.sessions.insert(session_id, session);
387
388 tracing::info!(
389 "Created session {} with {} nodes",
390 session_id,
391 self.sessions.get(&session_id).unwrap().model.nodes.len()
392 );
393
394 Ok(session_id)
395 }
396
397 pub fn get_session(&self, session_id: SessionId) -> Option<Arc<InferenceSession>> {
399 self.sessions
400 .get(&session_id)
401 .map(|entry| entry.value().clone())
402 }
403
404 pub async fn run_inference(
406 &self,
407 session_id: SessionId,
408 inputs: Vec<Tensor>,
409 ) -> Result<Vec<Tensor>> {
410 let session = self
411 .get_session(session_id)
412 .ok_or_else(|| anyhow!("Session not found: {}", session_id))?;
413
414 if session.is_marked_for_deletion() {
415 return Err(anyhow!("Session is marked for deletion: {}", session_id));
416 }
417
418 session.run_inference(&inputs).await
419 }
420
421 pub async fn destroy_session(&self, session_id: SessionId) -> Result<()> {
423 if let Some((_, session)) = self.sessions.remove(&session_id) {
424 let timeout = Duration::from_secs(5);
426 let start = Instant::now();
427
428 while start.elapsed() < timeout {
429 let usage = session.get_resource_usage().await;
430 if usage.active_inferences == 0 {
431 break;
432 }
433 tokio::time::sleep(Duration::from_millis(100)).await;
434 }
435
436 tracing::info!("Destroyed session {}", session_id);
437 Ok(())
438 } else {
439 Err(anyhow!("Session not found: {}", session_id))
440 }
441 }
442
443 pub async fn get_session_statistics(&self, session_id: SessionId) -> Result<SessionStatistics> {
445 let session = self
446 .get_session(session_id)
447 .ok_or_else(|| anyhow!("Session not found: {}", session_id))?;
448
449 Ok(session.get_statistics().await)
450 }
451
452 pub fn list_sessions(&self) -> Vec<SessionId> {
454 self.sessions.iter().map(|entry| *entry.key()).collect()
455 }
456
457 pub fn session_count(&self) -> usize {
459 self.sessions.len()
460 }
461
462 pub async fn cleanup_expired_sessions(&self) -> usize {
464 let mut removed_count = 0;
465 let max_age = Duration::from_secs(3600); let expired_sessions: Vec<SessionId> = self
468 .sessions
469 .iter()
470 .filter_map(|entry| {
471 let session = entry.value();
472 if session.age() > max_age || session.is_marked_for_deletion() {
473 Some(*entry.key())
474 } else {
475 None
476 }
477 })
478 .collect();
479
480 for session_id in expired_sessions {
481 if self.destroy_session(session_id).await.is_ok() {
482 removed_count += 1;
483 }
484 }
485
486 if removed_count > 0 {
487 tracing::info!("Cleaned up {} expired sessions", removed_count);
488 }
489
490 removed_count
491 }
492
493 pub async fn get_global_statistics(&self) -> GlobalStatistics {
495 let mut global_stats = GlobalStatistics::default();
496
497 for entry in self.sessions.iter() {
498 let session = entry.value();
499 let stats = session.get_statistics().await;
500 let usage = session.get_resource_usage().await;
501
502 global_stats.total_sessions += 1;
503 global_stats.total_inferences += stats.total_inferences;
504 global_stats.total_errors += stats.error_count;
505 global_stats.total_memory_bytes += usage.current_memory;
506 global_stats.active_inferences += usage.active_inferences as u64;
507 }
508
509 global_stats
510 }
511
512 pub async fn shutdown(&self) -> Result<()> {
514 let session_ids: Vec<SessionId> = self.list_sessions();
515
516 tracing::info!(
517 "Shutting down session manager with {} active sessions",
518 session_ids.len()
519 );
520
521 for session_id in session_ids {
522 if let Err(e) = self.destroy_session(session_id).await {
523 tracing::warn!("Failed to destroy session {}: {}", session_id, e);
524 }
525 }
526
527 Ok(())
528 }
529}
530
531impl Default for SessionManager {
532 fn default() -> Self {
533 Self::new()
534 }
535}
536
537#[derive(Debug, Clone, Default)]
539pub struct GlobalStatistics {
540 pub total_sessions: usize,
542 pub total_inferences: u64,
544 pub total_errors: u64,
546 pub total_memory_bytes: usize,
548 pub active_inferences: u64,
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use crate::graph::GraphBuilder;
556 use crate::types::{DataType, TensorLayout};
557
558 fn create_test_graph() -> ModelGraph {
559 let mut builder = GraphBuilder::new();
560
561 let input_id = builder.add_op("Input", Some("input_layer".to_string()));
562 builder.add_output(input_id, "input_tensor");
563
564 let conv_id = builder.add_op("Conv", Some("conv_layer".to_string()));
565 builder
566 .add_input(conv_id, "input_tensor")
567 .add_output(conv_id, "conv_output");
568
569 builder.connect(input_id, conv_id, "input_tensor").unwrap();
570 builder
571 .set_inputs(vec!["input_tensor".to_string()])
572 .set_outputs(vec!["conv_output".to_string()]);
573
574 builder.build().unwrap()
575 }
576
577 #[tokio::test]
578 async fn test_session_creation() -> Result<()> {
579 let manager = SessionManager::new();
580 let graph = create_test_graph();
581
582 let session_id = manager.create_session(graph).await?;
583 assert!(manager.get_session(session_id).is_some());
584 assert_eq!(manager.session_count(), 1);
585
586 Ok(())
587 }
588
589 #[tokio::test]
590 async fn test_session_inference() -> Result<()> {
591 let manager = SessionManager::new();
592 let graph = create_test_graph();
593
594 let session_id = manager.create_session(graph).await?;
595
596 let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
598 let inputs = vec![input];
599
600 let outputs = manager.run_inference(session_id, inputs).await?;
601 assert_eq!(outputs.len(), 1);
602
603 let stats = manager.get_session_statistics(session_id).await?;
604 assert_eq!(stats.total_inferences, 1);
605 assert!(stats.average_inference_time_ms > 0.0);
606
607 Ok(())
608 }
609
610 #[tokio::test]
611 async fn test_session_destruction() -> Result<()> {
612 let manager = SessionManager::new();
613 let graph = create_test_graph();
614
615 let session_id = manager.create_session(graph).await?;
616 assert_eq!(manager.session_count(), 1);
617
618 manager.destroy_session(session_id).await?;
619 assert_eq!(manager.session_count(), 0);
620 assert!(manager.get_session(session_id).is_none());
621
622 Ok(())
623 }
624
625 #[tokio::test]
626 async fn test_session_limits() -> Result<()> {
627 let config = SessionConfig::default();
628 let manager = SessionManager::with_config(None, Some(1), config);
629
630 let graph1 = create_test_graph();
631 let graph2 = create_test_graph();
632
633 let _session_id1 = manager.create_session(graph1).await?;
635
636 let result = manager.create_session(graph2).await;
638 assert!(result.is_err());
639
640 Ok(())
641 }
642
643 #[tokio::test]
644 async fn test_concurrent_inferences() -> Result<()> {
645 let mut config = SessionConfig::default();
646 config.max_concurrent_inferences = Some(2);
647
648 let manager = Arc::new(SessionManager::with_config(None, None, config.clone()));
649 let graph = create_test_graph();
650
651 let session_id = manager
652 .create_session_with_config(graph, Some(config))
653 .await?;
654
655 let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
656
657 let handles: Vec<_> = (0..5)
659 .map(|_| {
660 let manager = Arc::clone(&manager);
661 let input = input.clone();
662 tokio::spawn(async move { manager.run_inference(session_id, vec![input]).await })
663 })
664 .collect();
665
666 let results: Vec<_> = futures::future::join_all(handles).await;
667
668 let successes = results
670 .iter()
671 .filter(|r| r.as_ref().unwrap().is_ok())
672 .count();
673 let failures = results
674 .iter()
675 .filter(|r| r.as_ref().unwrap().is_err())
676 .count();
677
678 assert!(successes > 0);
679 assert!(failures > 0);
680
681 Ok(())
682 }
683
684 #[tokio::test]
685 async fn test_global_statistics() -> Result<()> {
686 let manager = SessionManager::new();
687 let graph = create_test_graph();
688
689 let session_id1 = manager.create_session(graph.clone()).await?;
690 let session_id2 = manager.create_session(graph).await?;
691
692 let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
693
694 manager
696 .run_inference(session_id1, vec![input.clone()])
697 .await?;
698 manager.run_inference(session_id2, vec![input]).await?;
699
700 let global_stats = manager.get_global_statistics().await;
701 assert_eq!(global_stats.total_sessions, 2);
702 assert_eq!(global_stats.total_inferences, 2);
703
704 Ok(())
705 }
706}