ort_parallel/
async_pool.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::Arc,
4    time::{Duration, SystemTime, SystemTimeError},
5};
6
7use ort::session::{
8    RunOptions, SelectedOutputMarker, Session, SessionInputs, SessionOutputs,
9    builder::{PrepackedWeights, SessionBuilder},
10};
11
12use tokio::sync::{Mutex, Semaphore};
13
14use crate::SessionBuilderFactory;
15
16pub struct AsyncSessionPool {
17    sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
18    available_sessions: Arc<Mutex<Vec<usize>>>,
19    sem: Arc<Semaphore>,
20    max: usize,
21    builder: SessionBuilderFactory,
22    file: PathBuf,
23}
24
25impl AsyncSessionPool {
26    pub fn commit_from_file(
27        builder: SessionBuilder,
28        path: &Path,
29        max_sessions: usize,
30    ) -> ort::Result<Self> {
31        assert!(max_sessions > 0);
32        let prepacked_weights = PrepackedWeights::new();
33        let builder = builder.with_prepacked_weights(&prepacked_weights)?;
34        Ok(Self {
35            sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(
36                builder.clone().commit_from_file(path)?,
37            ))])),
38            sem: Arc::new(Semaphore::new(max_sessions)),
39            available_sessions: Arc::new(Mutex::new(vec![0])),
40            max: max_sessions,
41            builder: SessionBuilderFactory(builder),
42            file: path.to_path_buf(),
43        })
44    }
45
46    pub async fn load_all(&self) -> ort::Result<()> {
47        let count = self.max - self.sessions.lock().await.len();
48        let mut sessions = self.sessions.lock().await;
49        let mut avail = self.available_sessions.lock().await;
50        if count != 0 {
51            for _ in 0..count {
52                let session = self.create_new()?;
53                sessions.push(session.clone());
54                avail.push(sessions.len() - 1);
55            }
56        }
57        Ok(())
58    }
59
60    fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
61        Ok(Arc::new(Mutex::new(
62            self.builder.generate().commit_from_file(&self.file)?,
63        )))
64    }
65
66    async fn release_session(&self, idx: usize) {
67        self.available_sessions.lock().await.push(idx);
68        self.sem.add_permits(1);
69    }
70
71    async fn get_session(&self) -> Result<(Arc<Mutex<Session>>, usize), ort::Error> {
72        let _permit = self.sem.acquire().await.unwrap();
73
74        if let Some(idx) = self.available_sessions.lock().await.pop() {
75            let sessions = self.sessions.lock().await;
76            return Ok((sessions[idx].clone(), idx));
77        }
78
79        if self.sessions.lock().await.len() < self.max {
80            let session = match self.create_new() {
81                Ok(v) => v,
82                Err(e) => {
83                    self.sem.add_permits(1);
84                    return Err(e);
85                }
86            };
87            let mut sessions = self.sessions.lock().await;
88            sessions.push(session.clone());
89            return Ok((session, sessions.len() - 1));
90        }
91
92        unreachable!()
93    }
94
95    pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 'r, O: SelectedOutputMarker, const N: usize>(
96        &'s self,
97        input_values: impl Into<SessionInputs<'i, 'v, N>>,
98        run_options: &'r RunOptions<O>,
99    ) -> Result<SessionOutputs<'r>, ort::Error> {
100        let (ses, idx) = self.get_session().await?;
101        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock().await as *mut Session) };
102        let out = match ses.run_async(input_values, run_options) {
103            Ok(v) => v.await,
104            Err(e) => Err(e),
105        };
106        self.release_session(idx).await;
107        out
108    }
109
110    pub async fn run_async_profiled<
111        'r,
112        's: 'r,
113        'i,
114        'v: 'i + 'r,
115        O: SelectedOutputMarker,
116        const N: usize,
117    >(
118        &'s self,
119        input_values: impl Into<SessionInputs<'i, 'v, N>>,
120        run_options: &'r RunOptions<O>,
121    ) -> Result<
122        (
123            SessionOutputs<'r>,
124            Result<Duration, SystemTimeError>,
125            String,
126        ),
127        ort::Error,
128    > {
129        let (ses, idx) = self.get_session().await?;
130        let ses_ptr = &mut *ses.lock().await as *mut Session;
131        let ses1: &'s mut Session = unsafe { &mut *(ses_ptr) };
132        let ses2: &'s mut Session = unsafe { &mut *(ses_ptr) };
133
134        let earlier = SystemTime::now();
135        let _ = ses1.profiling_start_ns()?;
136        let out = match ses1.run_async(input_values, run_options) {
137            Ok(v) => v.await,
138            Err(e) => Err(e),
139        };
140        let end = ses2.end_profiling()?;
141        let now = SystemTime::now().duration_since(earlier);
142        self.release_session(idx).await;
143        out.map(|v| (v, now, end))
144    }
145}