ort_parallel/
async_pool.rs1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::{Duration, SystemTime, SystemTimeError},
5};
6
7use ort::session::{
8 RunOptions, SelectedOutputMarker, Session, SessionInputs, SessionOutputs,
9 builder::{PrepackedWeights, SessionBuilder},
10};
11
12use tokio::sync::{Mutex, Semaphore};
13
14pub struct AsyncSessionPool {
15 sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
16 available_sessions: Arc<Mutex<Vec<usize>>>,
17 sem: Arc<Semaphore>,
18 max: usize,
19 builder: SessionBuilder,
20 file: PathBuf,
21}
22
23impl AsyncSessionPool {
24 pub fn commit_from_file(
25 builder: SessionBuilder,
26 path: &Path,
27 max_sessions: usize,
28 ) -> ort::Result<Self> {
29 assert!(max_sessions > 0);
30 let prepacked_weights = PrepackedWeights::new();
31 let builder = builder.with_prepacked_weights(&prepacked_weights)?;
32 Ok(Self {
33 sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(
34 builder.clone().commit_from_file(path)?,
35 ))])),
36 sem: Arc::new(Semaphore::new(max_sessions)),
37 available_sessions: Arc::new(Mutex::new(vec![0])),
38 max: max_sessions,
39 builder,
40 file: path.to_path_buf(),
41 })
42 }
43
44 pub async fn load_all(&self) -> ort::Result<()> {
45 let count = self.max - self.sessions.lock().await.len();
46 let mut sessions = self.sessions.lock().await;
47 let mut avail = self.available_sessions.lock().await;
48 if count != 0 {
49 for _ in 0..count {
50 let session = self.create_new()?;
51 sessions.push(session.clone());
52 avail.push(sessions.len() - 1);
53 }
54 }
55 Ok(())
56 }
57
58 fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
59 Ok(Arc::new(Mutex::new(
60 self.builder.clone().commit_from_file(&self.file)?,
61 )))
62 }
63
64 async fn release_session(&self, idx: usize) {
65 self.available_sessions.lock().await.push(idx);
66 self.sem.add_permits(1);
67 }
68
69 async fn get_session(&self) -> Result<(Arc<Mutex<Session>>, usize), ort::Error> {
70 let _permit = self.sem.acquire().await.unwrap();
71
72 if let Some(idx) = self.available_sessions.lock().await.pop() {
73 let sessions = self.sessions.lock().await;
74 return Ok((sessions[idx].clone(), idx));
75 }
76
77 if self.sessions.lock().await.len() < self.max {
78 let session = match self.create_new() {
79 Ok(v) => v,
80 Err(e) => {
81 self.sem.add_permits(1);
82 return Err(e);
83 }
84 };
85 let mut sessions = self.sessions.lock().await;
86 sessions.push(session.clone());
87 return Ok((session, sessions.len() - 1));
88 }
89
90 unreachable!()
91 }
92
93 pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 'r, O: SelectedOutputMarker, const N: usize>(
94 &'s self,
95 input_values: impl Into<SessionInputs<'i, 'v, N>>,
96 run_options: &'r RunOptions<O>,
97 ) -> Result<SessionOutputs<'r>, ort::Error> {
98 let (ses, idx) = self.get_session().await?;
99 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock().await as *mut Session) };
100 let out = match ses.run_async(input_values, run_options) {
101 Ok(v) => v.await,
102 Err(e) => Err(e),
103 };
104 self.release_session(idx).await;
105 out
106 }
107
108 pub async fn run_async_profiled<
109 'r,
110 's: 'r,
111 'i,
112 'v: 'i + 'r,
113 O: SelectedOutputMarker,
114 const N: usize,
115 >(
116 &'s self,
117 input_values: impl Into<SessionInputs<'i, 'v, N>>,
118 run_options: &'r RunOptions<O>,
119 ) -> Result<
120 (
121 SessionOutputs<'r>,
122 Result<Duration, SystemTimeError>,
123 String,
124 ),
125 ort::Error,
126 > {
127 let (ses, idx) = self.get_session().await?;
128 let ses_ptr = &mut *ses.lock().await as *mut Session;
129 let ses1: &'s mut Session = unsafe { &mut *(ses_ptr) };
130 let ses2: &'s mut Session = unsafe { &mut *(ses_ptr) };
131
132 let earlier = SystemTime::now();
133 let _ = ses1.profiling_start_ns()?;
134 let out = match ses1.run_async(input_values, run_options) {
135 Ok(v) => v.await,
136 Err(e) => Err(e),
137 };
138 let end = ses2.end_profiling()?;
139 let now = SystemTime::now().duration_since(earlier);
140 self.release_session(idx).await;
141 out.map(|v| (v, now, end))
142 }
143}