Skip to main content

maolan_engine/workers/
hw_worker.rs

1use crate::{
2    hw::config,
3    hw::traits::{HwMidiHub, HwWorkerDriver},
4    message::{HwMidiEvent, Message},
5    mutex::UnsafeMutex,
6};
7#[cfg(unix)]
8use nix::libc;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Condvar, Mutex};
11use std::thread::JoinHandle;
12use std::time::{Duration, Instant};
13use tokio::sync::mpsc::{Receiver, Sender};
14use tracing::error;
15
16pub trait Backend: Send + Sync + 'static {
17    type Driver: HwWorkerDriver + Send + 'static;
18    type MidiHub: HwMidiHub + Send + 'static;
19
20    const LABEL: &'static str;
21    const WORKER_THREAD_NAME: &'static str;
22    const ASSIST_THREAD_NAME: &'static str;
23    const ASSIST_AUTONOMOUS_ENV: &'static str;
24    const ASSIST_STEP_REQUIRES_REQUEST_CYCLE: bool = false;
25}
26
27#[derive(Debug)]
28pub struct HwWorker<B: Backend> {
29    driver: Arc<UnsafeMutex<B::Driver>>,
30    midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
31    rx: Receiver<Message>,
32    tx: Sender<Message>,
33    cycle_frames: u32,
34    pending_midi_out_events: Vec<HwMidiEvent>,
35    pending_midi_out_sorted: bool,
36}
37
38#[derive(Debug, Default)]
39struct AssistState {
40    shutdown: bool,
41    request_seq: u64,
42    done_seq: u64,
43    init_complete: bool,
44    last_error: Option<String>,
45}
46
47#[cfg(unix)]
48const RT_POLICY: i32 = libc::SCHED_FIFO;
49const RT_PRIORITY_WORKER: i32 = 18;
50const RT_PRIORITY_ASSIST: i32 = 12;
51const PROFILE_INTERVAL: Duration = Duration::from_secs(1);
52
53#[derive(Debug)]
54struct AssistProfiler {
55    report_at: Instant,
56    cycle_count: u64,
57    cycle_err_count: u64,
58    cycle_time_ns: u128,
59    step_count: u64,
60    step_work_count: u64,
61    step_err_count: u64,
62    step_time_ns: u128,
63    wait_count: u64,
64    wait_time_ns: u128,
65}
66
67impl AssistProfiler {
68    fn new() -> Self {
69        Self {
70            report_at: Instant::now() + PROFILE_INTERVAL,
71            cycle_count: 0,
72            cycle_err_count: 0,
73            cycle_time_ns: 0,
74            step_count: 0,
75            step_work_count: 0,
76            step_err_count: 0,
77            step_time_ns: 0,
78            wait_count: 0,
79            wait_time_ns: 0,
80        }
81    }
82
83    fn maybe_report(&mut self, cycle_samples: usize, sample_rate: i32, label: &str) {
84        let now = Instant::now();
85        if now < self.report_at {
86            return;
87        }
88        let cycle_avg_us = if self.cycle_count > 0 {
89            (self.cycle_time_ns / self.cycle_count as u128) as f64 / 1_000.0
90        } else {
91            0.0
92        };
93        let step_avg_us = if self.step_count > 0 {
94            (self.step_time_ns / self.step_count as u128) as f64 / 1_000.0
95        } else {
96            0.0
97        };
98        let wait_avg_us = if self.wait_count > 0 {
99            (self.wait_time_ns / self.wait_count as u128) as f64 / 1_000.0
100        } else {
101            0.0
102        };
103        let expected_cycles_per_sec = if cycle_samples > 0 && sample_rate > 0 {
104            sample_rate as f64 / cycle_samples as f64
105        } else {
106            0.0
107        };
108        error!(
109            "{} profile: expected_cps={:.1} cycles={} cycle_err={} cycle_avg_us={:.1} steps={} steps_work={} step_err={} step_avg_us={:.1} waits={} wait_avg_us={:.1}",
110            label,
111            expected_cycles_per_sec,
112            self.cycle_count,
113            self.cycle_err_count,
114            cycle_avg_us,
115            self.step_count,
116            self.step_work_count,
117            self.step_err_count,
118            step_avg_us,
119            self.wait_count,
120            wait_avg_us
121        );
122        self.report_at = now + PROFILE_INTERVAL;
123        self.cycle_count = 0;
124        self.cycle_err_count = 0;
125        self.cycle_time_ns = 0;
126        self.step_count = 0;
127        self.step_work_count = 0;
128        self.step_err_count = 0;
129        self.step_time_ns = 0;
130        self.wait_count = 0;
131        self.wait_time_ns = 0;
132    }
133}
134
135impl<B: Backend> HwWorker<B> {
136    fn profile_enabled() -> bool {
137        config::env_flag(config::HW_PROFILE_ENV)
138    }
139
140    fn assist_autonomous_enabled() -> bool {
141        config::env_flag(B::ASSIST_AUTONOMOUS_ENV)
142    }
143
144    fn configure_rt_thread(name: &str, priority: i32) -> Result<(), String> {
145        #[cfg(unix)]
146        {
147            let thread = unsafe { libc::pthread_self() };
148            #[cfg(any(target_os = "linux", target_os = "freebsd", target_os = "openbsd"))]
149            let c_name = std::ffi::CString::new(name).map_err(|e| e.to_string())?;
150            #[cfg(target_os = "linux")]
151            unsafe {
152                let _ = libc::pthread_setname_np(thread, c_name.as_ptr());
153            }
154            #[cfg(any(target_os = "freebsd", target_os = "openbsd"))]
155            unsafe {
156                libc::pthread_set_name_np(thread, c_name.as_ptr());
157            }
158
159            let param = unsafe {
160                let mut p = std::mem::zeroed::<libc::sched_param>();
161                p.sched_priority = priority;
162                p
163            };
164            let rc = unsafe { libc::pthread_setschedparam(thread, RT_POLICY, &param) };
165            if rc != 0 {
166                return Err(format!(
167                    "pthread_setschedparam({}, prio {}) failed with errno {}",
168                    name, priority, rc
169                ));
170            }
171
172            let mut actual_policy = 0_i32;
173            let mut actual_param = unsafe { std::mem::zeroed::<libc::sched_param>() };
174            let rc = unsafe {
175                libc::pthread_getschedparam(thread, &mut actual_policy, &mut actual_param)
176            };
177            if rc != 0 {
178                return Err(format!(
179                    "pthread_getschedparam({}) failed with errno {}",
180                    name, rc
181                ));
182            }
183            if actual_policy != RT_POLICY || actual_param.sched_priority != priority {
184                return Err(format!(
185                    "realtime verification failed for {}: policy {}, prio {}",
186                    name, actual_policy, actual_param.sched_priority
187                ));
188            }
189            Ok(())
190        }
191        #[cfg(not(unix))]
192        {
193            let _ = name;
194            let _ = priority;
195            Err("Realtime thread priority is not supported on this platform".to_string())
196        }
197    }
198
199    fn lock_memory_pages() -> Result<(), String> {
200        #[cfg(unix)]
201        {
202            let rc = unsafe { libc::mlockall(libc::MCL_CURRENT | libc::MCL_FUTURE) };
203            if rc == 0 {
204                Ok(())
205            } else {
206                Err(format!(
207                    "mlockall(MCL_CURRENT|MCL_FUTURE) failed: {}",
208                    std::io::Error::last_os_error()
209                ))
210            }
211        }
212        #[cfg(not(unix))]
213        {
214            Err("mlockall is not supported on this platform".to_string())
215        }
216    }
217
218    pub fn new(
219        driver: Arc<UnsafeMutex<B::Driver>>,
220        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
221        rx: Receiver<Message>,
222        tx: Sender<Message>,
223    ) -> Self {
224        let cycle_frames = {
225            let d = driver.lock();
226            d.cycle_samples() as u32
227        };
228        Self {
229            driver,
230            midi_hub,
231            rx,
232            tx,
233            cycle_frames,
234            pending_midi_out_events: vec![],
235            pending_midi_out_sorted: true,
236        }
237    }
238
239    pub async fn work(mut self) {
240        crate::enable_flush_denormals_to_zero();
241        if let Err(e) = Self::lock_memory_pages() {
242            error!("{} worker memory lock not enabled: {}", B::LABEL, e);
243        }
244        if let Err(e) = Self::configure_rt_thread(B::WORKER_THREAD_NAME, RT_PRIORITY_WORKER) {
245            error!("{} worker realtime priority not enabled: {}", B::LABEL, e);
246        }
247        #[cfg(target_os = "macos")]
248        unsafe {
249            libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, 0);
250        }
251        let assist_state = Arc::new((Mutex::new(AssistState::default()), Condvar::new()));
252        let assist_handle = Self::start_assist_thread(self.driver.clone(), assist_state.clone());
253        let midi_stop = Arc::new(AtomicBool::new(false));
254        let midi_handle = Self::start_midi_input_thread(
255            self.midi_hub.clone(),
256            self.tx.clone(),
257            self.cycle_frames,
258            midi_stop.clone(),
259        );
260        loop {
261            match self.rx.recv().await {
262                Some(msg) => match msg {
263                    Message::Request(crate::message::Action::Quit) => {
264                        if !self.pending_midi_out_events.is_empty() {
265                            if !self.pending_midi_out_sorted {
266                                self.pending_midi_out_events.sort_by(|a, b| {
267                                    a.event
268                                        .frame
269                                        .cmp(&b.event.frame)
270                                        .then_with(|| a.device.cmp(&b.device))
271                                });
272                                self.pending_midi_out_sorted = true;
273                            }
274                            let midi_hub = self.midi_hub.lock();
275                            midi_hub.write_events(&self.pending_midi_out_events);
276                            self.pending_midi_out_events.clear();
277                        }
278                        midi_stop.store(true, Ordering::Release);
279                        {
280                            let midi_hub = self.midi_hub.lock();
281                            midi_hub.wake_input_waiter();
282                        }
283                        let _ = midi_handle.join();
284                        Self::stop_assist_thread(&assist_state, assist_handle);
285                        return;
286                    }
287                    Message::TracksFinished => {
288                        {
289                            if !self.pending_midi_out_events.is_empty() {
290                                if !self.pending_midi_out_sorted {
291                                    self.pending_midi_out_events.sort_by(|a, b| {
292                                        a.event
293                                            .frame
294                                            .cmp(&b.event.frame)
295                                            .then_with(|| a.device.cmp(&b.device))
296                                    });
297                                    self.pending_midi_out_sorted = true;
298                                }
299                                let midi_hub = self.midi_hub.lock();
300                                midi_hub.write_events(&self.pending_midi_out_events);
301                                self.pending_midi_out_events.clear();
302                            }
303                        }
304                        if let Err(e) = Self::run_assist_cycle(&assist_state) {
305                            error!("{} assist cycle error: {}", B::LABEL, e);
306                        }
307                        if let Err(e) = self.tx.send(Message::HWFinished).await {
308                            error!(
309                                "{} worker failed to send HWFinished to engine: {}",
310                                B::LABEL,
311                                e
312                            );
313                        }
314                    }
315                    Message::HWMidiOutEvents(mut events) => {
316                        self.pending_midi_out_events.append(&mut events);
317                        self.pending_midi_out_sorted = false;
318                    }
319                    Message::ClearHWMidiOutEvents => {
320                        self.pending_midi_out_events.clear();
321                        self.pending_midi_out_sorted = true;
322                    }
323                    _ => {}
324                },
325                None => {
326                    midi_stop.store(true, Ordering::Release);
327                    {
328                        let midi_hub = self.midi_hub.lock();
329                        midi_hub.wake_input_waiter();
330                    }
331                    let _ = midi_handle.join();
332                    Self::stop_assist_thread(&assist_state, assist_handle);
333                    return;
334                }
335            }
336        }
337    }
338
339    fn start_midi_input_thread(
340        midi_hub: Arc<UnsafeMutex<B::MidiHub>>,
341        tx: Sender<Message>,
342        cycle_frames: u32,
343        stop: Arc<AtomicBool>,
344    ) -> JoinHandle<()> {
345        std::thread::spawn(move || {
346            crate::enable_flush_denormals_to_zero();
347            let mut midi_in_events = Vec::with_capacity(64);
348            while !stop.load(Ordering::Acquire) {
349                {
350                    let hub = midi_hub.lock();
351                    hub.read_events_blocking_into(&mut midi_in_events);
352                }
353                if midi_in_events.is_empty() {
354                    continue;
355                }
356                spread_hw_event_frames(&mut midi_in_events, cycle_frames);
357                let cap = midi_in_events.capacity();
358                let out = std::mem::replace(&mut midi_in_events, Vec::with_capacity(cap.max(64)));
359                if tx.blocking_send(Message::HWMidiEvents(out)).is_err() {
360                    break;
361                }
362            }
363        })
364    }
365
366    fn start_assist_thread(
367        driver: Arc<UnsafeMutex<B::Driver>>,
368        assist_state: Arc<(Mutex<AssistState>, Condvar)>,
369    ) -> JoinHandle<()> {
370        let profile = Self::profile_enabled();
371        let autonomous = Self::assist_autonomous_enabled();
372        std::thread::spawn(move || {
373            crate::enable_flush_denormals_to_zero();
374            if let Err(e) = Self::configure_rt_thread(B::ASSIST_THREAD_NAME, RT_PRIORITY_ASSIST) {
375                error!("{} assist realtime priority not enabled: {}", B::LABEL, e);
376            }
377            #[cfg(target_os = "macos")]
378            unsafe {
379                libc::pthread_set_qos_class_self_np(libc::qos_class_t::QOS_CLASS_USER_INITIATED, 0);
380            }
381            let mut profiler = if profile {
382                let (cycle_samples, sample_rate) = {
383                    let d = driver.lock();
384                    (d.cycle_samples(), d.sample_rate())
385                };
386                error!(
387                    "{} profile enabled: cycle_samples={} sample_rate={} expected_cps={:.1}",
388                    B::LABEL,
389                    cycle_samples,
390                    sample_rate,
391                    if cycle_samples > 0 {
392                        sample_rate as f64 / cycle_samples as f64
393                    } else {
394                        0.0
395                    }
396                );
397                Some(AssistProfiler::new())
398            } else {
399                None
400            };
401            let (lock, cvar) = &*assist_state;
402            loop {
403                let (shutdown, has_request, target, init_complete) = {
404                    let st = lock.lock().expect("assist mutex poisoned");
405                    (
406                        st.shutdown,
407                        st.request_seq > st.done_seq,
408                        st.request_seq,
409                        st.init_complete,
410                    )
411                };
412                if shutdown {
413                    break;
414                }
415                if has_request {
416                    let started = Instant::now();
417                    let run_error = {
418                        let d = driver.lock();
419                        d.run_cycle_for_worker().err().map(|e| e.to_string())
420                    };
421                    if let Some(p) = profiler.as_mut() {
422                        p.cycle_count += 1;
423                        if run_error.is_some() {
424                            p.cycle_err_count += 1;
425                        }
426                        p.cycle_time_ns += started.elapsed().as_nanos();
427                        let (cycle_samples, sample_rate) = {
428                            let d = driver.lock();
429                            (d.cycle_samples(), d.sample_rate())
430                        };
431                        p.maybe_report(cycle_samples, sample_rate, B::LABEL);
432                    }
433                    let mut st = lock.lock().expect("assist mutex poisoned");
434                    st.done_seq = st.done_seq.max(target);
435                    if run_error.is_none() {
436                        st.init_complete = true;
437                    }
438                    st.last_error = run_error;
439                    cvar.notify_all();
440                    continue;
441                }
442
443                if B::ASSIST_STEP_REQUIRES_REQUEST_CYCLE && !init_complete {
444                    let st = lock.lock().expect("assist mutex poisoned");
445                    if st.shutdown {
446                        break;
447                    }
448                    let wait_started = Instant::now();
449                    let _guard = cvar.wait(st).expect("assist condvar failed");
450                    if let Some(p) = profiler.as_mut() {
451                        p.wait_count += 1;
452                        p.wait_time_ns += wait_started.elapsed().as_nanos();
453                    }
454                    continue;
455                }
456
457                if !autonomous {
458                    let st = lock.lock().expect("assist mutex poisoned");
459                    if st.shutdown {
460                        break;
461                    }
462                    let wait_started = Instant::now();
463                    let _guard = cvar.wait(st).expect("assist condvar failed");
464                    if let Some(p) = profiler.as_mut() {
465                        p.wait_count += 1;
466                        p.wait_time_ns += wait_started.elapsed().as_nanos();
467                    }
468                    continue;
469                }
470
471                let started = Instant::now();
472                let did_work = {
473                    let d = driver.lock();
474                    match d.run_assist_step_for_worker() {
475                        Ok(v) => v,
476                        Err(e) => {
477                            if let Some(p) = profiler.as_mut() {
478                                p.step_err_count += 1;
479                            }
480                            let mut st = lock.lock().expect("assist mutex poisoned");
481                            st.last_error = Some(e.to_string());
482                            cvar.notify_all();
483                            false
484                        }
485                    }
486                };
487                if let Some(p) = profiler.as_mut() {
488                    p.step_count += 1;
489                    if did_work {
490                        p.step_work_count += 1;
491                    }
492                    p.step_time_ns += started.elapsed().as_nanos();
493                    let (cycle_samples, sample_rate) = {
494                        let d = driver.lock();
495                        (d.cycle_samples(), d.sample_rate())
496                    };
497                    p.maybe_report(cycle_samples, sample_rate, B::LABEL);
498                }
499                if !did_work {
500                    let st = lock.lock().expect("assist mutex poisoned");
501                    if st.shutdown {
502                        break;
503                    }
504                    let wait_started = Instant::now();
505                    let _guard = cvar.wait(st).expect("assist condvar failed");
506                    if let Some(p) = profiler.as_mut() {
507                        p.wait_count += 1;
508                        p.wait_time_ns += wait_started.elapsed().as_nanos();
509                    }
510                }
511            }
512        })
513    }
514
515    fn run_assist_cycle(assist_state: &Arc<(Mutex<AssistState>, Condvar)>) -> Result<(), String> {
516        let (lock, cvar) = &**assist_state;
517        let mut st = lock
518            .lock()
519            .map_err(|_| "assist mutex poisoned".to_string())?;
520        st.request_seq = st.request_seq.saturating_add(1);
521        let target = st.request_seq;
522        cvar.notify_one();
523        while st.done_seq < target && !st.shutdown {
524            st = cvar
525                .wait(st)
526                .map_err(|_| "assist condvar wait failed".to_string())?;
527        }
528        if let Some(err) = st.last_error.take() {
529            return Err(err);
530        }
531        Ok(())
532    }
533
534    fn stop_assist_thread(
535        assist_state: &Arc<(Mutex<AssistState>, Condvar)>,
536        assist_handle: JoinHandle<()>,
537    ) {
538        let (lock, cvar) = &**assist_state;
539        if let Ok(mut st) = lock.lock() {
540            st.shutdown = true;
541            cvar.notify_all();
542        }
543        let _ = assist_handle.join();
544    }
545}
546
547fn spread_hw_event_frames(events: &mut [HwMidiEvent], frames: u32) {
548    if events.len() <= 1 || frames <= 1 {
549        return;
550    }
551    let n = events.len() as u32;
552    for (idx, event) in events.iter_mut().enumerate() {
553        let pos = idx as u32;
554        event.event.frame = ((pos as u64 * (frames - 1) as u64) / n as u64) as u32;
555    }
556}