1use crate::config::Config;
2use crate::error::{Error, Result};
3use crate::executor::CpuPool;
4use crate::scheduler::SchedulerCoordinator;
5use parking_lot::RwLock;
6use std::sync::Arc;
7use std::sync::{Mutex, OnceLock};
8use std::collections::HashMap;
9use std::thread::ThreadId;
10
11pub struct Runtime {
12 pub(crate) pool: Arc<CpuPool>,
13 pub(crate) scheduler: Arc<SchedulerCoordinator>,
14 config: Config,
15}
16
17impl Runtime {
18 pub fn new(config: Config) -> Result<Self> {
19 config.validate()?;
20
21 let pool = CpuPool::new(&config)?;
22 let scheduler = SchedulerCoordinator::new(&config)?;
23
24 Ok(Self {
25 pool: Arc::new(pool),
26 scheduler: Arc::new(scheduler),
27 config,
28 })
29 }
30
31 pub fn config(&self) -> &Config {
32 &self.config
33 }
34}
35
36static GLOBAL_RUNTIME: OnceLock<RwLock<Option<Arc<Runtime>>>> = OnceLock::new();
38
39fn get_global_runtime() -> &'static RwLock<Option<Arc<Runtime>>> {
40 GLOBAL_RUNTIME.get_or_init(|| RwLock::new(None))
41}
42
43thread_local! {
45 static THREAD_RUNTIME: std::cell::RefCell<Option<Arc<Runtime>>> = std::cell::RefCell::new(None);
46}
47
48static THREAD_RUNTIME_MAP: OnceLock<Mutex<HashMap<ThreadId, bool>>> = OnceLock::new();
50
51fn get_thread_runtime_map() -> &'static Mutex<HashMap<ThreadId, bool>> {
52 THREAD_RUNTIME_MAP.get_or_init(|| Mutex::new(HashMap::new()))
53}
54
55static LAZY_INIT_ENABLED: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(true);
57
58pub fn set_lazy_init(enabled: bool) {
60 LAZY_INIT_ENABLED.store(enabled, std::sync::atomic::Ordering::Release);
61}
62
63fn ensure_runtime_initialized() {
65 if !LAZY_INIT_ENABLED.load(std::sync::atomic::Ordering::Acquire) {
66 return;
67 }
68
69 let thread_id = std::thread::current().id();
70 let has_thread_local = get_thread_runtime_map().lock().unwrap()
71 .get(&thread_id)
72 .copied()
73 .unwrap_or(false);
74
75 if has_thread_local {
76 let has_runtime = THREAD_RUNTIME.with(|rt| rt.borrow().is_some());
78 if !has_runtime {
79 let _ = init_thread_local();
80 }
81 } else {
82 let runtime = get_global_runtime().read();
84 if runtime.is_none() {
85 drop(runtime); let _ = init(); }
88 }
89}
90
91pub fn init() -> Result<()> {
92 init_with_config(Config::default())
93}
94
95pub fn init_with_config(config: Config) -> Result<()> {
96 let thread_id = std::thread::current().id();
97
98 let has_thread_local = get_thread_runtime_map().lock().unwrap()
100 .get(&thread_id)
101 .copied()
102 .unwrap_or(false);
103
104 if has_thread_local {
105 let has_existing = THREAD_RUNTIME.with(|rt| rt.borrow().is_some());
107 if has_existing {
108 return Err(Error::AlreadyInitialized);
109 }
110
111 let rt = Runtime::new(config)?;
112 THREAD_RUNTIME.with(|rt_cell| {
113 *rt_cell.borrow_mut() = Some(Arc::new(rt));
114 });
115
116 Ok(())
117 } else {
118 let mut runtime = get_global_runtime().write();
120
121 if runtime.is_some() {
122 return Err(Error::AlreadyInitialized);
123 }
124
125 let rt = Runtime::new(config)?;
126 *runtime = Some(Arc::new(rt));
127
128 Ok(())
129 }
130}
131
132pub fn init_thread_local() -> Result<()> {
134 init_thread_local_with_config(Config::default())
135}
136
137pub fn init_thread_local_with_config(config: Config) -> Result<()> {
139 let thread_id = std::thread::current().id();
140 get_thread_runtime_map().lock().unwrap().insert(thread_id, true);
141
142 let has_existing = THREAD_RUNTIME.with(|rt| rt.borrow().is_some());
143 if has_existing {
144 return Err(Error::AlreadyInitialized);
145 }
146
147 let rt = Runtime::new(config)?;
148 THREAD_RUNTIME.with(|rt_cell| {
149 *rt_cell.borrow_mut() = Some(Arc::new(rt));
150 });
151
152 Ok(())
153}
154
155pub(crate) fn current_runtime() -> Arc<Runtime> {
156 ensure_runtime_initialized();
158
159 let thread_id = std::thread::current().id();
160 let has_thread_local = get_thread_runtime_map().lock().unwrap()
161 .get(&thread_id)
162 .copied()
163 .unwrap_or(false);
164
165 if has_thread_local {
166 THREAD_RUNTIME.with(|rt| {
167 rt.borrow()
168 .as_ref()
169 .expect("VEDA runtime not initialized - call veda::init() first")
170 .clone()
171 })
172 } else {
173 get_global_runtime()
174 .read()
175 .as_ref()
176 .expect("VEDA runtime not initialized - call veda::init() first")
177 .clone()
178 }
179}
180
181pub(crate) fn with_current_runtime<F, R>(f: F) -> R
182where
183 F: FnOnce(&Runtime) -> R,
184{
185 let rt = current_runtime();
186 f(&rt)
187}
188
189pub fn shutdown() {
190 let thread_id = std::thread::current().id();
191 let has_thread_local = get_thread_runtime_map().lock().unwrap()
192 .get(&thread_id)
193 .copied()
194 .unwrap_or(false);
195
196 if has_thread_local {
197 THREAD_RUNTIME.with(|rt_cell| {
198 *rt_cell.borrow_mut() = None;
199 });
200 get_thread_runtime_map().lock().unwrap().remove(&thread_id);
201 } else {
202 let mut runtime = get_global_runtime().write();
203 *runtime = None;
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_runtime_init() {
213 shutdown();
214
215 let result = init();
216 assert!(result.is_ok());
217
218 let result2 = init();
219 assert!(result2.is_err());
220
221 shutdown();
222 }
223
224 #[test]
225 fn test_custom_config() {
226 shutdown();
227
228 let config = Config::builder()
229 .num_threads(2)
230 .build()
231 .unwrap();
232
233 init_with_config(config).unwrap();
234
235 let rt = current_runtime();
236 assert_eq!(rt.pool.num_threads(), 2);
237
238 shutdown();
239 }
240}