ort_parallel/
async_pool.rs1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::{Duration, SystemTime, SystemTimeError},
5};
6
7use ort::session::{
8 RunOptions, SelectedOutputMarker, Session, SessionInputs, SessionOutputs,
9 builder::{PrepackedWeights, SessionBuilder},
10};
11
12use tokio::sync::{Mutex, Semaphore};
13
14use crate::SessionBuilderFactory;
15
16pub struct AsyncSessionPool {
17 sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
18 available_sessions: Arc<Mutex<Vec<usize>>>,
19 sem: Arc<Semaphore>,
20 max: usize,
21 builder: SessionBuilderFactory,
22 file: PathBuf,
23}
24
25impl AsyncSessionPool {
26 pub fn commit_from_file(
27 builder: SessionBuilder,
28 path: &Path,
29 max_sessions: usize,
30 ) -> ort::Result<Self> {
31 assert!(max_sessions > 0);
32 let prepacked_weights = PrepackedWeights::new();
33 let builder = builder.with_prepacked_weights(&prepacked_weights)?;
34 Ok(Self {
35 sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(
36 builder.clone().commit_from_file(path)?,
37 ))])),
38 sem: Arc::new(Semaphore::new(max_sessions)),
39 available_sessions: Arc::new(Mutex::new(vec![0])),
40 max: max_sessions,
41 builder: SessionBuilderFactory(builder),
42 file: path.to_path_buf(),
43 })
44 }
45
46 pub async fn load_all(&self) -> ort::Result<()> {
47 let count = self.max - self.sessions.lock().await.len();
48 let mut sessions = self.sessions.lock().await;
49 let mut avail = self.available_sessions.lock().await;
50 if count != 0 {
51 for _ in 0..count {
52 let session = self.create_new()?;
53 sessions.push(session.clone());
54 avail.push(sessions.len() - 1);
55 }
56 }
57 Ok(())
58 }
59
60 fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
61 Ok(Arc::new(Mutex::new(
62 self.builder.generate().commit_from_file(&self.file)?,
63 )))
64 }
65
66 async fn release_session(&self, idx: usize) {
67 self.available_sessions.lock().await.push(idx);
68 self.sem.add_permits(1);
69 }
70
71 async fn get_session(&self) -> Result<(Arc<Mutex<Session>>, usize), ort::Error> {
72 let _permit = self.sem.acquire().await.unwrap();
73
74 if let Some(idx) = self.available_sessions.lock().await.pop() {
75 let sessions = self.sessions.lock().await;
76 return Ok((sessions[idx].clone(), idx));
77 }
78
79 if self.sessions.lock().await.len() < self.max {
80 let session = match self.create_new() {
81 Ok(v) => v,
82 Err(e) => {
83 self.sem.add_permits(1);
84 return Err(e);
85 }
86 };
87 let mut sessions = self.sessions.lock().await;
88 sessions.push(session.clone());
89 return Ok((session, sessions.len() - 1));
90 }
91
92 unreachable!()
93 }
94
95 pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 'r, O: SelectedOutputMarker, const N: usize>(
96 &'s self,
97 input_values: impl Into<SessionInputs<'i, 'v, N>>,
98 run_options: &'r RunOptions<O>,
99 ) -> Result<SessionOutputs<'r>, ort::Error> {
100 let (ses, idx) = self.get_session().await?;
101 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock().await as *mut Session) };
102 let out = match ses.run_async(input_values, run_options) {
103 Ok(v) => v.await,
104 Err(e) => Err(e),
105 };
106 self.release_session(idx).await;
107 out
108 }
109
110 pub async fn run_async_profiled<
111 'r,
112 's: 'r,
113 'i,
114 'v: 'i + 'r,
115 O: SelectedOutputMarker,
116 const N: usize,
117 >(
118 &'s self,
119 input_values: impl Into<SessionInputs<'i, 'v, N>>,
120 run_options: &'r RunOptions<O>,
121 ) -> Result<
122 (
123 SessionOutputs<'r>,
124 Result<Duration, SystemTimeError>,
125 String,
126 ),
127 ort::Error,
128 > {
129 let (ses, idx) = self.get_session().await?;
130 let ses_ptr = &mut *ses.lock().await as *mut Session;
131 let ses1: &'s mut Session = unsafe { &mut *(ses_ptr) };
132 let ses2: &'s mut Session = unsafe { &mut *(ses_ptr) };
133
134 let earlier = SystemTime::now();
135 let _ = ses1.profiling_start_ns()?;
136 let out = match ses1.run_async(input_values, run_options) {
137 Ok(v) => v.await,
138 Err(e) => Err(e),
139 };
140 let end = ses2.end_profiling()?;
141 let now = SystemTime::now().duration_since(earlier);
142 self.release_session(idx).await;
143 out.map(|v| (v, now, end))
144 }
145}