1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::{Duration, SystemTime, SystemTimeError},
5};
6
7use ort::session::{
8 Input, RunOptions, SelectedOutputMarker, Session, SessionInputs, SessionOutputs,
9 builder::{PrepackedWeights, SessionBuilder},
10};
11
12use tokio::sync::{Mutex, Semaphore, SemaphorePermit};
13
14use crate::SessionBuilderFactory;
15
16pub struct AsyncSessionPool {
17 sessions: Arc<Mutex<Vec<Arc<Mutex<Session>>>>>,
18 available_sessions: Arc<Mutex<Vec<usize>>>,
19 sem: Arc<Semaphore>,
20 max: usize,
21 builder: SessionBuilderFactory,
22 file: PathBuf,
23 pub inputs: Vec<Input>,
24}
25
26fn clone_inputs(inp: &Input) -> Input {
27 Input {
28 name: inp.name.clone(),
29 input_type: inp.input_type.clone(),
30 }
31}
32
33impl AsyncSessionPool {
34 pub fn commit_from_file(
35 builder: SessionBuilder,
36 path: &Path,
37 max_sessions: usize,
38 ) -> ort::Result<Self> {
39 assert!(max_sessions > 0);
40 let prepacked_weights = PrepackedWeights::new();
41 let builder = builder.with_prepacked_weights(&prepacked_weights)?;
42 let fs = builder.clone().commit_from_file(path)?;
43 let inputs = fs.inputs.iter().map(clone_inputs).collect::<Vec<Input>>();
44 Ok(Self {
45 inputs,
46 sessions: Arc::new(Mutex::new(vec![Arc::new(Mutex::new(fs))])),
47 sem: Arc::new(Semaphore::new(max_sessions)),
48 available_sessions: Arc::new(Mutex::new(vec![0])),
49 max: max_sessions,
50 builder: SessionBuilderFactory(builder),
51 file: path.to_path_buf(),
52 })
53 }
54
55 pub async fn load_all(&self) -> ort::Result<()> {
56 let count = self.max - self.sessions.lock().await.len();
57 let mut sessions = self.sessions.lock().await;
58 let mut avail = self.available_sessions.lock().await;
59 if count != 0 {
60 for _ in 0..count {
61 let session = self.create_new()?;
62 sessions.push(session.clone());
63 avail.push(sessions.len() - 1);
64 }
65 }
66 Ok(())
67 }
68
69 fn create_new(&self) -> Result<Arc<Mutex<Session>>, ort::Error> {
70 Ok(Arc::new(Mutex::new(
71 self.builder.generate().commit_from_file(&self.file)?,
72 )))
73 }
74
75 async fn release_session(&self, idx: usize, permit: SemaphorePermit<'_>) {
76 self.available_sessions.lock().await.push(idx);
77 drop(permit);
78 }
79
80 async fn get_session(
81 &self,
82 ) -> Result<(Arc<Mutex<Session>>, usize, SemaphorePermit<'_>), ort::Error> {
83 let _permit = self.sem.acquire().await.unwrap();
84
85 if let Some(idx) = self.available_sessions.lock().await.pop() {
86 let sessions = self.sessions.lock().await;
87 return Ok((sessions[idx].clone(), idx, _permit));
88 }
89
90 if self.sessions.lock().await.len() < self.max {
91 let session = match self.create_new() {
92 Ok(v) => v,
93 Err(e) => {
94 return Err(e);
95 }
96 };
97 let mut sessions = self.sessions.lock().await;
98 sessions.push(session.clone());
99 return Ok((session, sessions.len() - 1, _permit));
100 }
101
102 unreachable!()
103 }
104
105 pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 'r, O: SelectedOutputMarker, const N: usize>(
106 &'s self,
107 input_values: impl Into<SessionInputs<'i, 'v, N>>,
108 run_options: &'r RunOptions<O>,
109 ) -> Result<SessionOutputs<'r>, ort::Error> {
110 let (ses, idx, permit) = self.get_session().await?;
111 let ses: &'s mut Session = unsafe { &mut *(&mut *ses.lock().await as *mut Session) };
112 let out = match ses.run_async(input_values, run_options) {
113 Ok(v) => v.await,
114 Err(e) => Err(e),
115 };
116 self.release_session(idx, permit).await;
117 out
118 }
119
120 pub async fn run_async_profiled<
121 'r,
122 's: 'r,
123 'i,
124 'v: 'i + 'r,
125 O: SelectedOutputMarker,
126 const N: usize,
127 >(
128 &'s self,
129 input_values: impl Into<SessionInputs<'i, 'v, N>>,
130 run_options: &'r RunOptions<O>,
131 ) -> Result<
132 (
133 SessionOutputs<'r>,
134 Result<Duration, SystemTimeError>,
135 String,
136 ),
137 ort::Error,
138 > {
139 let (ses, idx, permit) = self.get_session().await?;
140 let ses_ptr = &mut *ses.lock().await as *mut Session;
141 let ses1: &'s mut Session = unsafe { &mut *(ses_ptr) };
142 let ses2: &'s mut Session = unsafe { &mut *(ses_ptr) };
143
144 let earlier = SystemTime::now();
145 let _ = ses1.profiling_start_ns()?;
146 let out = match ses1.run_async(input_values, run_options) {
147 Ok(v) => v.await,
148 Err(e) => Err(e),
149 };
150 let end = ses2.end_profiling()?;
151 let now = SystemTime::now().duration_since(earlier);
152 self.release_session(idx, permit).await;
153 out.map(|v| (v, now, end))
154 }
155}