1use crate::{
2 hw::config,
3 hw::traits::{HwMidiHub, HwWorkerDriver},
4 message::{HwMidiEvent, Message},
5 mutex::UnsafeMutex,
6};
7#[cfg(unix)]
8use nix::libc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Condvar, Mutex};
11use std::thread::JoinHandle;
12use std::time::{Duration, Instant};
13use tokio::sync::mpsc::{Receiver, Sender};
14use tracing::error;
15
16pub trait Backend: Send + Sync + 'static {
17 type Driver: HwWorkerDriver + Send + 'static;
18 type MidiHub: HwMidiHub + Send + 'static;
19
20 const LABEL: &'static str;
21 const WORKER_THREAD_NAME: &'static str;
22 const ASSIST_THREAD_NAME: &'static str;
23 const ASSIST_AUTONOMOUS_ENV: &'static str;
24 const ASSIST_AUTONOMOUS_DEFAULT: bool = false;
25 const CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS: bool = false;
26 const ASSIST_STEP_REQUIRES_REQUEST_CYCLE: bool = false;
27}
28
29#[derive(Debug)]
30pub struct HwWorker<B: Backend> {
31 driver: Arc<UnsafeMutex<B::Driver>>,
32 midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
33 rx: Receiver<Message>,
34 tx: Sender<Message>,
35 cycle_frames: u32,
36 pending_midi_out_events: Vec<HwMidiEvent>,
37 pending_midi_out_sorted: bool,
38 midi_stop: Arc<AtomicBool>,
39 assist_state: Arc<(Mutex<AssistState>, Condvar)>,
40}
41
42impl<B: Backend> Drop for HwWorker<B> {
43 fn drop(&mut self) {
44 self.driver.lock().request_stop();
45 self.midi_stop.store(true, Ordering::Release);
46 {
47 let midi_hub = self.midi_hub.lock();
48 midi_hub.wake_input_waiter();
49 midi_hub.close_input_waiter();
50 }
51 {
52 let (lock, cvar) = &*self.assist_state;
53 if let Ok(mut st) = lock.lock() {
54 st.shutdown = true;
55 cvar.notify_one();
56 }
57 }
58 }
59}
60
61#[derive(Debug, Default)]
62struct AssistState {
63 shutdown: bool,
64 request_seq: u64,
65 done_seq: u64,
66 init_complete: bool,
67 last_error: Option<String>,
68}
69
70#[cfg(unix)]
71const RT_POLICY: i32 = libc::SCHED_FIFO;
72const RT_PRIORITY_WORKER: i32 = 18;
73const RT_PRIORITY_ASSIST: i32 = 12;
74const PROFILE_INTERVAL: Duration = Duration::from_secs(1);
75
76#[derive(Debug)]
77struct AssistProfiler {
78 report_at: Instant,
79 cycle_count: u64,
80 cycle_err_count: u64,
81 cycle_time_ns: u128,
82 step_count: u64,
83 step_work_count: u64,
84 step_err_count: u64,
85 step_time_ns: u128,
86 wait_count: u64,
87 wait_time_ns: u128,
88}
89
90impl AssistProfiler {
91 fn new() -> Self {
92 Self {
93 report_at: Instant::now() + PROFILE_INTERVAL,
94 cycle_count: 0,
95 cycle_err_count: 0,
96 cycle_time_ns: 0,
97 step_count: 0,
98 step_work_count: 0,
99 step_err_count: 0,
100 step_time_ns: 0,
101 wait_count: 0,
102 wait_time_ns: 0,
103 }
104 }
105
106 fn maybe_report(&mut self, cycle_samples: usize, sample_rate: i32, label: &str) {
107 let now = Instant::now();
108 if now < self.report_at {
109 return;
110 }
111 let cycle_avg_us = if self.cycle_count > 0 {
112 (self.cycle_time_ns / self.cycle_count as u128) as f64 / 1_000.0
113 } else {
114 0.0
115 };
116 let step_avg_us = if self.step_count > 0 {
117 (self.step_time_ns / self.step_count as u128) as f64 / 1_000.0
118 } else {
119 0.0
120 };
121 let wait_avg_us = if self.wait_count > 0 {
122 (self.wait_time_ns / self.wait_count as u128) as f64 / 1_000.0
123 } else {
124 0.0
125 };
126 let expected_cycles_per_sec = if cycle_samples > 0 && sample_rate > 0 {
127 sample_rate as f64 / cycle_samples as f64
128 } else {
129 0.0
130 };
131 error!(
132 "{} profile: expected_cps={:.1} cycles={} cycle_err={} cycle_avg_us={:.1} steps={} steps_work={} step_err={} step_avg_us={:.1} waits={} wait_avg_us={:.1}",
133 label,
134 expected_cycles_per_sec,
135 self.cycle_count,
136 self.cycle_err_count,
137 cycle_avg_us,
138 self.step_count,
139 self.step_work_count,
140 self.step_err_count,
141 step_avg_us,
142 self.wait_count,
143 wait_avg_us
144 );
145 self.report_at = now + PROFILE_INTERVAL;
146 self.cycle_count = 0;
147 self.cycle_err_count = 0;
148 self.cycle_time_ns = 0;
149 self.step_count = 0;
150 self.step_work_count = 0;
151 self.step_err_count = 0;
152 self.step_time_ns = 0;
153 self.wait_count = 0;
154 self.wait_time_ns = 0;
155 }
156}
157
158impl<B: Backend> HwWorker<B> {
159 fn profile_enabled() -> bool {
160 config::env_flag(config::HW_PROFILE_ENV)
161 }
162
163 fn assist_autonomous_enabled() -> bool {
164 B::ASSIST_AUTONOMOUS_DEFAULT || config::env_flag(B::ASSIST_AUTONOMOUS_ENV)
165 }
166
167 fn configure_rt_thread(name: &str, priority: i32) -> Result<(), String> {
168 #[cfg(unix)]
169 {
170 let thread = unsafe { libc::pthread_self() };
171 #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "openbsd"))]
172 let c_name = std::ffi::CString::new(name).map_err(|e| e.to_string())?;
173 #[cfg(target_os = "linux")]
174 unsafe {
175 let _ = libc::pthread_setname_np(thread, c_name.as_ptr());
176 }
177 #[cfg(any(target_os = "freebsd", target_os = "openbsd"))]
178 unsafe {
179 libc::pthread_set_name_np(thread, c_name.as_ptr());
180 }
181
182 let param = unsafe {
183 let mut p = std::mem::zeroed::<libc::sched_param>();
184 p.sched_priority = priority;
185 p
186 };
187 let rc = unsafe { libc::pthread_setschedparam(thread, RT_POLICY, ¶m) };
188 if rc != 0 {
189 return Err(format!(
190 "pthread_setschedparam({}, prio {}) failed with errno {}",
191 name, priority, rc
192 ));
193 }
194
195 let mut actual_policy = 0_i32;
196 let mut actual_param = unsafe { std::mem::zeroed::<libc::sched_param>() };
197 let rc = unsafe {
198 libc::pthread_getschedparam(thread, &mut actual_policy, &mut actual_param)
199 };
200 if rc != 0 {
201 return Err(format!(
202 "pthread_getschedparam({}) failed with errno {}",
203 name, rc
204 ));
205 }
206 if actual_policy != RT_POLICY || actual_param.sched_priority != priority {
207 return Err(format!(
208 "realtime verification failed for {}: policy {}, prio {}",
209 name, actual_policy, actual_param.sched_priority
210 ));
211 }
212 Ok(())
213 }
214 #[cfg(not(unix))]
215 {
216 let _ = name;
217 let _ = priority;
218 Err("Realtime thread priority is not supported on this platform".to_string())
219 }
220 }
221
222 fn lock_memory_pages() -> Result<(), String> {
223 #[cfg(unix)]
224 {
225 let rc = unsafe { libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) };
226 if rc == 0 {
227 Ok(())
228 } else {
229 Err(format!(
230 "mlockall(MCL_CURRENT|MCL_FUTURE) failed: {}",
231 std::io::Error::last_os_error()
232 ))
233 }
234 }
235 #[cfg(not(unix))]
236 {
237 Err("mlockall is not supported on this platform".to_string())
238 }
239 }
240
241 pub fn new(
242 driver: Arc<UnsafeMutex<B::Driver>>,
243 midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
244 rx: Receiver<Message>,
245 tx: Sender<Message>,
246 ) -> Self {
247 let cycle_frames = {
248 let d = driver.lock();
249 d.cycle_samples() as u32
250 };
251 Self {
252 driver,
253 midi_hub,
254 rx,
255 tx,
256 cycle_frames,
257 pending_midi_out_events: vec![],
258 pending_midi_out_sorted: true,
259 midi_stop: Arc::new(AtomicBool::new(false)),
260 assist_state: Arc::new((Mutex::new(AssistState::default()), Condvar::new())),
261 }
262 }
263
264 pub async fn work(mut self) {
265 crate::enable_flush_denormals_to_zero();
266 if let Err(e) = Self::lock_memory_pages() {
267 error!("{} worker memory lock not enabled: {}", B::LABEL, e);
268 }
269 if let Err(e) = Self::configure_rt_thread(B::WORKER_THREAD_NAME, RT_PRIORITY_WORKER) {
270 error!("{} worker realtime priority not enabled: {}", B::LABEL, e);
271 }
272 #[cfg(target_os = "macos")]
273 unsafe {
274 libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, 0);
275 }
276
277 #[cfg(unix)]
278 {
279 let has_fds = self.driver.lock().capture_fd().is_some()
280 && self.driver.lock().playback_fd().is_some();
281 if has_fds {
282 self.work_async().await;
283 return;
284 }
285 }
286
287 self.work_legacy().await;
288 }
289
290 #[cfg(unix)]
291 async fn work_async(&mut self) {
292 let midi_handle = Self::start_midi_input_thread(
293 self.midi_hub.clone(),
294 self.tx.clone(),
295 self.cycle_frames,
296 self.midi_stop.clone(),
297 );
298 let mut cycle_running = false;
299 let (cycle_tx, mut cycle_rx) = tokio::sync::mpsc::channel::<Result<(), String>>(1);
300 loop {
301 tokio::select! {
302 msg = self.rx.recv() => {
303 let msg = match msg {
304 Some(m) => m,
305 None => {
306 self.driver.lock().request_stop();
307 if cycle_running {
308 let _ = cycle_rx.recv().await;
309 }
310 self.shutdown_channel_closed(midi_handle);
311 return;
312 }
313 };
314 match msg {
315 Message::Request(crate::message::Action::Quit) => {
316 self.driver.lock().request_stop();
317 if cycle_running {
318 let _ = cycle_rx.recv().await;
322 }
323 self.shutdown_quit(midi_handle);
324 return;
325 }
326 Message::TracksFinished => {
327 self.flush_pending_midi_out();
328 if !cycle_running {
329 cycle_running = true;
330 let driver = self.driver.clone();
331 let tx = cycle_tx.clone();
332 tokio::task::spawn_blocking(move || {
333 let result = driver.lock().run_cycle_for_worker();
334 let _ = tx.blocking_send(result);
335 });
336 }
337 }
338 Message::HWMidiOutEvents(mut events) => {
339 self.pending_midi_out_events.append(&mut events);
340 self.pending_midi_out_sorted = false;
341 }
342 Message::ClearHWMidiOutEvents => {
343 self.pending_midi_out_events.clear();
344 self.pending_midi_out_sorted = true;
345 }
346 _ => {}
347 }
348 }
349 result = cycle_rx.recv(), if cycle_running => {
350 cycle_running = false;
351 if let Some(Err(e)) = result {
352 error!("{} cycle error: {}", B::LABEL, e);
353 let _ = self.tx.send(Message::Response(Err(format!(
354 "{} cycle error: {}", B::LABEL, e
355 )))).await;
356 }
357 if let Err(e) = self.tx.send(Message::HWFinished).await {
358 error!("{} worker failed to send HWFinished: {}", B::LABEL, e);
359 }
360 }
361 }
362 }
363 }
364
365 async fn work_legacy(&mut self) {
366 let assist_handle =
367 Self::start_assist_thread(self.driver.clone(), self.assist_state.clone());
368 let midi_handle = Self::start_midi_input_thread(
369 self.midi_hub.clone(),
370 self.tx.clone(),
371 self.cycle_frames,
372 self.midi_stop.clone(),
373 );
374 loop {
375 let msg = match self.rx.recv().await {
376 Some(msg) => msg,
377 None => {
378 self.driver.lock().request_stop();
379 self.shutdown_midi(midi_handle);
380 Self::stop_assist_thread(&self.assist_state, assist_handle);
381 self.driver.lock().request_stop();
382 return;
383 }
384 };
385 match msg {
386 Message::Request(crate::message::Action::Quit) => {
387 self.driver.lock().request_stop();
388 self.flush_pending_midi_out();
389 self.shutdown_midi(midi_handle);
390 Self::stop_assist_thread(&self.assist_state, assist_handle);
391 self.driver.lock().request_stop();
392 return;
393 }
394 Message::TracksFinished => {
395 self.flush_pending_midi_out();
396 if let Err(e) = Self::run_assist_cycle(&self.driver, &self.assist_state) {
397 error!("{} assist cycle error: {}", B::LABEL, e);
398 let _ = self
399 .tx
400 .send(Message::Response(Err(format!(
401 "{} assist cycle error: {}",
402 B::LABEL,
403 e
404 ))))
405 .await;
406 }
407 if let Err(e) = self.tx.send(Message::HWFinished).await {
408 error!(
409 "{} worker failed to send HWFinished to engine: {}",
410 B::LABEL,
411 e
412 );
413 }
414 }
415 Message::HWMidiOutEvents(mut events) => {
416 self.pending_midi_out_events.append(&mut events);
417 self.pending_midi_out_sorted = false;
418 }
419 Message::ClearHWMidiOutEvents => {
420 self.pending_midi_out_events.clear();
421 self.pending_midi_out_sorted = true;
422 }
423 _ => {}
424 }
425 }
426 }
427
428 fn flush_pending_midi_out(&mut self) {
429 if self.pending_midi_out_events.is_empty() {
430 return;
431 }
432 if !self.pending_midi_out_sorted {
433 self.pending_midi_out_events.sort_by(|a, b| {
434 a.event
435 .frame
436 .cmp(&b.event.frame)
437 .then_with(|| a.device.cmp(&b.device))
438 });
439 self.pending_midi_out_sorted = true;
440 }
441 let midi_hub = self.midi_hub.lock();
442 midi_hub.write_events(&self.pending_midi_out_events);
443 self.pending_midi_out_events.clear();
444 }
445
446 fn shutdown_midi(&mut self, midi_handle: JoinHandle<()>) {
447 self.midi_stop.store(true, Ordering::Release);
448 {
449 let midi_hub = self.midi_hub.lock();
450 midi_hub.wake_input_waiter();
451 }
452 let _ = midi_handle.join();
453 {
454 let midi_hub = self.midi_hub.lock();
455 midi_hub.close_input_waiter();
456 }
457 }
458
459 fn shutdown_quit(&mut self, midi_handle: JoinHandle<()>) {
460 self.driver.lock().request_stop();
461 self.flush_pending_midi_out();
462 self.shutdown_midi(midi_handle);
463 self.driver.lock().request_stop();
464 }
465
466 fn shutdown_channel_closed(&mut self, midi_handle: JoinHandle<()>) {
467 self.driver.lock().request_stop();
468 self.shutdown_midi(midi_handle);
469 self.driver.lock().request_stop();
470 }
471
472 fn start_midi_input_thread(
473 midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
474 tx: Sender<Message>,
475 cycle_frames: u32,
476 stop: Arc<AtomicBool>,
477 ) -> JoinHandle<()> {
478 std::thread::spawn(move || {
479 crate::enable_flush_denormals_to_zero();
480 let mut midi_in_events = Vec::with_capacity(64);
481 while !stop.load(Ordering::Acquire) {
482 let ready_fds = {
483 let hub = midi_hub.lock();
484 hub.wait_ready_blocking()
485 };
486 if stop.load(Ordering::Acquire) {
487 break;
488 }
489 {
490 let hub = midi_hub.lock();
491 hub.read_events_for_fds(
492 ready_fds.as_deref().unwrap_or(&[]),
493 &mut midi_in_events,
494 );
495 }
496 if midi_in_events.is_empty() {
497 continue;
498 }
499 spread_hw_event_frames(&mut midi_in_events, cycle_frames);
500 let cap = midi_in_events.capacity();
501 let out = std::mem::replace(&mut midi_in_events, Vec::with_capacity(cap.max(64)));
502 if tx.blocking_send(Message::HWMidiEvents(out)).is_err() {
503 break;
504 }
505 }
506 })
507 }
508
509 fn start_assist_thread(
510 driver: Arc<UnsafeMutex<B::Driver>>,
511 assist_state: Arc<(Mutex<AssistState>, Condvar)>,
512 ) -> JoinHandle<()> {
513 let profile = Self::profile_enabled();
514 let autonomous = Self::assist_autonomous_enabled();
515 std::thread::spawn(move || {
516 crate::enable_flush_denormals_to_zero();
517 if let Err(e) = Self::configure_rt_thread(B::ASSIST_THREAD_NAME, RT_PRIORITY_ASSIST) {
518 error!("{} assist realtime priority not enabled: {}", B::LABEL, e);
519 }
520 #[cfg(target_os = "macos")]
521 unsafe {
522 libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INITIATED, 0);
523 }
524 let mut profiler = if profile {
525 let (cycle_samples, sample_rate) = {
526 let d = driver.lock();
527 (d.cycle_samples(), d.sample_rate())
528 };
529 error!(
530 "{} profile enabled: cycle_samples={} sample_rate={} expected_cps={:.1}",
531 B::LABEL,
532 cycle_samples,
533 sample_rate,
534 if cycle_samples > 0 {
535 sample_rate as f64 / cycle_samples as f64
536 } else {
537 0.0
538 }
539 );
540 Some(AssistProfiler::new())
541 } else {
542 None
543 };
544 let (lock, cvar) = &*assist_state;
545 loop {
546 let (shutdown, has_request, target, init_complete) = {
547 let st = lock.lock().expect("assist mutex poisoned");
548 (
549 st.shutdown,
550 st.request_seq > st.done_seq,
551 st.request_seq,
552 st.init_complete,
553 )
554 };
555 if shutdown {
556 break;
557 }
558 if has_request {
559 let started = Instant::now();
560 let run_error = {
561 let d = driver.lock();
562 d.run_cycle_for_worker().err().map(|e| e.to_string())
563 };
564 if let Some(p) = profiler.as_mut() {
565 p.cycle_count += 1;
566 if run_error.is_some() {
567 p.cycle_err_count += 1;
568 }
569 p.cycle_time_ns += started.elapsed().as_nanos();
570 let (cycle_samples, sample_rate) = {
571 let d = driver.lock();
572 (d.cycle_samples(), d.sample_rate())
573 };
574 p.maybe_report(cycle_samples, sample_rate, B::LABEL);
575 }
576 let mut st = lock.lock().expect("assist mutex poisoned");
577 st.done_seq = st.done_seq.max(target);
578 if run_error.is_none() {
579 st.init_complete = true;
580 }
581 st.last_error = run_error;
582 cvar.notify_all();
583 continue;
584 }
585
586 if B::ASSIST_STEP_REQUIRES_REQUEST_CYCLE && !init_complete {
587 let st = lock.lock().expect("assist mutex poisoned");
588 if st.shutdown {
589 break;
590 }
591 let wait_started = Instant::now();
592 let _guard = cvar.wait(st).expect("assist condvar failed");
593 if let Some(p) = profiler.as_mut() {
594 p.wait_count += 1;
595 p.wait_time_ns += wait_started.elapsed().as_nanos();
596 }
597 continue;
598 }
599
600 if !autonomous {
601 let st = lock.lock().expect("assist mutex poisoned");
602 if st.shutdown {
603 break;
604 }
605 let wait_started = Instant::now();
606 let _guard = cvar.wait(st).expect("assist condvar failed");
607 if let Some(p) = profiler.as_mut() {
608 p.wait_count += 1;
609 p.wait_time_ns += wait_started.elapsed().as_nanos();
610 }
611 continue;
612 }
613
614 let started = Instant::now();
615 let did_work = {
616 let d = driver.lock();
617 match d.run_assist_step_for_worker() {
618 Ok(v) => v,
619 Err(e) => {
620 if let Some(p) = profiler.as_mut() {
621 p.step_err_count += 1;
622 }
623 let mut st = lock.lock().expect("assist mutex poisoned");
624 st.last_error = Some(e.to_string());
625 cvar.notify_all();
626 false
627 }
628 }
629 };
630 if let Some(p) = profiler.as_mut() {
631 p.step_count += 1;
632 if did_work {
633 p.step_work_count += 1;
634 }
635 p.step_time_ns += started.elapsed().as_nanos();
636 let (cycle_samples, sample_rate) = {
637 let d = driver.lock();
638 (d.cycle_samples(), d.sample_rate())
639 };
640 p.maybe_report(cycle_samples, sample_rate, B::LABEL);
641 }
642 if !did_work {
643 let st = lock.lock().expect("assist mutex poisoned");
644 if st.shutdown {
645 break;
646 }
647 let wait_started = Instant::now();
648 let _guard = if autonomous {
649 cvar.wait_timeout(st, Duration::from_micros(100))
650 .expect("assist condvar failed")
651 .0
652 } else {
653 cvar.wait(st).expect("assist condvar failed")
654 };
655 if let Some(p) = profiler.as_mut() {
656 p.wait_count += 1;
657 p.wait_time_ns += wait_started.elapsed().as_nanos();
658 }
659 }
660 }
661 })
662 }
663
664 fn run_assist_cycle(
665 driver: &Arc<UnsafeMutex<B::Driver>>,
666 assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
667 ) -> Result<(), String> {
668 let autonomous =
669 Self::assist_autonomous_enabled() && B::CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS;
670 if autonomous {
671 let (lock, cvar) = &**assist_state;
672 {
673 let mut st = lock
674 .lock()
675 .map_err(|_| "assist mutex poisoned".to_string())?;
676 st.init_complete = true;
677 cvar.notify_one();
678 }
679 let result = driver.lock().run_cycle_for_worker();
680 {
681 let mut st = lock
682 .lock()
683 .map_err(|_| "assist mutex poisoned".to_string())?;
684 st.last_error = result.as_ref().err().map(|e| e.to_string());
685 cvar.notify_one();
686 }
687 return result;
688 }
689
690 let (lock, cvar) = &**assist_state;
691 let mut st = lock
692 .lock()
693 .map_err(|_| "assist mutex poisoned".to_string())?;
694 st.request_seq = st.request_seq.saturating_add(1);
695 let target = st.request_seq;
696 cvar.notify_one();
697 while st.done_seq < target && !st.shutdown {
698 st = cvar
699 .wait(st)
700 .map_err(|_| "assist condvar wait failed".to_string())?;
701 }
702 if let Some(err) = st.last_error.take() {
703 return Err(err);
704 }
705 Ok(())
706 }
707
708 fn stop_assist_thread(
709 assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
710 assist_handle: JoinHandle<()>,
711 ) {
712 let (lock, cvar) = &**assist_state;
713 if let Ok(mut st) = lock.lock() {
714 st.shutdown = true;
715 cvar.notify_all();
716 }
717 let _ = assist_handle.join();
718 }
719}
720
721fn spread_hw_event_frames(events: &mut [HwMidiEvent], frames: u32) {
722 if events.len() <= 1 || frames <= 1 {
723 return;
724 }
725 let n = events.len() as u32;
726 for (idx, event) in events.iter_mut().enumerate() {
727 let pos = idx as u32;
728 event.event.frame = ((pos as u64 * (frames - 1) as u64) / n as u64) as u32;
729 }
730}