Skip to main content

maolan_engine/workers/
hw_worker.rs

1use crate::{
2    hw::config,
3    hw::traits::{HwMidiHub, HwWorkerDriver},
4    message::{HwMidiEvent, Message},
5    mutex::UnsafeMutex,
6};
7#[cfg(unix)]
8use nix::libc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Condvar, Mutex};
11use std::thread::JoinHandle;
12use std::time::{Duration, Instant};
13use tokio::sync::mpsc::{Receiver, Sender};
14use tracing::error;
15
16pub trait Backend: Send + Sync + 'static {
17    type Driver: HwWorkerDriver + Send + 'static;
18    type MidiHub: HwMidiHub + Send + 'static;
19
20    const LABEL: &'static str;
21    const WORKER_THREAD_NAME: &'static str;
22    const ASSIST_THREAD_NAME: &'static str;
23    const ASSIST_AUTONOMOUS_ENV: &'static str;
24    const ASSIST_AUTONOMOUS_DEFAULT: bool = false;
25    const CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS: bool = false;
26    const ASSIST_STEP_REQUIRES_REQUEST_CYCLE: bool = false;
27}
28
29#[derive(Debug)]
30pub struct HwWorker<B: Backend> {
31    driver: Arc<UnsafeMutex<B::Driver>>,
32    midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
33    rx: Receiver<Message>,
34    tx: Sender<Message>,
35    cycle_frames: u32,
36    pending_midi_out_events: Vec<HwMidiEvent>,
37    pending_midi_out_sorted: bool,
38}
39
40#[derive(Debug, Default)]
41struct AssistState {
42    shutdown: bool,
43    request_seq: u64,
44    done_seq: u64,
45    init_complete: bool,
46    last_error: Option<String>,
47}
48
49#[cfg(unix)]
50const RT_POLICY: i32 = libc::SCHED_FIFO;
51const RT_PRIORITY_WORKER: i32 = 18;
52const RT_PRIORITY_ASSIST: i32 = 12;
53const PROFILE_INTERVAL: Duration = Duration::from_secs(1);
54
55#[derive(Debug)]
56struct AssistProfiler {
57    report_at: Instant,
58    cycle_count: u64,
59    cycle_err_count: u64,
60    cycle_time_ns: u128,
61    step_count: u64,
62    step_work_count: u64,
63    step_err_count: u64,
64    step_time_ns: u128,
65    wait_count: u64,
66    wait_time_ns: u128,
67}
68
69impl AssistProfiler {
70    fn new() -> Self {
71        Self {
72            report_at: Instant::now() + PROFILE_INTERVAL,
73            cycle_count: 0,
74            cycle_err_count: 0,
75            cycle_time_ns: 0,
76            step_count: 0,
77            step_work_count: 0,
78            step_err_count: 0,
79            step_time_ns: 0,
80            wait_count: 0,
81            wait_time_ns: 0,
82        }
83    }
84
85    fn maybe_report(&mut self, cycle_samples: usize, sample_rate: i32, label: &str) {
86        let now = Instant::now();
87        if now < self.report_at {
88            return;
89        }
90        let cycle_avg_us = if self.cycle_count > 0 {
91            (self.cycle_time_ns / self.cycle_count as u128) as f64 / 1_000.0
92        } else {
93            0.0
94        };
95        let step_avg_us = if self.step_count > 0 {
96            (self.step_time_ns / self.step_count as u128) as f64 / 1_000.0
97        } else {
98            0.0
99        };
100        let wait_avg_us = if self.wait_count > 0 {
101            (self.wait_time_ns / self.wait_count as u128) as f64 / 1_000.0
102        } else {
103            0.0
104        };
105        let expected_cycles_per_sec = if cycle_samples > 0 && sample_rate > 0 {
106            sample_rate as f64 / cycle_samples as f64
107        } else {
108            0.0
109        };
110        error!(
111            "{} profile: expected_cps={:.1} cycles={} cycle_err={} cycle_avg_us={:.1} steps={} steps_work={} step_err={} step_avg_us={:.1} waits={} wait_avg_us={:.1}",
112            label,
113            expected_cycles_per_sec,
114            self.cycle_count,
115            self.cycle_err_count,
116            cycle_avg_us,
117            self.step_count,
118            self.step_work_count,
119            self.step_err_count,
120            step_avg_us,
121            self.wait_count,
122            wait_avg_us
123        );
124        self.report_at = now + PROFILE_INTERVAL;
125        self.cycle_count = 0;
126        self.cycle_err_count = 0;
127        self.cycle_time_ns = 0;
128        self.step_count = 0;
129        self.step_work_count = 0;
130        self.step_err_count = 0;
131        self.step_time_ns = 0;
132        self.wait_count = 0;
133        self.wait_time_ns = 0;
134    }
135}
136
137impl<B: Backend> HwWorker<B> {
138    fn profile_enabled() -> bool {
139        config::env_flag(config::HW_PROFILE_ENV)
140    }
141
142    fn assist_autonomous_enabled() -> bool {
143        B::ASSIST_AUTONOMOUS_DEFAULT || config::env_flag(B::ASSIST_AUTONOMOUS_ENV)
144    }
145
146    fn configure_rt_thread(name: &str, priority: i32) -> Result<(), String> {
147        #[cfg(unix)]
148        {
149            let thread = unsafe { libc::pthread_self() };
150            #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "openbsd"))]
151            let c_name = std::ffi::CString::new(name).map_err(|e| e.to_string())?;
152            #[cfg(target_os = "linux")]
153            unsafe {
154                let _ = libc::pthread_setname_np(thread, c_name.as_ptr());
155            }
156            #[cfg(any(target_os = "freebsd", target_os = "openbsd"))]
157            unsafe {
158                libc::pthread_set_name_np(thread, c_name.as_ptr());
159            }
160
161            let param = unsafe {
162                let mut p = std::mem::zeroed::<libc::sched_param>();
163                p.sched_priority = priority;
164                p
165            };
166            let rc = unsafe { libc::pthread_setschedparam(thread, RT_POLICY, &param) };
167            if rc != 0 {
168                return Err(format!(
169                    "pthread_setschedparam({}, prio {}) failed with errno {}",
170                    name, priority, rc
171                ));
172            }
173
174            let mut actual_policy = 0_i32;
175            let mut actual_param = unsafe { std::mem::zeroed::<libc::sched_param>() };
176            let rc = unsafe {
177                libc::pthread_getschedparam(thread, &mut actual_policy, &mut actual_param)
178            };
179            if rc != 0 {
180                return Err(format!(
181                    "pthread_getschedparam({}) failed with errno {}",
182                    name, rc
183                ));
184            }
185            if actual_policy != RT_POLICY || actual_param.sched_priority != priority {
186                return Err(format!(
187                    "realtime verification failed for {}: policy {}, prio {}",
188                    name, actual_policy, actual_param.sched_priority
189                ));
190            }
191            Ok(())
192        }
193        #[cfg(not(unix))]
194        {
195            let _ = name;
196            let _ = priority;
197            Err("Realtime thread priority is not supported on this platform".to_string())
198        }
199    }
200
201    fn lock_memory_pages() -> Result<(), String> {
202        #[cfg(unix)]
203        {
204            let rc = unsafe { libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) };
205            if rc == 0 {
206                Ok(())
207            } else {
208                Err(format!(
209                    "mlockall(MCL_CURRENT|MCL_FUTURE) failed: {}",
210                    std::io::Error::last_os_error()
211                ))
212            }
213        }
214        #[cfg(not(unix))]
215        {
216            Err("mlockall is not supported on this platform".to_string())
217        }
218    }
219
220    pub fn new(
221        driver: Arc<UnsafeMutex<B::Driver>>,
222        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
223        rx: Receiver<Message>,
224        tx: Sender<Message>,
225    ) -> Self {
226        let cycle_frames = {
227            let d = driver.lock();
228            d.cycle_samples() as u32
229        };
230        Self {
231            driver,
232            midi_hub,
233            rx,
234            tx,
235            cycle_frames,
236            pending_midi_out_events: vec![],
237            pending_midi_out_sorted: true,
238        }
239    }
240
241    pub async fn work(mut self) {
242        crate::enable_flush_denormals_to_zero();
243        if let Err(e) = Self::lock_memory_pages() {
244            error!("{} worker memory lock not enabled: {}", B::LABEL, e);
245        }
246        if let Err(e) = Self::configure_rt_thread(B::WORKER_THREAD_NAME, RT_PRIORITY_WORKER) {
247            error!("{} worker realtime priority not enabled: {}", B::LABEL, e);
248        }
249        #[cfg(target_os = "macos")]
250        unsafe {
251            libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, 0);
252        }
253        let assist_state = Arc::new((Mutex::new(AssistState::default()), Condvar::new()));
254        let assist_handle = Self::start_assist_thread(self.driver.clone(), assist_state.clone());
255        let midi_stop = Arc::new(AtomicBool::new(false));
256        let midi_handle = Self::start_midi_input_thread(
257            self.midi_hub.clone(),
258            self.tx.clone(),
259            self.cycle_frames,
260            midi_stop.clone(),
261        );
262        loop {
263            match self.rx.recv().await {
264                Some(msg) => match msg {
265                    Message::Request(crate::message::Action::Quit) => {
266                        self.driver.lock().request_stop();
267                        if !self.pending_midi_out_events.is_empty() {
268                            if !self.pending_midi_out_sorted {
269                                self.pending_midi_out_events.sort_by(|a, b| {
270                                    a.event
271                                        .frame
272                                        .cmp(&b.event.frame)
273                                        .then_with(|| a.device.cmp(&b.device))
274                                });
275                                self.pending_midi_out_sorted = true;
276                            }
277                            let midi_hub = self.midi_hub.lock();
278                            midi_hub.write_events(&self.pending_midi_out_events);
279                            self.pending_midi_out_events.clear();
280                        }
281                        midi_stop.store(true, Ordering::Release);
282                        {
283                            let midi_hub = self.midi_hub.lock();
284                            midi_hub.wake_input_waiter();
285                        }
286                        let _ = midi_handle.join();
287                        Self::stop_assist_thread(&assist_state, assist_handle);
288                        return;
289                    }
290                    Message::TracksFinished => {
291                        {
292                            if !self.pending_midi_out_events.is_empty() {
293                                if !self.pending_midi_out_sorted {
294                                    self.pending_midi_out_events.sort_by(|a, b| {
295                                        a.event
296                                            .frame
297                                            .cmp(&b.event.frame)
298                                            .then_with(|| a.device.cmp(&b.device))
299                                    });
300                                    self.pending_midi_out_sorted = true;
301                                }
302                                let midi_hub = self.midi_hub.lock();
303                                midi_hub.write_events(&self.pending_midi_out_events);
304                                self.pending_midi_out_events.clear();
305                            }
306                        }
307                        if let Err(e) = Self::run_assist_cycle(&self.driver, &assist_state) {
308                            error!("{} assist cycle error: {}", B::LABEL, e);
309                            let _ = self
310                                .tx
311                                .send(Message::Response(Err(format!(
312                                    "{} assist cycle error: {}",
313                                    B::LABEL,
314                                    e
315                                ))))
316                                .await;
317                        }
318                        if let Err(e) = self.tx.send(Message::HWFinished).await {
319                            error!(
320                                "{} worker failed to send HWFinished to engine: {}",
321                                B::LABEL,
322                                e
323                            );
324                        }
325                    }
326                    Message::HWMidiOutEvents(mut events) => {
327                        self.pending_midi_out_events.append(&mut events);
328                        self.pending_midi_out_sorted = false;
329                    }
330                    Message::ClearHWMidiOutEvents => {
331                        self.pending_midi_out_events.clear();
332                        self.pending_midi_out_sorted = true;
333                    }
334                    _ => {}
335                },
336                None => {
337                    self.driver.lock().request_stop();
338                    midi_stop.store(true, Ordering::Release);
339                    {
340                        let midi_hub = self.midi_hub.lock();
341                        midi_hub.wake_input_waiter();
342                    }
343                    let _ = midi_handle.join();
344                    Self::stop_assist_thread(&assist_state, assist_handle);
345                    return;
346                }
347            }
348        }
349    }
350
351    fn start_midi_input_thread(
352        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
353        tx: Sender<Message>,
354        cycle_frames: u32,
355        stop: Arc<AtomicBool>,
356    ) -> JoinHandle<()> {
357        std::thread::spawn(move || {
358            crate::enable_flush_denormals_to_zero();
359            let mut midi_in_events = Vec::with_capacity(64);
360            while !stop.load(Ordering::Acquire) {
361                {
362                    let hub = midi_hub.lock();
363                    hub.read_events_blocking_into(&mut midi_in_events);
364                }
365                if midi_in_events.is_empty() {
366                    continue;
367                }
368                spread_hw_event_frames(&mut midi_in_events, cycle_frames);
369                let cap = midi_in_events.capacity();
370                let out = std::mem::replace(&mut midi_in_events, Vec::with_capacity(cap.max(64)));
371                if tx.blocking_send(Message::HWMidiEvents(out)).is_err() {
372                    break;
373                }
374            }
375        })
376    }
377
378    fn start_assist_thread(
379        driver: Arc<UnsafeMutex<B::Driver>>,
380        assist_state: Arc<(Mutex<AssistState>, Condvar)>,
381    ) -> JoinHandle<()> {
382        let profile = Self::profile_enabled();
383        let autonomous = Self::assist_autonomous_enabled();
384        std::thread::spawn(move || {
385            crate::enable_flush_denormals_to_zero();
386            if let Err(e) = Self::configure_rt_thread(B::ASSIST_THREAD_NAME, RT_PRIORITY_ASSIST) {
387                error!("{} assist realtime priority not enabled: {}", B::LABEL, e);
388            }
389            #[cfg(target_os = "macos")]
390            unsafe {
391                libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INITIATED, 0);
392            }
393            let mut profiler = if profile {
394                let (cycle_samples, sample_rate) = {
395                    let d = driver.lock();
396                    (d.cycle_samples(), d.sample_rate())
397                };
398                error!(
399                    "{} profile enabled: cycle_samples={} sample_rate={} expected_cps={:.1}",
400                    B::LABEL,
401                    cycle_samples,
402                    sample_rate,
403                    if cycle_samples > 0 {
404                        sample_rate as f64 / cycle_samples as f64
405                    } else {
406                        0.0
407                    }
408                );
409                Some(AssistProfiler::new())
410            } else {
411                None
412            };
413            let (lock, cvar) = &*assist_state;
414            loop {
415                let (shutdown, has_request, target, init_complete) = {
416                    let st = lock.lock().expect("assist mutex poisoned");
417                    (
418                        st.shutdown,
419                        st.request_seq > st.done_seq,
420                        st.request_seq,
421                        st.init_complete,
422                    )
423                };
424                if shutdown {
425                    break;
426                }
427                if has_request {
428                    let started = Instant::now();
429                    let run_error = {
430                        let d = driver.lock();
431                        d.run_cycle_for_worker().err().map(|e| e.to_string())
432                    };
433                    if let Some(p) = profiler.as_mut() {
434                        p.cycle_count += 1;
435                        if run_error.is_some() {
436                            p.cycle_err_count += 1;
437                        }
438                        p.cycle_time_ns += started.elapsed().as_nanos();
439                        let (cycle_samples, sample_rate) = {
440                            let d = driver.lock();
441                            (d.cycle_samples(), d.sample_rate())
442                        };
443                        p.maybe_report(cycle_samples, sample_rate, B::LABEL);
444                    }
445                    let mut st = lock.lock().expect("assist mutex poisoned");
446                    st.done_seq = st.done_seq.max(target);
447                    if run_error.is_none() {
448                        st.init_complete = true;
449                    }
450                    st.last_error = run_error;
451                    cvar.notify_all();
452                    continue;
453                }
454
455                if B::ASSIST_STEP_REQUIRES_REQUEST_CYCLE && !init_complete {
456                    let st = lock.lock().expect("assist mutex poisoned");
457                    if st.shutdown {
458                        break;
459                    }
460                    let wait_started = Instant::now();
461                    let _guard = cvar.wait(st).expect("assist condvar failed");
462                    if let Some(p) = profiler.as_mut() {
463                        p.wait_count += 1;
464                        p.wait_time_ns += wait_started.elapsed().as_nanos();
465                    }
466                    continue;
467                }
468
469                if !autonomous {
470                    let st = lock.lock().expect("assist mutex poisoned");
471                    if st.shutdown {
472                        break;
473                    }
474                    let wait_started = Instant::now();
475                    let _guard = cvar.wait(st).expect("assist condvar failed");
476                    if let Some(p) = profiler.as_mut() {
477                        p.wait_count += 1;
478                        p.wait_time_ns += wait_started.elapsed().as_nanos();
479                    }
480                    continue;
481                }
482
483                let started = Instant::now();
484                let did_work = {
485                    let d = driver.lock();
486                    match d.run_assist_step_for_worker() {
487                        Ok(v) => v,
488                        Err(e) => {
489                            if let Some(p) = profiler.as_mut() {
490                                p.step_err_count += 1;
491                            }
492                            let mut st = lock.lock().expect("assist mutex poisoned");
493                            st.last_error = Some(e.to_string());
494                            cvar.notify_all();
495                            false
496                        }
497                    }
498                };
499                if let Some(p) = profiler.as_mut() {
500                    p.step_count += 1;
501                    if did_work {
502                        p.step_work_count += 1;
503                    }
504                    p.step_time_ns += started.elapsed().as_nanos();
505                    let (cycle_samples, sample_rate) = {
506                        let d = driver.lock();
507                        (d.cycle_samples(), d.sample_rate())
508                    };
509                    p.maybe_report(cycle_samples, sample_rate, B::LABEL);
510                }
511                if !did_work {
512                    let st = lock.lock().expect("assist mutex poisoned");
513                    if st.shutdown {
514                        break;
515                    }
516                    let wait_started = Instant::now();
517                    let _guard = if autonomous {
518                        cvar.wait_timeout(st, Duration::from_micros(100))
519                            .expect("assist condvar failed")
520                            .0
521                    } else {
522                        cvar.wait(st).expect("assist condvar failed")
523                    };
524                    if let Some(p) = profiler.as_mut() {
525                        p.wait_count += 1;
526                        p.wait_time_ns += wait_started.elapsed().as_nanos();
527                    }
528                }
529            }
530        })
531    }
532
533    fn run_assist_cycle(
534        driver: &Arc<UnsafeMutex<B::Driver>>,
535        assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
536    ) -> Result<(), String> {
537        if Self::assist_autonomous_enabled() && B::CYCLE_ON_WORKER_WHEN_ASSIST_AUTONOMOUS {
538            let (lock, cvar) = &**assist_state;
539            {
540                let mut st = lock
541                    .lock()
542                    .map_err(|_| "assist mutex poisoned".to_string())?;
543                st.init_complete = true;
544                cvar.notify_one();
545            }
546            let result = driver.lock().run_cycle_for_worker();
547            {
548                let mut st = lock
549                    .lock()
550                    .map_err(|_| "assist mutex poisoned".to_string())?;
551                st.last_error = result.as_ref().err().map(|e| e.to_string());
552                cvar.notify_one();
553            }
554            return result;
555        }
556
557        let (lock, cvar) = &**assist_state;
558        let mut st = lock
559            .lock()
560            .map_err(|_| "assist mutex poisoned".to_string())?;
561        st.request_seq = st.request_seq.saturating_add(1);
562        let target = st.request_seq;
563        cvar.notify_one();
564        while st.done_seq < target && !st.shutdown {
565            st = cvar
566                .wait(st)
567                .map_err(|_| "assist condvar wait failed".to_string())?;
568        }
569        if let Some(err) = st.last_error.take() {
570            return Err(err);
571        }
572        Ok(())
573    }
574
575    fn stop_assist_thread(
576        assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
577        assist_handle: JoinHandle<()>,
578    ) {
579        let (lock, cvar) = &**assist_state;
580        if let Ok(mut st) = lock.lock() {
581            st.shutdown = true;
582            cvar.notify_all();
583        }
584        let _ = assist_handle.join();
585    }
586}
587
588fn spread_hw_event_frames(events: &mut [HwMidiEvent], frames: u32) {
589    if events.len() <= 1 || frames <= 1 {
590        return;
591    }
592    let n = events.len() as u32;
593    for (idx, event) in events.iter_mut().enumerate() {
594        let pos = idx as u32;
595        event.event.frame = ((pos as u64 * (frames - 1) as u64) / n as u64) as u32;
596    }
597}