1use crate::{
2 hw::config,
3 hw::traits::{HwMidiHub, HwWorkerDriver},
4 message::{HwMidiEvent, Message},
5 mutex::UnsafeMutex,
6};
7#[cfg(unix)]
8use nix::libc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Condvar, Mutex};
11use std::thread::JoinHandle;
12use std::time::{Duration, Instant};
13use tokio::sync::mpsc::{Receiver, Sender};
14use tracing::error;
15
16pub trait Backend: Send + Sync + 'static {
17 type Driver: HwWorkerDriver + Send + 'static;
18 type MidiHub: HwMidiHub + Send + 'static;
19
20 const LABEL: &'static str;
21 const WORKER_THREAD_NAME: &'static str;
22 const ASSIST_THREAD_NAME: &'static str;
23 const ASSIST_AUTONOMOUS_ENV: &'static str;
24 const ASSIST_AUTONOMOUS_DEFAULT: bool = false;
25 const CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS: bool = false;
26 const ASSIST_STEP_REQUIRES_REQUEST_CYCLE: bool = false;
27}
28
29#[derive(Debug)]
30pub struct HwWorker<B: Backend> {
31 driver: Arc<UnsafeMutex<B::Driver>>,
32 midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
33 rx: Receiver<Message>,
34 tx: Sender<Message>,
35 cycle_frames: u32,
36 pending_midi_out_events: Vec<HwMidiEvent>,
37 pending_midi_out_sorted: bool,
38}
39
40#[derive(Debug, Default)]
41struct AssistState {
42 shutdown: bool,
43 request_seq: u64,
44 done_seq: u64,
45 init_complete: bool,
46 last_error: Option<String>,
47}
48
49#[cfg(unix)]
50const RT_POLICY: i32 = libc::SCHED_FIFO;
51const RT_PRIORITY_WORKER: i32 = 18;
52const RT_PRIORITY_ASSIST: i32 = 12;
53const PROFILE_INTERVAL: Duration = Duration::from_secs(1);
54
55#[derive(Debug)]
56struct AssistProfiler {
57 report_at: Instant,
58 cycle_count: u64,
59 cycle_err_count: u64,
60 cycle_time_ns: u128,
61 step_count: u64,
62 step_work_count: u64,
63 step_err_count: u64,
64 step_time_ns: u128,
65 wait_count: u64,
66 wait_time_ns: u128,
67}
68
69impl AssistProfiler {
70 fn new() -> Self {
71 Self {
72 report_at: Instant::now() + PROFILE_INTERVAL,
73 cycle_count: 0,
74 cycle_err_count: 0,
75 cycle_time_ns: 0,
76 step_count: 0,
77 step_work_count: 0,
78 step_err_count: 0,
79 step_time_ns: 0,
80 wait_count: 0,
81 wait_time_ns: 0,
82 }
83 }
84
85 fn maybe_report(&mut self, cycle_samples: usize, sample_rate: i32, label: &str) {
86 let now = Instant::now();
87 if now < self.report_at {
88 return;
89 }
90 let cycle_avg_us = if self.cycle_count > 0 {
91 (self.cycle_time_ns / self.cycle_count as u128) as f64 / 1_000.0
92 } else {
93 0.0
94 };
95 let step_avg_us = if self.step_count > 0 {
96 (self.step_time_ns / self.step_count as u128) as f64 / 1_000.0
97 } else {
98 0.0
99 };
100 let wait_avg_us = if self.wait_count > 0 {
101 (self.wait_time_ns / self.wait_count as u128) as f64 / 1_000.0
102 } else {
103 0.0
104 };
105 let expected_cycles_per_sec = if cycle_samples > 0 && sample_rate > 0 {
106 sample_rate as f64 / cycle_samples as f64
107 } else {
108 0.0
109 };
110 error!(
111 "{} profile: expected_cps={:.1} cycles={} cycle_err={} cycle_avg_us={:.1} steps={} steps_work={} step_err={} step_avg_us={:.1} waits={} wait_avg_us={:.1}",
112 label,
113 expected_cycles_per_sec,
114 self.cycle_count,
115 self.cycle_err_count,
116 cycle_avg_us,
117 self.step_count,
118 self.step_work_count,
119 self.step_err_count,
120 step_avg_us,
121 self.wait_count,
122 wait_avg_us
123 );
124 self.report_at = now + PROFILE_INTERVAL;
125 self.cycle_count = 0;
126 self.cycle_err_count = 0;
127 self.cycle_time_ns = 0;
128 self.step_count = 0;
129 self.step_work_count = 0;
130 self.step_err_count = 0;
131 self.step_time_ns = 0;
132 self.wait_count = 0;
133 self.wait_time_ns = 0;
134 }
135}
136
137impl<B: Backend> HwWorker<B> {
138 fn profile_enabled() -> bool {
139 config::env_flag(config::HW_PROFILE_ENV)
140 }
141
142 fn assist_autonomous_enabled() -> bool {
143 B::ASSIST_AUTONOMOUS_DEFAULT || config::env_flag(B::ASSIST_AUTONOMOUS_ENV)
144 }
145
146 fn configure_rt_thread(name: &str, priority: i32) -> Result<(), String> {
147 #[cfg(unix)]
148 {
149 let thread = unsafe { libc::pthread_self() };
150 #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "openbsd"))]
151 let c_name = std::ffi::CString::new(name).map_err(|e| e.to_string())?;
152 #[cfg(target_os = "linux")]
153 unsafe {
154 let _ = libc::pthread_setname_np(thread, c_name.as_ptr());
155 }
156 #[cfg(any(target_os = "freebsd", target_os = "openbsd"))]
157 unsafe {
158 libc::pthread_set_name_np(thread, c_name.as_ptr());
159 }
160
161 let param = unsafe {
162 let mut p = std::mem::zeroed::<libc::sched_param>();
163 p.sched_priority = priority;
164 p
165 };
166 let rc = unsafe { libc::pthread_setschedparam(thread, RT_POLICY, ¶m) };
167 if rc != 0 {
168 return Err(format!(
169 "pthread_setschedparam({}, prio {}) failed with errno {}",
170 name, priority, rc
171 ));
172 }
173
174 let mut actual_policy = 0_i32;
175 let mut actual_param = unsafe { std::mem::zeroed::<libc::sched_param>() };
176 let rc = unsafe {
177 libc::pthread_getschedparam(thread, &mut actual_policy, &mut actual_param)
178 };
179 if rc != 0 {
180 return Err(format!(
181 "pthread_getschedparam({}) failed with errno {}",
182 name, rc
183 ));
184 }
185 if actual_policy != RT_POLICY || actual_param.sched_priority != priority {
186 return Err(format!(
187 "realtime verification failed for {}: policy {}, prio {}",
188 name, actual_policy, actual_param.sched_priority
189 ));
190 }
191 Ok(())
192 }
193 #[cfg(not(unix))]
194 {
195 let _ = name;
196 let _ = priority;
197 Err("Realtime thread priority is not supported on this platform".to_string())
198 }
199 }
200
201 fn lock_memory_pages() -> Result<(), String> {
202 #[cfg(unix)]
203 {
204 let rc = unsafe { libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) };
205 if rc == 0 {
206 Ok(())
207 } else {
208 Err(format!(
209 "mlockall(MCL_CURRENT|MCL_FUTURE) failed: {}",
210 std::io::Error::last_os_error()
211 ))
212 }
213 }
214 #[cfg(not(unix))]
215 {
216 Err("mlockall is not supported on this platform".to_string())
217 }
218 }
219
220 pub fn new(
221 driver: Arc<UnsafeMutex<B::Driver>>,
222 midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
223 rx: Receiver<Message>,
224 tx: Sender<Message>,
225 ) -> Self {
226 let cycle_frames = {
227 let d = driver.lock();
228 d.cycle_samples() as u32
229 };
230 Self {
231 driver,
232 midi_hub,
233 rx,
234 tx,
235 cycle_frames,
236 pending_midi_out_events: vec![],
237 pending_midi_out_sorted: true,
238 }
239 }
240
241 pub async fn work(mut self) {
242 crate::enable_flush_denormals_to_zero();
243 if let Err(e) = Self::lock_memory_pages() {
244 error!("{} worker memory lock not enabled: {}", B::LABEL, e);
245 }
246 if let Err(e) = Self::configure_rt_thread(B::WORKER_THREAD_NAME, RT_PRIORITY_WORKER) {
247 error!("{} worker realtime priority not enabled: {}", B::LABEL, e);
248 }
249 #[cfg(target_os = "macos")]
250 unsafe {
251 libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, 0);
252 }
253 let assist_state = Arc::new((Mutex::new(AssistState::default()), Condvar::new()));
254 let assist_handle = Self::start_assist_thread(self.driver.clone(), assist_state.clone());
255 let midi_stop = Arc::new(AtomicBool::new(false));
256 let midi_handle = Self::start_midi_input_thread(
257 self.midi_hub.clone(),
258 self.tx.clone(),
259 self.cycle_frames,
260 midi_stop.clone(),
261 );
262 loop {
263 match self.rx.recv().await {
264 Some(msg) => match msg {
265 Message::Request(crate::message::Action::Quit) => {
266 self.driver.lock().request_stop();
267 if !self.pending_midi_out_events.is_empty() {
268 if !self.pending_midi_out_sorted {
269 self.pending_midi_out_events.sort_by(|a, b| {
270 a.event
271 .frame
272 .cmp(&b.event.frame)
273 .then_with(|| a.device.cmp(&b.device))
274 });
275 self.pending_midi_out_sorted = true;
276 }
277 let midi_hub = self.midi_hub.lock();
278 midi_hub.write_events(&self.pending_midi_out_events);
279 self.pending_midi_out_events.clear();
280 }
281 midi_stop.store(true, Ordering::Release);
282 {
283 let midi_hub = self.midi_hub.lock();
284 midi_hub.wake_input_waiter();
285 }
286 let _ = midi_handle.join();
287 Self::stop_assist_thread(&assist_state, assist_handle);
288 return;
289 }
290 Message::TracksFinished => {
291 {
292 if !self.pending_midi_out_events.is_empty() {
293 if !self.pending_midi_out_sorted {
294 self.pending_midi_out_events.sort_by(|a, b| {
295 a.event
296 .frame
297 .cmp(&b.event.frame)
298 .then_with(|| a.device.cmp(&b.device))
299 });
300 self.pending_midi_out_sorted = true;
301 }
302 let midi_hub = self.midi_hub.lock();
303 midi_hub.write_events(&self.pending_midi_out_events);
304 self.pending_midi_out_events.clear();
305 }
306 }
307 if let Err(e) = Self::run_assist_cycle(&self.driver, &assist_state) {
308 error!("{} assist cycle error: {}", B::LABEL, e);
309 let _ = self
310 .tx
311 .send(Message::Response(Err(format!(
312 "{} assist cycle error: {}",
313 B::LABEL,
314 e
315 ))))
316 .await;
317 }
318 if let Err(e) = self.tx.send(Message::HWFinished).await {
319 error!(
320 "{} worker failed to send HWFinished to engine: {}",
321 B::LABEL,
322 e
323 );
324 }
325 }
326 Message::HWMidiOutEvents(mut events) => {
327 self.pending_midi_out_events.append(&mut events);
328 self.pending_midi_out_sorted = false;
329 }
330 Message::ClearHWMidiOutEvents => {
331 self.pending_midi_out_events.clear();
332 self.pending_midi_out_sorted = true;
333 }
334 _ => {}
335 },
336 None => {
337 self.driver.lock().request_stop();
338 midi_stop.store(true, Ordering::Release);
339 {
340 let midi_hub = self.midi_hub.lock();
341 midi_hub.wake_input_waiter();
342 }
343 let _ = midi_handle.join();
344 Self::stop_assist_thread(&assist_state, assist_handle);
345 return;
346 }
347 }
348 }
349 }
350
351 fn start_midi_input_thread(
352 midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
353 tx: Sender<Message>,
354 cycle_frames: u32,
355 stop: Arc<AtomicBool>,
356 ) -> JoinHandle<()> {
357 std::thread::spawn(move || {
358 crate::enable_flush_denormals_to_zero();
359 let mut midi_in_events = Vec::with_capacity(64);
360 while !stop.load(Ordering::Acquire) {
361 {
362 let hub = midi_hub.lock();
363 hub.read_events_blocking_into(&mut midi_in_events);
364 }
365 if midi_in_events.is_empty() {
366 continue;
367 }
368 spread_hw_event_frames(&mut midi_in_events, cycle_frames);
369 let cap = midi_in_events.capacity();
370 let out = std::mem::replace(&mut midi_in_events, Vec::with_capacity(cap.max(64)));
371 if tx.blocking_send(Message::HWMidiEvents(out)).is_err() {
372 break;
373 }
374 }
375 })
376 }
377
378 fn start_assist_thread(
379 driver: Arc<UnsafeMutex<B::Driver>>,
380 assist_state: Arc<(Mutex<AssistState>, Condvar)>,
381 ) -> JoinHandle<()> {
382 let profile = Self::profile_enabled();
383 let autonomous = Self::assist_autonomous_enabled();
384 std::thread::spawn(move || {
385 crate::enable_flush_denormals_to_zero();
386 if let Err(e) = Self::configure_rt_thread(B::ASSIST_THREAD_NAME, RT_PRIORITY_ASSIST) {
387 error!("{} assist realtime priority not enabled: {}", B::LABEL, e);
388 }
389 #[cfg(target_os = "macos")]
390 unsafe {
391 libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INITIATED, 0);
392 }
393 let mut profiler = if profile {
394 let (cycle_samples, sample_rate) = {
395 let d = driver.lock();
396 (d.cycle_samples(), d.sample_rate())
397 };
398 error!(
399 "{} profile enabled: cycle_samples={} sample_rate={} expected_cps={:.1}",
400 B::LABEL,
401 cycle_samples,
402 sample_rate,
403 if cycle_samples > 0 {
404 sample_rate as f64 / cycle_samples as f64
405 } else {
406 0.0
407 }
408 );
409 Some(AssistProfiler::new())
410 } else {
411 None
412 };
413 let (lock, cvar) = &*assist_state;
414 loop {
415 let (shutdown, has_request, target, init_complete) = {
416 let st = lock.lock().expect("assist mutex poisoned");
417 (
418 st.shutdown,
419 st.request_seq > st.done_seq,
420 st.request_seq,
421 st.init_complete,
422 )
423 };
424 if shutdown {
425 break;
426 }
427 if has_request {
428 let started = Instant::now();
429 let run_error = {
430 let d = driver.lock();
431 d.run_cycle_for_worker().err().map(|e| e.to_string())
432 };
433 if let Some(p) = profiler.as_mut() {
434 p.cycle_count += 1;
435 if run_error.is_some() {
436 p.cycle_err_count += 1;
437 }
438 p.cycle_time_ns += started.elapsed().as_nanos();
439 let (cycle_samples, sample_rate) = {
440 let d = driver.lock();
441 (d.cycle_samples(), d.sample_rate())
442 };
443 p.maybe_report(cycle_samples, sample_rate, B::LABEL);
444 }
445 let mut st = lock.lock().expect("assist mutex poisoned");
446 st.done_seq = st.done_seq.max(target);
447 if run_error.is_none() {
448 st.init_complete = true;
449 }
450 st.last_error = run_error;
451 cvar.notify_all();
452 continue;
453 }
454
455 if B::ASSIST_STEP_REQUIRES_REQUEST_CYCLE && !init_complete {
456 let st = lock.lock().expect("assist mutex poisoned");
457 if st.shutdown {
458 break;
459 }
460 let wait_started = Instant::now();
461 let _guard = cvar.wait(st).expect("assist condvar failed");
462 if let Some(p) = profiler.as_mut() {
463 p.wait_count += 1;
464 p.wait_time_ns += wait_started.elapsed().as_nanos();
465 }
466 continue;
467 }
468
469 if !autonomous {
470 let st = lock.lock().expect("assist mutex poisoned");
471 if st.shutdown {
472 break;
473 }
474 let wait_started = Instant::now();
475 let _guard = cvar.wait(st).expect("assist condvar failed");
476 if let Some(p) = profiler.as_mut() {
477 p.wait_count += 1;
478 p.wait_time_ns += wait_started.elapsed().as_nanos();
479 }
480 continue;
481 }
482
483 let started = Instant::now();
484 let did_work = {
485 let d = driver.lock();
486 match d.run_assist_step_for_worker() {
487 Ok(v) => v,
488 Err(e) => {
489 if let Some(p) = profiler.as_mut() {
490 p.step_err_count += 1;
491 }
492 let mut st = lock.lock().expect("assist mutex poisoned");
493 st.last_error = Some(e.to_string());
494 cvar.notify_all();
495 false
496 }
497 }
498 };
499 if let Some(p) = profiler.as_mut() {
500 p.step_count += 1;
501 if did_work {
502 p.step_work_count += 1;
503 }
504 p.step_time_ns += started.elapsed().as_nanos();
505 let (cycle_samples, sample_rate) = {
506 let d = driver.lock();
507 (d.cycle_samples(), d.sample_rate())
508 };
509 p.maybe_report(cycle_samples, sample_rate, B::LABEL);
510 }
511 if !did_work {
512 let st = lock.lock().expect("assist mutex poisoned");
513 if st.shutdown {
514 break;
515 }
516 let wait_started = Instant::now();
517 let _guard = if autonomous {
518 cvar.wait_timeout(st, Duration::from_micros(100))
519 .expect("assist condvar failed")
520 .0
521 } else {
522 cvar.wait(st).expect("assist condvar failed")
523 };
524 if let Some(p) = profiler.as_mut() {
525 p.wait_count += 1;
526 p.wait_time_ns += wait_started.elapsed().as_nanos();
527 }
528 }
529 }
530 })
531 }
532
533 fn run_assist_cycle(
534 driver: &Arc<UnsafeMutex<B::Driver>>,
535 assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
536 ) -> Result<(), String> {
537 if Self::assist_autonomous_enabled() && B::CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS {
538 let (lock, cvar) = &**assist_state;
539 {
540 let mut st = lock
541 .lock()
542 .map_err(|_| "assist mutex poisoned".to_string())?;
543 st.init_complete = true;
544 cvar.notify_one();
545 }
546 let result = driver.lock().run_cycle_for_worker();
547 {
548 let mut st = lock
549 .lock()
550 .map_err(|_| "assist mutex poisoned".to_string())?;
551 st.last_error = result.as_ref().err().map(|e| e.to_string());
552 cvar.notify_one();
553 }
554 return result;
555 }
556
557 let (lock, cvar) = &**assist_state;
558 let mut st = lock
559 .lock()
560 .map_err(|_| "assist mutex poisoned".to_string())?;
561 st.request_seq = st.request_seq.saturating_add(1);
562 let target = st.request_seq;
563 cvar.notify_one();
564 while st.done_seq < target && !st.shutdown {
565 st = cvar
566 .wait(st)
567 .map_err(|_| "assist condvar wait failed".to_string())?;
568 }
569 if let Some(err) = st.last_error.take() {
570 return Err(err);
571 }
572 Ok(())
573 }
574
575 fn stop_assist_thread(
576 assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
577 assist_handle: JoinHandle<()>,
578 ) {
579 let (lock, cvar) = &**assist_state;
580 if let Ok(mut st) = lock.lock() {
581 st.shutdown = true;
582 cvar.notify_all();
583 }
584 let _ = assist_handle.join();
585 }
586}
587
588fn spread_hw_event_frames(events: &mut [HwMidiEvent], frames: u32) {
589 if events.len() <= 1 || frames <= 1 {
590 return;
591 }
592 let n = events.len() as u32;
593 for (idx, event) in events.iter_mut().enumerate() {
594 let pos = idx as u32;
595 event.event.frame = ((pos as u64 * (frames - 1) as u64) / n as u64) as u32;
596 }
597}