ort_parallel/
async_pool.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::Arc,
4    time::{Duration, SystemTime, SystemTimeError},
5};
6
7use ort::session::{
8    RunOptions, SelectedOutputMarker, Session, SessionInputs, SessionOutputs,
9    builder::{PrepackedWeights, SessionBuilder},
10};
11
12use tokio::sync::{Mutex, Semaphore};
13
14pub struct AsyncSessionPool {
15    sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
16    available_sessions: Arc<Mutex<Vec<usize>>>,
17    sem: Arc<Semaphore>,
18    max: usize,
19    builder: SessionBuilder,
20    file: PathBuf,
21}
22
23impl AsyncSessionPool {
24    pub fn commit_from_file(
25        builder: SessionBuilder,
26        path: &Path,
27        max_sessions: usize,
28    ) -> ort::Result<Self> {
29        assert!(max_sessions > 0);
30        let prepacked_weights = PrepackedWeights::new();
31        let builder = builder.with_prepacked_weights(&prepacked_weights)?;
32        Ok(Self {
33            sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(
34                builder.clone().commit_from_file(path)?,
35            ))])),
36            sem: Arc::new(Semaphore::new(max_sessions)),
37            available_sessions: Arc::new(Mutex::new(vec![0])),
38            max: max_sessions,
39            builder,
40            file: path.to_path_buf(),
41        })
42    }
43
44    pub async fn load_all(&self) -> ort::Result<()> {
45        let count = self.max - self.sessions.lock().await.len();
46        let mut sessions = self.sessions.lock().await;
47        let mut avail = self.available_sessions.lock().await;
48        if count != 0 {
49            for _ in 0..count {
50                let session = self.create_new()?;
51                sessions.push(session.clone());
52                avail.push(sessions.len() - 1);
53            }
54        }
55        Ok(())
56    }
57
58    fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
59        Ok(Arc::new(Mutex::new(
60            self.builder.clone().commit_from_file(&self.file)?,
61        )))
62    }
63
64    async fn release_session(&self, idx: usize) {
65        self.available_sessions.lock().await.push(idx);
66        self.sem.add_permits(1);
67    }
68
69    async fn get_session(&self) -> Result<(Arc<Mutex<Session>>, usize), ort::Error> {
70        let _permit = self.sem.acquire().await.unwrap();
71
72        if let Some(idx) = self.available_sessions.lock().await.pop() {
73            let sessions = self.sessions.lock().await;
74            return Ok((sessions[idx].clone(), idx));
75        }
76
77        if self.sessions.lock().await.len() < self.max {
78            let session = match self.create_new() {
79                Ok(v) => v,
80                Err(e) => {
81                    self.sem.add_permits(1);
82                    return Err(e);
83                }
84            };
85            let mut sessions = self.sessions.lock().await;
86            sessions.push(session.clone());
87            return Ok((session, sessions.len() - 1));
88        }
89
90        unreachable!()
91    }
92
93    pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 'r, O: SelectedOutputMarker, const N: usize>(
94        &'s self,
95        input_values: impl Into<SessionInputs<'i, 'v, N>>,
96        run_options: &'r RunOptions<O>,
97    ) -> Result<SessionOutputs<'r>, ort::Error> {
98        let (ses, idx) = self.get_session().await?;
99        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock().await as *mut Session) };
100        let out = match ses.run_async(input_values, run_options) {
101            Ok(v) => v.await,
102            Err(e) => Err(e),
103        };
104        self.release_session(idx).await;
105        out
106    }
107
108    pub async fn run_async_profiled<
109        'r,
110        's: 'r,
111        'i,
112        'v: 'i + 'r,
113        O: SelectedOutputMarker,
114        const N: usize,
115    >(
116        &'s self,
117        input_values: impl Into<SessionInputs<'i, 'v, N>>,
118        run_options: &'r RunOptions<O>,
119    ) -> Result<
120        (
121            SessionOutputs<'r>,
122            Result<Duration, SystemTimeError>,
123            String,
124        ),
125        ort::Error,
126    > {
127        let (ses, idx) = self.get_session().await?;
128        let ses_ptr = &mut *ses.lock().await as *mut Session;
129        let ses1: &'s mut Session = unsafe { &mut *(ses_ptr) };
130        let ses2: &'s mut Session = unsafe { &mut *(ses_ptr) };
131
132        let earlier = SystemTime::now();
133        let _ = ses1.profiling_start_ns()?;
134        let out = match ses1.run_async(input_values, run_options) {
135            Ok(v) => v.await,
136            Err(e) => Err(e),
137        };
138        let end = ses2.end_profiling()?;
139        let now = SystemTime::now().duration_since(earlier);
140        self.release_session(idx).await;
141        out.map(|v| (v, now, end))
142    }
143}