ort_parallel/
sync_pool.rs1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4};
5
6use ort::{
7 io_binding::IoBinding,
8 session::{
9 NoSelectedOutputs, RunOptions, SelectedOutputMarker, Session, SessionInputs,
10 SessionOutputs,
11 builder::{PrepackedWeights, SessionBuilder},
12 },
13};
14use parking_lot::Mutex;
15
16use crate::{SessionBuilderFactory, semaphore::Semaphore};
17
18pub struct SessionPool {
19 sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
20 available_sessions: Arc<Mutex<Vec<usize>>>,
21 sem: Arc<Semaphore>,
22 max: usize,
23 builder: SessionBuilderFactory,
24 file: PathBuf,
25}
26
27impl SessionPool {
28 pub fn commit_from_file(
29 builder: SessionBuilder,
30 path: &Path,
31 max_sessions: usize,
32 ) -> ort::Result<Self> {
33 assert!(max_sessions > 0);
34 let prepacked_weights = PrepackedWeights::new();
35 let builder = builder.with_prepacked_weights(&prepacked_weights)?;
36 Ok(Self {
37 sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(
38 builder.clone().commit_from_file(path)?,
39 ))])),
40 sem: Arc::new(Semaphore::new(max_sessions)),
41 available_sessions: Arc::new(Mutex::new(vec![0])),
42 max: max_sessions,
43 builder: SessionBuilderFactory(builder),
44 file: path.to_path_buf(),
45 })
46 }
47
48 pub fn load_all(&self) -> ort::Result<()> {
49 let count = self.max - self.sessions.lock().len();
50 let mut sessions = self.sessions.lock();
51 let mut avail = self.available_sessions.lock();
52 if count != 0 {
53 for _ in 0..count {
54 let session = self.create_new()?;
55 sessions.push(session.clone());
56 avail.push(sessions.len() - 1);
57 }
58 }
59 Ok(())
60 }
61
62 fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
63 Ok(Arc::new(Mutex::new(
64 self.builder.generate().commit_from_file(&self.file)?,
65 )))
66 }
67
68 fn release_session(&self, idx: usize) {
69 self.available_sessions.lock().push(idx);
70 self.sem.release();
71 }
72
73 fn get_session(&self) -> Result<(Arc<Mutex<Session>>, usize), ort::Error> {
74 let _permit = self.sem.acquire();
75
76 if let Some(idx) = self.available_sessions.lock().pop() {
77 let sessions = self.sessions.lock();
78 return Ok((sessions[idx].clone(), idx));
79 }
80
81 if self.sessions.lock().len() < self.max {
82 let session = match self.create_new() {
83 Ok(v) => v,
84 Err(e) => {
85 self.sem.release();
86 return Err(e);
87 }
88 };
89 let mut sessions = self.sessions.lock();
90 sessions.push(session.clone());
91 return Ok((session, sessions.len() - 1));
92 }
93
94 unreachable!()
95 }
96
97 pub fn run_binding<'b, 's: 'b>(
98 &'s self,
99 binding: &'b IoBinding,
100 ) -> ort::Result<SessionOutputs<'b>> {
101 let (ses, idx) = self.get_session()?;
102 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
103 let out = ses.run_binding(binding);
104 self.release_session(idx);
105 out
106 }
107
108 pub fn run_binding_with_options<'r, 'b, 's: 'b>(
109 &'s self,
110 binding: &'b IoBinding,
111 run_options: &'r RunOptions<NoSelectedOutputs>,
112 ) -> ort::Result<SessionOutputs<'b>> {
113 let (ses, idx) = self.get_session()?;
114 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
115 let out = ses.run_binding_with_options(binding, run_options);
116 self.release_session(idx);
117 out
118 }
119
120 pub fn run<'s, 'i, 'v: 'i, const N: usize>(
121 &'s self,
122 input_values: impl Into<SessionInputs<'i, 'v, N>>,
123 ) -> ort::Result<SessionOutputs<'s>> {
124 let (ses, idx) = self.get_session()?;
125 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
126 let out = ses.run(input_values);
127 self.release_session(idx);
128 out
129 }
130
131 pub fn run_with_options<
132 'r,
133 's: 'r,
134 'i,
135 'v: 'i + 'r,
136 O: SelectedOutputMarker,
137 const N: usize,
138 >(
139 &'s self,
140 input_values: impl Into<SessionInputs<'i, 'v, N>>,
141 run_options: &'r RunOptions<O>,
142 ) -> Result<SessionOutputs<'r>, ort::Error> {
143 let (ses, idx) = self.get_session()?;
144 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock() as *mut Session) };
145 let out = ses.run_with_options(input_values, run_options);
146 self.release_session(idx);
147 out
148 }
149}