ort_parallel/
sync_pool.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::Arc,
4};
5
6use ort::{
7    io_binding::IoBinding,
8    session::{
9        NoSelectedOutputs, RunOptions, SelectedOutputMarker, Session, SessionInputs,
10        SessionOutputs,
11        builder::{PrepackedWeights, SessionBuilder},
12    },
13};
14use parking_lot::Mutex;
15
16use crate::semaphore::Semaphore;
17
18pub struct SessionPool {
19    sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
20    available_sessions: Arc<Mutex<Vec<usize>>>,
21    sem: Arc<Semaphore>,
22    max: usize,
23    builder: SessionBuilder,
24    file: PathBuf,
25}
26
27impl SessionPool {
28    pub fn commit_from_file(
29        builder: SessionBuilder,
30        path: &Path,
31        max_sessions: usize,
32    ) -> ort::Result<Self> {
33        assert!(max_sessions > 0);
34        let prepacked_weights = PrepackedWeights::new();
35        let builder = builder.with_prepacked_weights(&prepacked_weights)?;
36        Ok(Self {
37            sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(
38                builder.clone().commit_from_file(path)?,
39            ))])),
40            sem: Arc::new(Semaphore::new(max_sessions)),
41            available_sessions: Arc::new(Mutex::new(vec![0])),
42            max: max_sessions,
43            builder,
44            file: path.to_path_buf(),
45        })
46    }
47
48    pub fn load_all(&self) -> ort::Result<()> {
49        let count = self.max - self.sessions.lock().len();
50        let mut sessions = self.sessions.lock();
51        let mut avail = self.available_sessions.lock();
52        if count != 0 {
53            for _ in 0..count {
54                let session = self.create_new()?;
55                sessions.push(session.clone());
56                avail.push(sessions.len() - 1);
57            }
58        }
59        Ok(())
60    }
61
62    fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
63        Ok(Arc::new(Mutex::new(
64            self.builder.clone().commit_from_file(&self.file)?,
65        )))
66    }
67
68    fn release_session(&self, idx: usize) {
69        self.available_sessions.lock().push(idx);
70        self.sem.release();
71    }
72
73    fn get_session(&self) -> Result<(Arc<Mutex<Session>>, usize), ort::Error> {
74        let _permit = self.sem.acquire();
75
76        if let Some(idx) = self.available_sessions.lock().pop() {
77            let sessions = self.sessions.lock();
78            return Ok((sessions[idx].clone(), idx));
79        }
80
81        if self.sessions.lock().len() < self.max {
82            let session = match self.create_new() {
83                Ok(v) => v,
84                Err(e) => {
85                    self.sem.release();
86                    return Err(e);
87                }
88            };
89            let mut sessions = self.sessions.lock();
90            sessions.push(session.clone());
91            return Ok((session, sessions.len() - 1));
92        }
93
94        unreachable!()
95    }
96
97    pub fn run_binding<'b, 's: 'b>(
98        &'s self,
99        binding: &'b IoBinding,
100    ) -> ort::Result<SessionOutputs<'b>> {
101        let (ses, idx) = self.get_session()?;
102        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
103        let out = ses.run_binding(binding);
104        self.release_session(idx);
105        out
106    }
107
108    pub fn run_binding_with_options<'r, 'b, 's: 'b>(
109        &'s self,
110        binding: &'b IoBinding,
111        run_options: &'r RunOptions<NoSelectedOutputs>,
112    ) -> ort::Result<SessionOutputs<'b>> {
113        let (ses, idx) = self.get_session()?;
114        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
115        let out = ses.run_binding_with_options(binding, run_options);
116        self.release_session(idx);
117        out
118    }
119
120    pub fn run<'s, 'i, 'v: 'i, const N: usize>(
121        &'s self,
122        input_values: impl Into<SessionInputs<'i, 'v, N>>,
123    ) -> ort::Result<SessionOutputs<'s>> {
124        let (ses, idx) = self.get_session()?;
125        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
126        let out = ses.run(input_values);
127        self.release_session(idx);
128        out
129    }
130
131    pub fn run_with_options<
132        'r,
133        's: 'r,
134        'i,
135        'v: 'i + 'r,
136        O: SelectedOutputMarker,
137        const N: usize,
138    >(
139        &'s self,
140        input_values: impl Into<SessionInputs<'i, 'v, N>>,
141        run_options: &'r RunOptions<O>,
142    ) -> Result<SessionOutputs<'r>, ort::Error> {
143        let (ses, idx) = self.get_session()?;
144        let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
145        let out = ses.run_with_options(input_values, run_options);
146        self.release_session(idx);
147        out
148    }
149}