use crate::tensor::Tensor;
use crate::types::{ModelGraph, OptimizationLevel, ProviderId, SessionId};
use anyhow::{Result, anyhow};
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub thread_count: Option<usize>,
pub memory_limit: Option<usize>,
pub optimization_level: OptimizationLevel,
pub preferred_providers: Vec<ProviderId>,
pub timeout_seconds: Option<u64>,
pub max_concurrent_inferences: Option<usize>,
pub enable_metrics: bool,
pub custom_options: HashMap<String, String>,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
thread_count: None,
memory_limit: None,
optimization_level: OptimizationLevel::Basic,
preferred_providers: vec![ProviderId::CPU],
timeout_seconds: Some(30),
max_concurrent_inferences: Some(10),
enable_metrics: true,
custom_options: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct SessionStatistics {
pub total_inferences: u64,
pub total_inference_time_ms: u64,
pub average_inference_time_ms: f64,
pub min_inference_time_ms: Option<u64>,
pub max_inference_time_ms: Option<u64>,
pub peak_memory_bytes: usize,
pub current_memory_bytes: usize,
pub error_count: u64,
pub created_at: Instant,
pub last_inference_at: Option<Instant>,
}
impl Default for SessionStatistics {
fn default() -> Self {
Self {
total_inferences: 0,
total_inference_time_ms: 0,
average_inference_time_ms: 0.0,
min_inference_time_ms: None,
max_inference_time_ms: None,
peak_memory_bytes: 0,
current_memory_bytes: 0,
error_count: 0,
created_at: Instant::now(),
last_inference_at: None,
}
}
}
#[derive(Debug, Clone)]
struct ResourceUsage {
current_memory: usize,
#[allow(dead_code)]
peak_memory: usize,
active_inferences: usize,
}
impl Default for ResourceUsage {
fn default() -> Self {
Self {
current_memory: 0,
peak_memory: 0,
active_inferences: 0,
}
}
}
#[derive(Debug)]
pub struct InferenceSession {
pub id: SessionId,
pub model: Arc<ModelGraph>,
pub config: SessionConfig,
pub statistics: Arc<RwLock<SessionStatistics>>,
resource_usage: Arc<RwLock<ResourceUsage>>,
created_at: Instant,
marked_for_deletion: bool,
}
impl InferenceSession {
pub fn new(model: ModelGraph, config: SessionConfig) -> Self {
let id = SessionId::new_v4();
let created_at = Instant::now();
let mut statistics = SessionStatistics::default();
statistics.created_at = created_at;
Self {
id,
model: Arc::new(model),
config,
statistics: Arc::new(RwLock::new(statistics)),
resource_usage: Arc::new(RwLock::new(ResourceUsage::default())),
created_at,
marked_for_deletion: false,
}
}
pub async fn run_inference(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
let start_time = Instant::now();
self.check_resource_limits().await?;
{
let mut usage = self.resource_usage.write().await;
usage.active_inferences += 1;
if let Some(max_concurrent) = self.config.max_concurrent_inferences {
if usage.active_inferences > max_concurrent {
usage.active_inferences -= 1;
return Err(anyhow!("Max concurrent inferences exceeded"));
}
}
}
let result = self.execute_inference(inputs).await;
let inference_time = start_time.elapsed();
self.update_statistics(inference_time, result.is_ok()).await;
{
let mut usage = self.resource_usage.write().await;
usage.active_inferences = usage.active_inferences.saturating_sub(1);
}
result
}
async fn check_resource_limits(&self) -> Result<()> {
let usage = self.resource_usage.read().await;
if let Some(memory_limit) = self.config.memory_limit {
if usage.current_memory > memory_limit {
return Err(anyhow!(
"Memory limit exceeded: {} > {}",
usage.current_memory,
memory_limit
));
}
}
if let Some(timeout_seconds) = self.config.timeout_seconds {
let timeout = Duration::from_secs(timeout_seconds);
if self.created_at.elapsed() > timeout {
return Err(anyhow!("Session timeout exceeded"));
}
}
Ok(())
}
async fn execute_inference(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
if inputs.len() != self.model.inputs.len() {
return Err(anyhow!(
"Input tensor count mismatch: expected {}, got {}",
self.model.inputs.len(),
inputs.len()
));
}
tokio::time::sleep(Duration::from_millis(1)).await;
let outputs: Result<Vec<Tensor>> = self
.model
.outputs
.iter()
.enumerate()
.map(|(i, _output_name)| {
Tensor::ones(
vec![1, 10],
crate::types::DataType::F32,
crate::types::TensorLayout::RowMajor,
)
.map_err(|e| anyhow!("Failed to create output tensor {}: {}", i, e))
})
.collect();
outputs
}
async fn update_statistics(&self, inference_time: Duration, success: bool) {
let mut stats = self.statistics.write().await;
let inference_time_ms = inference_time.as_millis() as u64;
if success {
stats.total_inferences += 1;
stats.total_inference_time_ms += inference_time_ms;
stats.average_inference_time_ms =
stats.total_inference_time_ms as f64 / stats.total_inferences as f64;
stats.min_inference_time_ms = Some(
stats
.min_inference_time_ms
.map_or(inference_time_ms, |min| min.min(inference_time_ms)),
);
stats.max_inference_time_ms = Some(
stats
.max_inference_time_ms
.map_or(inference_time_ms, |max| max.max(inference_time_ms)),
);
} else {
stats.error_count += 1;
}
stats.last_inference_at = Some(Instant::now());
}
pub async fn get_statistics(&self) -> SessionStatistics {
self.statistics.read().await.clone()
}
pub async fn get_resource_usage(&self) -> ResourceUsage {
self.resource_usage.read().await.clone()
}
pub fn mark_for_deletion(&mut self) {
self.marked_for_deletion = true;
}
pub fn is_marked_for_deletion(&self) -> bool {
self.marked_for_deletion
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
#[derive(Debug)]
pub struct SessionManager {
sessions: DashMap<SessionId, Arc<InferenceSession>>,
#[allow(dead_code)]
global_memory_limit: Option<usize>,
max_sessions: Option<usize>,
default_config: SessionConfig,
}
impl SessionManager {
pub fn new() -> Self {
Self {
sessions: DashMap::new(),
global_memory_limit: None,
max_sessions: Some(100),
default_config: SessionConfig::default(),
}
}
pub fn with_config(
global_memory_limit: Option<usize>,
max_sessions: Option<usize>,
default_config: SessionConfig,
) -> Self {
Self {
sessions: DashMap::new(),
global_memory_limit,
max_sessions,
default_config,
}
}
pub async fn create_session(&self, model: ModelGraph) -> Result<SessionId> {
self.create_session_with_config(model, None).await
}
pub async fn create_session_with_config(
&self,
model: ModelGraph,
config: Option<SessionConfig>,
) -> Result<SessionId> {
if let Some(max_sessions) = self.max_sessions {
if self.sessions.len() >= max_sessions {
self.cleanup_expired_sessions().await;
if self.sessions.len() >= max_sessions {
return Err(anyhow!(
"Maximum number of sessions reached: {}",
max_sessions
));
}
}
}
model
.validate()
.map_err(|e| anyhow!("Invalid model graph: {}", e))?;
let session_config = config.unwrap_or_else(|| self.default_config.clone());
let session = Arc::new(InferenceSession::new(model, session_config));
let session_id = session.id;
self.sessions.insert(session_id, session);
tracing::info!(
"Created session {} with {} nodes",
session_id,
self.sessions.get(&session_id).unwrap().model.nodes.len()
);
Ok(session_id)
}
pub fn get_session(&self, session_id: SessionId) -> Option<Arc<InferenceSession>> {
self.sessions
.get(&session_id)
.map(|entry| entry.value().clone())
}
pub async fn run_inference(
&self,
session_id: SessionId,
inputs: Vec<Tensor>,
) -> Result<Vec<Tensor>> {
let session = self
.get_session(session_id)
.ok_or_else(|| anyhow!("Session not found: {}", session_id))?;
if session.is_marked_for_deletion() {
return Err(anyhow!("Session is marked for deletion: {}", session_id));
}
session.run_inference(&inputs).await
}
pub async fn destroy_session(&self, session_id: SessionId) -> Result<()> {
if let Some((_, session)) = self.sessions.remove(&session_id) {
let timeout = Duration::from_secs(5);
let start = Instant::now();
while start.elapsed() < timeout {
let usage = session.get_resource_usage().await;
if usage.active_inferences == 0 {
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
tracing::info!("Destroyed session {}", session_id);
Ok(())
} else {
Err(anyhow!("Session not found: {}", session_id))
}
}
pub async fn get_session_statistics(&self, session_id: SessionId) -> Result<SessionStatistics> {
let session = self
.get_session(session_id)
.ok_or_else(|| anyhow!("Session not found: {}", session_id))?;
Ok(session.get_statistics().await)
}
pub fn list_sessions(&self) -> Vec<SessionId> {
self.sessions.iter().map(|entry| *entry.key()).collect()
}
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub async fn cleanup_expired_sessions(&self) -> usize {
let mut removed_count = 0;
let max_age = Duration::from_secs(3600);
let expired_sessions: Vec<SessionId> = self
.sessions
.iter()
.filter_map(|entry| {
let session = entry.value();
if session.age() > max_age || session.is_marked_for_deletion() {
Some(*entry.key())
} else {
None
}
})
.collect();
for session_id in expired_sessions {
if self.destroy_session(session_id).await.is_ok() {
removed_count += 1;
}
}
if removed_count > 0 {
tracing::info!("Cleaned up {} expired sessions", removed_count);
}
removed_count
}
pub async fn get_global_statistics(&self) -> GlobalStatistics {
let mut global_stats = GlobalStatistics::default();
for entry in self.sessions.iter() {
let session = entry.value();
let stats = session.get_statistics().await;
let usage = session.get_resource_usage().await;
global_stats.total_sessions += 1;
global_stats.total_inferences += stats.total_inferences;
global_stats.total_errors += stats.error_count;
global_stats.total_memory_bytes += usage.current_memory;
global_stats.active_inferences += usage.active_inferences as u64;
}
global_stats
}
pub async fn shutdown(&self) -> Result<()> {
let session_ids: Vec<SessionId> = self.list_sessions();
tracing::info!(
"Shutting down session manager with {} active sessions",
session_ids.len()
);
for session_id in session_ids {
if let Err(e) = self.destroy_session(session_id).await {
tracing::warn!("Failed to destroy session {}: {}", session_id, e);
}
}
Ok(())
}
}
impl Default for SessionManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct GlobalStatistics {
pub total_sessions: usize,
pub total_inferences: u64,
pub total_errors: u64,
pub total_memory_bytes: usize,
pub active_inferences: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::GraphBuilder;
use crate::types::{DataType, TensorLayout};
fn create_test_graph() -> ModelGraph {
let mut builder = GraphBuilder::new();
let input_id = builder.add_op("Input", Some("input_layer".to_string()));
builder.add_output(input_id, "input_tensor");
let conv_id = builder.add_op("Conv", Some("conv_layer".to_string()));
builder
.add_input(conv_id, "input_tensor")
.add_output(conv_id, "conv_output");
builder.connect(input_id, conv_id, "input_tensor").unwrap();
builder
.set_inputs(vec!["input_tensor".to_string()])
.set_outputs(vec!["conv_output".to_string()]);
builder.build().unwrap()
}
#[tokio::test]
async fn test_session_creation() -> Result<()> {
let manager = SessionManager::new();
let graph = create_test_graph();
let session_id = manager.create_session(graph).await?;
assert!(manager.get_session(session_id).is_some());
assert_eq!(manager.session_count(), 1);
Ok(())
}
#[tokio::test]
async fn test_session_inference() -> Result<()> {
let manager = SessionManager::new();
let graph = create_test_graph();
let session_id = manager.create_session(graph).await?;
let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
let inputs = vec![input];
let outputs = manager.run_inference(session_id, inputs).await?;
assert_eq!(outputs.len(), 1);
let stats = manager.get_session_statistics(session_id).await?;
assert_eq!(stats.total_inferences, 1);
assert!(stats.average_inference_time_ms > 0.0);
Ok(())
}
#[tokio::test]
async fn test_session_destruction() -> Result<()> {
let manager = SessionManager::new();
let graph = create_test_graph();
let session_id = manager.create_session(graph).await?;
assert_eq!(manager.session_count(), 1);
manager.destroy_session(session_id).await?;
assert_eq!(manager.session_count(), 0);
assert!(manager.get_session(session_id).is_none());
Ok(())
}
#[tokio::test]
async fn test_session_limits() -> Result<()> {
let config = SessionConfig::default();
let manager = SessionManager::with_config(None, Some(1), config);
let graph1 = create_test_graph();
let graph2 = create_test_graph();
let _session_id1 = manager.create_session(graph1).await?;
let result = manager.create_session(graph2).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn test_concurrent_inferences() -> Result<()> {
let mut config = SessionConfig::default();
config.max_concurrent_inferences = Some(2);
let manager = Arc::new(SessionManager::with_config(None, None, config.clone()));
let graph = create_test_graph();
let session_id = manager
.create_session_with_config(graph, Some(config))
.await?;
let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
let handles: Vec<_> = (0..5)
.map(|_| {
let manager = Arc::clone(&manager);
let input = input.clone();
tokio::spawn(async move { manager.run_inference(session_id, vec![input]).await })
})
.collect();
let results: Vec<_> = futures::future::join_all(handles).await;
let successes = results
.iter()
.filter(|r| r.as_ref().unwrap().is_ok())
.count();
let failures = results
.iter()
.filter(|r| r.as_ref().unwrap().is_err())
.count();
assert!(successes > 0);
assert!(failures > 0);
Ok(())
}
#[tokio::test]
async fn test_global_statistics() -> Result<()> {
let manager = SessionManager::new();
let graph = create_test_graph();
let session_id1 = manager.create_session(graph.clone()).await?;
let session_id2 = manager.create_session(graph).await?;
let input = Tensor::ones(vec![1, 3, 224, 224], DataType::F32, TensorLayout::RowMajor)?;
manager
.run_inference(session_id1, vec![input.clone()])
.await?;
manager.run_inference(session_id2, vec![input]).await?;
let global_stats = manager.get_global_statistics().await;
assert_eq!(global_stats.total_sessions, 2);
assert_eq!(global_stats.total_inferences, 2);
Ok(())
}
}