use crate::error::{Error, Result};
use crate::{InferenceSession, Model, SessionOptions};
use ronn_core::tensor::Tensor;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
pub struct AsyncSession {
inner: Arc<RwLock<InferenceSession>>,
}
impl AsyncSession {
pub async fn from_file(path: impl AsRef<std::path::Path> + Send + 'static) -> Result<Self> {
let session = tokio::task::spawn_blocking(move || {
let model = Model::load(path)?;
model.create_session_default()
})
.await
.map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))??;
Ok(Self {
inner: Arc::new(RwLock::new(session)),
})
}
pub async fn with_options(
path: impl AsRef<std::path::Path> + Send + 'static,
options: SessionOptions,
) -> Result<Self> {
let session = tokio::task::spawn_blocking(move || {
let model = Model::load(path)?;
model.create_session(options)
})
.await
.map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))??;
Ok(Self {
inner: Arc::new(RwLock::new(session)),
})
}
pub async fn run(&self, inputs: HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
let session_arc = Arc::clone(&self.inner);
tokio::task::spawn_blocking(move || {
let session = session_arc
.read()
.map_err(|e| Error::InferenceError(format!("Lock poisoned: {}", e)))?;
let inputs_ref: HashMap<&str, Tensor> = inputs
.iter()
.map(|(k, v)| (k.as_str(), v.clone()))
.collect();
session.run(inputs_ref)
})
.await
.map_err(|e| Error::InferenceError(format!("Task join error: {}", e)))?
}
pub async fn run_concurrent(
&self,
inputs: HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>> {
self.run(inputs).await
}
pub fn clone_handle(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl Clone for AsyncSession {
fn clone(&self) -> Self {
self.clone_handle()
}
}
pub struct AsyncBatchProcessor {
session: AsyncSession,
max_batch_size: usize,
timeout_ms: u64,
}
impl AsyncBatchProcessor {
pub fn new(session: AsyncSession, max_batch_size: usize, timeout_ms: u64) -> Self {
Self {
session,
max_batch_size,
timeout_ms,
}
}
pub async fn infer(&self, inputs: HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
self.session.run(inputs).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_async_session_creation() {
}
#[tokio::test]
async fn test_concurrent_inference() {
}
#[tokio::test]
async fn test_batch_processor() {
}
}