ort_parallel/
async_pool.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::Arc,
4    time::{Duration, SystemTime, SystemTimeError},
5};
6
7use ort::session::{
8    Input, RunOptions, SelectedOutputMarker, Session, SessionInputs, SessionOutputs,
9    builder::{PrepackedWeights, SessionBuilder},
10};
11
12use tokio::sync::{Mutex, Semaphore, SemaphorePermit};
13
14use crate::SessionBuilderFactory;
15
16pub struct AsyncSessionPool {
17    sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
18    available_sessions: Arc<Mutex<Vec<usize>>>,
19    sem: Arc<Semaphore>,
20    max: usize,
21    builder: SessionBuilderFactory,
22    file: PathBuf,
23    pub inputs: Vec<Input>,
24}
25
26fn clone_inputs(inp: &Input) -> Input {
27    Input {
28        name: inp.name.clone(),
29        input_type: inp.input_type.clone(),
30    }
31}
32
33impl AsyncSessionPool {
34    pub fn commit_from_file(
35        builder: SessionBuilder,
36        path: &Path,
37        max_sessions: usize,
38    ) -> ort::Result<Self> {
39        assert!(max_sessions > 0);
40        let prepacked_weights = PrepackedWeights::new();
41        let builder = builder.with_prepacked_weights(&prepacked_weights)?;
42        let fs = builder.clone().commit_from_file(path)?;
43        let inputs = fs.inputs.iter().map(clone_inputs).collect::<Vec<Input>>();
44        Ok(Self {
45            inputs,
46            sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(fs))])),
47            sem: Arc::new(Semaphore::new(max_sessions)),
48            available_sessions: Arc::new(Mutex::new(vec![0])),
49            max: max_sessions,
50            builder: SessionBuilderFactory(builder),
51            file: path.to_path_buf(),
52        })
53    }
54
55    pub async fn load_all(&self) -> ort::Result<()> {
56        let count = self.max - self.sessions.lock().await.len();
57        let mut sessions = self.sessions.lock().await;
58        let mut avail = self.available_sessions.lock().await;
59        if count != 0 {
60            for _ in 0..count {
61                let session = self.create_new()?;
62                sessions.push(session.clone());
63                avail.push(sessions.len() - 1);
64            }
65        }
66        Ok(())
67    }
68
69    fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
70        Ok(Arc::new(Mutex::new(
71            self.builder.generate().commit_from_file(&self.file)?,
72        )))
73    }
74
75    async fn release_session(&self, idx: usize, permit: SemaphorePermit<'_>) {
76        self.available_sessions.lock().await.push(idx);
77        drop(permit);
78    }
79
80    async fn get_session(
81        &self,
82    ) -> Result<(Arc<Mutex<Session>>, usize, SemaphorePermit<'_>), ort::Error> {
83        let _permit = self.sem.acquire().await.unwrap();
84
85        if let Some(idx) = self.available_sessions.lock().await.pop() {
86            let sessions = self.sessions.lock().await;
87            return Ok((sessions[idx].clone(), idx, _permit));
88        }
89
90        if self.sessions.lock().await.len() < self.max {
91            let session = match self.create_new() {
92                Ok(v) => v,
93                Err(e) => {
94                    return Err(e);
95                }
96            };
97            let mut sessions = self.sessions.lock().await;
98            sessions.push(session.clone());
99            return Ok((session, sessions.len() - 1, _permit));
100        }
101
102        unreachable!()
103    }
104
105    pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 'r, O: SelectedOutputMarker, const N: usize>(
106        &'s self,
107        input_values: impl Into<SessionInputs<'i, 'v, N>>,
108        run_options: &'r RunOptions<O>,
109    ) -> Result<SessionOutputs<'r>, ort::Error> {
110        let (ses, idx, permit) = self.get_session().await?;
111        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock().await as *mut Session) };
112        let out = match ses.run_async(input_values, run_options) {
113            Ok(v) => v.await,
114            Err(e) => Err(e),
115        };
116        self.release_session(idx, permit).await;
117        out
118    }
119
120    pub async fn run_async_profiled<
121        'r,
122        's: 'r,
123        'i,
124        'v: 'i + 'r,
125        O: SelectedOutputMarker,
126        const N: usize,
127    >(
128        &'s self,
129        input_values: impl Into<SessionInputs<'i, 'v, N>>,
130        run_options: &'r RunOptions<O>,
131    ) -> Result<
132        (
133            SessionOutputs<'r>,
134            Result<Duration, SystemTimeError>,
135            String,
136        ),
137        ort::Error,
138    > {
139        let (ses, idx, permit) = self.get_session().await?;
140        let ses_ptr = &mut *ses.lock().await as *mut Session;
141        let ses1: &'s mut Session = unsafe { &mut *(ses_ptr) };
142        let ses2: &'s mut Session = unsafe { &mut *(ses_ptr) };
143
144        let earlier = SystemTime::now();
145        let _ = ses1.profiling_start_ns()?;
146        let out = match ses1.run_async(input_values, run_options) {
147            Ok(v) => v.await,
148            Err(e) => Err(e),
149        };
150        let end = ses2.end_profiling()?;
151        let now = SystemTime::now().duration_since(earlier);
152        self.release_session(idx, permit).await;
153        out.map(|v| (v, now, end))
154    }
155}