Skip to main content

maolan_engine/workers/
hw_worker.rs

1use crate::{
2    hw::config,
3    hw::traits::{HwMidiHub, HwWorkerDriver},
4    message::{HwMidiEvent, Message},
5    mutex::UnsafeMutex,
6};
7#[cfg(unix)]
8use nix::libc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Condvar, Mutex};
11use std::thread::JoinHandle;
12use std::time::{Duration, Instant};
13use tokio::sync::mpsc::{Receiver, Sender};
14use tracing::error;
15
16pub trait Backend: Send + Sync + 'static {
17    type Driver: HwWorkerDriver + Send + 'static;
18    type MidiHub: HwMidiHub + Send + 'static;
19
20    const LABEL: &'static str;
21    const WORKER_THREAD_NAME: &'static str;
22    const ASSIST_THREAD_NAME: &'static str;
23    const ASSIST_AUTONOMOUS_ENV: &'static str;
24    const ASSIST_AUTONOMOUS_DEFAULT: bool = false;
25    const CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS: bool = false;
26    const ASSIST_STEP_REQUIRES_REQUEST_CYCLE: bool = false;
27}
28
29#[derive(Debug)]
30pub struct HwWorker<B: Backend> {
31    driver: Arc<UnsafeMutex<B::Driver>>,
32    midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
33    rx: Receiver<Message>,
34    tx: Sender<Message>,
35    cycle_frames: u32,
36    pending_midi_out_events: Vec<HwMidiEvent>,
37    pending_midi_out_sorted: bool,
38    midi_stop: Arc<AtomicBool>,
39    assist_state: Arc<(Mutex<AssistState>, Condvar)>,
40}
41
42impl<B: Backend> Drop for HwWorker<B> {
43    fn drop(&mut self) {
44        self.driver.lock().request_stop();
45        self.midi_stop.store(true, Ordering::Release);
46        {
47            let midi_hub = self.midi_hub.lock();
48            midi_hub.wake_input_waiter();
49            midi_hub.close_input_waiter();
50        }
51        {
52            let (lock, cvar) = &*self.assist_state;
53            if let Ok(mut st) = lock.lock() {
54                st.shutdown = true;
55                cvar.notify_one();
56            }
57        }
58    }
59}
60
61#[derive(Debug, Default)]
62struct AssistState {
63    shutdown: bool,
64    request_seq: u64,
65    done_seq: u64,
66    init_complete: bool,
67    last_error: Option<String>,
68}
69
70#[cfg(unix)]
71const RT_POLICY: i32 = libc::SCHED_FIFO;
72const RT_PRIORITY_WORKER: i32 = 18;
73const RT_PRIORITY_ASSIST: i32 = 12;
74const PROFILE_INTERVAL: Duration = Duration::from_secs(1);
75
76#[derive(Debug)]
77struct AssistProfiler {
78    report_at: Instant,
79    cycle_count: u64,
80    cycle_err_count: u64,
81    cycle_time_ns: u128,
82    step_count: u64,
83    step_work_count: u64,
84    step_err_count: u64,
85    step_time_ns: u128,
86    wait_count: u64,
87    wait_time_ns: u128,
88}
89
90impl AssistProfiler {
91    fn new() -> Self {
92        Self {
93            report_at: Instant::now() + PROFILE_INTERVAL,
94            cycle_count: 0,
95            cycle_err_count: 0,
96            cycle_time_ns: 0,
97            step_count: 0,
98            step_work_count: 0,
99            step_err_count: 0,
100            step_time_ns: 0,
101            wait_count: 0,
102            wait_time_ns: 0,
103        }
104    }
105
106    fn maybe_report(&mut self, cycle_samples: usize, sample_rate: i32, label: &str) {
107        let now = Instant::now();
108        if now < self.report_at {
109            return;
110        }
111        let cycle_avg_us = if self.cycle_count > 0 {
112            (self.cycle_time_ns / self.cycle_count as u128) as f64 / 1_000.0
113        } else {
114            0.0
115        };
116        let step_avg_us = if self.step_count > 0 {
117            (self.step_time_ns / self.step_count as u128) as f64 / 1_000.0
118        } else {
119            0.0
120        };
121        let wait_avg_us = if self.wait_count > 0 {
122            (self.wait_time_ns / self.wait_count as u128) as f64 / 1_000.0
123        } else {
124            0.0
125        };
126        let expected_cycles_per_sec = if cycle_samples > 0 && sample_rate > 0 {
127            sample_rate as f64 / cycle_samples as f64
128        } else {
129            0.0
130        };
131        error!(
132            "{} profile: expected_cps={:.1} cycles={} cycle_err={} cycle_avg_us={:.1} steps={} steps_work={} step_err={} step_avg_us={:.1} waits={} wait_avg_us={:.1}",
133            label,
134            expected_cycles_per_sec,
135            self.cycle_count,
136            self.cycle_err_count,
137            cycle_avg_us,
138            self.step_count,
139            self.step_work_count,
140            self.step_err_count,
141            step_avg_us,
142            self.wait_count,
143            wait_avg_us
144        );
145        self.report_at = now + PROFILE_INTERVAL;
146        self.cycle_count = 0;
147        self.cycle_err_count = 0;
148        self.cycle_time_ns = 0;
149        self.step_count = 0;
150        self.step_work_count = 0;
151        self.step_err_count = 0;
152        self.step_time_ns = 0;
153        self.wait_count = 0;
154        self.wait_time_ns = 0;
155    }
156}
157
158impl<B: Backend> HwWorker<B> {
159    fn profile_enabled() -> bool {
160        config::env_flag(config::HW_PROFILE_ENV)
161    }
162
163    fn assist_autonomous_enabled() -> bool {
164        B::ASSIST_AUTONOMOUS_DEFAULT || config::env_flag(B::ASSIST_AUTONOMOUS_ENV)
165    }
166
167    fn configure_rt_thread(name: &str, priority: i32) -> Result<(), String> {
168        #[cfg(unix)]
169        {
170            let thread = unsafe { libc::pthread_self() };
171            #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "openbsd"))]
172            let c_name = std::ffi::CString::new(name).map_err(|e| e.to_string())?;
173            #[cfg(target_os = "linux")]
174            unsafe {
175                let _ = libc::pthread_setname_np(thread, c_name.as_ptr());
176            }
177            #[cfg(any(target_os = "freebsd", target_os = "openbsd"))]
178            unsafe {
179                libc::pthread_set_name_np(thread, c_name.as_ptr());
180            }
181
182            let param = unsafe {
183                let mut p = std::mem::zeroed::<libc::sched_param>();
184                p.sched_priority = priority;
185                p
186            };
187            let rc = unsafe { libc::pthread_setschedparam(thread, RT_POLICY, &param) };
188            if rc != 0 {
189                return Err(format!(
190                    "pthread_setschedparam({}, prio {}) failed with errno {}",
191                    name, priority, rc
192                ));
193            }
194
195            let mut actual_policy = 0_i32;
196            let mut actual_param = unsafe { std::mem::zeroed::<libc::sched_param>() };
197            let rc = unsafe {
198                libc::pthread_getschedparam(thread, &mut actual_policy, &mut actual_param)
199            };
200            if rc != 0 {
201                return Err(format!(
202                    "pthread_getschedparam({}) failed with errno {}",
203                    name, rc
204                ));
205            }
206            if actual_policy != RT_POLICY || actual_param.sched_priority != priority {
207                return Err(format!(
208                    "realtime verification failed for {}: policy {}, prio {}",
209                    name, actual_policy, actual_param.sched_priority
210                ));
211            }
212            Ok(())
213        }
214        #[cfg(not(unix))]
215        {
216            let _ = name;
217            let _ = priority;
218            Err("Realtime thread priority is not supported on this platform".to_string())
219        }
220    }
221
222    fn lock_memory_pages() -> Result<(), String> {
223        #[cfg(unix)]
224        {
225            let rc = unsafe { libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) };
226            if rc == 0 {
227                Ok(())
228            } else {
229                Err(format!(
230                    "mlockall(MCL_CURRENT|MCL_FUTURE) failed: {}",
231                    std::io::Error::last_os_error()
232                ))
233            }
234        }
235        #[cfg(not(unix))]
236        {
237            Err("mlockall is not supported on this platform".to_string())
238        }
239    }
240
241    pub fn new(
242        driver: Arc<UnsafeMutex<B::Driver>>,
243        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
244        rx: Receiver<Message>,
245        tx: Sender<Message>,
246    ) -> Self {
247        let cycle_frames = {
248            let d = driver.lock();
249            d.cycle_samples() as u32
250        };
251        Self {
252            driver,
253            midi_hub,
254            rx,
255            tx,
256            cycle_frames,
257            pending_midi_out_events: vec![],
258            pending_midi_out_sorted: true,
259            midi_stop: Arc::new(AtomicBool::new(false)),
260            assist_state: Arc::new((Mutex::new(AssistState::default()), Condvar::new())),
261        }
262    }
263
264    pub async fn work(mut self) {
265        crate::enable_flush_denormals_to_zero();
266        if let Err(e) = Self::lock_memory_pages() {
267            error!("{} worker memory lock not enabled: {}", B::LABEL, e);
268        }
269        if let Err(e) = Self::configure_rt_thread(B::WORKER_THREAD_NAME, RT_PRIORITY_WORKER) {
270            error!("{} worker realtime priority not enabled: {}", B::LABEL, e);
271        }
272        #[cfg(target_os = "macos")]
273        unsafe {
274            libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, 0);
275        }
276        let assist_handle =
277            Self::start_assist_thread(self.driver.clone(), self.assist_state.clone());
278        let midi_handle = Self::start_midi_input_thread(
279            self.midi_hub.clone(),
280            self.tx.clone(),
281            self.cycle_frames,
282            self.midi_stop.clone(),
283        );
284        loop {
285            let msg = match self.rx.recv().await {
286                Some(msg) => msg,
287                None => {
288                    self.driver.lock().request_stop();
289                    self.midi_stop.store(true, Ordering::Release);
290                    {
291                        let midi_hub = self.midi_hub.lock();
292                        midi_hub.wake_input_waiter();
293                    }
294                    let _ = midi_handle.join();
295                    {
296                        let midi_hub = self.midi_hub.lock();
297                        midi_hub.close_input_waiter();
298                    }
299                    Self::stop_assist_thread(&self.assist_state, assist_handle);
300                    return;
301                }
302            };
303            match msg {
304                Message::Request(crate::message::Action::Quit) => {
305                    self.driver.lock().request_stop();
306                    if !self.pending_midi_out_events.is_empty() {
307                        if !self.pending_midi_out_sorted {
308                            self.pending_midi_out_events.sort_by(|a, b| {
309                                a.event
310                                    .frame
311                                    .cmp(&b.event.frame)
312                                    .then_with(|| a.device.cmp(&b.device))
313                            });
314                            self.pending_midi_out_sorted = true;
315                        }
316                        let midi_hub = self.midi_hub.lock();
317                        midi_hub.write_events(&self.pending_midi_out_events);
318                        self.pending_midi_out_events.clear();
319                    }
320                    self.midi_stop.store(true, Ordering::Release);
321                    {
322                        let midi_hub = self.midi_hub.lock();
323                        midi_hub.wake_input_waiter();
324                    }
325                    let _ = midi_handle.join();
326                    {
327                        let midi_hub = self.midi_hub.lock();
328                        midi_hub.close_input_waiter();
329                    }
330                    Self::stop_assist_thread(&self.assist_state, assist_handle);
331                    return;
332                }
333                Message::TracksFinished => {
334                    {
335                        if !self.pending_midi_out_events.is_empty() {
336                            if !self.pending_midi_out_sorted {
337                                self.pending_midi_out_events.sort_by(|a, b| {
338                                    a.event
339                                        .frame
340                                        .cmp(&b.event.frame)
341                                        .then_with(|| a.device.cmp(&b.device))
342                                });
343                                self.pending_midi_out_sorted = true;
344                            }
345                            let midi_hub = self.midi_hub.lock();
346                            midi_hub.write_events(&self.pending_midi_out_events);
347                            self.pending_midi_out_events.clear();
348                        }
349                    }
350                    if let Err(e) = Self::run_assist_cycle(&self.driver, &self.assist_state) {
351                        error!("{} assist cycle error: {}", B::LABEL, e);
352                        let _ = self
353                            .tx
354                            .send(Message::Response(Err(format!(
355                                "{} assist cycle error: {}",
356                                B::LABEL,
357                                e
358                            ))))
359                            .await;
360                    }
361                    if let Err(e) = self.tx.send(Message::HWFinished).await {
362                        error!(
363                            "{} worker failed to send HWFinished to engine: {}",
364                            B::LABEL,
365                            e
366                        );
367                    }
368                }
369                Message::HWMidiOutEvents(mut events) => {
370                    self.pending_midi_out_events.append(&mut events);
371                    self.pending_midi_out_sorted = false;
372                }
373                Message::ClearHWMidiOutEvents => {
374                    self.pending_midi_out_events.clear();
375                    self.pending_midi_out_sorted = true;
376                }
377                _ => {}
378            }
379        }
380    }
381
382    fn start_midi_input_thread(
383        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
384        tx: Sender<Message>,
385        cycle_frames: u32,
386        stop: Arc<AtomicBool>,
387    ) -> JoinHandle<()> {
388        std::thread::spawn(move || {
389            crate::enable_flush_denormals_to_zero();
390            let mut midi_in_events = Vec::with_capacity(64);
391            while !stop.load(Ordering::Acquire) {
392                let ready_fds = {
393                    let hub = midi_hub.lock();
394                    hub.wait_ready_blocking()
395                };
396                if stop.load(Ordering::Acquire) {
397                    break;
398                }
399                {
400                    let hub = midi_hub.lock();
401                    hub.read_events_for_fds(
402                        ready_fds.as_deref().unwrap_or(&[]),
403                        &mut midi_in_events,
404                    );
405                }
406                if midi_in_events.is_empty() {
407                    continue;
408                }
409                spread_hw_event_frames(&mut midi_in_events, cycle_frames);
410                let cap = midi_in_events.capacity();
411                let out = std::mem::replace(&mut midi_in_events, Vec::with_capacity(cap.max(64)));
412                if tx.blocking_send(Message::HWMidiEvents(out)).is_err() {
413                    break;
414                }
415            }
416        })
417    }
418
419    fn start_assist_thread(
420        driver: Arc<UnsafeMutex<B::Driver>>,
421        assist_state: Arc<(Mutex<AssistState>, Condvar)>,
422    ) -> JoinHandle<()> {
423        let profile = Self::profile_enabled();
424        let autonomous = Self::assist_autonomous_enabled();
425        std::thread::spawn(move || {
426            crate::enable_flush_denormals_to_zero();
427            if let Err(e) = Self::configure_rt_thread(B::ASSIST_THREAD_NAME, RT_PRIORITY_ASSIST) {
428                error!("{} assist realtime priority not enabled: {}", B::LABEL, e);
429            }
430            #[cfg(target_os = "macos")]
431            unsafe {
432                libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INITIATED, 0);
433            }
434            let mut profiler = if profile {
435                let (cycle_samples, sample_rate) = {
436                    let d = driver.lock();
437                    (d.cycle_samples(), d.sample_rate())
438                };
439                error!(
440                    "{} profile enabled: cycle_samples={} sample_rate={} expected_cps={:.1}",
441                    B::LABEL,
442                    cycle_samples,
443                    sample_rate,
444                    if cycle_samples > 0 {
445                        sample_rate as f64 / cycle_samples as f64
446                    } else {
447                        0.0
448                    }
449                );
450                Some(AssistProfiler::new())
451            } else {
452                None
453            };
454            let (lock, cvar) = &*assist_state;
455            loop {
456                let (shutdown, has_request, target, init_complete) = {
457                    let st = lock.lock().expect("assist mutex poisoned");
458                    (
459                        st.shutdown,
460                        st.request_seq > st.done_seq,
461                        st.request_seq,
462                        st.init_complete,
463                    )
464                };
465                if shutdown {
466                    break;
467                }
468                if has_request {
469                    let started = Instant::now();
470                    let run_error = {
471                        let d = driver.lock();
472                        d.run_cycle_for_worker().err().map(|e| e.to_string())
473                    };
474                    if let Some(p) = profiler.as_mut() {
475                        p.cycle_count += 1;
476                        if run_error.is_some() {
477                            p.cycle_err_count += 1;
478                        }
479                        p.cycle_time_ns += started.elapsed().as_nanos();
480                        let (cycle_samples, sample_rate) = {
481                            let d = driver.lock();
482                            (d.cycle_samples(), d.sample_rate())
483                        };
484                        p.maybe_report(cycle_samples, sample_rate, B::LABEL);
485                    }
486                    let mut st = lock.lock().expect("assist mutex poisoned");
487                    st.done_seq = st.done_seq.max(target);
488                    if run_error.is_none() {
489                        st.init_complete = true;
490                    }
491                    st.last_error = run_error;
492                    cvar.notify_all();
493                    continue;
494                }
495
496                if B::ASSIST_STEP_REQUIRES_REQUEST_CYCLE && !init_complete {
497                    let st = lock.lock().expect("assist mutex poisoned");
498                    if st.shutdown {
499                        break;
500                    }
501                    let wait_started = Instant::now();
502                    let _guard = cvar.wait(st).expect("assist condvar failed");
503                    if let Some(p) = profiler.as_mut() {
504                        p.wait_count += 1;
505                        p.wait_time_ns += wait_started.elapsed().as_nanos();
506                    }
507                    continue;
508                }
509
510                if !autonomous {
511                    let st = lock.lock().expect("assist mutex poisoned");
512                    if st.shutdown {
513                        break;
514                    }
515                    let wait_started = Instant::now();
516                    let _guard = cvar.wait(st).expect("assist condvar failed");
517                    if let Some(p) = profiler.as_mut() {
518                        p.wait_count += 1;
519                        p.wait_time_ns += wait_started.elapsed().as_nanos();
520                    }
521                    continue;
522                }
523
524                let started = Instant::now();
525                let did_work = {
526                    let d = driver.lock();
527                    match d.run_assist_step_for_worker() {
528                        Ok(v) => v,
529                        Err(e) => {
530                            if let Some(p) = profiler.as_mut() {
531                                p.step_err_count += 1;
532                            }
533                            let mut st = lock.lock().expect("assist mutex poisoned");
534                            st.last_error = Some(e.to_string());
535                            cvar.notify_all();
536                            false
537                        }
538                    }
539                };
540                if let Some(p) = profiler.as_mut() {
541                    p.step_count += 1;
542                    if did_work {
543                        p.step_work_count += 1;
544                    }
545                    p.step_time_ns += started.elapsed().as_nanos();
546                    let (cycle_samples, sample_rate) = {
547                        let d = driver.lock();
548                        (d.cycle_samples(), d.sample_rate())
549                    };
550                    p.maybe_report(cycle_samples, sample_rate, B::LABEL);
551                }
552                if !did_work {
553                    let st = lock.lock().expect("assist mutex poisoned");
554                    if st.shutdown {
555                        break;
556                    }
557                    let wait_started = Instant::now();
558                    let _guard = if autonomous {
559                        cvar.wait_timeout(st, Duration::from_micros(100))
560                            .expect("assist condvar failed")
561                            .0
562                    } else {
563                        cvar.wait(st).expect("assist condvar failed")
564                    };
565                    if let Some(p) = profiler.as_mut() {
566                        p.wait_count += 1;
567                        p.wait_time_ns += wait_started.elapsed().as_nanos();
568                    }
569                }
570            }
571        })
572    }
573
574    fn run_assist_cycle(
575        driver: &Arc<UnsafeMutex<B::Driver>>,
576        assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
577    ) -> Result<(), String> {
578        let autonomous =
579            Self::assist_autonomous_enabled() && B::CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS;
580        if autonomous {
581            let (lock, cvar) = &**assist_state;
582            {
583                let mut st = lock
584                    .lock()
585                    .map_err(|_| "assist mutex poisoned".to_string())?;
586                st.init_complete = true;
587                cvar.notify_one();
588            }
589            let result = driver.lock().run_cycle_for_worker();
590            {
591                let mut st = lock
592                    .lock()
593                    .map_err(|_| "assist mutex poisoned".to_string())?;
594                st.last_error = result.as_ref().err().map(|e| e.to_string());
595                cvar.notify_one();
596            }
597            return result;
598        }
599
600        let (lock, cvar) = &**assist_state;
601        let mut st = lock
602            .lock()
603            .map_err(|_| "assist mutex poisoned".to_string())?;
604        st.request_seq = st.request_seq.saturating_add(1);
605        let target = st.request_seq;
606        cvar.notify_one();
607        while st.done_seq < target && !st.shutdown {
608            st = cvar
609                .wait(st)
610                .map_err(|_| "assist condvar wait failed".to_string())?;
611        }
612        if let Some(err) = st.last_error.take() {
613            return Err(err);
614        }
615        Ok(())
616    }
617
618    fn stop_assist_thread(
619        assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
620        assist_handle: JoinHandle<()>,
621    ) {
622        let (lock, cvar) = &**assist_state;
623        if let Ok(mut st) = lock.lock() {
624            st.shutdown = true;
625            cvar.notify_all();
626        }
627        let _ = assist_handle.join();
628    }
629}
630
631fn spread_hw_event_frames(events: &mut [HwMidiEvent], frames: u32) {
632    if events.len() <= 1 || frames <= 1 {
633        return;
634    }
635    let n = events.len() as u32;
636    for (idx, event) in events.iter_mut().enumerate() {
637        let pos = idx as u32;
638        event.event.frame = ((pos as u64 * (frames - 1) as u64) / n as u64) as u32;
639    }
640}