Skip to main content

ronn_core/
session.rs

1//! Session lifecycle management for inference contexts.
2//!
3//! This module provides thread-safe session management with resource isolation,
4//! configuration, and graceful cleanup.
5
6use crate::tensor::Tensor;
7use crate::types::{ModelGraph, OptimizationLevel, ProviderId, SessionId};
8use anyhow::{Result, anyhow};
9use dashmap::DashMap;
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15/// Configuration for inference sessions.
16#[derive(Debug, Clone)]
17pub struct SessionConfig {
18    /// Number of worker threads for this session.
19    pub thread_count: Option<usize>,
20    /// Memory limit in bytes for this session.
21    pub memory_limit: Option<usize>,
22    /// Optimization level for model execution.
23    pub optimization_level: OptimizationLevel,
24    /// Preferred execution providers in priority order.
25    pub preferred_providers: Vec<ProviderId>,
26    /// Session timeout in seconds.
27    pub timeout_seconds: Option<u64>,
28    /// Maximum number of concurrent inferences.
29    pub max_concurrent_inferences: Option<usize>,
30    /// Enable performance metrics collection.
31    pub enable_metrics: bool,
32    /// Custom configuration options.
33    pub custom_options: HashMap<String, String>,
34}
35
36impl Default for SessionConfig {
37    fn default() -> Self {
38        Self {
39            thread_count: None,
40            memory_limit: None,
41            optimization_level: OptimizationLevel::Basic,
42            preferred_providers: vec![ProviderId::CPU],
43            timeout_seconds: Some(30),
44            max_concurrent_inferences: Some(10),
45            enable_metrics: true,
46            custom_options: HashMap::new(),
47        }
48    }
49}
50
51/// Runtime statistics for a session.
52#[derive(Debug, Clone)]
53pub struct SessionStatistics {
54    /// Total number of inferences performed.
55    pub total_inferences: u64,
56    /// Total inference time in milliseconds.
57    pub total_inference_time_ms: u64,
58    /// Average inference time in milliseconds.
59    pub average_inference_time_ms: f64,
60    /// Minimum inference time in milliseconds.
61    pub min_inference_time_ms: Option<u64>,
62    /// Maximum inference time in milliseconds.
63    pub max_inference_time_ms: Option<u64>,
64    /// Peak memory usage in bytes.
65    pub peak_memory_bytes: usize,
66    /// Current memory usage in bytes.
67    pub current_memory_bytes: usize,
68    /// Number of errors encountered.
69    pub error_count: u64,
70    /// Session creation time.
71    pub created_at: Instant,
72    /// Last inference time.
73    pub last_inference_at: Option<Instant>,
74}
75
76impl Default for SessionStatistics {
77    fn default() -> Self {
78        Self {
79            total_inferences: 0,
80            total_inference_time_ms: 0,
81            average_inference_time_ms: 0.0,
82            min_inference_time_ms: None,
83            max_inference_time_ms: None,
84            peak_memory_bytes: 0,
85            current_memory_bytes: 0,
86            error_count: 0,
87            created_at: Instant::now(),
88            last_inference_at: None,
89        }
90    }
91}
92
93/// Resource usage tracking for sessions.
94#[derive(Debug, Clone)]
95struct ResourceUsage {
96    /// Current memory usage in bytes.
97    current_memory: usize,
98    /// Peak memory usage in bytes.
99    #[allow(dead_code)]
100    peak_memory: usize,
101    /// Number of active inferences.
102    active_inferences: usize,
103}
104
105impl Default for ResourceUsage {
106    fn default() -> Self {
107        Self {
108            current_memory: 0,
109            peak_memory: 0,
110            active_inferences: 0,
111        }
112    }
113}
114
115/// An active inference session.
116#[derive(Debug)]
117pub struct InferenceSession {
118    /// Unique session identifier.
119    pub id: SessionId,
120    /// The model graph for this session.
121    pub model: Arc<ModelGraph>,
122    /// Session configuration.
123    pub config: SessionConfig,
124    /// Runtime statistics.
125    pub statistics: Arc<RwLock<SessionStatistics>>,
126    /// Resource usage tracking.
127    resource_usage: Arc<RwLock<ResourceUsage>>,
128    /// Session creation time.
129    created_at: Instant,
130    /// Whether the session is marked for deletion.
131    marked_for_deletion: bool,
132}
133
134impl InferenceSession {
135    /// Create a new inference session.
136    pub fn new(model: ModelGraph, config: SessionConfig) -> Self {
137        let id = SessionId::new_v4();
138        let created_at = Instant::now();
139
140        let mut statistics = SessionStatistics::default();
141        statistics.created_at = created_at;
142
143        Self {
144            id,
145            model: Arc::new(model),
146            config,
147            statistics: Arc::new(RwLock::new(statistics)),
148            resource_usage: Arc::new(RwLock::new(ResourceUsage::default())),
149            created_at,
150            marked_for_deletion: false,
151        }
152    }
153
154    /// Run inference on the session.
155    pub async fn run_inference(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
156        let start_time = Instant::now();
157
158        // Check resource limits
159        self.check_resource_limits().await?;
160
161        // Increment active inference count
162        {
163            let mut usage = self.resource_usage.write().await;
164            usage.active_inferences += 1;
165
166            if let Some(max_concurrent) = self.config.max_concurrent_inferences {
167                if usage.active_inferences > max_concurrent {
168                    usage.active_inferences -= 1;
169                    return Err(anyhow!("Max concurrent inferences exceeded"));
170                }
171            }
172        }
173
174        // Simulate inference (in real implementation, this would use execution providers)
175        let result = self.execute_inference(inputs).await;
176
177        // Update statistics
178        let inference_time = start_time.elapsed();
179        self.update_statistics(inference_time, result.is_ok()).await;
180
181        // Decrement active inference count
182        {
183            let mut usage = self.resource_usage.write().await;
184            usage.active_inferences = usage.active_inferences.saturating_sub(1);
185        }
186
187        result
188    }
189
190    /// Check if resource limits are exceeded.
191    async fn check_resource_limits(&self) -> Result<()> {
192        let usage = self.resource_usage.read().await;
193
194        if let Some(memory_limit) = self.config.memory_limit {
195            if usage.current_memory > memory_limit {
196                return Err(anyhow!(
197                    "Memory limit exceeded: {} > {}",
198                    usage.current_memory,
199                    memory_limit
200                ));
201            }
202        }
203
204        // Check timeout
205        if let Some(timeout_seconds) = self.config.timeout_seconds {
206            let timeout = Duration::from_secs(timeout_seconds);
207            if self.created_at.elapsed() > timeout {
208                return Err(anyhow!("Session timeout exceeded"));
209            }
210        }
211
212        Ok(())
213    }
214
215    /// Execute the actual inference (placeholder implementation).
216    async fn execute_inference(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
217        // This is a placeholder implementation
218        // In the real implementation, this would:
219        // 1. Select appropriate execution provider
220        // 2. Compile the model graph into executable kernels
221        // 3. Execute the kernels with the given inputs
222        // 4. Return the results
223
224        // For now, just validate inputs match expected graph inputs
225        if inputs.len() != self.model.inputs.len() {
226            return Err(anyhow!(
227                "Input tensor count mismatch: expected {}, got {}",
228                self.model.inputs.len(),
229                inputs.len()
230            ));
231        }
232
233        // Simulate some work
234        tokio::time::sleep(Duration::from_millis(1)).await;
235
236        // Create dummy outputs based on graph outputs
237        let outputs: Result<Vec<Tensor>> = self
238            .model
239            .outputs
240            .iter()
241            .enumerate()
242            .map(|(i, _output_name)| {
243                // Create a small output tensor as placeholder
244                Tensor::ones(
245                    vec![1, 10],
246                    crate::types::DataType::F32,
247                    crate::types::TensorLayout::RowMajor,
248                )
249                .map_err(|e| anyhow!("Failed to create output tensor {}: {}", i, e))
250            })
251            .collect();
252
253        outputs
254    }
255
256    /// Update session statistics.
257    async fn update_statistics(&self, inference_time: Duration, success: bool) {
258        let mut stats = self.statistics.write().await;
259
260        let inference_time_ms = inference_time.as_millis() as u64;
261
262        if success {
263            stats.total_inferences += 1;
264            stats.total_inference_time_ms += inference_time_ms;
265            stats.average_inference_time_ms =
266                stats.total_inference_time_ms as f64 / stats.total_inferences as f64;
267
268            stats.min_inference_time_ms = Some(
269                stats
270                    .min_inference_time_ms
271                    .map_or(inference_time_ms, |min| min.min(inference_time_ms)),
272            );
273
274            stats.max_inference_time_ms = Some(
275                stats
276                    .max_inference_time_ms
277                    .map_or(inference_time_ms, |max| max.max(inference_time_ms)),
278            );
279        } else {
280            stats.error_count += 1;
281        }
282
283        stats.last_inference_at = Some(Instant::now());
284    }
285
286    /// Get session statistics.
287    pub async fn get_statistics(&self) -> SessionStatistics {
288        self.statistics.read().await.clone()
289    }
290
291    /// Get current resource usage.
292    pub async fn get_resource_usage(&self) -> ResourceUsage {
293        self.resource_usage.read().await.clone()
294    }
295
296    /// Mark session for deletion.
297    pub fn mark_for_deletion(&mut self) {
298        self.marked_for_deletion = true;
299    }
300
301    /// Check if session is marked for deletion.
302    pub fn is_marked_for_deletion(&self) -> bool {
303        self.marked_for_deletion
304    }
305
306    /// Get session age.
307    pub fn age(&self) -> Duration {
308        self.created_at.elapsed()
309    }
310}
311
312/// Thread-safe session manager for managing inference sessions.
313#[derive(Debug)]
314pub struct SessionManager {
315    /// Active sessions storage.
316    sessions: DashMap<SessionId, Arc<InferenceSession>>,
317    /// Global resource limits.
318    #[allow(dead_code)]
319    global_memory_limit: Option<usize>,
320    /// Maximum number of concurrent sessions.
321    max_sessions: Option<usize>,
322    /// Default session configuration.
323    default_config: SessionConfig,
324}
325
326impl SessionManager {
327    /// Create a new session manager.
328    pub fn new() -> Self {
329        Self {
330            sessions: DashMap::new(),
331            global_memory_limit: None,
332            max_sessions: Some(100),
333            default_config: SessionConfig::default(),
334        }
335    }
336
337    /// Create a new session manager with configuration.
338    pub fn with_config(
339        global_memory_limit: Option<usize>,
340        max_sessions: Option<usize>,
341        default_config: SessionConfig,
342    ) -> Self {
343        Self {
344            sessions: DashMap::new(),
345            global_memory_limit,
346            max_sessions,
347            default_config,
348        }
349    }
350
351    /// Create a new inference session.
352    pub async fn create_session(&self, model: ModelGraph) -> Result<SessionId> {
353        self.create_session_with_config(model, None).await
354    }
355
356    /// Create a new inference session with custom configuration.
357    pub async fn create_session_with_config(
358        &self,
359        model: ModelGraph,
360        config: Option<SessionConfig>,
361    ) -> Result<SessionId> {
362        // Check session limits
363        if let Some(max_sessions) = self.max_sessions {
364            if self.sessions.len() >= max_sessions {
365                // Try cleanup first
366                self.cleanup_expired_sessions().await;
367
368                if self.sessions.len() >= max_sessions {
369                    return Err(anyhow!(
370                        "Maximum number of sessions reached: {}",
371                        max_sessions
372                    ));
373                }
374            }
375        }
376
377        // Validate model
378        model
379            .validate()
380            .map_err(|e| anyhow!("Invalid model graph: {}", e))?;
381
382        let session_config = config.unwrap_or_else(|| self.default_config.clone());
383        let session = Arc::new(InferenceSession::new(model, session_config));
384        let session_id = session.id;
385
386        self.sessions.insert(session_id, session);
387
388        tracing::info!(
389            "Created session {} with {} nodes",
390            session_id,
391            self.sessions.get(&session_id).unwrap().model.nodes.len()
392        );
393
394        Ok(session_id)
395    }
396
397    /// Get a session by ID.
398    pub fn get_session(&self, session_id: SessionId) -> Option<Arc<InferenceSession>> {
399        self.sessions
400            .get(&session_id)
401            .map(|entry| entry.value().clone())
402    }
403
404    /// Run inference on a session.
405    pub async fn run_inference(
406        &self,
407        session_id: SessionId,
408        inputs: Vec<Tensor>,
409    ) -> Result<Vec<Tensor>> {
410        let session = self
411            .get_session(session_id)
412            .ok_or_else(|| anyhow!("Session not found: {}", session_id))?;
413
414        if session.is_marked_for_deletion() {
415            return Err(anyhow!("Session is marked for deletion: {}", session_id));
416        }
417
418        session.run_inference(&inputs).await
419    }
420
421    /// Destroy a session.
422    pub async fn destroy_session(&self, session_id: SessionId) -> Result<()> {
423        if let Some((_, session)) = self.sessions.remove(&session_id) {
424            // Wait for any ongoing inferences to complete
425            let timeout = Duration::from_secs(5);
426            let start = Instant::now();
427
428            while start.elapsed() < timeout {
429                let usage = session.get_resource_usage().await;
430                if usage.active_inferences == 0 {
431                    break;
432                }
433                tokio::time::sleep(Duration::from_millis(100)).await;
434            }
435
436            tracing::info!("Destroyed session {}", session_id);
437            Ok(())
438        } else {
439            Err(anyhow!("Session not found: {}", session_id))
440        }
441    }
442
443    /// Get session statistics.
444    pub async fn get_session_statistics(&self, session_id: SessionId) -> Result<SessionStatistics> {
445        let session = self
446            .get_session(session_id)
447            .ok_or_else(|| anyhow!("Session not found: {}", session_id))?;
448
449        Ok(session.get_statistics().await)
450    }
451
452    /// List all active session IDs.
453    pub fn list_sessions(&self) -> Vec<SessionId> {
454        self.sessions.iter().map(|entry| *entry.key()).collect()
455    }
456
457    /// Get the number of active sessions.
458    pub fn session_count(&self) -> usize {
459        self.sessions.len()
460    }
461
462    /// Cleanup expired sessions.
463    pub async fn cleanup_expired_sessions(&self) -> usize {
464        let mut removed_count = 0;
465        let max_age = Duration::from_secs(3600); // 1 hour
466
467        let expired_sessions: Vec<SessionId> = self
468            .sessions
469            .iter()
470            .filter_map(|entry| {
471                let session = entry.value();
472                if session.age() > max_age || session.is_marked_for_deletion() {
473                    Some(*entry.key())
474                } else {
475                    None
476                }
477            })
478            .collect();
479
480        for session_id in expired_sessions {
481            if self.destroy_session(session_id).await.is_ok() {
482                removed_count += 1;
483            }
484        }
485
486        if removed_count > 0 {
487            tracing::info!("Cleaned up {} expired sessions", removed_count);
488        }
489
490        removed_count
491    }
492
493    /// Get global statistics across all sessions.
494    pub async fn get_global_statistics(&self) -> GlobalStatistics {
495        let mut global_stats = GlobalStatistics::default();
496
497        for entry in self.sessions.iter() {
498            let session = entry.value();
499            let stats = session.get_statistics().await;
500            let usage = session.get_resource_usage().await;
501
502            global_stats.total_sessions += 1;
503            global_stats.total_inferences += stats.total_inferences;
504            global_stats.total_errors += stats.error_count;
505            global_stats.total_memory_bytes += usage.current_memory;
506            global_stats.active_inferences += usage.active_inferences as u64;
507        }
508
509        global_stats
510    }
511
512    /// Shutdown the session manager and cleanup all sessions.
513    pub async fn shutdown(&self) -> Result<()> {
514        let session_ids: Vec<SessionId> = self.list_sessions();
515
516        tracing::info!(
517            "Shutting down session manager with {} active sessions",
518            session_ids.len()
519        );
520
521        for session_id in session_ids {
522            if let Err(e) = self.destroy_session(session_id).await {
523                tracing::warn!("Failed to destroy session {}: {}", session_id, e);
524            }
525        }
526
527        Ok(())
528    }
529}
530
531impl Default for SessionManager {
532    fn default() -> Self {
533        Self::new()
534    }
535}
536
537/// Global statistics across all sessions.
538#[derive(Debug, Clone, Default)]
539pub struct GlobalStatistics {
540    /// Total number of active sessions.
541    pub total_sessions: usize,
542    /// Total inferences across all sessions.
543    pub total_inferences: u64,
544    /// Total errors across all sessions.
545    pub total_errors: u64,
546    /// Total memory usage across all sessions.
547    pub total_memory_bytes: usize,
548    /// Total active inferences across all sessions.
549    pub active_inferences: u64,
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use crate::graph::GraphBuilder;
556    use crate::types::{DataType, TensorLayout};
557
558    fn create_test_graph() -> ModelGraph {
559        let mut builder = GraphBuilder::new();
560
561        let input_id = builder.add_op("Input", Some("input_layer".to_string()));
562        builder.add_output(input_id, "input_tensor");
563
564        let conv_id = builder.add_op("Conv", Some("conv_layer".to_string()));
565        builder
566            .add_input(conv_id, "input_tensor")
567            .add_output(conv_id, "conv_output");
568
569        builder.connect(input_id, conv_id, "input_tensor").unwrap();
570        builder
571            .set_inputs(vec!["input_tensor".to_string()])
572            .set_outputs(vec!["conv_output".to_string()]);
573
574        builder.build().unwrap()
575    }
576
577    #[tokio::test]
578    async fn test_session_creation() -> Result<()> {
579        let manager = SessionManager::new();
580        let graph = create_test_graph();
581
582        let session_id = manager.create_session(graph).await?;
583        assert!(manager.get_session(session_id).is_some());
584        assert_eq!(manager.session_count(), 1);
585
586        Ok(())
587    }
588
589    #[tokio::test]
590    async fn test_session_inference() -> Result<()> {
591        let manager = SessionManager::new();
592        let graph = create_test_graph();
593
594        let session_id = manager.create_session(graph).await?;
595
596        // Create test input
597        let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
598        let inputs = vec![input];
599
600        let outputs = manager.run_inference(session_id, inputs).await?;
601        assert_eq!(outputs.len(), 1);
602
603        let stats = manager.get_session_statistics(session_id).await?;
604        assert_eq!(stats.total_inferences, 1);
605        assert!(stats.average_inference_time_ms > 0.0);
606
607        Ok(())
608    }
609
610    #[tokio::test]
611    async fn test_session_destruction() -> Result<()> {
612        let manager = SessionManager::new();
613        let graph = create_test_graph();
614
615        let session_id = manager.create_session(graph).await?;
616        assert_eq!(manager.session_count(), 1);
617
618        manager.destroy_session(session_id).await?;
619        assert_eq!(manager.session_count(), 0);
620        assert!(manager.get_session(session_id).is_none());
621
622        Ok(())
623    }
624
625    #[tokio::test]
626    async fn test_session_limits() -> Result<()> {
627        let config = SessionConfig::default();
628        let manager = SessionManager::with_config(None, Some(1), config);
629
630        let graph1 = create_test_graph();
631        let graph2 = create_test_graph();
632
633        // First session should succeed
634        let _session_id1 = manager.create_session(graph1).await?;
635
636        // Second session should fail due to limit
637        let result = manager.create_session(graph2).await;
638        assert!(result.is_err());
639
640        Ok(())
641    }
642
643    #[tokio::test]
644    async fn test_concurrent_inferences() -> Result<()> {
645        let mut config = SessionConfig::default();
646        config.max_concurrent_inferences = Some(2);
647
648        let manager = Arc::new(SessionManager::with_config(None, None, config.clone()));
649        let graph = create_test_graph();
650
651        let session_id = manager
652            .create_session_with_config(graph, Some(config))
653            .await?;
654
655        let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
656
657        // Launch multiple concurrent inferences
658        let handles: Vec<_> = (0..5)
659            .map(|_| {
660                let manager = Arc::clone(&manager);
661                let input = input.clone();
662                tokio::spawn(async move { manager.run_inference(session_id, vec![input]).await })
663            })
664            .collect();
665
666        let results: Vec<_> = futures::future::join_all(handles).await;
667
668        // Some should succeed, some should fail due to concurrency limit
669        let successes = results
670            .iter()
671            .filter(|r| r.as_ref().unwrap().is_ok())
672            .count();
673        let failures = results
674            .iter()
675            .filter(|r| r.as_ref().unwrap().is_err())
676            .count();
677
678        assert!(successes > 0);
679        assert!(failures > 0);
680
681        Ok(())
682    }
683
684    #[tokio::test]
685    async fn test_global_statistics() -> Result<()> {
686        let manager = SessionManager::new();
687        let graph = create_test_graph();
688
689        let session_id1 = manager.create_session(graph.clone()).await?;
690        let session_id2 = manager.create_session(graph).await?;
691
692        let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
693
694        // Run inference on both sessions
695        manager
696            .run_inference(session_id1, vec![input.clone()])
697            .await?;
698        manager.run_inference(session_id2, vec![input]).await?;
699
700        let global_stats = manager.get_global_statistics().await;
701        assert_eq!(global_stats.total_sessions, 2);
702        assert_eq!(global_stats.total_inferences, 2);
703
704        Ok(())
705    }
706}